feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

14
.gitignore vendored
View File

@@ -1,8 +1,10 @@
# Rust build artifacts
target/
# Cargo lock file
Cargo.lock
# Rust IDE files
# IDE files
.vscode/
.idea/
@@ -131,3 +133,13 @@ demo/
# Additional Rust specific
*.lock
# Generated model files
*.mpk
*.pt
*.bin
*.safetensors
*.ckpt
# User-specific data
user_data/

View File

@@ -8,7 +8,7 @@ clap = { version = "4.0", features = ["derive"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
anyhow = "1.0"
stablediffusion = { path = "../stable-diffusion-burn" }
stablediffusion = { path = "./crates/stable-diffusion-burn" }
burn = "0.14.0"
burn-autodiff = "0.14.0"
burn-ndarray = "0.14.0"

View File

@@ -0,0 +1,28 @@
[package]
name = "stablediffusion"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
wgpu-backend = ["burn-wgpu"]
[dependencies.burn-wgpu]
package = "burn-wgpu"
git = "https://github.com/burn-rs/burn.git"
optional = true
[dependencies]
burn = "0.14.0"
burn-ndarray = "0.14.0"
burn-tch = "0.14.0"
burn-autodiff = "0.14.0"
tch = "0.15.0"
serde = {version = "1.0.171", features = ["std", "derive"]}
npy = "0.4.0"
num-traits = "0.2.15"
rust_tokenizers = "8.1.0"
regex = "1.9.1"
image = "0.24.6"
cfg-if = "0.1"

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Gadersd
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,75 @@
# Stable-Diffusion-Burn
Stable-Diffusion-Burn is a Rust-based project which ports the V1 stable diffusion model into the deep learning framework, Burn. This repository is licensed under the MIT Licence.
## How To Use
### Step 0: Install libtorch v2.4.1
### Step 1: Download the Model and Set Environment Variables
Start by downloading the SDv1-4 model provided on HuggingFace.
```bash
wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/SDv1-4.mpk
```
### Step 2: Run the Sample Binary
Invoke the sample binary provided in the rust code. By default, torch is used. The WGPU backend is unstable for SD but may work well in the future as burn-wpu is optimized.
```bash
# torch (at least 6 GB VRAM, possibly less)
# Arguments: <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name> [cuda, mps, cpu]
# Cuda
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img cuda
# Mps(Mac)
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img mps
# wgpu (UNSTABLE)
# Arguments: <model_type(burn or dump)> <model> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image>
cargo run --release --features wgpu-backend --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img
```
This command will generate an image according to the provided prompt, which will be saved as 'img0.png'.
![An image of an ancient mossy stone](img0.png)
### Optional: Extract and Convert a Fine-Tuned Model
If users are interested in using a fine-tuned version of stable diffusion, the Python scripts provided in this project can be used to transform a weight dump into a Burn model file. This does not work on Windows.
```bash
# Step into the Python directory
cd python
# Download the model, this is just the base v1.4 model as an example
wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
# Install tinygrad
pip install -r requirements.txt
# Extract the weights
CPU=1 python3 dump.py sd-v1-4.ckpt
# Move the extracted weight folder out
mv params ..
# Step out of the Python directory
cd ..
# Convert the weights into a usable form
cargo run --release --bin convert params SDv1-4
```
The binaries 'convert' and 'sample' are contained in Rust. Convert works on CPU whereas sample needs CUDA.
Remember, `convert` should be used if you're planning on using the fine-tuned version of the stable diffusion.
## License
This project is licensed under MIT license.
We wish you a productive time using this project. Enjoy!

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,45 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science"]
description = "Automatic differentiation backend for the Burn framework"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "data"]
license.workspace = true
name = "burn-autodiff"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-autodiff"
documentation = "https://docs.rs/burn-autodiff"
version.workspace = true
[lints]
workspace = true
[features]
default = ["std", "tracing"]
std = ["dep:parking_lot"]
export_tests = [] # check checkpointer is_empty in tests
tracing = [
"dep:tracing",
"burn-std/tracing",
"burn-backend/tracing",
]
[dependencies]
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false }
burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false }
derive-new = { workspace = true }
spin = { workspace = true }
parking_lot = { workspace = true, optional = true }
log = { workspace = true }
hashbrown = { workspace = true }
num-traits = { workspace = true }
portable-atomic = { workspace = true }
tracing = { workspace = true, optional = true, features = ["default"] }
[package.metadata.docs.rs]
features = ["default"]
rustdoc-args = ["--cfg", "docsrs"]

View File

@@ -0,0 +1 @@
../../LICENSE-APACHE

View File

@@ -0,0 +1 @@
../../LICENSE-MIT

View File

@@ -0,0 +1,8 @@
# Burn Autodiff
> [Burn](https://github.com/tracel-ai/burn) autodiff backend
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-autodiff.svg)](https://crates.io/crates/burn-autodiff)
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-autodiff/blob/master/README.md)
For now only first order reverse mode autodiff is supported.

View File

@@ -0,0 +1,138 @@
use crate::{
checkpoint::strategy::{CheckpointStrategy, NoCheckpointing},
grads::Gradients,
tensor::AutodiffTensor,
};
use alloc::{format, string::String};
use burn_backend::{
backend::{AutodiffBackend, Backend, ExecutionError},
tensor::{BoolTensor, IntTensor, QuantizedTensor},
};
use core::marker::PhantomData;
/// Enable auto-differentiation on a backend.
///
/// This works as a backend decorator, extending the functionality of any backend with
/// backpropagation.
#[derive(Clone, Copy, Debug, Default)]
pub struct Autodiff<B, C = NoCheckpointing> {
_b: PhantomData<B>,
_checkpoint_strategy: PhantomData<C>,
}
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
type Device = B::Device;
type FloatTensorPrimitive = AutodiffTensor<B>;
type FloatElem = B::FloatElem;
type IntTensorPrimitive = B::IntTensorPrimitive;
type IntElem = B::IntElem;
type BoolTensorPrimitive = B::BoolTensorPrimitive;
type BoolElem = B::BoolElem;
type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive;
fn ad_enabled(_device: &Self::Device) -> bool {
true
}
fn name(device: &Self::Device) -> String {
format!("autodiff<{}>", B::name(device))
}
fn seed(device: &B::Device, seed: u64) {
B::seed(device, seed)
}
fn sync(device: &B::Device) -> Result<(), ExecutionError> {
B::sync(device)
}
fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
device: &Self::Device,
input: Input,
func: Func,
) -> Output {
B::memory_persistent_allocations(device, input, func)
}
fn memory_cleanup(device: &Self::Device) {
B::memory_cleanup(device)
}
fn staging<'a, Iter>(data: Iter, device: &Self::Device)
where
Iter: Iterator<Item = &'a mut burn_backend::TensorData>,
{
B::staging(data, device);
}
fn supports_dtype(device: &Self::Device, dtype: burn_std::DType) -> bool {
B::supports_dtype(device, dtype)
}
fn dtype_usage(device: &Self::Device, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {
B::dtype_usage(device, dtype)
}
}
impl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
type InnerBackend = B;
type Gradients = Gradients;
fn backward(tensor: AutodiffTensor<B>) -> Gradients {
tensor.backward()
}
fn grad(tensor: &AutodiffTensor<B>, grads: &Gradients) -> Option<B::FloatTensorPrimitive> {
tensor.grad(grads)
}
fn grad_remove(
tensor: &AutodiffTensor<B>,
grads: &mut Gradients,
) -> Option<B::FloatTensorPrimitive> {
tensor.grad_remove(grads)
}
fn inner(tensor: AutodiffTensor<B>) -> B::FloatTensorPrimitive {
tensor.primitive
}
fn from_inner(tensor: B::FloatTensorPrimitive) -> AutodiffTensor<B> {
AutodiffTensor::new(tensor)
}
fn grad_replace(
tensor: &AutodiffTensor<B>,
grads: &mut Self::Gradients,
grad: B::FloatTensorPrimitive,
) {
tensor.grad_replace(grads, grad);
}
fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend> {
tensor
}
fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend> {
tensor
}
fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self> {
tensor
}
fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self> {
tensor
}
fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend> {
tensor
}
fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self> {
tensor
}
}

View File

@@ -0,0 +1,82 @@
use super::{
retro_forward::RetroForwards,
state::{BackwardStates, State},
};
use crate::collections::HashMap;
use crate::graph::NodeId;
use alloc::{vec, vec::Vec};
#[derive(new, Debug)]
/// Links a [NodeId] to its autodiff graph [NodeRef]
pub(crate) struct NodeTree {
map: HashMap<NodeId, Vec<NodeId>>,
}
impl NodeTree {
/// Gives the parents of the node in the autodiff graph
pub(crate) fn parents(&self, node_id: &NodeId) -> Option<Vec<NodeId>> {
self.map.get(node_id).cloned()
}
}
#[derive(new, Debug)]
/// Struct responsible of fetching the output for a node in the autodiff graph during a backward pass
pub struct Checkpointer {
backward_states: BackwardStates,
retro_forwards: RetroForwards,
node_tree: NodeTree,
}
impl Checkpointer {
/// Gives the output of the given node, by recursively asking parents to compute themselves
/// or give their pre-computed tensors.
pub fn retrieve_node_output<T>(&mut self, node_id: NodeId) -> T
where
T: Clone + Send + 'static,
{
self.topological_sort(node_id).into_iter().for_each(|node| {
self.retro_forwards
.execute_retro_forward(node, &mut self.backward_states)
});
self.backward_states.get_state::<T>(&node_id)
}
/// Sorts the ancestors of NodeId in a way such that all parents come before their children
/// Useful to avoid recursivity later when mutating the states
///
/// The sort on a compute bound state or a memory bound that is already computed is trivial.
/// The match on State::Computed also serves as a stopping criterion for the sort,
/// we don't need to look higher than that during recursivity.
fn topological_sort(&self, node_id: NodeId) -> Vec<NodeId> {
match self.backward_states.get_state_ref(&node_id) {
Some(state) => match state {
State::Recompute { n_required: _ } => {
let mut sorted = Vec::new();
let parents = self.node_tree.parents(&node_id).unwrap();
for parent_node in parents {
let parent_sorted = self.topological_sort(parent_node);
for ps in parent_sorted {
if !sorted.contains(&ps) {
sorted.push(ps)
}
}
}
sorted.push(node_id);
sorted
}
State::Computed {
state_content: _,
n_required: _,
} => vec![node_id],
},
None => panic!("Node {node_id:?} is not in the backward_states. "),
}
}
/// Checks if checkpointer has been drained adequately. Useful for testing
pub fn is_empty(&self) -> bool {
self.backward_states.is_empty() && self.retro_forwards.is_empty()
}
}

View File

@@ -0,0 +1,304 @@
use crate::{
collections::HashMap,
graph::{ComputingProperty, NodeId},
tensor::AutodiffTensor,
};
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use burn_backend::Backend;
use core::any::Any;
use super::{
base::{Checkpointer, NodeTree},
retro_forward::{RetroForward, RetroForwards},
state::{BackwardStates, State},
};
#[derive(Debug)]
/// Determines if a node should checkpoint its computed output or its retro_forward for recomputation
/// The action is normally created by the child of the node, once the node is determined to be needed
pub enum CheckpointingAction {
/// The node's already computed output should be saved
Computed {
/// The node
node_id: NodeId,
/// The node's output
state_content: Box<dyn Any + Send>,
},
/// The node should recompute itself when asked
Recompute {
/// The node
node_id: NodeId,
/// How the node should recompute itself
retro_forward: Arc<dyn RetroForward>,
},
}
// TODO: Remove that when proper client server.
unsafe impl Send for CheckpointingAction {}
impl CheckpointingAction {
/// Utility function to access the id of the node of the checkpointing action
pub fn id(&self) -> NodeId {
match self {
CheckpointingAction::Computed {
node_id: node_ref,
state_content: _,
} => *node_ref,
CheckpointingAction::Recompute {
node_id: node_ref,
retro_forward: _,
} => *node_ref,
}
}
}
#[derive(new, Debug, Default)]
/// Accumulates checkpoints as checkpointing actions during the forward pass,
/// and builds a checkpointer right before the backward pass
pub struct CheckpointerBuilder {
explicit_actions: Vec<CheckpointingAction>,
backup_actions: Vec<CheckpointingAction>,
}
/// Determines if a checkpoint should impact the n_required values (Main)
/// or if it should just keep the state in case it's required (Backup)
///
pub(crate) enum ActionType {
/// Explicit actions have been explicitly requested by some operation to retrieve their state
Explicit,
/// Backup actions are not always needed. They exist to save the output of an operation
/// whose child is memory bound, in case the state is indirectly needed when computing
/// the child's retro_forward. If no explicit action ever asks for the child's output, then
/// the backup output will go out of scope when the checkpointer is built.
Backup,
}
impl CheckpointerBuilder {
pub(crate) fn checkpoint<B: Backend>(
&mut self,
tensor: &AutodiffTensor<B>,
action_type: ActionType,
) {
let action_list = match action_type {
ActionType::Explicit => &mut self.explicit_actions,
ActionType::Backup => &mut self.backup_actions,
};
match &tensor.node.properties {
ComputingProperty::ComputeBound | ComputingProperty::Ambiguous => {
action_list.push(CheckpointingAction::Computed {
node_id: tensor.node.id,
state_content: Box::new(tensor.primitive.clone()),
})
}
ComputingProperty::MemoryBound { retro_forward } => {
action_list.push(CheckpointingAction::Recompute {
node_id: tensor.node.id,
retro_forward: retro_forward.clone(),
})
}
}
}
pub(crate) fn extend(&mut self, other: CheckpointerBuilder) {
for other_action in other.explicit_actions {
self.explicit_actions.push(other_action)
}
for other_unsure in other.backup_actions {
self.backup_actions.push(other_unsure)
}
}
pub(crate) fn build(self, node_tree: NodeTree) -> Checkpointer {
let mut backward_states_map = HashMap::new();
let mut retro_forwards_map = HashMap::new();
// Find recursion stopping points
let stop_nodes: Vec<NodeId> = self.find_stop_nodes();
// We start by identifying how many times each node will be required.
let n_required_map = self.build_n_required_map(&node_tree, stop_nodes);
// Then we checkpoint the nodes with the corresponding n_required value
self.insert_checkpoints(
&mut backward_states_map,
&mut retro_forwards_map,
n_required_map,
);
Checkpointer::new(
BackwardStates::new(backward_states_map),
RetroForwards::new(retro_forwards_map),
node_tree,
)
}
fn find_stop_nodes(&self) -> Vec<NodeId> {
let mut stop_nodes = Vec::default();
for action in self
.explicit_actions
.iter()
.chain(self.backup_actions.iter())
{
match action {
CheckpointingAction::Computed {
node_id: node_ref,
state_content: _,
} => stop_nodes.push(*node_ref),
CheckpointingAction::Recompute {
node_id: _,
retro_forward: _,
} => {}
}
}
stop_nodes
}
fn build_n_required_map(
&self,
node_tree: &NodeTree,
stop_nodes: Vec<NodeId>,
) -> HashMap<NodeId, usize> {
let mut n_required_map = HashMap::<NodeId, usize>::default();
for action in self.explicit_actions.iter() {
match action {
CheckpointingAction::Computed {
node_id: node_ref,
state_content: _,
} => {
let id = *node_ref;
match n_required_map.remove(&id) {
Some(n) => {
n_required_map.insert(id, n + 1);
}
None => {
n_required_map.insert(id, 1);
}
};
}
CheckpointingAction::Recompute {
node_id: node_ref,
retro_forward: _,
} => {
let id = *node_ref;
Self::update_n_required_of_parents(
id,
&mut n_required_map,
node_tree,
&stop_nodes,
);
}
}
}
n_required_map
}
fn insert_checkpoints(
mut self,
backward_states_map: &mut HashMap<NodeId, State>,
retro_forward_map: &mut HashMap<NodeId, Arc<dyn RetroForward>>,
n_required_map: HashMap<NodeId, usize>,
) {
// We do not loop over checkpointing actions anymore because they can contain
// duplicates or miss some that are in backup. We loop over the n_required_map
// from which we use the ids to find them again in the checkpointing actions
for (node_id, n_required) in n_required_map {
// We find the checkpointing action for node_id. It's likely in checkpointing_actions
// so we check there first, otherwise it will be in backup.
// Technically it can be there several times but can never be of both types, so we can assume the first we find is fine
let action = match self
.explicit_actions
.iter()
.position(|action| action.id() == node_id)
{
Some(pos) => self.explicit_actions.remove(pos),
None => {
let pos = self
.backup_actions
.iter()
.position(|action| action.id() == node_id);
self.backup_actions.remove(pos.unwrap_or_else(|| {
panic!("Node {:?} is needed but never checkpointed", &node_id)
}))
}
};
match action {
CheckpointingAction::Computed {
node_id: _,
state_content,
} => {
self.checkpoint_compute(backward_states_map, node_id, state_content, n_required)
}
CheckpointingAction::Recompute {
node_id: _,
retro_forward,
} => self.checkpoint_lazy(
backward_states_map,
retro_forward_map,
node_id,
retro_forward,
n_required,
),
};
}
}
fn update_n_required_of_parents(
id: NodeId,
n_required_map: &mut HashMap<NodeId, usize>,
node_tree: &NodeTree,
stop_nodes: &Vec<NodeId>,
) {
match n_required_map.remove(&id) {
Some(n) => {
n_required_map.insert(id, n + 1);
}
None => {
n_required_map.insert(id, 1);
if !stop_nodes.contains(&id)
&& let Some(parents) = node_tree.parents(&id)
{
for p in parents {
Self::update_n_required_of_parents(
p,
n_required_map,
node_tree,
stop_nodes,
);
}
}
}
}
}
fn checkpoint_compute(
&self,
backward_states_map: &mut HashMap<NodeId, State>,
node_id: NodeId,
state_content: Box<dyn Any + Send>,
n_required: usize,
) {
backward_states_map.insert(
node_id,
State::Computed {
state_content,
n_required,
},
);
}
fn checkpoint_lazy(
&self,
backward_states_map: &mut HashMap<NodeId, State>,
retro_forward_map: &mut HashMap<NodeId, Arc<dyn RetroForward>>,
node_id: NodeId,
retro_forward: Arc<dyn RetroForward>,
n_required: usize,
) {
retro_forward_map.insert(node_id, retro_forward);
backward_states_map.insert(node_id, State::Recompute { n_required });
}
}

View File

@@ -0,0 +1,9 @@
/// Checkpointer module
pub mod base;
pub(crate) mod builder;
/// RetroForward module
pub mod retro_forward;
/// BackwardStates module
pub mod state;
/// CheckpointStrategy module
pub mod strategy;

View File

@@ -0,0 +1,116 @@
use crate::collections::HashMap;
use crate::graph::NodeId;
use alloc::sync::Arc;
use core::fmt::Debug;
use super::state::{BackwardStates, State};
/// Definition of the forward function of a node, called during retropropagation only.
/// This is different from the normal forward function because it reads and writes from
/// the [BackwardStates] map instead of having a clear function signature.
pub trait RetroForward: Debug + Send + 'static {
/// Applies the forward pass for retropropagation.
fn forward(&self, states: &mut BackwardStates, out_node: NodeId);
}
#[derive(new, Debug)]
/// Links [NodeId]s to their corresponding [RetroForward]
pub(crate) struct RetroForwards {
map: HashMap<NodeId, Arc<dyn RetroForward>>,
}
impl RetroForwards {
/// Executes the [RetroForward] for a given [NodeId] if the node's
/// [State] is [State::Recompute], otherwise does nothing.
pub(crate) fn execute_retro_forward(
&mut self,
node_id: NodeId,
backward_states: &mut BackwardStates,
) {
if let State::Recompute { n_required: _ } = backward_states
.get_state_ref(&node_id)
.unwrap_or_else(|| panic!("Should find node {node_id:?}"))
{
// Retro forwards are always used only once because afterwards their state is computed
let retro_forward = self.map.remove(&node_id).unwrap();
retro_forward.forward(backward_states, node_id);
}
}
pub(crate) fn is_empty(&self) -> bool {
self.map.is_empty()
}
}
#[macro_export]
/// Creates a RetroForward struct for unary scalar operations
macro_rules! retro_unary_scalar {
(
$name:ident,
$ops:expr
) => {
#[derive(new, Debug, Clone)]
struct $name<B: Backend> {
lhs_id: NodeId,
rhs: Scalar,
_backend: PhantomData<B>,
}
impl<B: Backend> RetroForward for $name<B> {
fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {
let lhs = states.get_state::<B::FloatTensorPrimitive>(&self.lhs_id);
let out = $ops(lhs, self.rhs);
states.save(out_node, out)
}
}
};
}
#[macro_export]
/// Creates a RetroForward struct for unary scalar operations
macro_rules! retro_unary {
(
$name:ident,
$ops:expr
) => {
#[derive(new, Debug, Clone)]
struct $name<B: Backend> {
input_id: NodeId,
_backend: PhantomData<B>,
}
impl<B: Backend> RetroForward for $name<B> {
fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {
let input = states.get_state::<B::FloatTensorPrimitive>(&self.input_id);
let out = $ops(input);
states.save(out_node, out)
}
}
};
}
#[macro_export]
/// Creates a RetroForward struct for binary operations
macro_rules! retro_binary {
(
$name:ident,
$ops:expr
) => {
#[derive(new, Debug, Clone)]
struct $name<B: Backend> {
lhs_id: NodeId,
rhs_id: NodeId,
_backend: PhantomData<B>,
}
impl<B: Backend> RetroForward for $name<B> {
fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {
let lhs = states.get_state::<B::FloatTensorPrimitive>(&self.lhs_id);
let rhs = states.get_state::<B::FloatTensorPrimitive>(&self.rhs_id);
let out = $ops(lhs, rhs);
states.save(out_node, out)
}
}
};
}

View File

@@ -0,0 +1,144 @@
use core::any::Any;
use crate::collections::HashMap;
use crate::graph::NodeId;
use alloc::boxed::Box;
/// In order to accept arbitrary node output in the same hashmap, we need to upcast them to any.
pub(crate) type StateContent = Box<dyn Any + Send>;
#[derive(Debug)]
/// The state contained at one node. Encapsulates the node output if precomputed,
/// or clearly asks that it needs to be recomputed from the parents.
/// Also keeps track of the number of times the state is required so it can be removed
/// from the map of states on its last use.
pub(crate) enum State {
/// The state was not checkpointed, will need to recompute it from the node's parents
Recompute { n_required: usize },
/// The state was checkpointed or computed during retropropagation and can be directly accessed
Computed {
state_content: StateContent,
n_required: usize,
},
}
impl State {
/// Returns a reference to the (not yet) downcasted node output, if checkpointed
pub(crate) fn to_state_content(&self) -> &StateContent {
match self {
State::Recompute { n_required: _ } => {
unreachable!(
"Can't get state content of recompute state. A child has likely been accessed before its parents."
)
}
State::Computed {
state_content,
n_required: _,
} => state_content,
}
}
/// Returns a (not yet) downcasted node output, if checkpointed
pub(crate) fn into_state_content(self) -> StateContent {
match self {
State::Recompute { n_required: _ } => {
unreachable!(
"Can't get state content of recompute state. A child has likely been accessed before its parents."
)
}
State::Computed {
state_content,
n_required: _,
} => state_content,
}
}
/// Returns the number of time the state is required
pub(crate) fn n_required(&self) -> usize {
match self {
State::Recompute { n_required } => *n_required,
State::Computed {
state_content: _,
n_required,
} => *n_required,
}
}
}
#[derive(new, Default, Debug)]
/// Links [NodeId]s to their current state
pub struct BackwardStates {
map: HashMap<NodeId, State>,
}
impl BackwardStates {
/// Returns the output in the state of the given [NodeId],
/// and decrements the number of times this state is required.
/// This function always gives ownership of the output, but will clone it if needed for further uses.
pub fn get_state<T>(&mut self, node_id: &NodeId) -> T
where
T: Clone + Send + 'static,
{
// Fetch the state and decrement its number of required
let state = self.map.remove(node_id).unwrap();
let remaining_n_required = state.n_required() - 1;
// Downcast the state to whatever it is supposed to be
// If still needed after giving ownership, we copy it back to the hashmap
if remaining_n_required > 0 {
let new_stored_state = match state {
State::Recompute { n_required: _ } => unreachable!(),
State::Computed {
state_content,
n_required: _,
} => State::Computed {
state_content,
n_required: remaining_n_required,
},
};
let downcasted = new_stored_state
.to_state_content()
.downcast_ref::<T>()
.unwrap()
.clone();
self.insert_state(*node_id, new_stored_state);
downcasted
} else {
let downcasted = state.into_state_content().downcast::<T>().unwrap();
*downcasted
}
}
/// Returns a reference to the [State] of the given node
/// Useful when we need [State] information without needing the underlying tensor
pub(crate) fn get_state_ref(&self, node_id: &NodeId) -> Option<&State> {
self.map.get(node_id)
}
/// Associates a [State] to its [NodeId]
pub(crate) fn insert_state(&mut self, node_id: NodeId, state: State) {
self.map.insert(node_id, state);
}
/// Saves the output to the state of the given [NodeId].
pub fn save<T>(&mut self, node_id: NodeId, saved_output: T)
where
T: Clone + Send + 'static,
{
let n_required = self.get_state_ref(&node_id).unwrap().n_required();
self.insert_state(
node_id,
State::Computed {
state_content: Box::new(saved_output),
n_required,
},
);
}
pub(crate) fn is_empty(&self) -> bool {
self.map.is_empty()
}
}

View File

@@ -0,0 +1,102 @@
use core::fmt::Debug;
use burn_backend::Backend;
use crate::{graph::ComputingProperty, tensor::AutodiffTensor};
use alloc::sync::Arc;
use super::{
builder::{ActionType, CheckpointerBuilder},
retro_forward::RetroForward,
};
/// Strategy for the amount of checkpointing to do during autodiff
pub trait CheckpointStrategy: Clone + Copy + Debug + Default + Send + Sync + 'static {
/// May modify the compute property depending on the strategy
fn compute_property<R: RetroForward>(retro_forward: R) -> ComputingProperty;
/// Checkpoints parents if necessary in the strategy
fn checkpoint_parents<'a, B2, A>(
parents: A,
builder: &mut CheckpointerBuilder,
) -> Result<(), CheckpointingError>
where
B2: Backend,
A: IntoIterator<Item = &'a AutodiffTensor<B2>>;
}
#[derive(Debug)]
/// Error that can happen when trying to checkpoint a tensor.
pub enum CheckpointingError {
/// When a parent is untracked, we can't easily checkpoint its state, since we don't know the
/// requirements in advanced.
UntrackedParent,
}
#[derive(Clone, Copy, Debug, Default)]
/// All operations are considered compute bound, notwithstanding how they are marked
pub struct NoCheckpointing {}
impl CheckpointStrategy for NoCheckpointing {
/// An operation marked as memory bound is actually compute bound.
fn compute_property<R: RetroForward>(_retro_forward: R) -> ComputingProperty {
ComputingProperty::ComputeBound
}
/// An operation marked as memory bound is actually compute bound.
/// It's therefore useless to checkpoint the parents
fn checkpoint_parents<'a, B2, A>(
_parents: A,
_builder: &mut CheckpointerBuilder,
) -> Result<(), CheckpointingError>
where
B2: Backend,
A: IntoIterator<Item = &'a AutodiffTensor<B2>>,
{
// Nothing to do here
Ok(())
}
}
#[derive(Clone, Copy, Debug, Default)]
/// Operation properties are as they are marked (compute or memory bound)
pub struct BalancedCheckpointing {}
impl CheckpointStrategy for BalancedCheckpointing {
/// An operation marked as memory bound is memory bound.
/// When memory bound, an operation needs to save its RetroForward
fn compute_property<R: RetroForward>(retro_forward: R) -> ComputingProperty {
ComputingProperty::MemoryBound {
retro_forward: Arc::new(retro_forward),
}
}
/// An operation marked as memory bound is really memory bound.
/// Since the operation may not checkpoint its parents but may need them indirectly
/// if asked to recompute itself, the method needs to know the parent tensors to maybe checkpoint them
fn checkpoint_parents<'a, B2, A>(
parents: A,
builder: &mut CheckpointerBuilder,
) -> Result<(), CheckpointingError>
where
B2: Backend,
A: IntoIterator<Item = &'a AutodiffTensor<B2>>,
{
let mut can_checkpoint = true;
for tensor in parents.into_iter() {
if let crate::graph::Requirement::None = tensor.node.requirement {
can_checkpoint = false;
} else {
builder.checkpoint(tensor, ActionType::Backup);
}
}
if !can_checkpoint {
*builder = CheckpointerBuilder::default();
return Err(CheckpointingError::UntrackedParent);
}
Ok(())
}
}

View File

@@ -0,0 +1,85 @@
use burn_backend::{
Backend, TensorMetadata, TensorPrimitive,
tensor::{FloatTensor, TensorContainer},
};
use crate::{
NodeId,
graph::{NodeRef, Requirement},
tensor::AutodiffTensor,
};
/// Gradient identifier.
pub type GradID = u64;
/// Gradients container used during the backward pass.
pub struct Gradients {
container: TensorContainer<GradID>,
}
impl Gradients {
/// Creates a new gradients container.
pub fn new<B: Backend>(root_node: NodeRef, root_tensor: FloatTensor<B>) -> Self {
let mut gradients = Self {
container: TensorContainer::new(),
};
gradients.register::<B>(
root_node.id,
B::float_ones(
root_tensor.shape(),
&B::float_device(&root_tensor),
root_tensor.dtype().into(),
),
);
gradients
}
/// Consumes the gradients for a given tensor.
///
/// Each tensor should be consumed exactly 1 time if its gradients are only required during the
/// backward pass, otherwise, it may be consume multiple times.
pub fn consume<B: Backend>(&mut self, node: &NodeRef) -> FloatTensor<B> {
match node.requirement {
Requirement::Grad => self
.container
.get::<B>(&node.id.value)
.map(|tensor| tensor.tensor())
.expect("Can't consume the gradients before they are registered at least once."),
Requirement::GradInBackward => self
.container
.remove::<B>(&node.id.value)
.map(|tensor| tensor.tensor())
.expect("Can't consume the gradients before they are registered at least once."),
Requirement::None => panic!("Trying to consume the gradients for an untracked tensor"),
}
}
/// Removes a grad tensor from the container.
pub fn remove<B: Backend>(&mut self, tensor: &AutodiffTensor<B>) -> Option<FloatTensor<B>> {
self.container
.remove::<B>(&tensor.node.id.value)
.map(|tensor| tensor.tensor())
}
/// Gets a grad tensor from the container.
pub fn get<B: Backend>(&self, tensor: &AutodiffTensor<B>) -> Option<FloatTensor<B>> {
self.container
.get::<B>(&tensor.node.id.value)
.map(|tensor| tensor.tensor())
}
/// Register a grad tensor in the container.
///
/// If the tensor already exists, add both tensors together before saving the result.
pub fn register<B: Backend>(&mut self, node_id: NodeId, value: FloatTensor<B>) {
if let Some(tensor_old) = self.container.remove::<B>(&node_id.value) {
self.container.register::<B>(
node_id.value,
TensorPrimitive::Float(B::float_add(value, tensor_old.tensor())),
);
} else {
self.container
.register::<B>(node_id.value, TensorPrimitive::Float(value));
}
}
}

View File

@@ -0,0 +1,17 @@
use super::NodeId;
use crate::{checkpoint::base::Checkpointer, grads::Gradients, graph::Parent};
use alloc::boxed::Box;
/// Backward step for reverse mode autodiff.
pub trait Step: Send + core::fmt::Debug {
/// Executes the step and consumes it.
fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer);
/// Depth of the operation relative to the first node added to a graph.
fn depth(&self) -> usize;
/// The node associated to the step.
fn node(&self) -> NodeId;
/// The parents of the node associated to the step.
fn parents(&self) -> &[Parent];
}
pub type StepBoxed = Box<dyn Step>;

View File

@@ -0,0 +1,9 @@
mod base;
mod node;
mod requirement;
pub mod traversal;
pub use base::*;
pub use node::*;
pub use requirement::*;

View File

@@ -0,0 +1,87 @@
use alloc::{sync::Arc, vec::Vec};
#[cfg(target_has_atomic = "64")]
use core::sync::atomic::{AtomicU64, Ordering};
#[cfg(not(target_has_atomic = "64"))]
use portable_atomic::{AtomicU64, Ordering};
use crate::checkpoint::retro_forward::RetroForward;
use crate::runtime::AutodiffClientImpl;
use super::Requirement;
#[derive(Debug, Clone)]
pub enum ComputingProperty {
ComputeBound,
MemoryBound {
retro_forward: Arc<dyn RetroForward>,
},
Ambiguous, // Maybe autotune someday
}
/// This is safe only because we only call RetroForward on the autodiff server.
/// Therefore, the trait will never be used by multiple threads at the same time.
///
/// TODO: Find a way to avoid cloning the compute property, which will remove the need to add the
/// Arc, which will make (dyn RetroForward) safely implement Send.
unsafe impl Send for ComputingProperty {}
/// unsafe Sync is required because Send is only implemented for Arc<Sync>, not Arc<Send>.
unsafe impl Sync for ComputingProperty {}
/// A node contains graph metadata and should be used wrapped in an Arc for cheap cloning.
#[derive(new, Debug)]
pub struct Node {
pub parents: Vec<Parent>,
pub order: usize,
pub id: NodeId,
pub requirement: Requirement,
pub properties: ComputingProperty,
pub client: AutodiffClientImpl,
}
pub type NodeRef = Arc<Node>;
#[derive(new, Debug, Clone, PartialEq, Eq)]
pub struct Parent {
pub id: NodeId,
}
impl Node {
/// Returns the [node](Node) only if gradients are required.
pub fn clone_if_require_grad(self: &Arc<Self>) -> Option<NodeRef> {
match self.requirement.is_none() {
true => None,
false => Some(self.clone()),
}
}
}
/// Unique identifier generated for each node.
#[derive(Clone, Hash, PartialEq, Eq, Debug, Copy)]
pub struct NodeId {
/// The integer representation of the id
pub value: u64,
}
impl core::fmt::Display for NodeId {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("NodeId({})", self.value))
}
}
impl NodeId {
/// Create a unique [node id](NodeId).
pub fn new() -> Self {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let value = COUNTER.fetch_add(1, Ordering::Relaxed);
if value == u64::MAX {
panic!("NodeId overflowed");
}
Self { value }
}
}
impl Default for NodeId {
fn default() -> Self {
Self::new()
}
}

View File

@@ -0,0 +1,38 @@
use super::NodeRef;
/// Requirement for each tensor in the graph.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Requirement {
/// Operations that require gradients.
Grad,
/// Operations that require gradients only for backprop.
GradInBackward,
/// Operations that don't need gradients, therefore not to be included in the graph.
None,
}
impl Requirement {
/// Returns true if gradients are not required.
pub fn is_none(&self) -> bool {
matches!(self, Self::None)
}
/// Returns the right requirement from a list of nodes.
pub fn from_nodes(nodes: &[NodeRef]) -> Self {
if nodes.len() == 1 {
return nodes[0].requirement.infer(&Requirement::None);
}
nodes
.iter()
.map(|node| node.requirement)
.reduce(|acc, requirement| requirement.infer(&acc))
.unwrap_or(Requirement::None)
}
fn infer(&self, other: &Self) -> Self {
match self.is_none() && other.is_none() {
true => Self::None,
false => Self::GradInBackward,
}
}
}

View File

@@ -0,0 +1,74 @@
use super::{Step, StepBoxed};
use crate::{
NodeId,
collections::{HashMap, HashSet},
graph::Parent,
};
use alloc::vec::Vec;
/// Breadth for search algorithm.
pub struct BreadthFirstSearch;
pub trait TraversalItem {
fn id(&self) -> NodeId;
fn parents(&self) -> &[Parent];
fn parent_nodes(&self) -> Vec<NodeId> {
self.parents().iter().map(|p| p.id).collect()
}
}
impl BreadthFirstSearch {
/// Traverse the graph of backward steps from a root node.
pub fn traverse<F, I>(
&self,
root_id: NodeId,
root_step: I,
steps: &mut HashMap<NodeId, I>,
mut callback: F,
) where
F: FnMut(NodeId, I),
I: TraversalItem,
{
let mut visited = HashSet::new();
let mut parents = Vec::new();
visited.insert(root_id);
parents.append(&mut root_step.parent_nodes());
callback(root_id, root_step);
while let Some(id) = parents.pop() {
let step = match steps.remove(&id) {
Some(step) => step,
None => continue,
};
let step_node = step.id();
let step_parents = step.parent_nodes();
if visited.contains(&step_node) {
continue;
}
visited.insert(step_node);
for id in step_parents.iter() {
if !visited.contains(id) {
parents.push(*id);
}
}
callback(step_node, step);
}
}
}
impl TraversalItem for StepBoxed {
fn id(&self) -> NodeId {
Step::node(self.as_ref())
}
fn parents(&self) -> &[Parent] {
Step::parents(self.as_ref())
}
}

View File

@@ -0,0 +1,43 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
//! # Burn Autodiff
//!
//! This autodiff library is a part of the Burn project. It is a standalone crate
//! that can be used to perform automatic differentiation on tensors. It is
//! designed to be used with the Burn Tensor crate, but it can be used with any
//! tensor library that implements the `Backend` trait.
#[macro_use]
extern crate derive_new;
extern crate alloc;
/// Checkpoint module.
pub mod checkpoint;
/// Gradients module.
pub mod grads;
/// Operation module.
pub mod ops;
pub(crate) mod graph;
// Exported for backend extension
pub use graph::NodeId;
pub(crate) mod tensor;
pub(crate) mod utils;
mod backend;
pub(crate) mod runtime;
pub use backend::*;
/// A facade around for HashMap and HashSet.
/// This avoids elaborate import wrangling having to happen in every module.
mod collections {
#[cfg(not(feature = "std"))]
pub use hashbrown::{HashMap, HashSet};
#[cfg(feature = "std")]
pub use std::collections::{HashMap, HashSet};
}

View File

@@ -0,0 +1,167 @@
use core::marker::PhantomData;
use crate::{
Autodiff,
checkpoint::{
base::Checkpointer, retro_forward::RetroForward, state::BackwardStates,
strategy::CheckpointStrategy,
},
grads::Gradients,
graph::NodeId,
ops::{Backward, Ops, OpsKind, unary},
retro_unary,
};
use burn_backend::{Backend, ops::ActivationOps, tensor::FloatTensor};
impl<B: Backend, C: CheckpointStrategy> ActivationOps<Autodiff<B, C>> for Autodiff<B, C> {
fn gelu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
#[derive(Debug)]
struct Gelu;
retro_unary!(RetroGelu, B::gelu);
impl<B: Backend> Backward<B, 1> for Gelu {
type State = NodeId;
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let input = checkpointer.retrieve_node_output(ops.state);
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
B::gelu_backward(input, grad)
});
}
}
match Gelu
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroGelu::<B>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
OpsKind::Tracked(mut prep) => {
let state = prep.checkpoint(&tensor);
prep.finish(state, B::gelu(tensor.primitive.clone()))
}
OpsKind::UnTracked(prep) => prep.finish(B::gelu(tensor.primitive)),
}
}
fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
#[derive(Debug)]
struct Relu;
retro_unary!(RetroRelu, B::relu);
impl<B: Backend> Backward<B, 1> for Relu {
type State = NodeId;
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let state = checkpointer.retrieve_node_output(ops.state);
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
B::relu_backward(state, grad)
});
}
}
match Relu
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroRelu::<B>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
OpsKind::Tracked(mut prep) => {
let state = prep.checkpoint(&tensor);
prep.finish(state, B::relu(tensor.primitive))
}
OpsKind::UnTracked(prep) => prep.finish(B::relu(tensor.primitive)),
}
}
fn sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
#[derive(Debug)]
struct Sigmoid;
retro_unary!(RetroSigmoid, B::sigmoid);
impl<B: Backend> Backward<B, 1> for Sigmoid {
type State = NodeId;
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let input = checkpointer.retrieve_node_output(ops.state);
let output = B::sigmoid(input);
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
B::sigmoid_backward(output, grad)
});
}
}
match Sigmoid
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroSigmoid::<B>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
OpsKind::Tracked(mut prep) => {
let state = prep.checkpoint(&tensor);
prep.finish(state, B::sigmoid(tensor.primitive))
}
OpsKind::UnTracked(prep) => prep.finish(B::sigmoid(tensor.primitive)),
}
}
fn log_sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
#[derive(Debug)]
struct LogSigmoid;
retro_unary!(RetroLogSigmoid, B::log_sigmoid);
impl<B: Backend> Backward<B, 1> for LogSigmoid {
type State = NodeId;
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let input = checkpointer.retrieve_node_output(ops.state);
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
B::log_sigmoid_backward(input, grad)
});
}
}
match LogSigmoid
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroLogSigmoid::<B>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
OpsKind::Tracked(mut prep) => {
let state = prep.checkpoint(&tensor);
prep.finish(state, B::log_sigmoid(tensor.primitive.clone()))
}
OpsKind::UnTracked(prep) => prep.finish(B::log_sigmoid(tensor.primitive)),
}
}
}

View File

@@ -0,0 +1,88 @@
use super::{Ops, OpsPrep};
use crate::{
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder, strategy::CheckpointStrategy},
grads::Gradients,
graph::{ComputingProperty, NodeRef, Requirement},
utils::duplicate,
};
use burn_backend::Backend;
/// Trait for all operations.
///
/// # Notes
///
/// Concrete types implementing this trait should not have any state.
/// If a state is necessary during the backward pass,
/// they should be declared with the associated type 'State'.
pub trait Backward<B, const N: usize>: Send + core::fmt::Debug
where
Self: Sized + 'static,
B: Backend,
{
/// Associated type to compute the backward pass.
type State: Clone + Send + core::fmt::Debug + 'static;
/// The backward pass.
fn backward(
self,
ops: Ops<Self::State, N>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
);
/// Prepare the backward ops.
fn prepare<C: CheckpointStrategy>(
self,
nodes: [NodeRef; N],
) -> OpsPrep<Self, B, Self::State, C, N> {
let requirement = Requirement::from_nodes(&nodes);
OpsPrep::new(
nodes,
requirement,
self,
ComputingProperty::Ambiguous, // If not specified we start with ambiguous
CheckpointerBuilder::default(),
)
}
}
/// Execute a binary operation during the backward step.
pub fn binary<B, FLhs, FRhs>(
parents: [Option<NodeRef>; 2],
node: NodeRef,
grads: &mut Gradients,
func_lhs: FLhs,
func_rhs: FRhs,
) where
B: Backend,
FLhs: FnOnce(B::FloatTensorPrimitive) -> B::FloatTensorPrimitive,
FRhs: FnOnce(B::FloatTensorPrimitive) -> B::FloatTensorPrimitive,
{
let [grad_4lhs, grad_4rhs] = duplicate(&parents, Some(grads.consume::<B>(&node)));
let [node_lhs, node_rhs] = parents;
if let Some(node) = node_lhs {
let grad = func_lhs(grad_4lhs.unwrap());
grads.register::<B>(node.id, grad)
}
if let Some(node) = node_rhs {
let grad = func_rhs(grad_4rhs.unwrap());
grads.register::<B>(node.id, grad)
}
}
/// Execute a unary operation during the backward step.
pub fn unary<B, F>(parents: [Option<NodeRef>; 1], node: NodeRef, grads: &mut Gradients, func: F)
where
B: Backend,
F: FnOnce(B::FloatTensorPrimitive) -> B::FloatTensorPrimitive,
{
let [parent_node] = parents;
let grad = grads.consume::<B>(&node);
if let Some(node) = parent_node {
let grad = func(grad);
grads.register::<B>(node.id, grad)
}
}

View File

@@ -0,0 +1,317 @@
use super::Backward;
use crate::{
checkpoint::{
base::Checkpointer,
builder::{ActionType, CheckpointerBuilder},
retro_forward::RetroForward,
strategy::CheckpointStrategy,
},
grads::Gradients,
graph::{ComputingProperty, NodeId, NodeRef, Parent, Requirement, Step},
tensor::AutodiffTensor,
};
use alloc::boxed::Box;
use burn_backend::{Backend, TensorMetadata, tensor::FloatTensor};
use burn_std::Shape;
use core::marker::PhantomData;
/// Operation in preparation.
///
/// Each mode has its own set of functions to minimize cloning for unused backward states.
#[derive(new)]
pub struct OpsPrep<Backward, B, S, C, const N: usize, Mode = Init> {
nodes: [NodeRef; N],
requirement: Requirement,
backward: Backward,
compute_property: ComputingProperty,
checkpointer_builder: CheckpointerBuilder,
checkpoint_strategy: PhantomData<C>,
phantom_backend: PhantomData<B>,
phantom_state: PhantomData<S>,
marker: PhantomData<Mode>,
}
/// Operation is initialized
pub struct Init;
/// Operation has been tagged as memory bound
pub struct MemoryBound;
/// Memory bound operation has received its RetroForward
pub struct MemoryBoundRetroForward;
/// Operation's compute property is fixed
pub struct ComputePropertyDone;
/// Tracked operation tag.
pub struct Tracked;
/// Untracked operation tag.
pub struct UnTracked;
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, Init>
where
B: Backend,
BO: Backward<B, N, State = S>,
{
/// Indicates that the operation is compute bound, meaning its computation
/// is heavy and should not be recomputed
pub fn compute_bound(self) -> OpsPrep<BO, B, S, C, N, ComputePropertyDone> {
OpsPrep::new(
self.nodes,
self.requirement,
self.backward,
ComputingProperty::ComputeBound,
self.checkpointer_builder,
)
}
/// Indicates that the operation is memory bound, meaning its computation
/// is light and can be recomputed
pub fn memory_bound(self) -> OpsPrep<BO, B, S, C, N, MemoryBound> {
OpsPrep::new(
self.nodes,
self.requirement,
self.backward,
self.compute_property,
self.checkpointer_builder,
)
}
}
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, MemoryBound>
where
B: Backend,
BO: Backward<B, N, State = S>,
C: CheckpointStrategy,
{
/// Registers the retro forward, if needed
pub fn retro_forward<R: RetroForward>(
self,
retro_forward: R,
) -> OpsPrep<BO, B, S, C, N, MemoryBoundRetroForward> {
OpsPrep::new(
self.nodes,
self.requirement,
self.backward,
C::compute_property(retro_forward),
self.checkpointer_builder,
)
}
}
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, MemoryBoundRetroForward>
where
B: Backend,
BO: Backward<B, N, State = S>,
C: CheckpointStrategy,
{
/// Checkpoints the parents, if needed
pub fn parents<'a, B2, A>(mut self, parents: A) -> OpsPrep<BO, B, S, C, N, ComputePropertyDone>
where
B2: Backend,
A: IntoIterator<Item = &'a AutodiffTensor<B2>>,
{
let compute_property = match C::checkpoint_parents(parents, &mut self.checkpointer_builder)
{
Ok(..) => self.compute_property,
Err(..) => ComputingProperty::ComputeBound,
};
OpsPrep::new(
self.nodes,
self.requirement,
self.backward,
compute_property,
self.checkpointer_builder,
)
}
}
impl<BO, B, C, const N: usize> OpsPrep<BO, B, (), C, N, ComputePropertyDone>
where
B: Backend,
BO: Backward<B, N, State = ()>,
{
/// Prepare a stateless operation.
pub fn stateless(self, output: FloatTensor<B>) -> AutodiffTensor<B> {
match self.stateful() {
OpsKind::Tracked(prep) => prep.finish((), output),
OpsKind::UnTracked(prep) => prep.finish(output),
}
}
}
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, ComputePropertyDone>
where
B: Backend,
S: Clone + Send + core::fmt::Debug + 'static,
BO: Backward<B, N, State = S>,
{
/// Prepare an operation that requires a state during the backward pass.
pub fn stateful(self) -> OpsKind<BO, B, S, C, N> {
match self.requirement.is_none() {
false => OpsKind::Tracked(OpsPrep::new(
self.nodes,
self.requirement,
self.backward,
self.compute_property,
self.checkpointer_builder,
)),
true => OpsKind::UnTracked(OpsPrep::new(
self.nodes,
self.requirement,
self.backward,
self.compute_property,
self.checkpointer_builder,
)),
}
}
}
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, UnTracked>
where
B: Backend,
S: Clone + Send + core::fmt::Debug + 'static,
BO: Backward<B, N, State = S>,
{
/// Finish the preparation of an untracked operation and returns the output tensor.
pub fn finish(self, output: FloatTensor<B>) -> AutodiffTensor<B> {
let output = AutodiffTensor::from_parents(
output,
&self.nodes,
self.requirement,
self.compute_property,
);
let parents = self.nodes.map(|node| node.clone_if_require_grad());
let ops = Ops::new(parents, output.node.clone(), ());
// We register the ops in the graph even if untracked, otherwise memory bound operations
// that have an untracked parent would not be able to retrieve it
output.register_step(UntrackedOpsStep::new(ops), self.checkpointer_builder)
}
}
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, Tracked>
where
B: Backend,
S: Clone + Send + core::fmt::Debug + 'static,
BO: Backward<B, N, State = S>,
{
/// Finish the preparation of a tracked operation and returns the output tensor.
pub fn finish(self, state: S, output: FloatTensor<B>) -> AutodiffTensor<B> {
let output = AutodiffTensor::from_parents(
output,
&self.nodes,
self.requirement,
self.compute_property,
);
let parents = self.nodes.map(|node| node.clone_if_require_grad());
let ops = Ops::new(parents, output.node.clone(), state);
output.register_step(OpsStep::new(ops, self.backward), self.checkpointer_builder)
}
/// Checkpoints the tensor
pub fn checkpoint(&mut self, tensor: &AutodiffTensor<B>) -> NodeId {
self.checkpointer_builder
.checkpoint(tensor, ActionType::Explicit);
tensor.node.id
}
}
/// Enum used before finishing tracked and untracked operations.
pub enum OpsKind<BO, B, S, C, const N: usize> {
/// Tracked operation preparation.
Tracked(OpsPrep<BO, B, S, C, N, Tracked>),
/// Untracked operation preparation.
UnTracked(OpsPrep<BO, B, S, C, N, UnTracked>),
}
/// Operation containing its parent nodes, its own node and the backward step state.
#[derive(new, Debug)]
pub struct Ops<S, const N: usize> {
/// Parents nodes.
pub parents: [Option<NodeRef>; N],
/// The node.
pub node: NodeRef,
/// The state.
pub state: S,
}
/// Operation implementing backward [step](Step) with type erasing.
#[derive(new, Debug)]
struct OpsStep<B, T, SB, const N: usize>
where
B: Backend,
T: Backward<B, N, State = SB>,
SB: Clone + Send + core::fmt::Debug + 'static,
{
ops: Ops<SB, N>,
backward: T,
phantom: PhantomData<B>,
}
impl<B, T, SB, const N: usize> Step for OpsStep<B, T, SB, N>
where
B: Backend,
T: Backward<B, N, State = SB>,
SB: Clone + Send + core::fmt::Debug + 'static,
{
fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer) {
self.backward.backward(self.ops, grads, checkpointer);
}
fn node(&self) -> NodeId {
self.ops.node.id
}
fn parents(&self) -> &[Parent] {
&self.ops.node.parents
}
fn depth(&self) -> usize {
self.ops.node.order
}
}
#[derive(new, Debug)]
struct UntrackedOpsStep<const N: usize> {
ops: Ops<(), N>,
}
impl<const N: usize> Step for UntrackedOpsStep<N> {
fn step(self: Box<Self>, _grads: &mut Gradients, _checkpointer: &mut Checkpointer) {
// Nothing to do
}
fn node(&self) -> NodeId {
self.ops.node.id
}
fn parents(&self) -> &[Parent] {
&self.ops.node.parents
}
fn depth(&self) -> usize {
self.ops.node.order
}
}
/// Make sure the grad tensor has the given shape.
///
/// If broadcasting happened during the forward pass, the gradients will be sum along the
/// broadcasted dimension.
pub fn broadcast_shape<B: Backend>(mut grad: FloatTensor<B>, shape: &Shape) -> FloatTensor<B> {
let shape_grad = grad.shape();
let ndims = shape_grad.num_dims();
for i in 0..ndims {
if shape_grad[i] != shape[i] {
if shape[i] != 1 {
panic!(
"Invalid broadcast shapes: Next grad shape {:?}, Previous grad shape {:?}. {}",
shape, shape_grad, "Expected the shape of the next grad to be 1."
);
}
grad = B::float_sum_dim(grad, i);
}
}
grad
}

View File

@@ -0,0 +1,161 @@
use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor};
use alloc::vec::Vec;
use burn_backend::{
Backend, ExecutionError, Scalar, TensorData,
ops::BoolTensorOps,
tensor::{BoolTensor, Device, IntTensor},
};
use burn_std::Shape;
impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B> {
B::bool_from_data(data, device)
}
async fn bool_into_data(tensor: BoolTensor<B>) -> Result<TensorData, ExecutionError> {
B::bool_into_data(tensor).await
}
fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B> {
B::bool_into_int(tensor)
}
fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B> {
B::bool_to_device(tensor, device)
}
fn bool_device(tensor: &BoolTensor<B>) -> Device<B> {
B::bool_device(tensor)
}
fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B> {
B::bool_reshape(tensor, shape)
}
fn bool_slice(tensor: BoolTensor<B>, slices: &[burn_std::Slice]) -> BoolTensor<B> {
B::bool_slice(tensor, slices)
}
fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B> {
B::bool_empty(shape, device)
}
fn bool_zeros(shape: Shape, device: &Device<B>) -> BoolTensor<B> {
B::bool_zeros(shape, device)
}
fn bool_ones(shape: Shape, device: &Device<B>) -> BoolTensor<B> {
B::bool_ones(shape, device)
}
fn bool_slice_assign(
tensor: BoolTensor<Self>,
slices: &[burn_std::Slice],
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
B::bool_slice_assign(tensor, slices, value)
}
fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
B::bool_cat(tensors, dim)
}
fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
B::bool_equal(lhs, rhs)
}
fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B> {
B::bool_not(tensor)
}
fn bool_and(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
B::bool_and(lhs, rhs)
}
fn bool_or(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
B::bool_or(lhs, rhs)
}
fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
B::bool_xor(lhs, rhs)
}
fn bool_into_float(tensor: BoolTensor<B>) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {
AutodiffTensor::new(B::bool_into_float(tensor))
}
fn bool_swap_dims(
tensor: <Autodiff<B> as Backend>::BoolTensorPrimitive,
dim1: usize,
dim2: usize,
) -> <Autodiff<B> as Backend>::BoolTensorPrimitive {
B::bool_swap_dims(tensor, dim1, dim2)
}
fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
B::bool_permute(tensor, axes)
}
fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B> {
B::bool_flip(tensor, axes)
}
async fn bool_argwhere(tensor: BoolTensor<B>) -> IntTensor<B> {
B::bool_argwhere(tensor).await
}
fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B> {
B::bool_expand(tensor, shape)
}
fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
B::bool_repeat_dim(tensor, dim, times)
}
fn bool_unfold(
tensor: BoolTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> BoolTensor<Self> {
B::bool_unfold(tensor, dim, size, step)
}
fn bool_mask_where(
tensor: BoolTensor<Self>,
mask: BoolTensor<Self>,
source: BoolTensor<Self>,
) -> BoolTensor<Self> {
B::bool_mask_where(tensor, mask, source)
}
fn bool_mask_fill(
tensor: BoolTensor<Self>,
mask: BoolTensor<Self>,
value: Scalar,
) -> BoolTensor<Self> {
B::bool_mask_fill(tensor, mask, value)
}
fn bool_gather(
dim: usize,
tensor: BoolTensor<Self>,
indices: IntTensor<Self>,
) -> BoolTensor<Self> {
B::bool_gather(dim, tensor, indices)
}
fn bool_scatter_or(
dim: usize,
tensor: BoolTensor<Self>,
indices: IntTensor<Self>,
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
B::bool_scatter_or(dim, tensor, indices, value)
}
fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
B::bool_equal_elem(lhs, rhs)
}
}

View File

@@ -0,0 +1,406 @@
use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor};
use alloc::vec::Vec;
use burn_backend::{
Backend, Distribution, ExecutionError, Scalar, TensorData,
ops::IntTensorOps,
tensor::{BoolTensor, Device, IntTensor},
};
use burn_std::{IntDType, Shape};
impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<B> {
B::int_from_data(data, device)
}
async fn int_into_data(tensor: IntTensor<B>) -> Result<TensorData, ExecutionError> {
B::int_into_data(tensor).await
}
fn int_to_device(tensor: IntTensor<B>, device: &Device<Self>) -> IntTensor<B> {
B::int_to_device(tensor, device)
}
fn int_device(tensor: &IntTensor<B>) -> Device<Self> {
B::int_device(tensor)
}
fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B> {
B::int_reshape(tensor, shape)
}
fn int_slice(tensor: IntTensor<B>, slices: &[burn_std::Slice]) -> IntTensor<B> {
B::int_slice(tensor, slices)
}
fn int_empty(
shape: Shape,
device: &<Autodiff<B> as Backend>::Device,
dtype: IntDType,
) -> IntTensor<B> {
B::int_empty(shape, device, dtype)
}
fn int_slice_assign(
tensor: IntTensor<B>,
slices: &[burn_std::Slice],
value: IntTensor<B>,
) -> IntTensor<B> {
B::int_slice_assign(tensor, slices, value)
}
fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
B::int_cat(tensors, dim)
}
fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
B::int_equal(lhs, rhs)
}
fn int_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
B::int_equal_elem(lhs, rhs)
}
fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
B::int_add(lhs, rhs)
}
fn int_add_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
B::int_add_scalar(lhs, rhs)
}
fn int_clamp_min(tensor: IntTensor<B>, min: Scalar) -> IntTensor<B> {
B::int_clamp_min(tensor, min)
}
fn int_clamp_max(tensor: IntTensor<B>, max: Scalar) -> IntTensor<B> {
B::int_clamp_max(tensor, max)
}
fn int_clamp(tensor: IntTensor<B>, min: Scalar, max: Scalar) -> IntTensor<B> {
B::int_clamp(tensor, min, max)
}
fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
B::int_sub(lhs, rhs)
}
fn int_sub_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
B::int_sub_scalar(lhs, rhs)
}
fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
B::int_mul(lhs, rhs)
}
fn int_mul_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
B::int_mul_scalar(lhs, rhs)
}
fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
B::int_div(lhs, rhs)
}
fn int_div_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
B::int_div_scalar(lhs, rhs)
}
fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
B::int_remainder(lhs, rhs)
}
fn int_remainder_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
B::int_remainder_scalar(lhs, rhs)
}
fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
B::int_matmul(lhs, rhs)
}
fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
B::int_neg(tensor)
}
fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<B> {
B::int_zeros(shape, device, dtype)
}
fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<B> {
B::int_ones(shape, device, dtype)
}
fn int_full(
shape: Shape,
fill_value: Scalar,
device: &Device<Self>,
dtype: IntDType,
) -> IntTensor<B> {
B::int_full(shape, fill_value, device, dtype)
}
fn int_sum(tensor: IntTensor<B>) -> IntTensor<B> {
B::int_sum(tensor)
}
fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_sum_dim(tensor, dim)
}
fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
B::int_mean(tensor)
}
fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_mean_dim(tensor, dim)
}
fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_cumsum(tensor, dim)
}
fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_cumprod(tensor, dim)
}
fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_cummin(tensor, dim)
}
fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_cummax(tensor, dim)
}
fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
B::int_repeat_dim(tensor, dim, times)
}
fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
B::int_greater(lhs, rhs)
}
fn int_greater_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
B::int_greater_elem(lhs, rhs)
}
fn int_greater_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
B::int_greater_equal(lhs, rhs)
}
fn int_greater_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
B::int_greater_equal_elem(lhs, rhs)
}
fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
B::int_lower(lhs, rhs)
}
fn int_lower_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
B::int_lower_elem(lhs, rhs)
}
fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
B::int_lower_equal(lhs, rhs)
}
fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
B::int_lower_equal_elem(lhs, rhs)
}
fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B> {
B::int_gather(dim, tensor, indices)
}
fn int_scatter_add(
dim: usize,
tensor: IntTensor<B>,
indices: IntTensor<B>,
value: IntTensor<B>,
) -> IntTensor<B> {
B::int_scatter_add(dim, tensor, indices, value)
}
fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B> {
B::int_select(tensor, dim, indices)
}
fn int_select_add(
tensor: IntTensor<B>,
dim: usize,
indices: IntTensor<B>,
value: IntTensor<B>,
) -> IntTensor<B> {
B::int_select_add(tensor, dim, indices, value)
}
fn int_mask_where(
tensor: IntTensor<B>,
mask: BoolTensor<B>,
value: IntTensor<B>,
) -> <Autodiff<B> as Backend>::IntTensorPrimitive {
B::int_mask_where(tensor, mask, value)
}
fn int_mask_fill(
tensor: IntTensor<B>,
mask: BoolTensor<B>,
value: Scalar,
) -> <Autodiff<B> as Backend>::IntTensorPrimitive {
B::int_mask_fill(tensor, mask, value)
}
fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_argmax(tensor, dim)
}
fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_argmin(tensor, dim)
}
fn int_max(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
B::int_max(tensor)
}
fn int_max_dim(tensor: B::IntTensorPrimitive, dim: usize) -> B::IntTensorPrimitive {
B::int_max_dim(tensor, dim)
}
fn int_max_dim_with_indices(
tensor: B::IntTensorPrimitive,
dim: usize,
) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) {
B::int_max_dim_with_indices(tensor, dim)
}
fn int_min(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
B::int_min(tensor)
}
fn int_min_dim(tensor: B::IntTensorPrimitive, dim: usize) -> B::IntTensorPrimitive {
B::int_min_dim(tensor, dim)
}
fn int_min_dim_with_indices(
tensor: B::IntTensorPrimitive,
dim: usize,
) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) {
B::int_min_dim_with_indices(tensor, dim)
}
fn int_abs(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
B::int_abs(tensor)
}
fn int_into_float(
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive,
) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {
AutodiffTensor::new(B::int_into_float(tensor))
}
fn int_swap_dims(
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive,
dim1: usize,
dim2: usize,
) -> <Autodiff<B> as Backend>::IntTensorPrimitive {
B::int_swap_dims(tensor, dim1, dim2)
}
fn int_random(
shape: Shape,
distribution: Distribution,
device: &Device<Self>,
) -> IntTensor<Self> {
B::int_random(shape, distribution, device)
}
fn int_arange(range: core::ops::Range<i64>, device: &Device<Self>) -> IntTensor<Self> {
B::int_arange(range, device)
}
fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
B::int_permute(tensor, axes)
}
fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
B::int_flip(tensor, axes)
}
fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
B::int_sign(tensor)
}
fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
B::int_prod(tensor)
}
fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
B::int_prod_dim(tensor, dim)
}
fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B> {
B::int_expand(tensor, shape)
}
fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
B::int_sort(tensor, dim, descending)
}
fn int_sort_with_indices(
tensor: IntTensor<Self>,
dim: usize,
descending: bool,
) -> (IntTensor<Self>, IntTensor<Self>) {
B::int_sort_with_indices(tensor, dim, descending)
}
fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
B::int_argsort(tensor, dim, descending)
}
fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
B::bitwise_and(lhs, rhs)
}
fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
B::bitwise_and_scalar(lhs, rhs)
}
fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
B::bitwise_or(lhs, rhs)
}
fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
B::bitwise_or_scalar(lhs, rhs)
}
fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
B::bitwise_xor(lhs, rhs)
}
fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
B::bitwise_xor_scalar(lhs, rhs)
}
fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
B::bitwise_not(tensor)
}
fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
B::bitwise_left_shift(lhs, rhs)
}
fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
B::bitwise_left_shift_scalar(lhs, rhs)
}
fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
B::bitwise_right_shift(lhs, rhs)
}
fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
B::bitwise_right_shift_scalar(lhs, rhs)
}
fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
B::int_cast(tensor, dtype)
}
fn int_unfold(
tensor: IntTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> IntTensor<Self> {
B::int_unfold(tensor, dim, size, step)
}
}

View File

@@ -0,0 +1,27 @@
use super::{Backward, Ops, unary};
use crate::{checkpoint::base::Checkpointer, grads::Gradients};
use burn_backend::{Backend, TensorMetadata};
use burn_std::Shape;
#[derive(Debug)]
pub(crate) struct MaxMinDim;
impl<B: Backend> Backward<B, 1> for MaxMinDim {
type State = (B::IntTensorPrimitive, Shape, usize);
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
let (indices, shape, dim) = ops.state;
let device = B::float_device(&grad);
let dtype = grad.dtype();
let zeros = B::float_zeros(shape, &device, dtype.into());
B::float_scatter_add(dim, zeros, indices, grad)
});
}
}

View File

@@ -0,0 +1,15 @@
mod activation;
mod backward;
mod base;
mod bool_tensor;
mod int_tensor;
mod module;
mod qtensor;
mod tensor;
mod transaction;
pub(crate) mod maxmin;
pub(crate) mod sort;
pub use backward::*;
pub use base::*;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,106 @@
use burn_backend::{
Backend, ExecutionError, TensorData,
ops::QTensorOps,
tensor::{
Device, FloatTensor, IntTensor, QuantizedTensor,
quantization::QuantizationParametersPrimitive,
},
};
use burn_std::{QuantScheme, Shape};
use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy};
impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {
fn q_from_data(_data: TensorData, _device: &Device<Self>) -> QuantizedTensor<Self> {
todo!()
}
fn quantize(
_tensor: FloatTensor<Self>,
_scheme: &QuantScheme,
_qparams: QuantizationParametersPrimitive<Self>,
) -> QuantizedTensor<Self> {
todo!() // required for QAT
}
fn quantize_dynamic(
_tensor: FloatTensor<Self>,
_scheme: &QuantScheme,
) -> QuantizedTensor<Self> {
todo!()
}
fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
todo!()
}
fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
B::q_device(tensor)
}
fn q_to_device(
_tensor: QuantizedTensor<Self>,
_device: &Device<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
B::q_reshape(tensor, shape)
}
async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
B::q_into_data(tensor).await
}
fn q_swap_dims(
_tensor: QuantizedTensor<Self>,
_dim1: usize,
_dim2: usize,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_gather(
_dim: usize,
_tensor: QuantizedTensor<Self>,
_indices: IntTensor<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_select(
_tensor: QuantizedTensor<Self>,
_dim: usize,
_indices: IntTensor<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_slice(
_tensor: QuantizedTensor<Self>,
_slices: &[burn_std::Slice],
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
B::q_argmax(tensor, dim)
}
fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
B::q_argmin(tensor, dim)
}
fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
unimplemented!()
}
}

View File

@@ -0,0 +1,27 @@
use super::{Backward, Ops, unary};
use crate::{checkpoint::base::Checkpointer, grads::Gradients};
use burn_backend::{Backend, TensorMetadata};
use burn_std::Shape;
#[derive(Debug)]
pub(crate) struct SortDim;
impl<B: Backend> Backward<B, 1> for SortDim {
type State = (B::IntTensorPrimitive, Shape, usize);
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
let (indices, shape, dim) = ops.state;
let device = B::float_device(&grad);
let dtype = grad.dtype();
let zeros = B::float_zeros(shape, &device, dtype.into());
B::float_scatter_add(dim, zeros, indices, grad)
});
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,24 @@
use burn_backend::{
Backend, ExecutionError,
ops::{TransactionOps, TransactionPrimitive},
};
use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy};
impl<B: Backend, C: CheckpointStrategy> TransactionOps<Self> for Autodiff<B, C> {
async fn tr_execute(
transaction: TransactionPrimitive<Self>,
) -> Result<burn_backend::ops::TransactionPrimitiveData, ExecutionError> {
B::tr_execute(TransactionPrimitive::new(
transaction
.read_floats
.into_iter()
.map(|t| t.primitive)
.collect(),
transaction.read_qfloats,
transaction.read_ints,
transaction.read_bools,
))
.await
}
}

View File

@@ -0,0 +1,18 @@
use crate::{
checkpoint::builder::CheckpointerBuilder,
grads::Gradients,
graph::StepBoxed,
tensor::{AutodiffTensor, NodeRefCount},
};
use burn_backend::Backend;
/// Client used to communicate with the autodiff server.
pub trait AutodiffClient: Send + Clone {
/// Register a new step.
fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder);
/// Call backpropagation from the given tensor.
fn backward<B: Backend>(&self, tensor: AutodiffTensor<B>) -> Gradients;
}
/// Client implementation in used.
pub type AutodiffClientImpl = super::graph::GraphMutexClient;

View File

@@ -0,0 +1,335 @@
use super::{AutodiffClient, server::AutodiffServer};
use crate::{
NodeId,
checkpoint::builder::CheckpointerBuilder,
grads::Gradients,
graph::{Parent, StepBoxed},
runtime::server::NodeCleaner,
tensor::{AutodiffTensor, NodeRefCount},
};
use alloc::sync::Arc;
use alloc::vec::Vec;
use burn_backend::Backend;
use hashbrown::{HashMap, HashSet};
#[cfg(feature = "std")]
use parking_lot::{Mutex, MutexGuard};
#[cfg(not(feature = "std"))]
use spin::{Mutex, MutexGuard};
/// A client for managing multiple graphs using mutex-based synchronization.
///
/// The biggest benefit of using this client implementation is that each graph can modify its own
/// data without blocking other graphs, which is essential for multi-device training.
///
/// # Notes
///
/// The [AutodiffServer] fully supports multiple graphs with sharing nodes, however those type of
/// graphs will be stored under a single mutex-protected graph by the client, limiting
/// parallelisation.
#[derive(Clone, new, Debug)]
pub struct GraphMutexClient;
/// Manages a collection of graphs, mapping [node ids](NodeId) to their respective graph.
///
/// The `GraphLocator` is responsible for selecting and merging graphs based on their IDs and parent
/// dependencies, ensuring proper synchronization and server allocation.
///
/// # Notes
///
/// Multiple node ids can point to the same graph, where the autodiff graph is stored.
#[derive(Default)]
pub struct GraphLocator {
graphs: HashMap<NodeId, Arc<Graph>>,
/// We keep a mapping of each original node id (graph id) => all nodes that point to that graph.
/// This is to ensure that when merging graphs, we correctly move all previous graphs to
/// the new merged one.
keys: HashMap<NodeId, HashSet<NodeId>>,
}
/// Represents a single computation graph with a mutex-protected server.
///
/// Each `Graph` contains an [AutodiffServer] and the original [NodeId] where the server was
/// first created.
pub(crate) struct Graph {
origin: NodeId,
state: Mutex<GraphState>,
}
#[derive(Default)]
struct GraphState {
server: AutodiffServer,
}
impl core::fmt::Debug for Graph {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Graph")
.field("origin", &self.origin)
.finish()
}
}
static STATE: Mutex<Option<GraphLocator>> = Mutex::new(None);
impl GraphMutexClient {
/// Retrieves or creates a graph for the given [NodeId] and parent dependencies.
///
/// # Parameters
/// - `node`: The unique identifier for the stream.
/// - `parents`: A slice of parent nodes that the stream depends on.
///
/// # Returns
/// An `Arc<Graph>` representing the selected or newly created stream.
fn graph(node: NodeId, parents: &[Parent]) -> Arc<Graph> {
let mut state = STATE.lock();
match state.as_mut() {
Some(locator) => locator.select(node, parents),
None => {
let mut locator = GraphLocator::default();
let stream = locator.select(node, parents);
*state = Some(locator);
stream
}
}
}
}
impl AutodiffClient for GraphMutexClient {
fn register(&self, node_id_ref: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {
let node_id = *node_id_ref;
let graph = GraphMutexClient::graph(node_id, step.parents());
let mut state = graph.state.lock();
state.server.register(node_id_ref, step, actions);
}
fn backward<B: Backend>(&self, root: AutodiffTensor<B>) -> Gradients {
let node_id = root.node.id;
let graph = GraphMutexClient::graph(root.node.id, &[]);
let grads = Gradients::new::<B>(root.node, root.primitive);
let grads = {
let mut state = graph.state.lock();
state.server.backward::<GraphCleaner>(grads, node_id)
}; // lock released
GraphCleaner::cleanup_orphaned_entries();
grads
}
}
struct GraphCleaner<'a> {
guard: MutexGuard<'a, Option<GraphLocator>>,
}
impl<'a> GraphCleaner<'a> {
fn cleanup_orphaned_entries() {
let graphs = {
// Get the available graphs and release the lock
match STATE.lock().as_ref() {
Some(state) => state.graphs.clone(),
None => return,
}
};
let mut should_remove = Vec::new();
for graph in graphs.values() {
{
let mut guard = graph.state.lock();
// Double safety: in case it was marked as no longer useful, but other
// nodes are still relevant, we only check which nodes can safely be removed.
if !guard.server.maybe_useful() {
guard
.server
.free_unused_roots(|node| should_remove.push(*node));
}
}
}
if !should_remove.is_empty() {
let mut state = STATE.lock();
if let Some(state) = state.as_mut() {
for node in should_remove {
state.remove_entry(&node);
}
}
}
}
}
impl<'a> NodeCleaner for GraphCleaner<'a> {
fn init() -> Self {
let guard = STATE.lock();
Self { guard }
}
fn clean(&mut self, node: &NodeId) {
if let Some(state) = self.guard.as_mut() {
state.remove_entry(node);
}
}
}
impl GraphLocator {
/// Selects a single graph for the given [NodeId], considering parent dependencies.
///
/// If multiple graphs are found, they are merged into a single one.
///
/// # Parameters
/// - `node`: The node ID of the graph to select.
/// - `parents`: A slice of parent nodes that the graph depends on.
///
/// # Returns
///
/// An `Arc<Graph>` representing the selected or merged graph.
pub(crate) fn select(&mut self, node: NodeId, parents: &[Parent]) -> Arc<Graph> {
match self.analyse(node, parents) {
GraphAnalysis::NoCollision(graph) => {
if graph.origin != node {
self.graphs.insert(node, graph.clone());
self.register_key(graph.origin, node);
}
graph
}
GraphAnalysis::Collisions(graphs) => self.merge(node, graphs),
}
}
/// Analyses the graph for a given node and its parents, returning the associated `GraphAnalysis`.
fn analyse(&mut self, node: NodeId, parents: &[Parent]) -> GraphAnalysis {
// If no parents, there is no collision, therefore a single graph is ok.
if parents.is_empty() {
let graph = match self.graphs.get(&node) {
Some(val) => val.clone(),
None => self.new_graph(node),
};
return GraphAnalysis::NoCollision(graph);
};
// We collect all graphs of parents and of the current node based on their origin node id.
let mut graphs = HashMap::<NodeId, Arc<Graph>>::new();
if let Some(val) = self.graphs.get(&node) {
graphs.insert(val.origin, val.clone());
}
for parent in parents {
match self.graphs.get(&parent.id) {
Some(graph) => graphs.insert(graph.origin, graph.clone()),
None => continue,
};
}
if graphs.is_empty() {
return match self.graphs.get(&node) {
Some(old) => GraphAnalysis::NoCollision(old.clone()),
None => GraphAnalysis::NoCollision(self.new_graph(node)),
};
}
if graphs.len() == 1 {
return GraphAnalysis::NoCollision(graphs.drain().next().unwrap().1);
}
GraphAnalysis::Collisions(graphs)
}
/// Merges multiple graphs associated with a node into a single graph.
fn merge(&mut self, node: NodeId, mut graphs: HashMap<NodeId, Arc<Graph>>) -> Arc<Graph> {
let mut graphs = graphs.drain().map(|g| g.1);
let main = graphs.next().expect("At least one graph");
self.register_key(main.origin, node);
let mut state = main.state.lock();
for graph in graphs {
self.merge_two(&mut state, &main, graph);
}
self.graphs.insert(main.origin, main.clone());
self.graphs.insert(node, main.clone());
core::mem::drop(state);
main
}
/// Registers a key for a given origin node.
fn register_key(&mut self, origin: NodeId, key: NodeId) {
if !self.keys.contains_key(&origin) {
// Ensure an entry exists for this origin
self.keys.insert(origin, HashSet::new());
}
if origin != key {
// Register this node to point to the origin graph
self.keys.get_mut(&origin).unwrap().insert(key);
}
}
/// Merges two graphs by combining their states and updating graph mappings.
fn merge_two(&mut self, main_state: &mut GraphState, main: &Arc<Graph>, merged: Arc<Graph>) {
let mut locked = merged.state.lock();
let mut state_old = GraphState::default();
core::mem::swap(&mut state_old, &mut locked);
main_state.server.extend(state_old.server);
// Re-map merged origin to the main graph
self.graphs.insert(merged.origin, main.clone());
// Move all keys (node IDs) from the merged graph to the main graph
if let Some(locator_keys) = self.keys.remove(&merged.origin) {
for k in locator_keys.iter() {
self.graphs.insert(*k, main.clone());
}
let locator_keys_main = self
.keys
.get_mut(&main.origin)
.expect("Should be init before the merge.");
locator_keys_main.extend(locator_keys);
}
}
/// Creates a new graph for a given node.
fn new_graph(&mut self, origin: NodeId) -> Arc<Graph> {
let graph = Arc::new(Graph {
origin,
state: Mutex::new(GraphState::default()),
});
self.graphs.insert(origin, graph.clone());
self.keys.insert(origin, HashSet::new());
graph
}
fn remove_entry(&mut self, node: &NodeId) {
if let Some(graph) = self.graphs.remove(node) {
let mut remove = false;
if let Some(entry) = self.keys.get_mut(&graph.origin) {
entry.remove(node);
if entry.is_empty() {
remove = true;
}
}
if remove {
self.keys.remove(&graph.origin);
}
}
}
}
/// Represents the analysis result of graph operations for a given node and its parents.
#[derive(Debug)]
enum GraphAnalysis {
/// No collision detected, contains the graph associated with the node.
NoCollision(Arc<Graph>),
/// Collision detected, contains a map of node IDs to their associated graphs.
Collisions(HashMap<NodeId, Arc<Graph>>),
}

View File

@@ -0,0 +1,294 @@
use crate::{
NodeId,
collections::{HashMap, HashSet},
graph::Parent,
tensor::NodeRefCount,
};
use alloc::{borrow::ToOwned, sync::Arc, vec, vec::Vec};
use core::mem;
#[derive(Default, Debug)]
pub struct GraphMemoryManagement {
nodes: HashMap<NodeRefCount, Vec<NodeId>>,
leaves: HashSet<NodeId>,
statuses: HashMap<NodeId, NodeMemoryStatus>,
}
#[derive(Debug, Clone, PartialEq)]
enum NodeMemoryStatus {
Useful,
Unavailable,
Unknown,
}
impl GraphMemoryManagement {
pub fn extend(&mut self, other: Self) {
self.nodes.extend(other.nodes);
self.leaves.extend(other.leaves);
self.statuses.extend(other.statuses);
}
/// Register a new node with its parent.
pub fn register(&mut self, node: NodeRefCount, parents: &[Parent]) {
let node_id = *node.as_ref();
for parent in parents.iter() {
self.leaves.remove(&parent.id);
}
self.leaves.insert(node_id);
self.nodes
.insert(node, parents.iter().map(|p| p.id).collect());
}
/// Free the node from the state.
pub fn consume_node(&mut self, node_id: NodeId) {
if !self.is_referenced(node_id) {
self.leaves.remove(&node_id);
self.nodes.remove(&node_id);
}
}
/// Free all nodes whose backward call has become impossible
///
/// This function goes into three steps, which must happen for all leaves
/// before going into the next step. Then it deletes what can be safely deleted
pub(crate) fn free_unavailable_nodes(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
let leaves = self.leaves.clone();
let mut new_leaves = HashSet::new();
let mut deletables = Vec::new();
// When consuming nodes with a backward pass, some other backward passes become
// unavailable because some of their parents have been consumed. They are
// identified here.
for leaf in leaves.clone() {
self.unavailable_propagation(leaf);
}
// Among the available nodes that remain, some may be useless if no
// available node with a tensor reference exist in their descendance.
// But some may seem useless from some leaf but be useful from another one,
// hence the need to iterate on all leaves.
self.useful_propagation(leaves.clone());
// New leaves are the roots of a useful backward sub-tree.
// Deletables are everything not marked as useful.
for leaf in leaves {
self.identify_leaves_and_deletables(leaf, &mut new_leaves, &mut deletables);
}
// Replace leaves by the new ones and delete everything not useful anymore
mem::swap(&mut self.leaves, &mut new_leaves);
self.clear_unused_roots(&mut deletables);
self.statuses.clear();
for node_to_delete in deletables {
self.nodes.remove(&node_to_delete);
on_free_graph(&node_to_delete)
}
}
pub(crate) fn free_unused_roots(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
let mut deletables = Vec::new();
self.clear_unused_roots(&mut deletables);
for node_id in deletables {
self.nodes.remove(&node_id);
on_free_graph(&node_id);
}
}
fn clear_unused_roots(&self, to_delete: &mut Vec<NodeId>) {
for (id, parents) in self.nodes.iter() {
let is_useful = matches!(
self.statuses.get(id.as_ref()),
Some(NodeMemoryStatus::Useful)
);
// Check if parents are either empty or absent from self.nodes
let parents_absent = parents.iter().all(|p| !self.nodes.contains_key(p));
if !is_useful && Arc::strong_count(id) == 1 && parents_absent {
to_delete.push(*id.as_ref())
}
}
}
fn unavailable_propagation(&mut self, node_id: NodeId) -> NodeMemoryStatus {
// If already visited
if let Some(status) = self.statuses.get(&node_id) {
return status.clone();
}
match self.nodes.get(&node_id).cloned() {
// If node exists and any of its parents is unavailable, it is unavailable as well
// If node exists but the parents vec is empty, it is a tensor that never had parents;
// the status remains unknown
Some(parents) => {
let mut node_status = NodeMemoryStatus::Unknown;
for parent in parents {
let parent_status = self.unavailable_propagation(parent);
if let NodeMemoryStatus::Unavailable = parent_status {
node_status = NodeMemoryStatus::Unavailable;
}
}
self.statuses.insert(node_id, node_status.clone());
node_status
}
// If node does not exist, it was
// deleted, so this and all its descendants are unavailable
None => {
self.statuses.insert(node_id, NodeMemoryStatus::Unavailable);
NodeMemoryStatus::Unavailable
}
}
}
fn useful_propagation(&mut self, leaves: HashSet<NodeId>) {
// Accumulate visited nodes
let mut explored = HashSet::new();
let mut tagged_useful = HashSet::new();
// Queue of nodes to visit
let mut to_tag_useful = PopNodeSet::default();
let mut to_explore = PopNodeSet::new(leaves);
// Utility function to iterate over a node's parents
let parents = |node_id| {
self.nodes
.get(&node_id)
.cloned()
.unwrap_or_default()
.into_iter()
};
loop {
// Pop a node id, greedily looking at tag_useful ones first
let (node_id, status) = match to_tag_useful.pop() {
Some(node_id) => (node_id, NodeMemoryStatus::Useful),
None => match to_explore.pop() {
Some(node_id) => {
let node_status = self
.statuses
.get(&node_id)
.expect("All nodes should have received a status during unavailable_propagation")
.to_owned();
if let NodeMemoryStatus::Unknown = node_status {
match self.is_referenced(node_id) {
true => (node_id, NodeMemoryStatus::Useful),
false => (node_id, NodeMemoryStatus::Unknown),
}
} else {
(node_id, node_status)
}
}
None => {
// There are no nodes in the queues anymore
break;
}
},
};
match status {
NodeMemoryStatus::Useful => {
tagged_useful.insert(node_id);
for parent in parents(node_id) {
// The node can be explored, as long as it's not already tagged useful
if !(tagged_useful.contains(&parent) || to_tag_useful.contains(&parent)) {
to_tag_useful.insert(parent);
}
}
}
_ => {
explored.insert(node_id);
for parent in parents(node_id) {
if !(explored.contains(&parent) || to_explore.contains(&parent)) {
to_explore.insert(parent);
}
}
}
}
self.statuses.insert(node_id, status);
}
}
fn identify_leaves_and_deletables(
&self,
leaf_id: NodeId,
new_leaves: &mut HashSet<NodeId>,
to_delete: &mut Vec<NodeId>,
) {
let mut visited = HashSet::new();
let mut to_visit = vec![leaf_id];
while let Some(node_id) = to_visit.pop() {
visited.insert(node_id);
match self
.statuses
.get(&node_id)
.expect("Node should have status")
{
NodeMemoryStatus::Useful => {
new_leaves.insert(node_id);
}
_ => {
to_delete.push(node_id);
for parent in self
.nodes
.get(&node_id)
.cloned()
.unwrap_or_default()
.into_iter()
{
if !visited.contains(&parent) {
to_visit.push(parent);
}
}
}
};
}
}
fn is_referenced(&self, node_id: NodeId) -> bool {
match self.nodes.get_key_value(&node_id) {
Some((key, _value)) => Arc::strong_count(key) > 1,
None => panic!("Node should be in the nodes map"),
}
}
pub(crate) fn maybe_useful(&self) -> bool {
self.nodes.keys().any(|node| Arc::strong_count(node) > 1)
}
}
/// Wrapper over hash set for fast popping of any node
#[derive(new, Default)]
struct PopNodeSet {
hash_set: HashSet<NodeId>,
}
impl PopNodeSet {
#[inline(always)]
fn pop(&mut self) -> Option<NodeId> {
self.hash_set
.iter()
.next()
.copied()
.and_then(|node_id| self.hash_set.take(&node_id))
}
#[inline(always)]
fn contains(&self, node_id: &NodeId) -> bool {
self.hash_set.contains(node_id)
}
#[inline(always)]
fn insert(&mut self, node_id: NodeId) {
self.hash_set.insert(node_id);
}
}

View File

@@ -0,0 +1,6 @@
mod client;
mod memory_management;
mod server;
pub mod graph;
pub use client::*;

View File

@@ -0,0 +1,143 @@
use super::memory_management::GraphMemoryManagement;
use crate::{
NodeId,
checkpoint::{
base::{Checkpointer, NodeTree},
builder::CheckpointerBuilder,
},
collections::HashMap,
grads::Gradients,
graph::{StepBoxed, traversal::BreadthFirstSearch},
tensor::NodeRefCount,
};
use alloc::vec::Vec;
#[derive(Default)]
pub struct AutodiffServer {
steps: HashMap<NodeId, StepBoxed>,
actions_builder: HashMap<NodeId, CheckpointerBuilder>,
memory_management: GraphMemoryManagement,
}
/// Defines how nodes are clean.
pub trait NodeCleaner {
/// Initialize a new cleaner.
fn init() -> Self;
/// Cleans a single [node](NodeId).
fn clean(&mut self, node: &NodeId);
}
impl AutodiffServer {
pub fn extend(&mut self, other: AutodiffServer) {
self.steps.extend(other.steps);
self.actions_builder.extend(other.actions_builder);
self.memory_management.extend(other.memory_management);
}
pub fn register(&mut self, rc: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {
let parents = step.parents();
let node_id = *rc.as_ref();
self.memory_management.register(rc, parents);
self.steps.insert(node_id, step);
self.actions_builder.insert(node_id, actions);
}
pub fn backward<NC: NodeCleaner>(&mut self, grads: Gradients, node_id: NodeId) -> Gradients {
let step = self.steps.remove(&node_id).expect(
"Node should have a step registered, did you forget to call \
`Tensor::register_grad` on the tensor where you need gradients?",
);
let builder = self.actions_builder.remove(&node_id).unwrap();
let mut consumed = Vec::new();
let (tape, checkpointer) = self.build_tape(node_id, step, builder, &mut consumed);
let gradients = Self::execute_steps(tape, grads, checkpointer);
// Cleanup
let mut cleaner = NC::init();
self.memory_management
.free_unavailable_nodes(|node_id: &NodeId| {
self.steps.remove(node_id);
self.actions_builder.remove(node_id);
NC::clean(&mut cleaner, node_id);
});
for node_id in consumed {
cleaner.clean(&node_id)
}
gradients
}
pub(crate) fn free_unused_roots(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
self.memory_management.free_unused_roots(|node_id| {
self.steps.remove(node_id);
self.actions_builder.remove(node_id);
on_free_graph(node_id);
});
}
fn build_tape(
&mut self,
node: NodeId,
node_step: StepBoxed,
mut builder: CheckpointerBuilder,
consumed: &mut Vec<NodeId>,
) -> (Vec<Vec<StepBoxed>>, Checkpointer) {
let mut tape = (0..node_step.depth())
.map(|_| Vec::with_capacity(1))
.collect::<Vec<_>>();
let mut tree = HashMap::default();
BreadthFirstSearch.traverse(node, node_step, &mut self.steps, |id, step| {
self.memory_management.consume_node(id);
// Clean up consumed node
consumed.push(id);
let depth = step.depth();
if depth == 0 {
return;
}
if let Some(steps) = tape.get_mut(depth - 1) {
let parents = step.parents().iter().map(|p| p.id).filter(|s| *s != id);
tree.insert(id, parents.collect());
steps.push(step);
}
if let Some(node_builder) = self.actions_builder.remove(&id) {
builder.extend(node_builder);
}
});
let checkpointer = builder.build(NodeTree::new(tree));
(tape, checkpointer)
}
fn execute_steps(
tape: Vec<Vec<StepBoxed>>,
mut grads: Gradients,
mut checkpointer: Checkpointer,
) -> Gradients {
tape.into_iter().rev().for_each(|steps| {
steps
.into_iter()
.for_each(|step| step.step(&mut grads, &mut checkpointer))
});
// For checkpointing tests
#[cfg(feature = "export_tests")]
assert!(checkpointer.is_empty());
grads
}
pub(crate) fn maybe_useful(&self) -> bool {
self.memory_management.maybe_useful()
}
}

View File

@@ -0,0 +1,189 @@
use crate::{
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder},
grads::Gradients,
graph::{ComputingProperty, Node, NodeId, NodeRef, Parent, Requirement, Step},
runtime::{AutodiffClient, AutodiffClientImpl},
};
use alloc::{boxed::Box, sync::Arc, vec};
use burn_backend::{Backend, TensorMetadata};
#[derive(Debug, Clone)]
pub struct AutodiffTensor<B: Backend> {
pub primitive: B::FloatTensorPrimitive,
pub node: NodeRef,
pub rc: NodeRefCount,
}
impl<B: Backend> TensorMetadata for AutodiffTensor<B> {
fn dtype(&self) -> burn_std::DType {
self.primitive.dtype()
}
fn shape(&self) -> burn_std::Shape {
self.primitive.shape()
}
fn rank(&self) -> usize {
self.primitive.rank()
}
}
pub type NodeRefCount = Arc<NodeId>;
#[derive(new, Debug)]
pub(crate) struct RootStep {
node: NodeRef,
}
impl Step for RootStep {
fn step(self: Box<Self>, _grads: &mut Gradients, _checkpointer: &mut Checkpointer) {
// Nothing to do
}
fn node(&self) -> NodeId {
self.node.id
}
fn parents(&self) -> &[Parent] {
&self.node.parents
}
fn depth(&self) -> usize {
self.node.order
}
}
impl<B: Backend> AutodiffTensor<B> {
/// Create a new leaf tensor.
pub fn new(primitive: B::FloatTensorPrimitive) -> Self {
let id = NodeId::new();
let node: NodeRef = Node::new(
vec![],
0,
id,
Requirement::None,
ComputingProperty::Ambiguous,
AutodiffClientImpl::new(),
)
.into();
Self {
rc: Arc::new(node.id),
primitive,
node: node.clone(),
}
}
pub fn is_tracked(&self) -> bool {
!self.node.requirement.is_none()
}
/// Mark the tensor as requiring gradients.
///
/// # Panics
///
/// It panics if the tensor is not a leaf.
pub fn require_grad(mut self) -> Self {
match self.node.requirement {
Requirement::Grad => self,
Requirement::GradInBackward => {
panic!("Can't convert a non leaf tensor into a tracked tensor")
}
Requirement::None => {
self.node = Node::new(
vec![],
0,
self.node.id,
Requirement::Grad,
self.node.properties.clone(),
self.node.client.clone(),
)
.into();
let step = RootStep::new(self.node.clone());
self.register_step(step, CheckpointerBuilder::default())
}
}
}
/// Create a tensor from parent infos.
pub fn from_parents(
primitive: B::FloatTensorPrimitive,
parent_nodes: &[NodeRef],
requirement: Requirement,
computing_properties: ComputingProperty,
) -> Self {
let order = parent_nodes
.iter()
.map(|node| node.order)
.reduce(usize::max)
.unwrap_or(0)
+ 1;
let client = parent_nodes
.first()
.map(|node| node.client.clone())
.unwrap_or_else(AutodiffClientImpl::new);
let node: NodeRef = Node::new(
parent_nodes
.iter()
.filter_map(|node| node.clone_if_require_grad())
.map(|node| Parent::new(node.id))
.collect(),
order,
NodeId::new(),
requirement,
computing_properties,
client,
)
.into();
Self {
rc: Arc::new(node.id),
primitive,
node,
}
}
/// Register a step into a graph for that tensor.
///
/// # Warning
///
/// This should be called only once per tensor.
pub fn register_step<S: Step + 'static>(
self,
step_that_created_the_tensor: S,
actions: CheckpointerBuilder,
) -> Self {
self.node.client.register(
self.rc.clone(),
Box::new(step_that_created_the_tensor),
actions,
);
self
}
pub fn into_primitive(self) -> B::FloatTensorPrimitive {
self.primitive
}
pub fn backward(self) -> Gradients {
let client = self.node.client.clone();
AutodiffClient::backward::<B>(&client, self)
}
pub fn grad(&self, grads: &Gradients) -> Option<B::FloatTensorPrimitive> {
grads.get::<B>(self)
}
pub fn grad_remove(&self, grads: &mut Gradients) -> Option<B::FloatTensorPrimitive> {
grads.remove::<B>(self)
}
pub fn grad_replace(&self, grads: &mut Gradients, grad: B::FloatTensorPrimitive) {
grads.remove::<B>(self);
grads.register::<B>(self.node.id, grad);
}
}

View File

@@ -0,0 +1,25 @@
use alloc::vec::Vec;
use crate::graph::NodeRef;
/// Duplicate the given object for each node that requires gradients.
///
/// # Notes
///
/// This is useful since you don't have to keep N cloned references alive event if just 1 node
/// will be updated.
///
/// If the object is a tensor and if one reference exists, it can be updated inplace.
pub fn duplicate<T: Clone + core::fmt::Debug, const N: usize>(
nodes: &[Option<NodeRef>; N],
obj: Option<T>,
) -> [Option<T>; N] {
nodes
.iter()
.map(|node| match node {
Some(_) => obj.clone(),
None => None,
})
.collect::<Vec<_>>()
.try_into()
.unwrap()
}

View File

@@ -0,0 +1,10 @@
[alias]
test-cpu = "test --release --no-default-features --features cpu,std"
test-cuda = "test --release --no-default-features --features cuda,std"
test-ndarray = "test --release --no-default-features --features ndarray,std"
test-rocm = "test --release --no-default-features --features rocm,std"
test-router = "test --release --no-default-features --features router,std"
test-tch = "test --release --no-default-features --features tch,std"
test-wgpu = "test --release --no-default-features --features wgpu,std"
test-vulkan = "test --release --no-default-features --features vulkan,std"
test-metal = "test --release --no-default-features --features metal,std"

View File

@@ -0,0 +1,120 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science", "no-std", "embedded", "wasm"]
description = "Tensor tests for Burn backends"
documentation = "https://docs.rs/burn-backend-tests"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
license.workspace = true
name = "burn-backend-tests"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-backend-tests"
version.workspace = true
[lints]
workspace = true
[features]
default = [
"burn-tensor/default",
"burn-autodiff/default",
# Backends (default not enabled for CubeCL backends as it activates fusion)
"burn-cpu?/default",
"burn-ndarray?/default",
"burn-tch?/default",
# Default
"ndarray",
"std",
]
std = [
"burn-tensor/std",
"burn-autodiff/std",
# Backends
"burn-cpu?/std",
"burn-ndarray?/std",
"burn-wgpu?/std",
"burn-router?/std",
"burn-cuda?/std",
"burn-rocm?/std",
]
tracing = [
"cubecl?/tracing",
"burn-tensor/tracing",
"burn-autodiff/tracing",
# Backends
"burn-cpu?/tracing",
"burn-ndarray?/tracing",
"burn-wgpu?/tracing",
"burn-router?/tracing",
"burn-cuda?/tracing",
"burn-rocm?/tracing",
]
# Backends
cuda = ["burn-cuda", "quantization", "cube"]
rocm = ["burn-rocm", "quantization", "cube"]
ndarray = ["burn-ndarray", "quantization"]
tch = ["burn-tch"]
vulkan = ["wgpu", "burn-wgpu/vulkan"]
webgpu = ["wgpu", "burn-wgpu/webgpu"]
metal = ["wgpu", "burn-wgpu/metal"]
wgpu = ["burn-wgpu", "quantization", "cube"]
cpu = ["burn-cpu", "cube"]
router = ["burn-router", "ndarray", "burn-wgpu"]
autotune = [
"burn-wgpu?/autotune",
"burn-cuda?/autotune",
"burn-rocm?/autotune",
"burn-cpu?/autotune",
]
autotune-checks = [
"burn-wgpu?/autotune-checks",
"burn-cuda?/autotune-checks",
"burn-rocm?/autotune-checks",
"burn-cpu?/autotune-checks",
]
# CubeCL backends
cube = [
"cubecl",
"cubek",
"autotune",
"burn-fusion",
"burn-cubecl",
"burn-ndarray",
]
# Test configs
quantization = []
[dependencies]
burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", default-features = false }
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "=0.21.0-pre.2" }
# Backends
burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2", default-features = false, features = [
"export_tests",
] }
burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-cpu = { path = "../burn-cpu", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-rocm = { path = "../burn-rocm", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2", optional = true, default-features = false, features = [
"export_tests",
] }
burn-router = { path = "../burn-router", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true, default-features = false }
# To wrap `Fusion<CubeBackend>
burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true }
burn-cubecl = { path = "../burn-cubecl", version = "=0.21.0-pre.2", optional = true, features = [
"fusion",
] }
num-traits = { workspace = true }
serial_test = { workspace = true }
cubecl = { workspace = true, optional = true }
cubek = { workspace = true, features = ["random"], optional = true }

View File

@@ -0,0 +1,111 @@
# Burn Backend Tests
This crate provides a comprehensive suite of tests for Burn backends, covering:
- Tensor operations: [tests/tensor/](./tests/tensor/)
- Autodiff: [tests/autodiff/](./tests/autodiff/)
- (Optional) CubeCL kernels correctness: [tests/cubecl/](./tests/cubecl/)
## Running Tests
The `TestBackend` is selected via feature flags. Use the provided shorthand commands for
convenience:
```sh
# Cpu
cargo test-cpu
# Cuda
cargo test-cuda
# Rocm
cargo test-rocm
# Wgpu / WebGpu
cargo test-wgpu
# Vulkan
cargo test-vulkan
# Metal
cargo test-metal
# Router
cargo test-router
# NdArray
cargo test-ndarray
# LibTorch
cargo test-tch
```
By default, `cargo test` fail-fast across integration test binaries. When one integration test
binary fails, Cargo does not run the remaining test binaries. If you want to run all test binaries
regardless of failures, pass `--no-fail-fast`, for example:
```sh
cargo test-cuda --no-fail-fast
```
## Structure
- `tests/tensor.rs`: Tensor tests
- `tests/autodiff.rs`: Autodiff tests
- `tests/fusion.rs`: Fusion backend tests wrapping tensor and autodiff tests
- `tests/cubecl.rs`: CubeCL kernel tests
Each test module assumes exactly one `FloatElemType`, `IntElemType`, and `TestBackend` in scope.
### Common Modules
- `common/backend.rs`: Backend type definitions
- `common/tensor.rs`: Reusable tensor test suite, split across float, int and bool tensor kinds
- `common/autodiff.rs`: Reusable autodiff test suite, with and without checkpointing
### Test Reusability
This crate uses a pattern of parameterized test modules to run the same tests with different
configurations (backends, dtypes, etc.):
1. **Type aliases define the configuration**: Each test scope declares `FloatElemType`,
`IntElemType`, and `TestBackend`
1. **`#[path = "..."]` references shared modules**: Points to test files outside the normal module
hierarchy, e.g. `"common/tensor.rs"`
1. **`include!()` imports test code**: Test modules are included multiple times with different type
configurations
1. **`use super::*;`** propagates types down the module tree: Each level re-exports parent types so
deeply nested tests have access to the configured types
For example, `common/tensor.rs` can be included with `FloatElemType = f32` for base tests, then
included again with `FloatElemType = f16` for half-precision tests, running the same test suite
twice with different dtypes.
## Adding New Tests
Add test modules under `tests/tensor/`, `tests/autodiff/`, or `tests/cubecl` respectively. They will
automatically run for all required configurations.
For tensor tests, make sure to add the test to each relevant tensor kind:
- `tensor/bool`: boolean tensor tests
- `tensor/float`: float tensor tests
- `tensor/int`: integer tensor tests
**Guidelines:**
Import types with `use super::*;` at the top of each module and use the types defined in
`common/backend.rs`:
```rust
/// Collection of types used across tests
pub use burn_autodiff::Autodiff;
pub use burn_tensor::Tensor;
pub type TestBackend = ...;
pub type TestTensor<const D: usize> = Tensor<TestBackend, D>;
pub type TestTensorInt<const D: usize> = Tensor<TestBackend, D, burn_tensor::Int>;
pub type TestTensorBool<const D: usize> = Tensor<TestBackend, D, burn_tensor::Bool>;
pub type FloatElem = burn_tensor::ops::FloatElem<TestBackend>;
pub type IntElem = burn_tensor::ops::IntElem<TestBackend>;
pub type TestAutodiffBackend = Autodiff<TestBackend>;
pub type TestAutodiffTensor<const D: usize> = Tensor<TestAutodiffBackend, D>;
```
Tests will automatically run with default dtypes and any variants (f16, bf16, etc.) based on the
backend configuration.

View File

@@ -0,0 +1,22 @@
extern crate alloc;
#[cfg(feature = "std")]
pub use burn_tensor_testgen::might_panic;
/// Generate a test module with custom floating element types.
#[macro_export]
macro_rules! test_float_elem_variant {
($modname:ident, $float:ty, $module:literal, [$($feat:literal),* $(,)?]) => {
#[cfg(all(test, any($(feature = $feat),*)))]
mod $modname {
pub type FloatElemType = $float;
#[allow(unused)]
pub use super::IntElemType;
mod ty {
include!("backend.rs");
include!($module);
}
}
};
}

View File

@@ -0,0 +1,20 @@
//! Burn autodiff tests.
#![allow(
clippy::single_range_in_vec_init,
clippy::duplicate_mod,
reason = "false positive"
)]
extern crate alloc;
pub type FloatElemType = f32;
#[allow(unused)]
pub type IntElemType = i32;
#[path = "common/backend.rs"]
mod backend;
pub use backend::*;
#[allow(clippy::module_inception)]
#[path = "common/autodiff.rs"]
mod autodiff;

View File

@@ -0,0 +1,58 @@
use super::*;
use burn_tensor::{TensorData, Tolerance, cast::ToElement};
#[test]
fn should_diff_abs() {
let data_1 = TensorData::from([[0.0, -1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, -10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[71.0, 107.0], [71.0, 107.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[84.0, 42.0], [90.0, 54.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_abs_no_nans() {
let data_1 = TensorData::from([[6.0, 7.0], [9.0, -10.0]]);
let data_2 = TensorData::from([[0.0, -1.0], [3.0, 4.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[1.0, 7.0], [1.0, 7.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[0.0, -15.0], [-3.0, -3.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let contains_nan = grad_2.contains_nan();
assert!(!contains_nan.into_scalar().to_bool());
}

View File

@@ -0,0 +1,50 @@
use super::*;
use burn_tensor::module::adaptive_avg_pool1d;
use burn_tensor::{Shape, Tolerance};
#[test]
fn test_avg_pool1d_simple() {
let test = AdaptiveAvgPool1dTestCase {
batch_size: 1,
channels: 2,
length: 5,
output_size: 3,
};
test.assert_output(TestTensor::from_floats(
[[
[0.5000, 0.83333, 0.33333, 0.83333, 0.5000],
[0.5000, 0.83333, 0.33333, 0.83333, 0.5000],
]],
&Default::default(),
));
}
struct AdaptiveAvgPool1dTestCase {
batch_size: usize,
channels: usize,
length: usize,
output_size: usize,
}
impl AdaptiveAvgPool1dTestCase {
fn assert_output(self, x_grad: TestTensor<3>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<3, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = adaptive_avg_pool1d(x.clone(), self.output_size);
let grads = output.backward();
let x_grad_actual = x.grad(&grads).unwrap();
x_grad.to_data().assert_approx_eq::<FloatElem>(
&x_grad_actual.into_data(),
Tolerance::default().set_half_precision_relative(1e-3),
);
}
}

View File

@@ -0,0 +1,96 @@
use super::*;
use burn_tensor::module::adaptive_avg_pool2d;
use burn_tensor::{Shape, Tolerance};
#[test]
fn test_avg_pool2d_simple() {
let test = AdaptiveAvgPool2dTestCase {
batch_size: 1,
channels: 2,
height: 5,
width: 3,
output_size_1: 3,
output_size_2: 2,
};
test.assert_output(TestTensor::from_floats(
[[
[
[0.2500, 0.5000, 0.2500],
[0.41667, 0.83333, 0.41667],
[0.16667, 0.33333, 0.16667],
[0.41667, 0.83333, 0.41667],
[0.2500, 0.5000, 0.2500],
],
[
[0.2500, 0.5000, 0.2500],
[0.41667, 0.83333, 0.41667],
[0.16667, 0.33333, 0.16667],
[0.41667, 0.83333, 0.41667],
[0.2500, 0.5000, 0.2500],
],
]],
&Default::default(),
));
}
#[test]
fn test_avg_pool2d_output_1() {
let test = AdaptiveAvgPool2dTestCase {
batch_size: 1,
channels: 1,
height: 4,
width: 8,
output_size_1: 1,
output_size_2: 1,
};
test.assert_output(TestTensor::from_floats(
[[[
[
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
],
[
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
],
[
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
],
[
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
],
]]],
&Default::default(),
));
}
struct AdaptiveAvgPool2dTestCase {
batch_size: usize,
channels: usize,
height: usize,
width: usize,
output_size_1: usize,
output_size_2: usize,
}
impl AdaptiveAvgPool2dTestCase {
fn assert_output(self, x_grad: TestTensor<4>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<4, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = adaptive_avg_pool2d(x.clone(), [self.output_size_1, self.output_size_2]);
let grads = output.backward();
let x_grad_actual = x.grad(&grads).unwrap();
x_grad.to_data().assert_approx_eq::<FloatElem>(
&x_grad_actual.into_data(),
Tolerance::default().set_half_precision_relative(1e-3),
);
}
}

View File

@@ -0,0 +1,74 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_add() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_floats([2.0, 5.0], &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_floats([4.0, 1.0], &device).require_grad();
let tensor_3 = tensor_1.clone() + tensor_2.clone();
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([1.0, 1.0]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([1.0, 1.0]), false);
tensor_3
.to_data()
.assert_eq(&TensorData::from([6.0, 6.0]), false);
}
#[test]
fn should_diff_add_scalar() {
let data = TensorData::from([2.0, 10.0]);
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
let tensor_out = tensor.clone().add_scalar(5.0);
let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();
grad.to_data()
.assert_eq(&TensorData::from([1.0, 1.0]), false);
tensor_out
.into_data()
.assert_eq(&TensorData::from([7.0, 15.0]), false);
}
#[test]
fn test_add_complex_1() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().add(tensor_2.clone());
let tensor_5 = tensor_4
.add(tensor_3)
.add_scalar(5.0)
.add(tensor_1.clone())
.add(tensor_2.clone());
let tensor_6 = tensor_1.clone().add(tensor_5);
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[3.0, 3.0], [3.0, 3.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[2.0, 2.0], [2.0, 2.0]]), false);
}

View File

@@ -0,0 +1,138 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_mean() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.mean().unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[3.5, 9.5], [3.5, 9.5]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[-0.75, -0.75], [3.0, 3.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_sum_1() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.sum().unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[14.0, 38.0], [14.0, 38.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[-3.0, -3.0], [12.0, 12.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_sum_2() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.clone().sum_dim(1);
let tensor_5 = tensor_4.mul(tensor_3);
let grads = tensor_5.sum().backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[494.0, 722.0], [2990.0, 4370.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[690.0, 690.0], [958.0, 958.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_mean_dim() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.mean_dim(1).unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[4.0, 36.0], [3.0, -17.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[9.0, 9.0], [35.5, 35.5]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_sum_dim() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.sum_dim(1).unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[8.0, 72.0], [6.0, -34.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[18.0, 18.0], [71.0, 71.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,102 @@
use super::*;
use burn_tensor::module::avg_pool1d;
use burn_tensor::{Shape, Tolerance};
#[test]
fn test_avg_pool1d_simple() {
let test = AvgPool1dTestCase {
batch_size: 1,
channels: 1,
kernel_size: 3,
padding: 0,
stride: 1,
length: 6,
count_include_pad: true,
};
test.assert_output(TestTensor::from_floats(
[[[0.33333, 0.66667, 1.0000, 1.0000, 0.66667, 0.33333]]],
&Default::default(),
));
}
#[test]
fn test_avg_pool1d_complex() {
let test = AvgPool1dTestCase {
batch_size: 1,
channels: 2,
kernel_size: 3,
padding: 1,
stride: 2,
length: 6,
count_include_pad: true,
};
test.assert_output(TestTensor::from_floats(
[[
[0.33333, 0.66667, 0.33333, 0.66667, 0.33333, 0.33333],
[0.33333, 0.66667, 0.33333, 0.66667, 0.33333, 0.33333],
]],
&Default::default(),
));
}
#[test]
fn test_avg_pool1d_complex_dont_count_pad() {
let test = AvgPool1dTestCase {
batch_size: 1,
channels: 2,
kernel_size: 3,
padding: 1,
stride: 2,
length: 6,
count_include_pad: false,
};
test.assert_output(TestTensor::from_floats(
[[
[0.5000, 0.83333, 0.33333, 0.66667, 0.33333, 0.33333],
[0.5000, 0.83333, 0.33333, 0.66667, 0.33333, 0.33333],
]],
&Default::default(),
));
}
struct AvgPool1dTestCase {
batch_size: usize,
channels: usize,
kernel_size: usize,
padding: usize,
stride: usize,
length: usize,
count_include_pad: bool,
}
impl AvgPool1dTestCase {
fn assert_output(self, x_grad: TestTensor<3>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<3, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = avg_pool1d(
x.clone(),
self.kernel_size,
self.stride,
self.padding,
self.count_include_pad,
false,
);
let grads = output.backward();
let x_grad_actual = x.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
x_grad
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.into_data(), tolerance);
}
}

View File

@@ -0,0 +1,129 @@
use super::*;
use burn_tensor::module::avg_pool2d;
use burn_tensor::{Shape, Tolerance};
#[test]
fn test_avg_pool2d_simple() {
let test = AvgPool2dTestCase {
batch_size: 1,
channels: 1,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
height: 6,
width: 6,
count_include_pad: true,
};
test.assert_output(TestTensor::from_floats(
[[[
[0.11111, 0.22222, 0.33333, 0.33333, 0.22222, 0.11111],
[0.22222, 0.44444, 0.66667, 0.66667, 0.44444, 0.22222],
[0.33333, 0.66667, 1.00000, 1.00000, 0.66667, 0.33333],
[0.33333, 0.66667, 1.00000, 1.00000, 0.66667, 0.33333],
[0.22222, 0.44444, 0.66667, 0.66667, 0.44444, 0.22222],
[0.11111, 0.22222, 0.33333, 0.33333, 0.22222, 0.11111],
]]],
&Default::default(),
));
}
#[test]
fn test_avg_pool2d_complex() {
let test = AvgPool2dTestCase {
batch_size: 1,
channels: 1,
kernel_size_1: 3,
kernel_size_2: 4,
padding_1: 1,
padding_2: 2,
stride_1: 1,
stride_2: 2,
height: 4,
width: 6,
count_include_pad: true,
};
test.assert_output(TestTensor::from_floats(
[[[
[0.33333, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333],
[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
[0.33333, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333],
]]],
&Default::default(),
));
}
#[test]
fn test_avg_pool2d_complex_dont_include_pad() {
let test = AvgPool2dTestCase {
batch_size: 1,
channels: 1,
kernel_size_1: 3,
kernel_size_2: 4,
padding_1: 1,
padding_2: 2,
stride_1: 1,
stride_2: 2,
height: 4,
width: 6,
count_include_pad: false,
};
test.assert_output(TestTensor::from_floats(
[[[
[0.6250, 0.6250, 0.41667, 0.41667, 0.6250, 0.6250],
[0.8750, 0.8750, 0.58333, 0.58333, 0.8750, 0.8750],
[0.8750, 0.8750, 0.58333, 0.58333, 0.8750, 0.8750],
[0.6250, 0.6250, 0.41667, 0.41667, 0.6250, 0.6250],
]]],
&Default::default(),
));
}
struct AvgPool2dTestCase {
batch_size: usize,
channels: usize,
kernel_size_1: usize,
kernel_size_2: usize,
padding_1: usize,
padding_2: usize,
stride_1: usize,
stride_2: usize,
height: usize,
width: usize,
count_include_pad: bool,
}
impl AvgPool2dTestCase {
fn assert_output(self, x_grad: TestTensor<4>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<4, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = avg_pool2d(
x.clone(),
[self.kernel_size_1, self.kernel_size_2],
[self.stride_1, self.stride_2],
[self.padding_1, self.padding_2],
self.count_include_pad,
false,
);
let grads = output.backward();
let x_grad_actual = x.grad(&grads).unwrap();
x_grad.to_data().assert_approx_eq::<FloatElem>(
&x_grad_actual.into_data(),
Tolerance::default().set_half_precision_relative(1e-3),
);
}
}

View File

@@ -0,0 +1,24 @@
use super::*;
use burn_tensor::{Int, Tensor, TensorData, module::embedding};
#[test]
fn test_embedding_backward() {
let weights = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let indices = TensorData::from([[0, 1], [1, 1]]);
let x = TensorData::from([
[[1.0, 2.0], [4.0, 5.0], [3.0, 4.0]],
[[4.0, 5.0], [8.0, 5.0], [1.0, 9.0]],
]);
let device = Default::default();
let weights = Tensor::<TestAutodiffBackend, 2>::from_data(weights, &device).require_grad();
let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(indices, &device);
let x = Tensor::<TestAutodiffBackend, 3>::from_data(x, &device).require_grad();
let output = embedding(weights.clone(), indices);
let output = output.matmul(x);
let grads = output.backward();
let grad = weights.grad(&grads).unwrap();
grad.to_data()
.assert_eq(&TensorData::from([[3., 9., 7.], [21., 35., 27.]]), false);
}

View File

@@ -0,0 +1,27 @@
use super::*;
use burn_tensor::{DType, Distribution, Tensor};
#[test]
fn test_full_precision() {
let device = Default::default();
let x1 = Tensor::<TestAutodiffBackend, 2>::random([32, 32], Distribution::Default, &device)
.require_grad();
let x2 = Tensor::<TestAutodiffBackend, 2>::random([32, 32], Distribution::Default, &device)
.require_grad();
let dtype = x1.dtype();
let x3 = x1.clone().cast(DType::F32);
let x4 = x2.clone().cast(DType::F32);
let x5 = x3.matmul(x4);
let x6 = x5.cast(dtype);
let x7 = x6 * x1.clone() / x2.clone();
let grads = x7.backward();
let x1_grad = x1.grad(&grads);
let x2_grad = x2.grad(&grads);
assert!(x1_grad.is_some());
assert!(x2_grad.is_some());
}

View File

@@ -0,0 +1,56 @@
use super::*;
#[test]
fn mul_broadcast() {
test_ops_broadcast_backward(|x, y| x * y);
}
#[test]
fn div_broadcast() {
test_ops_broadcast_backward(|x, y| x / y);
}
#[test]
fn sub_broadcast() {
test_ops_broadcast_backward(|x, y| x - y);
}
#[test]
fn add_broadcast() {
test_ops_broadcast_backward(|x, y| x + y);
}
#[test]
fn matmul_broadcast() {
test_ops_broadcast_backward(|x, y| x.matmul(y));
}
#[test]
fn mask_where_broadcast() {
test_ops_broadcast_backward(|x, y| {
let cond = y.clone().equal_elem(4);
x.mask_where(cond, y)
});
}
fn test_ops_broadcast_backward<F>(func: F)
where
F: Fn(TestAutodiffTensor<3>, TestAutodiffTensor<3>) -> TestAutodiffTensor<3>,
{
let device = Default::default();
let w = TestAutodiffTensor::zeros([16, 5, 5], &device).require_grad();
let x = TestAutodiffTensor::zeros([4, 5, 5], &device).require_grad();
// Slice isn't a broadcastable operation, so it will fail when the previous backward pass
// of an operation that support broadcast doesn't support it during the backward pass.
let y = func(w.clone().slice([0..1]), x.clone());
// Will panic if broadcast isn't supported!
let grads = y.backward();
let w_grad = w.grad(&grads).unwrap();
let x_grad = x.grad(&grads).unwrap();
assert_eq!(w_grad.shape(), w.shape());
assert_eq!(x_grad.shape(), x.shape());
}

View File

@@ -0,0 +1,28 @@
// Skip on metal - F64 not supported
#![cfg(all(feature = "std", not(feature = "metal")))]
use super::*;
use burn_backend_tests::might_panic;
use burn_tensor::{DType, Tensor, TensorData};
#[might_panic(reason = "Unsupported precision for fusion")]
#[test]
fn cast_keeps_gradient_flow() {
let device = Default::default();
let x = Tensor::<TestAutodiffBackend, 2>::from_data(
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
&device,
)
.require_grad();
let y = x.clone().cast(DType::F64);
let z = y.sum();
let grads = z.backward();
let grad_x = x.grad(&grads).unwrap();
grad_x
.to_data()
.assert_eq(&TensorData::from([[1., 1.], [1., 1.]]), false);
}

View File

@@ -0,0 +1,110 @@
use super::*;
use burn_tensor::Tolerance;
#[test]
fn should_diff_cat() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_data([[2.0, -1.0], [5.0, 2.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::<2>::from_data([[5.0, 4.0], [-1.0, 4.0]], &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let mut tensor_1_list = Vec::new();
let mut tensor_2_list = Vec::new();
for i in 0..2 {
tensor_1_list.push(tensor_1.clone().slice([i..i + 1]));
tensor_2_list.push(tensor_2.clone().slice([i..i + 1]));
}
let tensor_1_cat = TestAutodiffTensor::cat(tensor_1_list.clone(), 0);
let tensor_2_cat = TestAutodiffTensor::cat(tensor_2_list.clone(), 0);
let tensor_3_cat = tensor_1_cat.clone().matmul(tensor_2_cat.clone());
let grads = tensor_3_cat.backward();
let grad_1_slice_1 = tensor_1.grad(&grads).unwrap().slice([0..1]);
let grad_1_slice_2 = tensor_1.grad(&grads).unwrap().slice([1..2]);
let grad_2_slice_1 = tensor_2.grad(&grads).unwrap().slice([0..1]);
let grad_2_slice_2 = tensor_2.grad(&grads).unwrap().slice([1..2]);
grad_1
.clone()
.slice([0..1])
.to_data()
.assert_approx_eq::<FloatElem>(&grad_1_slice_1.to_data(), Tolerance::default());
grad_1
.slice([1..2])
.to_data()
.assert_approx_eq::<FloatElem>(&grad_1_slice_2.to_data(), Tolerance::default());
grad_2
.clone()
.slice([0..1])
.to_data()
.assert_approx_eq::<FloatElem>(&grad_2_slice_1.to_data(), Tolerance::default());
grad_2
.slice([1..2])
.to_data()
.assert_approx_eq::<FloatElem>(&grad_2_slice_2.to_data(), Tolerance::default());
}
#[test]
fn should_diff_cat_more_than_1_dim() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_data([[2.0, -1.0], [5.0, 2.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::<2>::from_data([[5.0, 4.0], [-1.0, 4.0], [4.0, 1.0]], &device)
.require_grad();
// Concat a tensor [2, 2] with another tensor [3, 2] along dim 0.
// The resulting tensor should be [5, 2]
let tensor_3 = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 0);
assert_eq!(tensor_3.dims(), [5, 2]);
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
assert_eq!(tensor_1.dims(), grad_1.dims());
assert_eq!(tensor_2.dims(), grad_2.dims());
}
#[test]
fn should_slice_grads_correctly_when_some_inputs_not_tracked() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data([[1.0]], &device).require_grad(); // tracked
let tensor_2 = TestAutodiffTensor::<2>::from_data([[10.0, 20.0]], &device); // not tracked
let tensor_3 =
TestAutodiffTensor::<2>::from_data([[100.0, 200.0, 300.0]], &device).require_grad(); // tracked
let cat = TestAutodiffTensor::cat(
vec![tensor_1.clone(), tensor_2.clone(), tensor_3.clone()],
1,
);
// Make gradient per column unique so wrong slicing shows up.
let weights = TestAutodiffTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]], &device);
let loss = (cat * weights).sum();
let grads = loss.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_3 = tensor_3.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&burn_tensor::TensorData::from([[1.0]]), false);
grad_3
.to_data()
.assert_eq(&burn_tensor::TensorData::from([[4.0, 5.0, 6.0]]), false);
}

View File

@@ -0,0 +1,21 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_ceil() {
let data = TensorData::from([
[-1.9751, 0.0714, 0.0643, 0.2406],
[-1.3172, 0.1252, -0.1119, -0.0127],
]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
let tensor_2 = tensor_1.clone().ceil();
let grads = tensor_2.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
grad_1.to_data().assert_eq(
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
false,
);
}

View File

@@ -0,0 +1,215 @@
use super::*;
use burn_tensor::{Bool, Tensor, TensorData};
#[test]
fn test_autodiff_checkpoint_complicated_computation() {
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
let device = Default::default();
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
let tensor_5 = compute_bound_eager(tensor_0, tensor_1);
let tensor_6 = compute_bound_lazy(tensor_2, tensor_3.clone());
let tensor_7 = memory_bound_eager(tensor_3, tensor_4);
let tensor_8 = compute_bound_lazy(tensor_6, tensor_7.clone());
let tensor_9 = memory_bound_eager_scalar(tensor_7, 11.);
let tensor_10 = memory_bound_lazy(tensor_5, tensor_8.clone());
let tensor_11 = memory_bound_lazy(tensor_8, tensor_9);
let tensor_12 = compute_bound_lazy(tensor_10, tensor_11);
assert_checkpoint(tensor_12);
}
#[test]
fn test_autodiff_checkpoint_with_missing_requirement() {
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
let device = Default::default();
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device); // does not require_grad
let tensor_2 = memory_bound_eager(tensor_0, tensor_1);
let tensor_3 = memory_bound_eager_scalar(tensor_2.clone(), 11.);
let tensor_4 = memory_bound_eager_scalar(tensor_2.clone(), 11.);
let tensor_5 = compute_bound_lazy(tensor_3, tensor_4);
let tensor_6 = compute_bound_eager_scalar(tensor_5.clone(), 11.);
let tensor_7 = memory_bound_eager(tensor_5, tensor_2);
let tensor_8 = memory_bound_eager(tensor_6, tensor_7);
assert_checkpoint(tensor_8);
}
#[test]
fn test_autodiff_checkpoint_with_many_duplicates() {
let data_0 = TensorData::from([[4.0, 7.0], [7.0, 7.0]]);
let device = Default::default();
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
let tensor_1 = memory_bound_eager(tensor_0.clone(), tensor_0.clone());
let tensor_2 = compute_bound_eager(tensor_0.clone(), tensor_0.clone());
let tensor_3 = memory_bound_lazy(tensor_0.clone(), tensor_0.clone());
let tensor_4 = compute_bound_lazy(tensor_0.clone(), tensor_0.clone());
let tensor_5 = memory_bound_eager(tensor_1.clone(), tensor_0.clone());
let tensor_6 = memory_bound_eager(tensor_0.clone(), tensor_5.clone());
let tensor_7 = compute_bound_lazy(tensor_3.clone(), tensor_5.clone());
let tensor_8 = compute_bound_eager(tensor_4.clone(), tensor_2.clone());
let tensor_9 = memory_bound_lazy(tensor_6, tensor_7);
let tensor_10 = memory_bound_eager(tensor_0, tensor_9);
let tensor_11 = memory_bound_eager_scalar(tensor_10, 9.);
let tensor_12 = compute_bound_lazy(tensor_8, tensor_11);
assert_checkpoint(tensor_12);
}
#[test]
fn test_autodiff_checkpoint_with_long_chain_of_eager_memory_bound() {
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
let device = Default::default();
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device);
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
let tensor_5 = memory_bound_eager(tensor_0, tensor_1.clone());
let tensor_6 = memory_bound_eager(tensor_5, tensor_2);
let tensor_7 = memory_bound_eager(tensor_6, tensor_3);
let tensor_8 = memory_bound_eager(tensor_7, tensor_4);
let tensor_9 = memory_bound_lazy(tensor_8, tensor_1);
assert_checkpoint(tensor_9)
}
#[test]
fn test_autodiff_checkpoint_half_sub_graph_not_tracked() {
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
let data_5 = TensorData::from([[0.5, 7.0], [7.0, 7.0]]);
let device = Default::default();
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device);
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device);
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device);
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
let tensor_5 = TestAutodiffTensor::from_data(data_5, &device).require_grad();
let tensor_6 = memory_bound_lazy(tensor_0, tensor_1);
let tensor_7 = compute_bound_eager(tensor_6, tensor_2);
let tensor_8 = memory_bound_eager(tensor_3, tensor_4);
let tensor_9 = compute_bound_lazy(tensor_8, tensor_5);
let tensor_10 = compute_bound_lazy(tensor_7, tensor_9);
assert_checkpoint(tensor_10);
}
#[test]
fn test_autodiff_checkpoint_very_complex() {
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
let device = Default::default();
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device);
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
let tensor_5 = memory_bound_eager_scalar(tensor_0, 8.);
let tensor_6 = memory_bound_lazy(tensor_5.clone(), tensor_1.clone());
let tensor_7 = compute_bound_lazy(tensor_6.clone(), tensor_6);
let tensor_8 = memory_bound_lazy(tensor_1.clone(), tensor_5.clone());
let tensor_9 = memory_bound_eager_scalar(tensor_7.clone(), 7.);
let tensor_10 = compute_bound_eager(tensor_5, tensor_8);
let tensor_11 = memory_bound_eager(tensor_2.clone(), tensor_9);
let tensor_12 = memory_bound_lazy(tensor_2.clone(), tensor_2);
let tensor_13 = compute_bound_eager(tensor_10.clone(), tensor_11);
let tensor_14 = compute_bound_eager_scalar(tensor_3, 8.);
let tensor_15 = compute_bound_lazy(tensor_4, tensor_12);
let tensor_16 = memory_bound_lazy(tensor_10, tensor_7);
let tensor_17 = compute_bound_lazy(tensor_13, tensor_1);
let tensor_18 = memory_bound_eager(tensor_15, tensor_16);
let tensor_19 = compute_bound_eager(tensor_14, tensor_17);
let tensor_20 = memory_bound_lazy(tensor_18, tensor_19);
let tensor_21 = memory_bound_eager_scalar(tensor_20, 8.);
assert_checkpoint(tensor_21)
}
fn assert_checkpoint<const D: usize>(tensor: TestAutodiffTensor<D>) {
// Assert is not explicit here, but the test can fail
// - when a tensor is actually required more than n_required, it won't be found and will panic
// - when a tensor is actually required less than n_required, the backward states map won't be
// empty and will fail the assertion within the backward code, same for retro_forwards
tensor.backward();
}
// Does not save its state and does not need its parents
fn memory_bound_eager<const D: usize>(
tensor_a: TestAutodiffTensor<D>,
tensor_b: TestAutodiffTensor<D>,
) -> TestAutodiffTensor<D> {
tensor_a.add(tensor_b)
}
fn memory_bound_eager_scalar<const D: usize>(
tensor_a: TestAutodiffTensor<D>,
b: f32,
) -> TestAutodiffTensor<D> {
tensor_a.add_scalar(b)
}
// Saves its own state and does not need its parents
fn compute_bound_eager<const D: usize>(
tensor_a: TestAutodiffTensor<D>,
tensor_b: TestAutodiffTensor<D>,
) -> TestAutodiffTensor<D> {
let mask = Tensor::<TestAutodiffBackend, D, Bool>::empty(tensor_a.shape(), &tensor_a.device());
tensor_a.mask_where(mask, tensor_b)
}
fn compute_bound_eager_scalar<const D: usize>(
tensor_a: TestAutodiffTensor<D>,
b: f32,
) -> TestAutodiffTensor<D> {
let mask = Tensor::<TestAutodiffBackend, D, Bool>::empty(tensor_a.shape(), &tensor_a.device());
tensor_a.mask_fill(mask, b)
}
// Does not save its state and needs its parents
fn memory_bound_lazy<const D: usize>(
tensor_a: TestAutodiffTensor<D>,
tensor_b: TestAutodiffTensor<D>,
) -> TestAutodiffTensor<D> {
tensor_a.mul(tensor_b)
}
// Saves its own state and needs its parents
fn compute_bound_lazy<const D: usize>(
tensor_a: TestAutodiffTensor<D>,
tensor_b: TestAutodiffTensor<D>,
) -> TestAutodiffTensor<D> {
tensor_a.matmul(tensor_b)
}

View File

@@ -0,0 +1,81 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_full_complex_1() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.matmul(tensor_1.clone());
let tensor_5 = tensor_4.mul(tensor_2.clone());
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[593., 463.0], [487.0, 539.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[734.0, 294.0], [1414.0, 242.0]]), false);
}
#[test]
fn should_diff_full_complex_2() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.matmul(tensor_1.clone());
let tensor_5 = tensor_4.add_scalar(17.0).add(tensor_2.clone());
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[166.0, 110.0], [212.0, 156.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[113.0, 141.0], [33.0, 41.0]]), false);
}
#[test]
fn should_diff_full_complex_3() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.matmul(tensor_1.clone());
let tensor_5 = tensor_4.clone().sub(tensor_2.clone());
let tensor_6 = tensor_5.add(tensor_4);
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[332.0, 220.0], [424.0, 312.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[223.0, 279.0], [63.0, 79.0]]), false);
}

View File

@@ -0,0 +1,277 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv1d, ops::ConvOptions};
#[test]
fn test_conv1d_basic() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 1,
groups: 1,
length: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[14., 24., 24., 18.], [26., 42., 42., 30.]],
[[14., 24., 24., 18.], [26., 42., 42., 30.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[30., 44., 36.], [54., 76., 60.]],
[[30., 44., 36.], [54., 76., 60.]],
],
&device,
),
bias: TestTensor::from_floats([8., 8.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv1d_different_channels() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 3,
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 1,
groups: 1,
length: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[39., 63., 63., 45.], [57., 90., 90., 63.]],
[[39., 63., 63., 45.], [57., 90., 90., 63.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[30., 44., 36.], [54., 76., 60.]],
[[30., 44., 36.], [54., 76., 60.]],
[[30., 44., 36.], [54., 76., 60.]],
],
&device,
),
bias: TestTensor::from_floats([8., 8., 8.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv1d_with_padding() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 2,
stride: 1,
dilation: 1,
groups: 1,
length: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[24., 24., 24., 24.], [42., 42., 42., 42.]],
[[24., 24., 24., 24.], [42., 42., 42., 42.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[44., 44., 44.], [76., 76., 76.]],
[[44., 44., 44.], [76., 76., 76.]],
],
&device,
),
bias: TestTensor::from_floats([12., 12.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv1d_with_stride() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
stride: 2,
dilation: 1,
groups: 1,
length: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[8., 16., 8., 10.], [14., 28., 14., 16.]],
[[8., 16., 8., 10.], [14., 28., 14., 16.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[10., 20., 24.], [18., 36., 40.]],
[[10., 20., 24.], [18., 36., 40.]],
],
&device,
),
bias: TestTensor::from_floats([4., 4.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv1d_dilation() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 2,
groups: 1,
length: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[6., 8., 8., 10.], [12., 14., 14., 16.]],
[[6., 8., 8., 10.], [12., 14., 14., 16.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[8., 22., 14.], [16., 38., 22.]],
[[8., 22., 14.], [16., 38., 22.]],
],
&device,
),
bias: TestTensor::from_floats([4., 4.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv1d_groups() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 1,
groups: 2,
length: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[1., 3., 3., 3.], [7., 12., 12., 9.]],
[[1., 3., 3., 3.], [7., 12., 12., 9.]],
],
&device,
),
weight: TestTensor::from_floats([[[30., 44., 36.]], [[54., 76., 60.]]], &device),
bias: TestTensor::from_floats([8., 8.], &device),
};
test.assert_grads(grads);
}
struct Conv1dTestCase {
batch_size: usize,
channels_in: usize,
channels_out: usize,
kernel_size: usize,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
length: usize,
}
struct Grads {
x: TestTensor<3>,
weight: TestTensor<3>,
bias: TestTensor<1>,
}
impl Conv1dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]);
let shape_weight = Shape::new([
self.channels_out,
self.channels_in / self.groups,
self.kernel_size,
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<3, _>(shape_weight)
.into_data(),
&device,
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<3, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = conv1d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
let tolerance = Tolerance::default();
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
}
}

View File

@@ -0,0 +1,962 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv2d, ops::ConvOptions};
#[test]
fn test_conv2d_basic() {
let test = Conv2dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[88., 138., 138., 96.],
[150., 234., 234., 162.],
[150., 234., 234., 162.],
[112., 174., 174., 120.],
],
[
[160., 246., 246., 168.],
[258., 396., 396., 270.],
[258., 396., 396., 270.],
[184., 282., 282., 192.],
],
],
[
[
[88., 138., 138., 96.],
[150., 234., 234., 162.],
[150., 234., 234., 162.],
[112., 174., 174., 120.],
],
[
[160., 246., 246., 168.],
[258., 396., 396., 270.],
[258., 396., 396., 270.],
[184., 282., 282., 192.],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
[
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
],
&device,
),
bias: TestTensor::from_floats([32., 32.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_channels() {
let test = Conv2dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 3,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[240., 369., 369., 252.],
[387., 594., 594., 405.],
[387., 594., 594., 405.],
[276., 423., 423., 288.],
],
[
[348., 531., 531., 360.],
[549., 837., 837., 567.],
[549., 837., 837., 567.],
[384., 585., 585., 396.],
],
],
[
[
[240., 369., 369., 252.],
[387., 594., 594., 405.],
[387., 594., 594., 405.],
[276., 423., 423., 288.],
],
[
[348., 531., 531., 360.],
[549., 837., 837., 567.],
[549., 837., 837., 567.],
[384., 585., 585., 396.],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
[
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
[
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
],
&device,
),
bias: TestTensor::from_floats([32., 32., 32.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_kernel_size() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 4,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[116., 180., 192., 132.],
[198., 306., 324., 222.],
[198., 306., 324., 222.],
[148., 228., 240., 164.],
],
[
[212., 324., 336., 228.],
[342., 522., 540., 366.],
[342., 522., 540., 366.],
[244., 372., 384., 260.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[27., 45., 54., 39.],
[52., 84., 96., 68.],
[51., 81., 90., 63.],
],
[
[123., 189., 198., 135.],
[180., 276., 288., 196.],
[147., 225., 234., 159.],
],
],
[
[
[27., 45., 54., 39.],
[52., 84., 96., 68.],
[51., 81., 90., 63.],
],
[
[123., 189., 198., 135.],
[180., 276., 288., 196.],
[147., 225., 234., 159.],
],
],
],
&device,
),
bias: TestTensor::from_floats([12., 12.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_padding() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 2,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[138., 138., 138., 138.],
[234., 234., 234., 234.],
[234., 234., 234., 234.],
[174., 174., 174., 174.],
],
[
[246., 246., 246., 246.],
[396., 396., 396., 396.],
[396., 396., 396., 396.],
[282., 282., 282., 282.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]],
[[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]],
],
[
[[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]],
[[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]],
],
],
&device,
),
bias: TestTensor::from_floats([24., 24.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_width() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 4,
width: 5,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[88., 138., 138., 138., 96.],
[150., 234., 234., 234., 162.],
[150., 234., 234., 234., 162.],
[112., 174., 174., 174., 120.],
],
[
[160., 246., 246., 246., 168.],
[258., 396., 396., 396., 270.],
[258., 396., 396., 396., 270.],
[184., 282., 282., 282., 192.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]],
[[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]],
],
[
[[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]],
[[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]],
],
],
&device,
),
bias: TestTensor::from_floats([20., 20.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_stride_2() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 2,
stride_2: 2,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 6,
width: 6,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[26., 52., 26., 52., 26., 28.],
[52., 104., 52., 104., 52., 56.],
[26., 52., 26., 52., 26., 28.],
[52., 104., 52., 104., 52., 56.],
[26., 52., 26., 52., 26., 28.],
[32., 64., 32., 64., 32., 34.],
],
[
[44., 88., 44., 88., 44., 46.],
[88., 176., 88., 176., 88., 92.],
[44., 88., 44., 88., 44., 46.],
[88., 176., 88., 176., 88., 92.],
[44., 88., 44., 88., 44., 46.],
[50., 100., 50., 100., 50., 52.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]],
[[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]],
],
[
[[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]],
[[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]],
],
],
&device,
),
bias: TestTensor::from_floats([9., 9.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_stride() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 3,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 8,
width: 8,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[50., 78., 78., 78., 78., 78., 78., 54.],
[62., 96., 96., 96., 96., 96., 96., 66.],
[38., 60., 60., 60., 60., 60., 60., 42.],
[50., 78., 78., 78., 78., 78., 78., 54.],
[62., 96., 96., 96., 96., 96., 96., 66.],
[38., 60., 60., 60., 60., 60., 60., 42.],
[50., 78., 78., 78., 78., 78., 78., 54.],
[62., 96., 96., 96., 96., 96., 96., 66.],
],
[
[86., 132., 132., 132., 132., 132., 132., 90.],
[98., 150., 150., 150., 150., 150., 150., 102.],
[74., 114., 114., 114., 114., 114., 114., 78.],
[86., 132., 132., 132., 132., 132., 132., 90.],
[98., 150., 150., 150., 150., 150., 150., 102.],
[74., 114., 114., 114., 114., 114., 114., 78.],
[86., 132., 132., 132., 132., 132., 132., 90.],
[98., 150., 150., 150., 150., 150., 150., 102.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]],
[
[1330., 1528., 1344.],
[1911., 2196., 1932.],
[2079., 2388., 2100.],
],
],
[
[[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]],
[
[1330., 1528., 1344.],
[1911., 2196., 1932.],
[2079., 2388., 2100.],
],
],
],
&device,
),
bias: TestTensor::from_floats([24., 24.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_dilation_2() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 2,
dilation_2: 2,
groups: 1,
height: 6,
width: 6,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[18., 38., 38., 42., 42., 22.],
[42., 88., 88., 96., 96., 50.],
[42., 88., 88., 96., 96., 50.],
[54., 112., 112., 120., 120., 62.],
[54., 112., 112., 120., 120., 62.],
[30., 62., 62., 66., 66., 34.],
],
[
[36., 74., 74., 78., 78., 40.],
[78., 160., 160., 168., 168., 86.],
[78., 160., 160., 168., 168., 86.],
[90., 184., 184., 192., 192., 98.],
[90., 184., 184., 192., 192., 98.],
[48., 98., 98., 102., 102., 52.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]],
[[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]],
],
[
[[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]],
[[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]],
],
],
&device,
),
bias: TestTensor::from_floats([16., 16.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_dilation() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 2,
dilation_2: 3,
groups: 1,
height: 6,
width: 6,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[18., 0., 20., 20., 0., 22.],
[42., 0., 46., 46., 0., 50.],
[42., 0., 46., 46., 0., 50.],
[54., 0., 58., 58., 0., 62.],
[54., 0., 58., 58., 0., 62.],
[30., 0., 32., 32., 0., 34.],
],
[
[36., 0., 38., 38., 0., 40.],
[78., 0., 82., 82., 0., 86.],
[78., 0., 82., 82., 0., 86.],
[90., 0., 94., 94., 0., 98.],
[90., 0., 94., 94., 0., 98.],
[48., 0., 50., 50., 0., 52.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]],
[[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]],
],
[
[[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]],
[[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]],
],
],
&device,
),
bias: TestTensor::from_floats([8., 8.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_groups() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 2,
height: 5,
width: 5,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[0., 1., 3., 3., 2.],
[3., 8., 15., 12., 7.],
[9., 21., 36., 27., 15.],
[9., 20., 33., 24., 13.],
[6., 13., 21., 15., 8.],
],
[
[9., 19., 30., 21., 11.],
[21., 44., 69., 48., 25.],
[36., 75., 117., 81., 42.],
[27., 56., 87., 60., 31.],
[15., 31., 48., 33., 17.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[[54., 63., 72.], [99., 108., 117.], [144., 153., 162.]]],
[[[279., 288., 297.], [324., 333., 342.], [369., 378., 387.]]],
],
&device,
),
bias: TestTensor::from_floats([9., 9.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_groups_stride_2() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 4,
channels_out: 4,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 2,
stride_2: 2,
dilation_1: 1,
dilation_2: 1,
groups: 4,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[4., 8., 4., 5.],
[8., 16., 8., 10.],
[4., 8., 4., 5.],
[7., 14., 7., 8.],
],
[
[13., 26., 13., 14.],
[26., 52., 26., 28.],
[13., 26., 13., 14.],
[16., 32., 16., 17.],
],
[
[22., 44., 22., 23.],
[44., 88., 44., 46.],
[22., 44., 22., 23.],
[25., 50., 25., 26.],
],
[
[31., 62., 31., 32.],
[62., 124., 62., 64.],
[31., 62., 31., 32.],
[34., 68., 34., 35.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[[5., 10., 12.], [10., 20., 24.], [18., 36., 40.]]],
[[[21., 42., 44.], [42., 84., 88.], [50., 100., 104.]]],
[[[37., 74., 76.], [74., 148., 152.], [82., 164., 168.]]],
[[[53., 106., 108.], [106., 212., 216.], [114., 228., 232.]]],
],
&device,
),
bias: TestTensor::from_floats([4., 4., 4., 4.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_groups_different_channels() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 3,
channels_out: 6,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 3,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[9., 20., 24., 13.],
[24., 52., 60., 32.],
[36., 76., 84., 44.],
[21., 44., 48., 25.],
],
[
[45., 92., 96., 49.],
[96., 196., 204., 104.],
[108., 220., 228., 116.],
[57., 116., 120., 61.],
],
[
[81., 164., 168., 85.],
[168., 340., 348., 176.],
[180., 364., 372., 188.],
[93., 188., 192., 97.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]],
[[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]],
[[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]],
[[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]],
[[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]],
[[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]],
],
&device,
),
bias: TestTensor::from_floats([4., 4., 4., 4., 4., 4.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_complex() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 3,
kernel_size_1: 2,
kernel_size_2: 3,
padding_1: 1,
padding_2: 2,
stride_1: 1,
stride_2: 2,
dilation_1: 2,
dilation_2: 3,
groups: 1,
height: 4,
width: 5,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[36., 39., 0., 39., 42.],
[81., 87., 0., 87., 93.],
[81., 87., 0., 87., 93.],
[45., 48., 0., 48., 51.],
],
[
[54., 57., 0., 57., 60.],
[117., 123., 0., 123., 129.],
[117., 123., 0., 123., 129.],
[63., 66., 0., 66., 69.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[15., 42., 27.], [30., 72., 42.]],
[[75., 162., 87.], [90., 192., 102.]],
],
[
[[15., 42., 27.], [30., 72., 42.]],
[[75., 162., 87.], [90., 192., 102.]],
],
[
[[15., 42., 27.], [30., 72., 42.]],
[[75., 162., 87.], [90., 192., 102.]],
],
],
&device,
),
bias: TestTensor::from_floats([8., 8., 8.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_groups_stride_2_no_pad() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 4,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 2,
stride_2: 2,
dilation_1: 1,
dilation_2: 1,
groups: 2,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[0., 1., 2., 0.],
[3., 4., 5., 0.],
[6., 7., 8., 0.],
[0., 0., 0., 0.],
],
[
[9., 10., 11., 0.],
[12., 13., 14., 0.],
[15., 16., 17., 0.],
[0., 0., 0., 0.],
],
[
[18., 19., 20., 0.],
[21., 22., 23., 0.],
[24., 25., 26., 0.],
[0., 0., 0., 0.],
],
[
[27., 28., 29., 0.],
[30., 31., 32., 0.],
[33., 34., 35., 0.],
[0., 0., 0., 0.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[0., 1., 2.], [4., 5., 6.], [8., 9., 10.]],
[[16., 17., 18.], [20., 21., 22.], [24., 25., 26.]],
],
[
[[32., 33., 34.], [36., 37., 38.], [40., 41., 42.]],
[[48., 49., 50.], [52., 53., 54.], [56., 57., 58.]],
],
],
&device,
),
bias: TestTensor::from_floats([1., 1.], &device),
};
test.assert_grads(grads);
}
struct Conv2dTestCase {
batch_size: usize,
channels_in: usize,
channels_out: usize,
kernel_size_1: usize,
kernel_size_2: usize,
padding_1: usize,
padding_2: usize,
stride_1: usize,
stride_2: usize,
dilation_1: usize,
dilation_2: usize,
groups: usize,
height: usize,
width: usize,
}
struct Grads {
x: TestTensor<4>,
weight: TestTensor<4>,
bias: TestTensor<1>,
}
impl Conv2dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
let shape_weight = Shape::new([
self.channels_out,
self.channels_in / self.groups,
self.kernel_size_1,
self.kernel_size_2,
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<4, _>(shape_weight)
.into_data(),
&device,
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<4, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = conv2d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvOptions::new(
[self.stride_1, self.stride_2],
[self.padding_1, self.padding_2],
[self.dilation_1, self.dilation_2],
self.groups,
),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
let tolerance = Tolerance::rel_abs(0.01, 0.01);
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
}
}

View File

@@ -0,0 +1,690 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv3d, ops::ConvOptions};
#[test]
fn test_conv3d_basic() {
let test = Conv3dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
kernel_size_3: 3,
padding_1: 1,
padding_2: 1,
padding_3: 1,
stride_1: 1,
stride_2: 1,
stride_3: 1,
dilation_1: 1,
dilation_2: 1,
dilation_3: 1,
groups: 1,
depth: 4,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[
[536., 816., 816., 552.],
[840., 1278., 1278., 864.],
[840., 1278., 1278., 864.],
[584., 888., 888., 600.],
],
[
[912., 1386., 1386., 936.],
[1422., 2160., 2160., 1458.],
[1422., 2160., 2160., 1458.],
[984., 1494., 1494., 1008.],
],
[
[912., 1386., 1386., 936.],
[1422., 2160., 2160., 1458.],
[1422., 2160., 2160., 1458.],
[984., 1494., 1494., 1008.],
],
[
[680., 1032., 1032., 696.],
[1056., 1602., 1602., 1080.],
[1056., 1602., 1602., 1080.],
[728., 1104., 1104., 744.],
],
],
[
[
[968., 1464., 1464., 984.],
[1488., 2250., 2250., 1512.],
[1488., 2250., 2250., 1512.],
[1016., 1536., 1536., 1032.],
],
[
[1560., 2358., 2358., 1584.],
[2394., 3618., 3618., 2430.],
[2394., 3618., 3618., 2430.],
[1632., 2466., 2466., 1656.],
],
[
[1560., 2358., 2358., 1584.],
[2394., 3618., 3618., 2430.],
[2394., 3618., 3618., 2430.],
[1632., 2466., 2466., 1656.],
],
[
[1112., 1680., 1680., 1128.],
[1704., 2574., 2574., 1728.],
[1704., 2574., 2574., 1728.],
[1160., 1752., 1752., 1176.],
],
],
],
[
[
[
[536., 816., 816., 552.],
[840., 1278., 1278., 864.],
[840., 1278., 1278., 864.],
[584., 888., 888., 600.],
],
[
[912., 1386., 1386., 936.],
[1422., 2160., 2160., 1458.],
[1422., 2160., 2160., 1458.],
[984., 1494., 1494., 1008.],
],
[
[912., 1386., 1386., 936.],
[1422., 2160., 2160., 1458.],
[1422., 2160., 2160., 1458.],
[984., 1494., 1494., 1008.],
],
[
[680., 1032., 1032., 696.],
[1056., 1602., 1602., 1080.],
[1056., 1602., 1602., 1080.],
[728., 1104., 1104., 744.],
],
],
[
[
[968., 1464., 1464., 984.],
[1488., 2250., 2250., 1512.],
[1488., 2250., 2250., 1512.],
[1016., 1536., 1536., 1032.],
],
[
[1560., 2358., 2358., 1584.],
[2394., 3618., 3618., 2430.],
[2394., 3618., 3618., 2430.],
[1632., 2466., 2466., 1656.],
],
[
[1560., 2358., 2358., 1584.],
[2394., 3618., 3618., 2430.],
[2394., 3618., 3618., 2430.],
[1632., 2466., 2466., 1656.],
],
[
[1112., 1680., 1680., 1128.],
[1704., 2574., 2574., 1728.],
[1704., 2574., 2574., 1728.],
[1160., 1752., 1752., 1176.],
],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[
[4590., 6156., 4644.],
[6264., 8400., 6336.],
[4806., 6444., 4860.],
],
[
[6696., 8976., 6768.],
[9120., 12224., 9216.],
[6984., 9360., 7056.],
],
[
[5454., 7308., 5508.],
[7416., 9936., 7488.],
[5670., 7596., 5724.],
],
],
[
[
[8046., 10764., 8100.],
[10872., 14544., 10944.],
[8262., 11052., 8316.],
],
[
[11304., 15120., 11376.],
[15264., 20416., 15360.],
[11592., 15504., 11664.],
],
[
[8910., 11916., 8964.],
[12024., 16080., 12096.],
[9126., 12204., 9180.],
],
],
],
[
[
[
[4590., 6156., 4644.],
[6264., 8400., 6336.],
[4806., 6444., 4860.],
],
[
[6696., 8976., 6768.],
[9120., 12224., 9216.],
[6984., 9360., 7056.],
],
[
[5454., 7308., 5508.],
[7416., 9936., 7488.],
[5670., 7596., 5724.],
],
],
[
[
[8046., 10764., 8100.],
[10872., 14544., 10944.],
[8262., 11052., 8316.],
],
[
[11304., 15120., 11376.],
[15264., 20416., 15360.],
[11592., 15504., 11664.],
],
[
[8910., 11916., 8964.],
[12024., 16080., 12096.],
[9126., 12204., 9180.],
],
],
],
],
&device,
),
bias: TestTensor::from_floats([128., 128.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv3d_complex() {
let test = Conv3dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 3,
kernel_size_1: 2,
kernel_size_2: 3,
kernel_size_3: 4,
padding_1: 1,
padding_2: 2,
padding_3: 3,
stride_1: 1,
stride_2: 2,
stride_3: 3,
dilation_1: 2,
dilation_2: 3,
dilation_3: 4,
groups: 1,
depth: 5,
height: 6,
width: 7,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[
[0., 147., 0., 0., 0., 150., 0.],
[0., 159., 0., 0., 0., 162., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 159., 0., 0., 0., 162., 0.],
[0., 171., 0., 0., 0., 174., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 330., 0., 0., 0., 336., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 378., 0., 0., 0., 384., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 330., 0., 0., 0., 336., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 378., 0., 0., 0., 384., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 330., 0., 0., 0., 336., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 378., 0., 0., 0., 384., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 183., 0., 0., 0., 186., 0.],
[0., 195., 0., 0., 0., 198., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 195., 0., 0., 0., 198., 0.],
[0., 207., 0., 0., 0., 210., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
],
[
[
[0., 219., 0., 0., 0., 222., 0.],
[0., 231., 0., 0., 0., 234., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 231., 0., 0., 0., 234., 0.],
[0., 243., 0., 0., 0., 246., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 474., 0., 0., 0., 480., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 522., 0., 0., 0., 528., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 474., 0., 0., 0., 480., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 522., 0., 0., 0., 528., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 474., 0., 0., 0., 480., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 522., 0., 0., 0., 528., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 255., 0., 0., 0., 258., 0.],
[0., 267., 0., 0., 0., 270., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 267., 0., 0., 0., 270., 0.],
[0., 279., 0., 0., 0., 282., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[
[0., 256., 272., 0.],
[0., 624., 656., 0.],
[0., 368., 384., 0.],
],
[
[0., 424., 440., 0.],
[0., 960., 992., 0.],
[0., 536., 552., 0.],
],
],
[
[
[0., 1096., 1112., 0.],
[0., 2304., 2336., 0.],
[0., 1208., 1224., 0.],
],
[
[0., 1264., 1280., 0.],
[0., 2640., 2672., 0.],
[0., 1376., 1392., 0.],
],
],
],
[
[
[
[0., 256., 272., 0.],
[0., 624., 656., 0.],
[0., 368., 384., 0.],
],
[
[0., 424., 440., 0.],
[0., 960., 992., 0.],
[0., 536., 552., 0.],
],
],
[
[
[0., 1096., 1112., 0.],
[0., 2304., 2336., 0.],
[0., 1208., 1224., 0.],
],
[
[0., 1264., 1280., 0.],
[0., 2640., 2672., 0.],
[0., 1376., 1392., 0.],
],
],
],
[
[
[
[0., 256., 272., 0.],
[0., 624., 656., 0.],
[0., 368., 384., 0.],
],
[
[0., 424., 440., 0.],
[0., 960., 992., 0.],
[0., 536., 552., 0.],
],
],
[
[
[0., 1096., 1112., 0.],
[0., 2304., 2336., 0.],
[0., 1208., 1224., 0.],
],
[
[0., 1264., 1280., 0.],
[0., 2640., 2672., 0.],
[0., 1376., 1392., 0.],
],
],
],
],
&device,
),
bias: TestTensor::from_floats([10., 10., 10.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv3d_groups_stride_2_no_pad() {
let test = Conv3dTestCase {
batch_size: 1,
channels_in: 4,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
kernel_size_3: 3,
padding_1: 0,
padding_2: 0,
padding_3: 0,
stride_1: 2,
stride_2: 2,
stride_3: 2,
dilation_1: 1,
dilation_2: 1,
dilation_3: 1,
groups: 2,
depth: 4,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[
[0., 1., 2., 0.],
[3., 4., 5., 0.],
[6., 7., 8., 0.],
[0., 0., 0., 0.],
],
[
[9., 10., 11., 0.],
[12., 13., 14., 0.],
[15., 16., 17., 0.],
[0., 0., 0., 0.],
],
[
[18., 19., 20., 0.],
[21., 22., 23., 0.],
[24., 25., 26., 0.],
[0., 0., 0., 0.],
],
[
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
],
],
[
[
[27., 28., 29., 0.],
[30., 31., 32., 0.],
[33., 34., 35., 0.],
[0., 0., 0., 0.],
],
[
[36., 37., 38., 0.],
[39., 40., 41., 0.],
[42., 43., 44., 0.],
[0., 0., 0., 0.],
],
[
[45., 46., 47., 0.],
[48., 49., 50., 0.],
[51., 52., 53., 0.],
[0., 0., 0., 0.],
],
[
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
],
],
[
[
[54., 55., 56., 0.],
[57., 58., 59., 0.],
[60., 61., 62., 0.],
[0., 0., 0., 0.],
],
[
[63., 64., 65., 0.],
[66., 67., 68., 0.],
[69., 70., 71., 0.],
[0., 0., 0., 0.],
],
[
[72., 73., 74., 0.],
[75., 76., 77., 0.],
[78., 79., 80., 0.],
[0., 0., 0., 0.],
],
[
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
],
],
[
[
[81., 82., 83., 0.],
[84., 85., 86., 0.],
[87., 88., 89., 0.],
[0., 0., 0., 0.],
],
[
[90., 91., 92., 0.],
[93., 94., 95., 0.],
[96., 97., 98., 0.],
[0., 0., 0., 0.],
],
[
[99., 100., 101., 0.],
[102., 103., 104., 0.],
[105., 106., 107., 0.],
[0., 0., 0., 0.],
],
[
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[[0., 1., 2.], [4., 5., 6.], [8., 9., 10.]],
[[16., 17., 18.], [20., 21., 22.], [24., 25., 26.]],
[[32., 33., 34.], [36., 37., 38.], [40., 41., 42.]],
],
[
[[64., 65., 66.], [68., 69., 70.], [72., 73., 74.]],
[[80., 81., 82.], [84., 85., 86.], [88., 89., 90.]],
[[96., 97., 98.], [100., 101., 102.], [104., 105., 106.]],
],
],
[
[
[[128., 129., 130.], [132., 133., 134.], [136., 137., 138.]],
[[144., 145., 146.], [148., 149., 150.], [152., 153., 154.]],
[[160., 161., 162.], [164., 165., 166.], [168., 169., 170.]],
],
[
[[192., 193., 194.], [196., 197., 198.], [200., 201., 202.]],
[[208., 209., 210.], [212., 213., 214.], [216., 217., 218.]],
[[224., 225., 226.], [228., 229., 230.], [232., 233., 234.]],
],
],
],
&device,
),
bias: TestTensor::from_floats([1., 1.], &device),
};
test.assert_grads(grads);
}
struct Conv3dTestCase {
batch_size: usize,
channels_in: usize,
channels_out: usize,
kernel_size_1: usize,
kernel_size_2: usize,
kernel_size_3: usize,
padding_1: usize,
padding_2: usize,
padding_3: usize,
stride_1: usize,
stride_2: usize,
stride_3: usize,
dilation_1: usize,
dilation_2: usize,
dilation_3: usize,
groups: usize,
depth: usize,
height: usize,
width: usize,
}
struct Grads {
x: TestTensor<5>,
weight: TestTensor<5>,
bias: TestTensor<1>,
}
impl Conv3dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([
self.batch_size,
self.channels_in,
self.depth,
self.height,
self.width,
]);
let shape_weight = Shape::new([
self.channels_out,
self.channels_in / self.groups,
self.kernel_size_1,
self.kernel_size_2,
self.kernel_size_3,
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<5, _>(shape_weight)
.into_data(),
&device,
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<5, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = conv3d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvOptions::new(
[self.stride_1, self.stride_2, self.stride_3],
[self.padding_1, self.padding_2, self.padding_3],
[self.dilation_1, self.dilation_2, self.dilation_3],
self.groups,
),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
let tolerance = Tolerance::default();
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
}
}

View File

@@ -0,0 +1,292 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv_transpose1d, ops::ConvTransposeOptions};
#[test]
fn test_conv_transpose1d_basic() {
let test = ConvTranspose1dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: 3,
padding: 0,
padding_out: 0,
stride: 1,
dilation: 1,
groups: 1,
size: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]],
[[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[44.0, 44.0, 44.0], [44.0, 44.0, 44.0]],
[[76.0, 76.0, 76.0], [76.0, 76.0, 76.0]],
],
&device,
),
bias: TestTensor::from_floats([12., 12.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose1d_padding() {
let test = ConvTranspose1dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: 3,
padding: 2,
padding_out: 0,
stride: 1,
dilation: 1,
groups: 1,
size: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[7., 12., 8., 3.], [19., 36., 32., 15.]],
[[7., 12., 8., 3.], [19., 36., 32., 15.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[26., 22., 18.], [26., 22., 18.]],
[[42., 38., 34.], [42., 38., 34.]],
],
&device,
),
bias: TestTensor::from_floats([4., 4.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose1d_stride() {
let test = ConvTranspose1dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: 3,
padding: 0,
padding_out: 0,
stride: 2,
dilation: 1,
groups: 1,
size: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[44., 44., 44.], [44., 44., 44.]],
[[76., 76., 76.], [76., 76., 76.]],
],
&device,
),
bias: TestTensor::from_floats([18., 18.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose1d_stride_padding_out() {
let test = ConvTranspose1dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: 3,
padding: 0,
padding_out: 1,
stride: 2,
dilation: 1,
groups: 1,
size: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[44., 44., 44.], [44., 44., 44.]],
[[76., 76., 76.], [76., 76., 76.]],
],
&device,
),
bias: TestTensor::from_floats([20., 20.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose1d_dilation() {
let test = ConvTranspose1dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: 3,
padding: 0,
padding_out: 0,
stride: 1,
dilation: 2,
groups: 1,
size: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[44., 44., 44.], [44., 44., 44.]],
[[76., 76., 76.], [76., 76., 76.]],
],
&device,
),
bias: TestTensor::from_floats([16., 16.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose1d_complex() {
let test = ConvTranspose1dTestCase {
batch_size: 2,
channels: [2, 4],
kernel_size: 3,
padding: 1,
padding_out: 1,
stride: 2,
dilation: 2,
groups: 2,
size: 8,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0],
[36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0],
],
[
[12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0],
[36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[[168.0, 184.0, 184.0], [168.0, 184.0, 184.0]],
[[280.0, 312.0, 312.0], [280.0, 312.0, 312.0]],
],
&device,
),
bias: TestTensor::from_floats([36.0, 36.0, 36.0, 36.0], &device),
};
test.assert_grads(grads);
}
struct ConvTranspose1dTestCase {
batch_size: usize,
channels: [usize; 2],
kernel_size: usize,
padding: usize,
padding_out: usize,
stride: usize,
dilation: usize,
groups: usize,
size: usize,
}
struct Grads {
x: TestTensor<3>,
weight: TestTensor<3>,
bias: TestTensor<1>,
}
impl ConvTranspose1dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([self.batch_size, self.channels[0], self.size]);
let shape_weight = Shape::new([
self.channels[0],
self.channels[1] / self.groups,
self.kernel_size,
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<3, _>(shape_weight)
.into_data(),
&device,
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<3, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = conv_transpose1d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvTransposeOptions::new(
[self.stride],
[self.padding],
[self.padding_out],
[self.dilation],
self.groups,
),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), Tolerance::default());
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), Tolerance::default());
}
}

View File

@@ -0,0 +1,706 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv_transpose2d, ops::ConvTransposeOptions};
#[test]
fn test_conv_transpose2d_basic() {
let test = ConvTranspose2dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: [3, 3],
padding: [0, 0],
padding_out: [0, 0],
stride: [1, 1],
dilation: [1, 1],
groups: 1,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[153., 153., 153., 153.],
[153., 153., 153., 153.],
[153., 153., 153., 153.],
[153., 153., 153., 153.],
],
[
[477., 477., 477., 477.],
[477., 477., 477., 477.],
[477., 477., 477., 477.],
[477., 477., 477., 477.],
],
],
[
[
[153., 153., 153., 153.],
[153., 153., 153., 153.],
[153., 153., 153., 153.],
[153., 153., 153., 153.],
],
[
[477., 477., 477., 477.],
[477., 477., 477., 477.],
[477., 477., 477., 477.],
[477., 477., 477., 477.],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]],
[[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]],
],
[
[
[1264., 1264., 1264.],
[1264., 1264., 1264.],
[1264., 1264., 1264.],
],
[
[1264., 1264., 1264.],
[1264., 1264., 1264.],
[1264., 1264., 1264.],
],
],
],
&device,
),
bias: TestTensor::from_floats([72., 72.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_padding() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [1, 1],
kernel_size: [3, 3],
padding: [1, 2],
padding_out: [0, 0],
stride: [1, 1],
dilation: [1, 1],
groups: 1,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[[
[13., 24., 20., 9.],
[15., 27., 21., 9.],
[15., 27., 21., 9.],
[7., 12., 8., 3.],
]]],
&device,
),
weight: TestTensor::from_floats(
[[[[63., 57., 51.], [68., 60., 52.], [39., 33., 27.]]]],
&device,
),
bias: TestTensor::from_floats([8.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_stride() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [1, 1],
kernel_size: [3, 3],
padding: [0, 0],
padding_out: [0, 0],
stride: [2, 3],
dilation: [1, 1],
groups: 1,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[[
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
]]],
&device,
),
weight: TestTensor::from_floats(
[[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]],
&device,
),
bias: TestTensor::from_floats([108.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_stride_padding_out() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [1, 1],
kernel_size: [3, 3],
padding: [0, 0],
padding_out: [1, 2],
stride: [2, 3],
dilation: [1, 1],
groups: 1,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[[
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
]]],
&device,
),
weight: TestTensor::from_floats(
[[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]],
&device,
),
bias: TestTensor::from_floats([140.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_dilation() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [1, 1],
kernel_size: [3, 3],
padding: [0, 0],
padding_out: [0, 0],
stride: [1, 1],
dilation: [2, 3],
groups: 1,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[[
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
]]],
&device,
),
weight: TestTensor::from_floats(
[[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]],
&device,
),
bias: TestTensor::from_floats([80.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_channels() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [2, 3],
kernel_size: [3, 3],
padding: [0, 0],
padding_out: [0, 0],
stride: [1, 1],
dilation: [1, 1],
groups: 1,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[351., 351., 351., 351.],
[351., 351., 351., 351.],
[351., 351., 351., 351.],
[351., 351., 351., 351.],
],
[
[1080., 1080., 1080., 1080.],
[1080., 1080., 1080., 1080.],
[1080., 1080., 1080., 1080.],
[1080., 1080., 1080., 1080.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],
[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],
[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],
],
[
[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],
[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],
[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],
],
],
&device,
),
bias: TestTensor::from_floats([36., 36., 36.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_kernel_size() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [1, 1],
kernel_size: [3, 5],
padding: [0, 0],
padding_out: [0, 0],
stride: [1, 1],
dilation: [1, 1],
groups: 1,
size: [6, 6],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[[
[105., 105., 105., 105., 105., 105.],
[105., 105., 105., 105., 105., 105.],
[105., 105., 105., 105., 105., 105.],
[105., 105., 105., 105., 105., 105.],
[105., 105., 105., 105., 105., 105.],
[105., 105., 105., 105., 105., 105.],
]]],
&device,
),
weight: TestTensor::from_floats(
[[[
[630., 630., 630., 630., 630.],
[630., 630., 630., 630., 630.],
[630., 630., 630., 630., 630.],
]]],
&device,
),
bias: TestTensor::from_floats([80.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_groups() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [2, 2],
kernel_size: [3, 3],
padding: [0, 0],
padding_out: [0, 0],
stride: [1, 1],
dilation: [1, 1],
groups: 2,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
],
[
[117., 117., 117., 117.],
[117., 117., 117., 117.],
[117., 117., 117., 117.],
[117., 117., 117., 117.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]],
[[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]]],
],
&device,
),
bias: TestTensor::from_floats([36., 36.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_complex_no_groups() {
let test = ConvTranspose2dTestCase {
batch_size: 2,
channels: [2, 3],
kernel_size: [3, 5],
padding: [1, 2],
padding_out: [1, 2],
stride: [2, 3],
dilation: [2, 3],
groups: 1,
size: [6, 8],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[600., 735., 735., 735., 735., 735., 735., 735.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
],
[
[1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
],
],
[
[
[600., 735., 735., 735., 735., 735., 735., 735.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
],
[
[1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[5320., 6040., 6040., 6040., 6040.],
[6048., 6864., 6864., 6864., 6864.],
[6048., 6864., 6864., 6864., 6864.],
],
[
[5320., 6040., 6040., 6040., 6040.],
[6048., 6864., 6864., 6864., 6864.],
[6048., 6864., 6864., 6864., 6864.],
],
[
[5320., 6040., 6040., 6040., 6040.],
[6048., 6864., 6864., 6864., 6864.],
[6048., 6864., 6864., 6864., 6864.],
],
],
[
[
[8680., 9880., 9880., 9880., 9880.],
[10080., 11472., 11472., 11472., 11472.],
[10080., 11472., 11472., 11472., 11472.],
],
[
[8680., 9880., 9880., 9880., 9880.],
[10080., 11472., 11472., 11472., 11472.],
[10080., 11472., 11472., 11472., 11472.],
],
[
[8680., 9880., 9880., 9880., 9880.],
[10080., 11472., 11472., 11472., 11472.],
[10080., 11472., 11472., 11472., 11472.],
],
],
],
&device,
),
bias: TestTensor::from_floats([896., 896., 896.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_complex_no_groups_2() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [4, 2],
kernel_size: [2, 3],
padding: [1, 2],
padding_out: [1, 2],
stride: [2, 3],
dilation: [1, 2],
groups: 1,
size: [10, 10],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[30., 42., 42., 42., 42., 42., 42., 42., 42., 42.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
],
[
[78., 114., 114., 114., 114., 114., 114., 114., 114., 114.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
],
[
[126., 186., 186., 186., 186., 186., 186., 186., 186., 186.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
],
[
[174., 258., 258., 258., 258., 258., 258., 258., 258., 258.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[4455., 4905., 4905.], [4500., 4950., 4950.]],
[[4455., 4905., 4905.], [4500., 4950., 4950.]],
],
[
[[12555., 13905., 13905.], [13500., 14950., 14950.]],
[[12555., 13905., 13905.], [13500., 14950., 14950.]],
],
[
[[20655., 22905., 22905.], [22500., 24950., 24950.]],
[[20655., 22905., 22905.], [22500., 24950., 24950.]],
],
[
[[28755., 31905., 31905.], [31500., 34950., 34950.]],
[[28755., 31905., 31905.], [31500., 34950., 34950.]],
],
],
&device,
),
bias: TestTensor::from_floats([570., 570.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_complex_groups() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [4, 2],
kernel_size: [2, 3],
padding: [1, 2],
padding_out: [1, 2],
stride: [2, 3],
dilation: [1, 2],
groups: 2,
size: [10, 10],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[9., 12., 12., 12., 12., 12., 12., 12., 12., 12.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
],
[
[21., 30., 30., 30., 30., 30., 30., 30., 30., 30.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
],
[
[33., 48., 48., 48., 48., 48., 48., 48., 48., 48.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
],
[
[45., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[[4455., 4905., 4905.], [4500., 4950., 4950.]]],
[[[12555., 13905., 13905.], [13500., 14950., 14950.]]],
[[[20655., 22905., 22905.], [22500., 24950., 24950.]]],
[[[28755., 31905., 31905.], [31500., 34950., 34950.]]],
],
&device,
),
bias: TestTensor::from_floats([570., 570.], &device),
};
test.assert_grads(grads);
}
struct ConvTranspose2dTestCase {
batch_size: usize,
channels: [usize; 2],
kernel_size: [usize; 2],
padding: [usize; 2],
padding_out: [usize; 2],
stride: [usize; 2],
dilation: [usize; 2],
groups: usize,
size: [usize; 2],
}
struct Grads {
x: TestTensor<4>,
weight: TestTensor<4>,
bias: TestTensor<1>,
}
impl ConvTranspose2dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([
self.batch_size,
self.channels[0],
self.size[0],
self.size[1],
]);
let shape_weight = Shape::new([
self.channels[0],
self.channels[1] / self.groups,
self.kernel_size[0],
self.kernel_size[1],
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<4, _>(shape_weight)
.into_data(),
&device,
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<4, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = conv_transpose2d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvTransposeOptions::new(
self.stride,
self.padding,
self.padding_out,
self.dilation,
self.groups,
),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
let tolerance = Tolerance::permissive();
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
}
}

View File

@@ -0,0 +1,711 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv_transpose3d, ops::ConvTransposeOptions};
#[test]
fn test_conv_transpose3d_basic() {
let test = ConvTranspose3dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: [3, 3, 3],
padding: [0, 0, 0],
padding_out: [0, 0, 0],
stride: [1, 1, 1],
dilation: [1, 1, 1],
groups: 1,
size: [4, 4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
],
[
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
],
],
[
[
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
],
[
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
],
[
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
],
[
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
],
],
[
[
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
],
[
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
],
[
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
],
],
],
[
[
[
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
],
[
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
],
[
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
],
],
[
[
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
],
[
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
],
[
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
],
],
],
],
&device,
),
bias: TestTensor::from_floats([432., 432.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose3d_complex_groups() {
let test = ConvTranspose3dTestCase {
batch_size: 1,
channels: [4, 2],
kernel_size: [2, 3, 4],
padding: [1, 2, 3],
padding_out: [1, 2, 3],
stride: [2, 3, 4],
dilation: [1, 2, 3],
groups: 2,
size: [6, 6, 6],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[
[1.250000, 1.625000, 1.625000, 1.625000, 1.625000, 1.625000],
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
],
[
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
],
[
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
],
[
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
],
[
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
],
[
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
],
],
[
[
[2.750000, 3.625000, 3.625000, 3.625000, 3.625000, 3.625000],
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
],
[
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
],
[
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
],
[
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
],
[
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
],
[
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
],
],
[
[
[4.250000, 5.625000, 5.625000, 5.625000, 5.625000, 5.625000],
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
],
[
[
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
],
[
[
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
],
[
[
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
],
[
[
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
],
[
[
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
],
],
[
[
[5.750000, 7.625000, 7.625000, 7.625000, 7.625000, 7.625000],
[
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
],
[
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
],
[
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
],
[
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
],
[
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
],
],
[
[
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
],
[
[
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
],
[
[
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
],
[
[
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
],
[
[
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[
[
[18.663193, 22.309027, 22.309027, 22.309027],
[21.875000, 26.145834, 26.145834, 26.145834],
[21.875000, 26.145834, 26.145834, 26.145834],
],
[
[19.270832, 23.020834, 23.020834, 23.020834],
[22.500000, 26.875002, 26.875002, 26.875002],
[22.500000, 26.875002, 26.875002, 26.875002],
],
]],
[[
[
[49.913193, 59.809029, 59.809029, 59.809029],
[59.375000, 71.145836, 71.145836, 71.145836],
[59.375000, 71.145836, 71.145836, 71.145836],
],
[
[56.770836, 68.020836, 68.020836, 68.020836],
[67.500000, 80.875000, 80.875000, 80.875000],
[67.500000, 80.875000, 80.875000, 80.875000],
],
]],
[[
[
[81.163193, 97.309029, 97.309029, 97.309029],
[96.875000, 116.145828, 116.145828, 116.145828],
[96.875000, 116.145828, 116.145828, 116.145828],
],
[
[94.270828, 113.020828, 113.020828, 113.020828],
[112.500000, 134.875000, 134.875000, 134.875000],
[112.500000, 134.875000, 134.875000, 134.875000],
],
]],
[[
[
[112.413200, 134.809021, 134.809021, 134.809021],
[134.375000, 161.145828, 161.145828, 161.145828],
[134.375000, 161.145828, 161.145828, 161.145828],
],
[
[131.770844, 158.020828, 158.020828, 158.020828],
[157.500000, 188.875000, 188.875000, 188.875000],
[157.500000, 188.875000, 188.875000, 188.875000],
],
]],
],
&device,
),
bias: TestTensor::from_floats([5346., 5346.], &device),
};
test.assert_grads(grads);
}
struct ConvTranspose3dTestCase {
batch_size: usize,
channels: [usize; 2],
kernel_size: [usize; 3],
padding: [usize; 3],
padding_out: [usize; 3],
stride: [usize; 3],
dilation: [usize; 3],
groups: usize,
size: [usize; 3],
}
struct Grads {
x: TestTensor<5>,
weight: TestTensor<5>,
bias: TestTensor<1>,
}
impl ConvTranspose3dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([
self.batch_size,
self.channels[0],
self.size[0],
self.size[1],
self.size[2],
]);
let shape_weight = Shape::new([
self.channels[0],
self.channels[1] / self.groups,
self.kernel_size[0],
self.kernel_size[1],
self.kernel_size[2],
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<5, _>(shape_weight.clone())
.into_data(),
&device,
)
.div_scalar(shape_weight.num_elements() as f32)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<5, _>(shape_x.clone())
.into_data(),
&device,
)
.div_scalar(shape_x.num_elements() as f32)
.require_grad();
let output = conv_transpose3d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvTransposeOptions::new(
self.stride,
self.padding,
self.padding_out,
self.dilation,
self.groups,
),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
let tolerance = Tolerance::permissive();
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
}
}

View File

@@ -0,0 +1,103 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[cfg(feature = "std")]
use burn_backend_tests::might_panic;
#[test]
fn backward_basic() {
let device = Default::default();
let a = TestAutodiffTensor::<2>::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
)
.require_grad();
let b = TestAutodiffTensor::<2>::from_data(
TensorData::from([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
&device,
)
.require_grad();
// Simple cross product; grad is a vector of ones.
let c = a.clone().cross(b.clone(), 1);
let grads = c.backward();
let a_grad = a.grad(&grads).unwrap().to_data();
let b_grad = b.grad(&grads).unwrap().to_data();
// For a: b×grad_out, where grad_out = [1,1,1]
let expected_a = TensorData::from([[-1.0, 2.0, -1.0], [-1.0, 2.0, -1.0]]);
// For b: grad_out×a
let expected_b = TensorData::from([[1.0, -2.0, 1.0], [1.0, -2.0, 1.0]]);
a_grad.assert_approx_eq::<FloatElem>(&expected_a, Tolerance::default());
b_grad.assert_approx_eq::<FloatElem>(&expected_b, Tolerance::default());
}
#[test]
fn backward_after_sum() {
let device = Default::default();
let a = TestAutodiffTensor::<2>::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
)
.require_grad();
let b = TestAutodiffTensor::<2>::from_data(
TensorData::from([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
&device,
)
.require_grad();
// Sum reduces to scalar, but the gradient should be the same.
let c = a.clone().cross(b.clone(), 1).sum();
let grads = c.backward();
let a_grad = a.grad(&grads).unwrap().to_data();
let b_grad = b.grad(&grads).unwrap().to_data();
let expected_a = TensorData::from([[-1.0, 2.0, -1.0], [-1.0, 2.0, -1.0]]);
let expected_b = TensorData::from([[1.0, -2.0, 1.0], [1.0, -2.0, 1.0]]);
a_grad.assert_approx_eq::<FloatElem>(&expected_a, Tolerance::default());
b_grad.assert_approx_eq::<FloatElem>(&expected_b, Tolerance::default());
}
#[cfg(feature = "std")]
#[might_panic(reason = "not implemented: Cross product on non-last dimension")]
#[test]
fn different_dim() {
// Also check when the cross is along a different dimension (e.g. dim 0).
let device = Default::default();
let a_raw = [[1.0, 4.0, 7.0], [2.0, 5.0, 8.0], [3.0, 6.0, 9.0]];
let b_raw = [[9.0, 6.0, 3.0], [8.0, 5.0, 2.0], [7.0, 4.0, 1.0]];
let a = TestTensor::<2>::from_data(TensorData::from(a_raw), &device);
let b = TestTensor::<2>::from_data(TensorData::from(b_raw), &device);
// Cross along dim 0. Some backends (for example CubeCL) may not support
// cross on non-last dimensions and will intentionally panic with a
// message like "Cross product on non-last dimension not yet implemented".
// In that case we treat the panic as a skipped test for that backend.
let out = a.cross(b.clone(), 0);
// Manually compute cross of each column vector using raw arrays
let expected = [
[
a_raw[1][0] * b_raw[2][0] - a_raw[2][0] * b_raw[1][0],
a_raw[1][1] * b_raw[2][1] - a_raw[2][1] * b_raw[1][1],
a_raw[1][2] * b_raw[2][2] - a_raw[2][2] * b_raw[1][2],
],
[
a_raw[2][0] * b_raw[0][0] - a_raw[0][0] * b_raw[2][0],
a_raw[2][1] * b_raw[0][1] - a_raw[0][1] * b_raw[2][1],
a_raw[2][2] * b_raw[0][2] - a_raw[0][2] * b_raw[2][2],
],
[
a_raw[0][0] * b_raw[1][0] - a_raw[1][0] * b_raw[0][0],
a_raw[0][1] * b_raw[1][1] - a_raw[1][1] * b_raw[0][1],
a_raw[0][2] * b_raw[1][2] - a_raw[1][2] * b_raw[0][2],
],
];
out.to_data()
.assert_approx_eq::<FloatElem>(&TensorData::from(expected), Tolerance::default());
}

View File

@@ -0,0 +1,33 @@
use super::*;
use burn_tensor::{Tensor, TensorData, Tolerance, loss};
#[test]
fn test_cross_entropy_loss_grad() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let data_targets = TensorData::from([[0.8, 0.2], [0.9, 0.1]]);
let device = Default::default();
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();
let tensor_targets =
Tensor::<TestAutodiffBackend, 2>::from_data(data_targets, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = loss::cross_entropy_with_logits(tensor_3, tensor_targets);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::permissive();
let expected = TensorData::from([[0.26553, 0.26553], [0.44954, 0.44954]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[-1.34863, 1.34863], [-2.06371, 2.06371]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,117 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_cummax() {
// Simple test to verify cummax gradients work
let device = Default::default();
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 3.0, 2.0]), &device)
.require_grad();
let output = tensor.clone().cummax(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 2.0, 0.0]
let expected = TensorData::from([1.0, 2.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummax_2d() {
// Test 2D cummax gradients
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(
TensorData::from([[1.0, 3.0, 2.0], [2.0, 5.0, 4.0]]),
&device,
)
.require_grad();
let output = tensor.clone().cummax(1);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]
let expected = TensorData::from([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummax_duplicate_values() {
// Test with duplicate maximum values - critical edge case
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 3.0, 3.0, 2.0]), &device)
.require_grad();
let output = tensor.clone().cummax(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// input: [1.0, 3.0, 3.0, 2.0]
// cummax: [1.0, 3.0, 3.0, 3.0]
// PyTorch reference: [1.0, 1.0, 2.0, 0.0]
// Position 2 gets grad from itself + position 3
let expected = TensorData::from([1.0, 1.0, 2.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummax_all_same() {
// Test with all same values
let device = Default::default();
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 2.0, 2.0]), &device)
.require_grad();
let output = tensor.clone().cummax(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 1.0, 1.0]
// Each position matches cummax, so each gets its own gradient
let expected = TensorData::from([1.0, 1.0, 1.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummax_increasing() {
// Test with increasing sequence
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 2.0, 3.0, 4.0]), &device)
.require_grad();
let output = tensor.clone().cummax(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 1.0, 1.0, 1.0]
// Each position is a new maximum
let expected = TensorData::from([1.0, 1.0, 1.0, 1.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummax_2d_duplicates() {
// Test 2D with duplicate values
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(
TensorData::from([[1.0, 3.0, 3.0, 2.0], [2.0, 5.0, 5.0, 4.0]]),
&device,
)
.require_grad();
let output = tensor.clone().cummax(1);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]
let expected = TensorData::from([[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,117 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_cummin() {
// Simple test to verify cummin gradients work
let device = Default::default();
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([3.0, 2.0, 4.0]), &device)
.require_grad();
let output = tensor.clone().cummin(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 2.0, 0.0]
let expected = TensorData::from([1.0, 2.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummin_2d() {
// Test 2D cummin gradients
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(
TensorData::from([[3.0, 2.0, 4.0], [5.0, 1.0, 3.0]]),
&device,
)
.require_grad();
let output = tensor.clone().cummin(1);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]
let expected = TensorData::from([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummin_duplicate_values() {
// Test with duplicate minimum values - critical edge case
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([3.0, 2.0, 2.0, 4.0]), &device)
.require_grad();
let output = tensor.clone().cummin(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// input: [3.0, 2.0, 2.0, 4.0]
// cummin: [3.0, 2.0, 2.0, 2.0]
// PyTorch reference: [1.0, 1.0, 2.0, 0.0]
// Position 2 gets grad from itself + position 3
let expected = TensorData::from([1.0, 1.0, 2.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummin_all_same() {
// Test with all same values
let device = Default::default();
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 2.0, 2.0]), &device)
.require_grad();
let output = tensor.clone().cummin(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 1.0, 1.0]
// Each position matches cummin, so each gets its own gradient
let expected = TensorData::from([1.0, 1.0, 1.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummin_decreasing() {
// Test with decreasing sequence
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([5.0, 4.0, 3.0, 2.0]), &device)
.require_grad();
let output = tensor.clone().cummin(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 1.0, 1.0, 1.0]
// Each position is a new minimum
let expected = TensorData::from([1.0, 1.0, 1.0, 1.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummin_2d_duplicates() {
// Test 2D with duplicate values
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(
TensorData::from([[3.0, 2.0, 2.0, 4.0], [5.0, 1.0, 1.0, 3.0]]),
&device,
)
.require_grad();
let output = tensor.clone().cummin(1);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]
let expected = TensorData::from([[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,132 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_cumprod() {
// Simple test to verify cumprod gradients work
let device = Default::default();
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 3.0, 4.0]), &device)
.require_grad();
let output = tensor.clone().cumprod(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [16.0, 10.0, 6.0]
let expected = TensorData::from([16.0, 10.0, 6.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cumprod_2d() {
// Test 2D cumprod gradients
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
)
.require_grad();
let output = tensor.clone().cumprod(1);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [[9.0, 4.0, 2.0], [36.0, 28.0, 20.0]]
let expected = TensorData::from([[9.0, 4.0, 2.0], [36.0, 28.0, 20.0]]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
// TODO: The following tests are currently ignored due to a known limitation
// in the cumprod gradient implementation. The current implementation uses
// division (grad / input), which produces NaN when the input contains zeros.
//
// A proper fix requires implementing a zero-safe algorithm using exclusive
// cumulative products (similar to PyTorch's cumprod_backward or JAX's
// associative_scan approach). This is a non-trivial implementation that
// requires careful handling of cumulative products in both forward and
// reverse directions.
//
// See: https://github.com/tracel-ai/burn/issues/3864
//
// References:
// - PyTorch: https://github.com/pytorch/pytorch (cumprod_backward)
// - JAX PR #2596: Parallel prefix scan implementation
// - TensorFlow Issue #3862: tf.cumprod's gradient produces nans given zeros
#[test]
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
fn should_diff_cumprod_zero_in_middle() {
// Test cumprod with zero in the middle - edge case for division
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 0.0, 3.0, 4.0]), &device)
.require_grad();
let output = tensor.clone().cumprod(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 32.0, 0.0, 0.0]
let expected = TensorData::from([1.0, 32.0, 0.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
fn should_diff_cumprod_zero_at_start() {
// Test cumprod with zero at the beginning
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([0.0, 2.0, 3.0, 4.0]), &device)
.require_grad();
let output = tensor.clone().cumprod(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [33.0, 0.0, 0.0, 0.0]
let expected = TensorData::from([33.0, 0.0, 0.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
fn should_diff_cumprod_zero_at_end() {
// Test cumprod with zero at the end
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 3.0, 4.0, 0.0]), &device)
.require_grad();
let output = tensor.clone().cumprod(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [16.0, 10.0, 6.0, 24.0]
let expected = TensorData::from([16.0, 10.0, 6.0, 24.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
fn should_diff_cumprod_multiple_zeros() {
// Test cumprod with multiple zeros
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 0.0, 3.0, 0.0, 5.0]), &device)
.require_grad();
let output = tensor.clone().cumprod(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 8.0, 0.0, 0.0, 0.0]
let expected = TensorData::from([1.0, 8.0, 0.0, 0.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,89 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_cumsum_dim0() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.cumsum(0);
let tensor_5 = tensor_1.clone().mul(tensor_4);
let grads = tensor_5.sum().backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
// Expected gradients computed with PyTorch
let expected = TensorData::from([[-14.0, 24.0], [17.0, 6.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[3.0, 10.0], [-1.0, 37.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cumsum_dim1() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.cumsum(1);
let tensor_5 = tensor_1.clone().mul(tensor_4);
let grads = tensor_5.sum().backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
// Expected gradients computed with PyTorch
let expected = TensorData::from([[1.0, 69.0], [-13.0, -28.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[18.0, 13.0], [71.0, 58.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cumsum_complex() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.clone().cumsum(1);
let tensor_5 = tensor_4.mul(tensor_3);
let grads = tensor_5.sum().backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
// Expected gradients computed with PyTorch
let expected = TensorData::from([[371.0, 542.0], [2246.0, 3281.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[507.0, 528.0], [704.0, 733.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,105 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_div() {
let data_1 = TensorData::from([1.0, 7.0]);
let data_2 = TensorData::from([4.0, 7.0]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().div(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([0.25, 0.14285715]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([-0.0625, -0.14285715]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_div_scalar() {
let data = TensorData::from([1.0, 7.0]);
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
let tensor_out = tensor.clone().div_scalar(4.0);
let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();
grad.to_data()
.assert_eq(&TensorData::from([0.25, 0.25]), false);
}
#[test]
fn test_div_complex_1() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().div(tensor_2.clone());
let tensor_5 = tensor_4.div(tensor_3.clone());
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let grad_3 = tensor_3.grad(&grads).unwrap();
let expected = TensorData::from([[0.1250, 0.07142857], [0.25, 0.16666667]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[-0.03125, -0.07142857], [-1.6250, 0.16666667]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[-0.0625, -0.25], [-1.6250, 0.25]]);
grad_3
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn test_div_complex_2() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.div(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_absolute(2e-3);
let expected = TensorData::from([[2.00, 2.92857146], [1.36666667, 2.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[0.08333334, 0.09591837], [-0.05555558, -0.06714284]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,29 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_erf() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().erf());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[32.0, 32.0], [32.0, 32.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[8.0, 8.0], [8.0, 8.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,29 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_exp() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().exp());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default();
let expected = TensorData::from([[54.5991, 27.4746], [54.5991, 27.4746]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[-5.4598e+01, -9.1188e-04], [2.9556e+01, 8.0342e+01]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,39 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_expand() {
// Python code to generate the test case values
// import torch
// x1 = torch.tensor([4.0, 7.0, 2.0, 3.0], requires_grad=True)
// x2 = torch.tensor([2.0, 4.5, 7.0, 3.0], requires_grad=True)
// y = x1.expand(4, 4)
// z = (x2 * y).sum()
// z.backward()
// print("x1", x1.grad)
// print("x2", x2.grad)
let device = Default::default();
let data_1 = TensorData::from([4.0, 7.0, 2.0, 3.0]);
let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad();
let data_2 = TensorData::from([2.0, 4.5, 7.0, 3.0]);
let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().expand([4, 4]);
// Use unsqueeze to make tensor_2 have the same shape as tensor_3
let tensor_4 = tensor_2.clone().unsqueeze().mul(tensor_3).sum();
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([8., 18., 28., 12.]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([16., 28., 8., 12.]), false);
}

View File

@@ -0,0 +1,29 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_flip() {
let data_1 = TensorData::from([[[1.0, 7.0], [2.0, 3.0]]]); // 1x2x2
let data_2 = TensorData::from([[[3.0, 2.0, 7.0], [3.0, 3.2, 1.0]]]); // 1x2x3
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<3>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_2.clone().flip([1, 2]);
let tensor_4 = tensor_1.clone().matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
grad_1
.into_data()
.assert_approx_eq::<FloatElem>(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), tolerance); // 1x2x2
grad_2.into_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[[10.0, 10.0, 10.0], [3.0, 3.0, 3.0]]]),
tolerance,
); // 1x2x3
}

View File

@@ -0,0 +1,21 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_floor() {
let data = TensorData::from([
[-1.9751, 0.0714, 0.0643, 0.2406],
[-1.3172, 0.1252, -0.1119, -0.0127],
]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
let tensor_2 = tensor_1.clone().floor();
let grads = tensor_2.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
grad_1.to_data().assert_eq(
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
false,
);
}

View File

@@ -0,0 +1,99 @@
use super::*;
use burn_tensor::{IndexingUpdateOp, Int, Tensor, TensorData};
#[test]
fn test_gather_grad() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(
TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),
&device,
)
.require_grad();
let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(
TensorData::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]]),
&device,
);
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1.clone().gather(1, indices);
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
grad_1.to_data().assert_eq(
&TensorData::from([[94., 150., 187.], [242., 305., 304.]]),
false,
);
}
#[test]
fn test_scatter_grad() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(
TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),
&device,
)
.require_grad();
let values = TestAutodiffTensor::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
)
.require_grad();
let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(
TensorData::from([[2, 1, 0], [2, 0, 1]]),
&device,
);
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1
.clone()
.scatter(1, indices, values.clone(), IndexingUpdateOp::Add);
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = values.grad(&grads).unwrap();
grad_1.to_data().assert_eq(
&TensorData::from([[127., 181., 235.], [226., 316., 406.]]),
false,
);
grad_2
.to_data()
.assert_eq(&TensorData::from([[19., 19., 19.], [64., 64., 64.]]), false);
}
#[test]
fn test_scatter_add_grad_partial_indices() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::from_data(TensorData::from([[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]]), &device)
.require_grad();
let tensor_2 =
TestAutodiffTensor::from_data(TensorData::from([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]), &device)
.require_grad();
let values =
TestAutodiffTensor::from_data(TensorData::from([[4.0, 5.0, 6.0]]), &device).require_grad();
let indices =
Tensor::<TestAutodiffBackend, 2, Int>::from_data(TensorData::from([[2, 1, 0]]), &device);
let tensor_3 = tensor_1.clone().mul(tensor_2);
let tensor_4 = tensor_3
.clone()
.scatter(1, indices, values.clone(), IndexingUpdateOp::Add);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = values.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[1., 2., 3., 4., 5., 6.]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[1., 1., 1.]]), false);
}

View File

@@ -0,0 +1,29 @@
use super::*;
use burn_tensor::{TensorData, Tolerance, activation};
#[test]
fn should_diff_gelu() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_floats([[0.0, 1.0], [-3.0, 4.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_floats([[6.0, -0.5], [9.0, 10.0]], &device).require_grad();
let x = tensor_1.clone().matmul(activation::gelu(tensor_2.clone()));
let x = tensor_1.clone().matmul(x);
let grads = x.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::permissive();
let expected = TensorData::from([[1.46281, 1.46281], [48.22866, 153.46280]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[-15.0000, -1.98757], [17.0000, 17.0000]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,24 @@
use super::*;
use burn_tensor::{Distribution, activation};
#[test]
fn should_update_tensor_when_grad_replace() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::random([32, 32], Distribution::Default, &device).require_grad();
let tensor_2 = TestAutodiffTensor::random([32, 32], Distribution::Default, &device);
let x = tensor_1.clone().matmul(activation::gelu(tensor_2));
let mut grads = x.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_1_updated =
TestAutodiffTensor::random([32, 32], Distribution::Default, &device).require_grad();
tensor_1.grad_replace(&mut grads, grad_1_updated.clone().inner());
let grad_1_new = tensor_1.grad(&grads).unwrap();
assert_ne!(grad_1_new.to_data(), grad_1.into_data());
assert_eq!(grad_1_new.into_data(), grad_1_updated.into_data());
}

View File

@@ -0,0 +1,30 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_log() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
let expected = TensorData::from([[60.2652, 72.3130], [60.2652, 72.3130]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[22.8614, 24.5043], [24.5729, 26.8507]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,28 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_log1p() {
let tensor_1 = TestAutodiffTensor::<2>::from([[0.0, 1.0], [3.0, 4.0]]).require_grad();
let tensor_2 = TestAutodiffTensor::from([[6.0, 7.0], [9.0, 10.0]]).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log1p());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
let expected = TensorData::from([[64.80622101, 75.49362183], [64.80622101, 75.49362183]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[22.92208481, 24.47565651], [24.72780228, 26.86416626]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,19 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{TensorData, activation};
#[test]
fn should_diff_log_sigmoid() {
let data = TensorData::from([[0.8762, -0.1423], [-300., 200.]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
let tensor_2 = activation::log_sigmoid(tensor_1.clone());
let grads = tensor_2.backward();
let grad = tensor_1.grad(&grads).unwrap();
let expected = TensorData::from([[0.293966, 0.535515], [1.000000, 0.000000]]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,65 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{Bool, Tensor, TensorData};
#[test]
fn should_diff_mask_fill() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let mask = TensorData::from([[true, false], [false, true]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let mask = Tensor::<TestAutodiffBackend, 2, Bool>::from_bool(mask, &device);
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.mask_fill(mask, 2.0);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[7.0, 3.0], [4.0, 2.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[2.0, 1.0], [3.0, 7.0]]), false);
}
#[test]
fn should_diff_mask_where() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data([[1.0, 7.0], [2.0, 3.0]], &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data([[4.0, 7.0], [2.0, 3.0]], &device).require_grad();
let tensor_3 =
TestAutodiffTensor::from_data([[8.8, 9.8], [10.8, 11.8]], &device).require_grad();
let mask =
Tensor::<TestAutodiffBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_5 = tensor_4.clone().matmul(tensor_3.clone());
let tensor_6 = tensor_5.mask_where(mask, tensor_3.clone());
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let grad_3 = tensor_3.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
let expected = TensorData::from([[121.8, 55.0], [110.8, 50.0]]);
grad_1
.into_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[27.4, 33.4], [95.0, 115.0]]);
grad_2
.into_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[15., 18.], [23., 29.]]);
grad_3
.into_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,83 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_matmul() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[3.0, 3.0], [10.0, 10.0]]), false);
tensor_3
.to_data()
.assert_eq(&TensorData::from([[18.0, 28.0], [14.0, 23.0]]), false);
}
#[test]
fn test_matmul_complex_1() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_5 = tensor_4.matmul(tensor_3);
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[44.0, 20.0], [44.0, 20.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[56.0, 56.0], [16.0, 16.0]]), false);
}
#[test]
fn test_matmul_complex_2() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_5 = tensor_4.matmul(tensor_3.clone());
let tensor_6 = tensor_1.clone().matmul(tensor_5);
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[800.0, 792.0], [360.0, 592.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[264., 264.0], [344.0, 344.0]]), false);
}

View File

@@ -0,0 +1,82 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_max_dim() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.max_dim(1).unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[50.0, 34.0], [40.0, -10.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[8.0, 10.0], [56.0, 15.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_min_dim() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.min_dim(1).unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[-42.0, 38.0], [-34.0, -24.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[10.0, 8.0], [15.0, 56.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_min_dim_3d_dim1() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<3>::from_floats([[[1.0, 7.0], [-2.0, -3.0]]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::<3>::from_floats([[[4., -7.], [2., 3.]]], &device).require_grad();
let tensor_3 = tensor_1.clone().mul(tensor_2.clone());
let tensor_4 = tensor_3.min_dim(1);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[[0., -7.], [2., 0.]]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[[0., 7.], [-2., -0.]]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,134 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::module::max_pool1d;
#[test]
fn test_max_pool1d_simple() {
let kernel_size = 4;
let padding = 0;
let stride = 1;
let dilation = 1;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221]]],
&device,
)
.require_grad();
let x_grad_expected =
TestAutodiffTensor::<3>::from_floats([[[1., 1., 0., 0., 0., 1.]]], &device);
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool1d_with_dilation() {
let kernel_size = 4;
let padding = 0;
let stride = 1;
let dilation = 2;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
0.4610, 0.5365, 0.6880,
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<3>::from_floats(
[[[
0., 0., 1., 0., 0., 3., 0., 1., 2., 1., 0., 0., 2., 0., 0., 0., 4., 4., 0., 0., 0., 0.,
0., 0., 1.,
]]],
&device,
);
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool1d_complex() {
let kernel_size = 4;
let padding = 0;
let stride = 1;
let dilation = 1;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
0.4610, 0.5365, 0.6880,
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<3>::from_floats(
[[[
0., 0., 0., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0.,
1., 1., 1.,
]]],
&device,
);
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool1d_complex_with_padding() {
let kernel_size = 4;
let padding = 2;
let stride = 1;
let dilation = 1;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
0.4610, 0.5365, 0.6880,
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<3>::from_floats(
[[[
1., 0., 1., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0.,
1., 1., 3.,
]]],
&device,
);
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}

View File

@@ -0,0 +1,271 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::module::max_pool2d;
#[test]
fn test_max_pool2d_simple_1() {
let kernel_size_1 = 3;
let kernel_size_2 = 3;
let padding_1 = 0;
let padding_2 = 0;
let stride_1 = 1;
let stride_2 = 1;
let dilation_1 = 1;
let dilation_2 = 1;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
[0.2479, 0.6386, 0.3166, 0.5742],
[0.7065, 0.1940, 0.6305, 0.8959],
[0.5416, 0.8602, 0.8129, 0.1662],
[0.3358, 0.3059, 0.8293, 0.0990],
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
[[[
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 2.0],
[0.0, 2.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
]]],
&device,
);
let output = max_pool2d(
x.clone(),
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
false,
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool2d_simple_2() {
let kernel_size_1 = 2;
let kernel_size_2 = 2;
let padding_1 = 1;
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 1;
let dilation_1 = 1;
let dilation_2 = 1;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
[0.2479, 0.6386, 0.3166, 0.5742],
[0.7065, 0.1940, 0.6305, 0.8959],
[0.5416, 0.8602, 0.8129, 0.1662],
[0.3358, 0.3059, 0.8293, 0.0990],
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
[[[
[1., 3., 0., 2.],
[3., 0., 0., 4.],
[1., 4., 0., 1.],
[2., 0., 3., 1.],
]]],
&device,
);
let output = max_pool2d(
x.clone(),
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
false,
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool2d_with_dilation() {
let kernel_size_1 = 2;
let kernel_size_2 = 2;
let padding_1 = 1;
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 1;
let dilation_1 = 2;
let dilation_2 = 2;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
[0.2479, 0.6386, 0.3166, 0.5742],
[0.7065, 0.1940, 0.6305, 0.8959],
[0.5416, 0.8602, 0.8129, 0.1662],
[0.3358, 0.3059, 0.8293, 0.0990],
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
[[[
[0., 0., 0., 0.],
[1., 1., 1., 2.],
[0., 4., 4., 0.],
[0., 1., 2., 0.],
]]],
&device,
);
let output = max_pool2d(
x.clone(),
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
false,
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool2d_complex() {
let kernel_size_1 = 4;
let kernel_size_2 = 2;
let padding_1 = 2;
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 2;
let dilation_1 = 1;
let dilation_2 = 1;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
[0.5388, 0.0676, 0.7122, 0.8316, 0.0653],
[0.9154, 0.1536, 0.9089, 0.8016, 0.7518],
[0.2073, 0.0501, 0.8811, 0.5604, 0.5075],
[0.4384, 0.9963, 0.9698, 0.4988, 0.2609],
[0.3391, 0.2230, 0.4610, 0.5365, 0.6880],
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
[[[
[0., 0., 0., 3., 0.],
[4., 0., 2., 1., 0.],
[0., 0., 0., 0., 0.],
[2., 4., 0., 0., 0.],
[0., 0., 0., 0., 2.],
]]],
&device,
);
let output = max_pool2d(
x.clone(),
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
false,
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool2d_ceil_mode() {
// Test ceil_mode=true with gradient computation
// Using 1x1x6x6 input with kernel 3x3, stride 2x2, padding 0
// Floor mode: output 2x2
// Ceil mode: output 3x3
let kernel_size_1 = 3;
let kernel_size_2 = 3;
let padding_1 = 0;
let padding_2 = 0;
let stride_1 = 2;
let stride_2 = 2;
let dilation_1 = 1;
let dilation_2 = 1;
let device = Default::default();
// Input (values 1-36):
let x = TestAutodiffTensor::from_floats(
[[[
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
[7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0, 17.0, 18.0],
[19.0, 20.0, 21.0, 22.0, 23.0, 24.0],
[25.0, 26.0, 27.0, 28.0, 29.0, 30.0],
[31.0, 32.0, 33.0, 34.0, 35.0, 36.0],
]]],
&device,
)
.require_grad();
// Expected gradients for ceil_mode output 3x3:
// Output positions and their max value positions:
// (0,0): max at (2,2)=15 -> grad[2,2] += 1
// (0,1): max at (2,4)=17 -> grad[2,4] += 1
// (0,2): max at (2,5)=18 -> grad[2,5] += 1
// (1,0): max at (4,2)=27 -> grad[4,2] += 1
// (1,1): max at (4,4)=29 -> grad[4,4] += 1
// (1,2): max at (4,5)=30 -> grad[4,5] += 1
// (2,0): max at (5,2)=33 -> grad[5,2] += 1
// (2,1): max at (5,4)=35 -> grad[5,4] += 1
// (2,2): max at (5,5)=36 -> grad[5,5] += 1
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
[[[
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 1., 1.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 1., 1.],
[0., 0., 1., 0., 1., 1.],
]]],
&device,
);
let output = max_pool2d(
x.clone(),
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
true,
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}

View File

@@ -0,0 +1,290 @@
use super::*;
use burn_tensor::{Tensor, TensorData};
#[test]
fn test_mm_independent_trees() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// First tree
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_4 = tensor_0 * tensor_1;
let tensor_5 = tensor_2 * tensor_3;
let tensor_6 = tensor_4 * tensor_5;
// Second tree
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_11 = tensor_7.clone() * tensor_8.clone();
let tensor_12 = tensor_9.clone() * tensor_10.clone();
let tensor_13 = tensor_11 * tensor_12;
let _grads = tensor_6.backward();
let grads = tensor_13.backward();
assert!(tensor_7.grad(&grads).is_some());
assert!(tensor_8.grad(&grads).is_some());
assert!(tensor_9.grad(&grads).is_some());
assert!(tensor_10.grad(&grads).is_some());
}
#[test]
#[should_panic]
fn test_mm_crossover_trees_root_unavailable() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// First tree
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_4 = tensor_0 * tensor_1;
let tensor_5 = tensor_2 * tensor_3;
let tensor_6 = tensor_4.clone() * tensor_5;
// Second tree
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_9 = tensor_7.clone() * tensor_8.clone();
let tensor_10 = tensor_4 * tensor_9;
let _grads = tensor_6.backward();
let _grads = tensor_10.backward();
}
#[test]
fn test_mm_crossover_trees_with_referred_subtree() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// First tree
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_4 = tensor_0 * tensor_1;
let tensor_5 = tensor_2 * tensor_3;
let tensor_6 = tensor_4.clone() * tensor_5;
// Second tree
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_9 = tensor_7.clone() * tensor_8.clone();
let _tensor_10 = tensor_4 * tensor_9.clone();
let _grads = tensor_6.backward();
let _grads = tensor_9.backward();
}
#[test]
fn test_mm_three_crossover_trees_last_still_usable() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// First tree
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_4 = tensor_0 * tensor_1;
let tensor_5 = tensor_2 * tensor_3;
let tensor_6 = tensor_4 * tensor_5.clone();
// Third tree
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_11 = tensor_7 * tensor_8;
let tensor_12 = tensor_9 * tensor_10;
let tensor_13 = tensor_11 * tensor_12.clone();
// Second tree (in between)
let _tensor_14 = tensor_5 * tensor_12;
let _grads = tensor_6.backward();
let _grads = tensor_13.backward();
}
#[test]
#[should_panic]
fn test_mm_three_crossover_trees_middle_one_unavailable() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// First tree
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_4 = tensor_0 * tensor_1;
let tensor_5 = tensor_2 * tensor_3;
let tensor_6 = tensor_4 * tensor_5.clone();
// Third tree
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_11 = tensor_7 * tensor_8;
let tensor_12 = tensor_9 * tensor_10;
let _tensor_13 = tensor_11 * tensor_12.clone();
// Second tree (in between)
let tensor_14 = tensor_5 * tensor_12;
let _grads = tensor_6.backward();
let _grads = tensor_14.backward();
}
#[test]
fn test_mm_self_referencing_tree() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// First tree
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_3 = tensor_0 * tensor_1;
let tensor_5 = tensor_2 * tensor_3.clone();
let tensor_6 = tensor_3 * tensor_5;
let _grads = tensor_6.backward();
}
#[test]
fn test_mm_with_non_impacting_detach() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_3 = Tensor::<TestAutodiffBackend, 2>::from_data(data, &device).require_grad();
let tensor_4 = tensor_1.clone() * tensor_2.clone();
let tensor_5 = tensor_4.detach() * tensor_3.clone();
let grads = tensor_5.backward();
assert!(tensor_3.grad(&grads).is_some());
}
#[test]
fn test_mm_with_missing_require_grad_after_cleanup() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device);
let tensor_3 = Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device);
let tensor_4 = tensor_1.clone() * tensor_2.clone();
let tensor_5 = tensor_4 * tensor_3.clone();
// Trivial backward, just to trigger cleanup
Tensor::<TestAutodiffBackend, 2>::from_data(data, &device)
.require_grad()
.backward();
let grads = tensor_5.backward();
assert!(tensor_1.grad(&grads).is_some());
assert!(tensor_2.grad(&grads).is_none());
assert!(tensor_3.grad(&grads).is_none());
}
#[test]
fn test_mm_with_detach_after_cleanup() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_3 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_4 = tensor_1.clone() * tensor_2.clone();
let tensor_5 = tensor_4 * tensor_3.clone().detach();
// Trivial backward, just to trigger cleanup
Tensor::<TestAutodiffBackend, 2>::from_data(data, &device)
.require_grad()
.backward();
let grads = tensor_5.backward();
assert!(tensor_1.grad(&grads).is_some());
assert!(tensor_2.grad(&grads).is_some());
assert!(tensor_3.grad(&grads).is_none());
}
#[test]
#[should_panic]
fn test_mm_deletables_propagate_well() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
let tensor_0 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 = tensor_0 * tensor_1;
let tensor_3 = tensor_2.clone().exp();
let _tensor_4 = tensor_3.clone().log();
let _grads = tensor_2.backward();
// We are testing that after backward on tensor_2, not only the leaf tensor_4 is deleted, but
// the intermediate tensor_3 as well
let _grads = tensor_3.backward();
}
#[test]
fn test_mm_node_explored_once_can_still_be_tagged_as_useful_when_found_again_deeper() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// The test has 50% chance of starting with leaf tensor_8 instead of tensor_4, which is not informative
// By repeating it many times it becomes almost impossible that it passes if it shouldn't
for _ in 0..12 {
let tensor_0 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 = tensor_1.clone().exp();
let tensor_3 = tensor_0.exp();
let _tensor_4 = tensor_3.clone() * tensor_2.clone();
let tensor_5 = tensor_2.exp();
let tensor_6 = tensor_5.exp();
let tensor_7 = tensor_6.exp();
let tensor_8 = tensor_7.exp();
// tensor_2 should be tagged unknown through the leaf tensor_4, then useful through the leaf tensor_8
// which should happen after because tensor_2 is deeper from tensor_8 point of view and we're in breadth first search
tensor_3.backward();
let grads = tensor_8.backward();
assert!(tensor_1.grad(&grads).is_some());
}
}

View File

@@ -0,0 +1,74 @@
#[allow(unused_imports)] // required for re-included modules
pub use super::*;
mod abs;
mod adaptive_avgpool1d;
mod adaptive_avgpool2d;
mod add;
mod aggregation;
mod avgpool1d;
mod avgpool2d;
mod backward;
mod bridge;
mod broadcast;
mod cast;
mod cat;
mod ceil;
mod checkpoint;
mod complex;
mod conv1d;
mod conv2d;
mod conv3d;
mod conv_transpose1d;
mod conv_transpose2d;
mod conv_transpose3d;
mod cross;
mod cross_entropy;
mod cummax;
mod cummin;
mod cumprod;
mod cumsum;
mod deform_conv2d;
mod div;
mod erf;
mod exp;
mod expand;
mod flip;
mod floor;
mod gather_scatter;
mod gelu;
mod gradients;
mod log;
mod log1p;
mod log_sigmoid;
mod mask;
mod matmul;
mod maxmin;
mod maxpool1d;
mod maxpool2d;
mod memory_management;
mod mul;
mod multithread;
mod nearest_interpolate;
mod neg;
mod nonzero;
mod permute;
mod pow;
mod recip;
mod relu;
mod remainder;
mod repeat_dim;
mod reshape;
mod round;
mod select;
mod sigmoid;
mod sign;
mod slice;
mod slice_assign;
mod softmax;
mod sort;
mod sqrt;
mod sub;
mod transpose;
mod trig;
mod unfold;

View File

@@ -0,0 +1,68 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_mul() {
let data_1 = TensorData::from([1.0, 7.0]);
let data_2 = TensorData::from([4.0, 7.0]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();
let tensor_3 = tensor_1.clone().mul(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let _grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_eq(&data_2, false);
tensor_3
.into_data()
.assert_eq(&TensorData::from([4.0, 49.0]), false);
}
#[test]
fn should_diff_mul_scalar() {
let data = TensorData::from([2.0, 5.0]);
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
let tensor_out = tensor.clone().mul_scalar(4.0);
let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();
tensor_out
.into_data()
.assert_eq(&TensorData::from([8.0, 20.0]), false);
grad.to_data()
.assert_eq(&TensorData::from([4.0, 4.0]), false);
}
#[test]
fn test_mul_complex_1() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().mul(tensor_2.clone());
let tensor_5 = tensor_4.mul(tensor_3);
let tensor_6 = tensor_1.clone().mul(tensor_5);
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[16.0, 196.0], [104.0, -36.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[2.0, 98.0], [338.0, 18.0]]), false);
}

View File

@@ -0,0 +1,88 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_behave_the_same_with_multithread() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let with_move = || {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.clone().matmul(tensor_2.clone());
let tensor_5 = tensor_4.matmul(tensor_3);
// Task 1
let tensor_1_cloned = tensor_1.clone();
let tensor_2_cloned = tensor_2.clone();
let tensor_5_cloned = tensor_5.clone();
let first_call = move || {
let tensor_6_1 = tensor_5_cloned.matmul(tensor_2_cloned);
tensor_6_1.matmul(tensor_1_cloned)
};
// Task 2
let tensor_1_cloned = tensor_1.clone();
let tensor_2_cloned = tensor_2.clone();
let tensor_5_cloned = tensor_5;
let second_call = move || {
let tensor_6_2 = tensor_5_cloned.matmul(tensor_1_cloned);
tensor_6_2.matmul(tensor_2_cloned)
};
let tensor_7_1_handle = std::thread::spawn(first_call);
let tensor_7_2_handle = std::thread::spawn(second_call);
let tensor_7_1 = tensor_7_1_handle.join().unwrap();
let tensor_7_2 = tensor_7_2_handle.join().unwrap();
let tensor_8 = tensor_7_1.matmul(tensor_7_2);
let grads = tensor_8.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
(grad_1, grad_2)
};
let without_move = || {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.clone().matmul(tensor_2.clone());
let tensor_5 = tensor_4.matmul(tensor_3);
// Task 1
let tensor_6_1 = tensor_5.clone().matmul(tensor_2.clone());
let tensor_7_1 = tensor_6_1.matmul(tensor_1.clone());
// Task 2
let tensor_6_2 = tensor_5.matmul(tensor_1.clone());
let tensor_7_2 = tensor_6_2.matmul(tensor_2.clone());
let tensor_8 = tensor_7_1.matmul(tensor_7_2);
let grads = tensor_8.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
(grad_1, grad_2)
};
let (grad_1, grad_2) = without_move();
let (grad_1_moved, grad_2_moved) = with_move();
grad_1
.into_data()
.assert_approx_eq::<FloatElem>(&grad_1_moved.into_data(), Tolerance::default());
grad_2
.into_data()
.assert_approx_eq::<FloatElem>(&grad_2_moved.into_data(), Tolerance::default());
}

View File

@@ -0,0 +1,97 @@
use super::*;
use burn_tensor::Shape;
use burn_tensor::Tolerance;
use burn_tensor::module::interpolate;
use burn_tensor::ops::{InterpolateMode, InterpolateOptions};
#[test]
fn test_upsample_interpolation() {
let test = InterpolateTestCase {
batch_size: 2,
channels: 1,
height: 7,
width: 5,
height_out: 8,
width_out: 7,
};
test.assert_output(TestTensor::from([
[[
[4., 2., 4., 2., 2.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
]],
[[
[4., 2., 4., 2., 2.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
]],
]));
}
#[test]
fn test_downsample_interpolation() {
let test = InterpolateTestCase {
batch_size: 1,
channels: 1,
height: 8,
width: 8,
height_out: 4,
width_out: 6,
};
test.assert_output(TestTensor::from([[[
[1., 1., 1., 0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
]]]));
}
struct InterpolateTestCase {
batch_size: usize,
channels: usize,
height: usize,
width: usize,
height_out: usize,
width_out: usize,
}
impl InterpolateTestCase {
fn assert_output(self, x_grad: TestTensor<4>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &x_grad.device())
.reshape::<4, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = interpolate(
x.clone(),
[self.height_out, self.width_out],
InterpolateOptions::new(InterpolateMode::Nearest),
);
let grads = output.backward();
let x_grad_actual = x.grad(&grads).unwrap();
x_grad
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.into_data(), Tolerance::permissive());
}
}

View File

@@ -0,0 +1,26 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_neg() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().neg());
let tensor_4 = tensor_3.neg();
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[3.0, 3.0], [10.0, 10.0]]), false);
}

View File

@@ -0,0 +1,41 @@
use super::*;
use burn_tensor::{Bool, Tensor, TensorData};
#[test]
fn should_diff_nonzero() {
let data_1 = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let data_2 = TensorData::from([-1.0, 1.0]);
let mask = TensorData::from([[false, true], [true, false]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad();
// Multi-dimensional tensor indexing isn't really supported yet so the easiest way to do
// this is to flatten the mask and tensor to get proper indexing. Anyway the returned tensor would
// have dimensions different from the input, so this is somewhat equivalent.
let mask = Tensor::<TestAutodiffBackend, 2, Bool>::from_bool(mask, &device).flatten::<1>(0, 1);
let indices = mask.nonzero();
let tensor_3 = tensor_1
.clone()
.flatten::<1>(0, 1)
.select(0, indices[0].clone());
// Vector dot product not supported (only 2D matmuls) so unsqueeze for test purposes
let tensor_4 = tensor_2
.clone()
.unsqueeze_dim::<2>(0)
.matmul(tensor_3.unsqueeze_dim(1));
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[0.0, -1.0], [1.0, 0.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([2.0, 3.0]), false);
}

Some files were not shown because too many files have changed in this diff Show More