feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
14
.gitignore
vendored
14
.gitignore
vendored
@@ -1,8 +1,10 @@
|
||||
# Rust build artifacts
|
||||
target/
|
||||
|
||||
# Cargo lock file
|
||||
Cargo.lock
|
||||
|
||||
# Rust IDE files
|
||||
# IDE files
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
@@ -131,3 +133,13 @@ demo/
|
||||
|
||||
# Additional Rust specific
|
||||
*.lock
|
||||
|
||||
# Generated model files
|
||||
*.mpk
|
||||
*.pt
|
||||
*.bin
|
||||
*.safetensors
|
||||
*.ckpt
|
||||
|
||||
# User-specific data
|
||||
user_data/
|
||||
@@ -8,7 +8,7 @@ clap = { version = "4.0", features = ["derive"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
anyhow = "1.0"
|
||||
stablediffusion = { path = "../stable-diffusion-burn" }
|
||||
stablediffusion = { path = "./crates/stable-diffusion-burn" }
|
||||
burn = "0.14.0"
|
||||
burn-autodiff = "0.14.0"
|
||||
burn-ndarray = "0.14.0"
|
||||
|
||||
28
crates/stable-diffusion-burn/Cargo.toml
Normal file
28
crates/stable-diffusion-burn/Cargo.toml
Normal file
@@ -0,0 +1,28 @@
|
||||
[package]
|
||||
name = "stablediffusion"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[features]
|
||||
wgpu-backend = ["burn-wgpu"]
|
||||
|
||||
[dependencies.burn-wgpu]
|
||||
package = "burn-wgpu"
|
||||
git = "https://github.com/burn-rs/burn.git"
|
||||
optional = true
|
||||
|
||||
[dependencies]
|
||||
burn = "0.14.0"
|
||||
burn-ndarray = "0.14.0"
|
||||
burn-tch = "0.14.0"
|
||||
burn-autodiff = "0.14.0"
|
||||
tch = "0.15.0"
|
||||
serde = {version = "1.0.171", features = ["std", "derive"]}
|
||||
npy = "0.4.0"
|
||||
num-traits = "0.2.15"
|
||||
rust_tokenizers = "8.1.0"
|
||||
regex = "1.9.1"
|
||||
image = "0.24.6"
|
||||
cfg-if = "0.1"
|
||||
21
crates/stable-diffusion-burn/LICENSE
Normal file
21
crates/stable-diffusion-burn/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Gadersd
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
75
crates/stable-diffusion-burn/README.md
Normal file
75
crates/stable-diffusion-burn/README.md
Normal file
@@ -0,0 +1,75 @@
|
||||
# Stable-Diffusion-Burn
|
||||
|
||||
Stable-Diffusion-Burn is a Rust-based project which ports the V1 stable diffusion model into the deep learning framework, Burn. This repository is licensed under the MIT Licence.
|
||||
|
||||
## How To Use
|
||||
|
||||
### Step 0: Install libtorch v2.4.1
|
||||
|
||||
### Step 1: Download the Model and Set Environment Variables
|
||||
|
||||
Start by downloading the SDv1-4 model provided on HuggingFace.
|
||||
|
||||
```bash
|
||||
wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/SDv1-4.mpk
|
||||
```
|
||||
|
||||
### Step 2: Run the Sample Binary
|
||||
|
||||
Invoke the sample binary provided in the rust code. By default, torch is used. The WGPU backend is unstable for SD but may work well in the future as burn-wpu is optimized.
|
||||
|
||||
```bash
|
||||
# torch (at least 6 GB VRAM, possibly less)
|
||||
# Arguments: <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name> [cuda, mps, cpu]
|
||||
|
||||
# Cuda
|
||||
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img cuda
|
||||
|
||||
# Mps(Mac)
|
||||
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img mps
|
||||
|
||||
# wgpu (UNSTABLE)
|
||||
# Arguments: <model_type(burn or dump)> <model> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image>
|
||||
cargo run --release --features wgpu-backend --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img
|
||||
```
|
||||
|
||||
This command will generate an image according to the provided prompt, which will be saved as 'img0.png'.
|
||||
|
||||

|
||||
|
||||
### Optional: Extract and Convert a Fine-Tuned Model
|
||||
|
||||
If users are interested in using a fine-tuned version of stable diffusion, the Python scripts provided in this project can be used to transform a weight dump into a Burn model file. This does not work on Windows.
|
||||
|
||||
```bash
|
||||
# Step into the Python directory
|
||||
cd python
|
||||
|
||||
# Download the model, this is just the base v1.4 model as an example
|
||||
wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
|
||||
|
||||
# Install tinygrad
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Extract the weights
|
||||
CPU=1 python3 dump.py sd-v1-4.ckpt
|
||||
|
||||
# Move the extracted weight folder out
|
||||
mv params ..
|
||||
|
||||
# Step out of the Python directory
|
||||
cd ..
|
||||
|
||||
# Convert the weights into a usable form
|
||||
cargo run --release --bin convert params SDv1-4
|
||||
```
|
||||
|
||||
The binaries 'convert' and 'sample' are contained in Rust. Convert works on CPU whereas sample needs CUDA.
|
||||
|
||||
Remember, `convert` should be used if you're planning on using the fine-tuned version of the stable diffusion.
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under MIT license.
|
||||
|
||||
We wish you a productive time using this project. Enjoy!
|
||||
262145
crates/stable-diffusion-burn/bpe_simple_vocab_16e6.txt
Normal file
262145
crates/stable-diffusion-burn/bpe_simple_vocab_16e6.txt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,45 @@
|
||||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science"]
|
||||
description = "Automatic differentiation backend for the Burn framework"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "data"]
|
||||
license.workspace = true
|
||||
name = "burn-autodiff"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-autodiff"
|
||||
documentation = "https://docs.rs/burn-autodiff"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = ["std", "tracing"]
|
||||
std = ["dep:parking_lot"]
|
||||
export_tests = [] # check checkpointer is_empty in tests
|
||||
|
||||
tracing = [
|
||||
"dep:tracing",
|
||||
"burn-std/tracing",
|
||||
"burn-backend/tracing",
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false }
|
||||
burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false }
|
||||
|
||||
|
||||
derive-new = { workspace = true }
|
||||
spin = { workspace = true }
|
||||
parking_lot = { workspace = true, optional = true }
|
||||
log = { workspace = true }
|
||||
hashbrown = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
portable-atomic = { workspace = true }
|
||||
tracing = { workspace = true, optional = true, features = ["default"] }
|
||||
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["default"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
@@ -0,0 +1 @@
|
||||
../../LICENSE-APACHE
|
||||
@@ -0,0 +1 @@
|
||||
../../LICENSE-MIT
|
||||
@@ -0,0 +1,8 @@
|
||||
# Burn Autodiff
|
||||
|
||||
> [Burn](https://github.com/tracel-ai/burn) autodiff backend
|
||||
|
||||
[](https://crates.io/crates/burn-autodiff)
|
||||
[](https://github.com/tracel-ai/burn-autodiff/blob/master/README.md)
|
||||
|
||||
For now only first order reverse mode autodiff is supported.
|
||||
@@ -0,0 +1,138 @@
|
||||
use crate::{
|
||||
checkpoint::strategy::{CheckpointStrategy, NoCheckpointing},
|
||||
grads::Gradients,
|
||||
tensor::AutodiffTensor,
|
||||
};
|
||||
use alloc::{format, string::String};
|
||||
use burn_backend::{
|
||||
backend::{AutodiffBackend, Backend, ExecutionError},
|
||||
tensor::{BoolTensor, IntTensor, QuantizedTensor},
|
||||
};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
/// Enable auto-differentiation on a backend.
|
||||
///
|
||||
/// This works as a backend decorator, extending the functionality of any backend with
|
||||
/// backpropagation.
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub struct Autodiff<B, C = NoCheckpointing> {
|
||||
_b: PhantomData<B>,
|
||||
_checkpoint_strategy: PhantomData<C>,
|
||||
}
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
|
||||
type Device = B::Device;
|
||||
|
||||
type FloatTensorPrimitive = AutodiffTensor<B>;
|
||||
type FloatElem = B::FloatElem;
|
||||
|
||||
type IntTensorPrimitive = B::IntTensorPrimitive;
|
||||
type IntElem = B::IntElem;
|
||||
|
||||
type BoolTensorPrimitive = B::BoolTensorPrimitive;
|
||||
type BoolElem = B::BoolElem;
|
||||
|
||||
type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive;
|
||||
|
||||
fn ad_enabled(_device: &Self::Device) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn name(device: &Self::Device) -> String {
|
||||
format!("autodiff<{}>", B::name(device))
|
||||
}
|
||||
|
||||
fn seed(device: &B::Device, seed: u64) {
|
||||
B::seed(device, seed)
|
||||
}
|
||||
|
||||
fn sync(device: &B::Device) -> Result<(), ExecutionError> {
|
||||
B::sync(device)
|
||||
}
|
||||
|
||||
fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
|
||||
device: &Self::Device,
|
||||
input: Input,
|
||||
func: Func,
|
||||
) -> Output {
|
||||
B::memory_persistent_allocations(device, input, func)
|
||||
}
|
||||
|
||||
fn memory_cleanup(device: &Self::Device) {
|
||||
B::memory_cleanup(device)
|
||||
}
|
||||
|
||||
fn staging<'a, Iter>(data: Iter, device: &Self::Device)
|
||||
where
|
||||
Iter: Iterator<Item = &'a mut burn_backend::TensorData>,
|
||||
{
|
||||
B::staging(data, device);
|
||||
}
|
||||
|
||||
fn supports_dtype(device: &Self::Device, dtype: burn_std::DType) -> bool {
|
||||
B::supports_dtype(device, dtype)
|
||||
}
|
||||
|
||||
fn dtype_usage(device: &Self::Device, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {
|
||||
B::dtype_usage(device, dtype)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
|
||||
type InnerBackend = B;
|
||||
type Gradients = Gradients;
|
||||
|
||||
fn backward(tensor: AutodiffTensor<B>) -> Gradients {
|
||||
tensor.backward()
|
||||
}
|
||||
|
||||
fn grad(tensor: &AutodiffTensor<B>, grads: &Gradients) -> Option<B::FloatTensorPrimitive> {
|
||||
tensor.grad(grads)
|
||||
}
|
||||
|
||||
fn grad_remove(
|
||||
tensor: &AutodiffTensor<B>,
|
||||
grads: &mut Gradients,
|
||||
) -> Option<B::FloatTensorPrimitive> {
|
||||
tensor.grad_remove(grads)
|
||||
}
|
||||
fn inner(tensor: AutodiffTensor<B>) -> B::FloatTensorPrimitive {
|
||||
tensor.primitive
|
||||
}
|
||||
|
||||
fn from_inner(tensor: B::FloatTensorPrimitive) -> AutodiffTensor<B> {
|
||||
AutodiffTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn grad_replace(
|
||||
tensor: &AutodiffTensor<B>,
|
||||
grads: &mut Self::Gradients,
|
||||
grad: B::FloatTensorPrimitive,
|
||||
) {
|
||||
tensor.grad_replace(grads, grad);
|
||||
}
|
||||
|
||||
fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self> {
|
||||
tensor
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
use super::{
|
||||
retro_forward::RetroForwards,
|
||||
state::{BackwardStates, State},
|
||||
};
|
||||
use crate::collections::HashMap;
|
||||
use crate::graph::NodeId;
|
||||
|
||||
use alloc::{vec, vec::Vec};
|
||||
|
||||
#[derive(new, Debug)]
|
||||
/// Links a [NodeId] to its autodiff graph [NodeRef]
|
||||
pub(crate) struct NodeTree {
|
||||
map: HashMap<NodeId, Vec<NodeId>>,
|
||||
}
|
||||
|
||||
impl NodeTree {
|
||||
/// Gives the parents of the node in the autodiff graph
|
||||
pub(crate) fn parents(&self, node_id: &NodeId) -> Option<Vec<NodeId>> {
|
||||
self.map.get(node_id).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
/// Struct responsible of fetching the output for a node in the autodiff graph during a backward pass
|
||||
pub struct Checkpointer {
|
||||
backward_states: BackwardStates,
|
||||
retro_forwards: RetroForwards,
|
||||
node_tree: NodeTree,
|
||||
}
|
||||
|
||||
impl Checkpointer {
|
||||
/// Gives the output of the given node, by recursively asking parents to compute themselves
|
||||
/// or give their pre-computed tensors.
|
||||
pub fn retrieve_node_output<T>(&mut self, node_id: NodeId) -> T
|
||||
where
|
||||
T: Clone + Send + 'static,
|
||||
{
|
||||
self.topological_sort(node_id).into_iter().for_each(|node| {
|
||||
self.retro_forwards
|
||||
.execute_retro_forward(node, &mut self.backward_states)
|
||||
});
|
||||
|
||||
self.backward_states.get_state::<T>(&node_id)
|
||||
}
|
||||
|
||||
/// Sorts the ancestors of NodeId in a way such that all parents come before their children
|
||||
/// Useful to avoid recursivity later when mutating the states
|
||||
///
|
||||
/// The sort on a compute bound state or a memory bound that is already computed is trivial.
|
||||
/// The match on State::Computed also serves as a stopping criterion for the sort,
|
||||
/// we don't need to look higher than that during recursivity.
|
||||
fn topological_sort(&self, node_id: NodeId) -> Vec<NodeId> {
|
||||
match self.backward_states.get_state_ref(&node_id) {
|
||||
Some(state) => match state {
|
||||
State::Recompute { n_required: _ } => {
|
||||
let mut sorted = Vec::new();
|
||||
let parents = self.node_tree.parents(&node_id).unwrap();
|
||||
for parent_node in parents {
|
||||
let parent_sorted = self.topological_sort(parent_node);
|
||||
for ps in parent_sorted {
|
||||
if !sorted.contains(&ps) {
|
||||
sorted.push(ps)
|
||||
}
|
||||
}
|
||||
}
|
||||
sorted.push(node_id);
|
||||
sorted
|
||||
}
|
||||
State::Computed {
|
||||
state_content: _,
|
||||
n_required: _,
|
||||
} => vec![node_id],
|
||||
},
|
||||
None => panic!("Node {node_id:?} is not in the backward_states. "),
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if checkpointer has been drained adequately. Useful for testing
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.backward_states.is_empty() && self.retro_forwards.is_empty()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,304 @@
|
||||
use crate::{
|
||||
collections::HashMap,
|
||||
graph::{ComputingProperty, NodeId},
|
||||
tensor::AutodiffTensor,
|
||||
};
|
||||
use alloc::{boxed::Box, sync::Arc, vec::Vec};
|
||||
use burn_backend::Backend;
|
||||
use core::any::Any;
|
||||
|
||||
use super::{
|
||||
base::{Checkpointer, NodeTree},
|
||||
retro_forward::{RetroForward, RetroForwards},
|
||||
state::{BackwardStates, State},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Determines if a node should checkpoint its computed output or its retro_forward for recomputation
|
||||
/// The action is normally created by the child of the node, once the node is determined to be needed
|
||||
pub enum CheckpointingAction {
|
||||
/// The node's already computed output should be saved
|
||||
Computed {
|
||||
/// The node
|
||||
node_id: NodeId,
|
||||
/// The node's output
|
||||
state_content: Box<dyn Any + Send>,
|
||||
},
|
||||
/// The node should recompute itself when asked
|
||||
Recompute {
|
||||
/// The node
|
||||
node_id: NodeId,
|
||||
/// How the node should recompute itself
|
||||
retro_forward: Arc<dyn RetroForward>,
|
||||
},
|
||||
}
|
||||
|
||||
// TODO: Remove that when proper client server.
|
||||
unsafe impl Send for CheckpointingAction {}
|
||||
|
||||
impl CheckpointingAction {
|
||||
/// Utility function to access the id of the node of the checkpointing action
|
||||
pub fn id(&self) -> NodeId {
|
||||
match self {
|
||||
CheckpointingAction::Computed {
|
||||
node_id: node_ref,
|
||||
state_content: _,
|
||||
} => *node_ref,
|
||||
CheckpointingAction::Recompute {
|
||||
node_id: node_ref,
|
||||
retro_forward: _,
|
||||
} => *node_ref,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new, Debug, Default)]
|
||||
/// Accumulates checkpoints as checkpointing actions during the forward pass,
|
||||
/// and builds a checkpointer right before the backward pass
|
||||
pub struct CheckpointerBuilder {
|
||||
explicit_actions: Vec<CheckpointingAction>,
|
||||
backup_actions: Vec<CheckpointingAction>,
|
||||
}
|
||||
|
||||
/// Determines if a checkpoint should impact the n_required values (Main)
|
||||
/// or if it should just keep the state in case it's required (Backup)
|
||||
///
|
||||
pub(crate) enum ActionType {
|
||||
/// Explicit actions have been explicitly requested by some operation to retrieve their state
|
||||
Explicit,
|
||||
/// Backup actions are not always needed. They exist to save the output of an operation
|
||||
/// whose child is memory bound, in case the state is indirectly needed when computing
|
||||
/// the child's retro_forward. If no explicit action ever asks for the child's output, then
|
||||
/// the backup output will go out of scope when the checkpointer is built.
|
||||
Backup,
|
||||
}
|
||||
|
||||
impl CheckpointerBuilder {
|
||||
pub(crate) fn checkpoint<B: Backend>(
|
||||
&mut self,
|
||||
tensor: &AutodiffTensor<B>,
|
||||
action_type: ActionType,
|
||||
) {
|
||||
let action_list = match action_type {
|
||||
ActionType::Explicit => &mut self.explicit_actions,
|
||||
ActionType::Backup => &mut self.backup_actions,
|
||||
};
|
||||
match &tensor.node.properties {
|
||||
ComputingProperty::ComputeBound | ComputingProperty::Ambiguous => {
|
||||
action_list.push(CheckpointingAction::Computed {
|
||||
node_id: tensor.node.id,
|
||||
state_content: Box::new(tensor.primitive.clone()),
|
||||
})
|
||||
}
|
||||
ComputingProperty::MemoryBound { retro_forward } => {
|
||||
action_list.push(CheckpointingAction::Recompute {
|
||||
node_id: tensor.node.id,
|
||||
retro_forward: retro_forward.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn extend(&mut self, other: CheckpointerBuilder) {
|
||||
for other_action in other.explicit_actions {
|
||||
self.explicit_actions.push(other_action)
|
||||
}
|
||||
for other_unsure in other.backup_actions {
|
||||
self.backup_actions.push(other_unsure)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn build(self, node_tree: NodeTree) -> Checkpointer {
|
||||
let mut backward_states_map = HashMap::new();
|
||||
let mut retro_forwards_map = HashMap::new();
|
||||
|
||||
// Find recursion stopping points
|
||||
let stop_nodes: Vec<NodeId> = self.find_stop_nodes();
|
||||
|
||||
// We start by identifying how many times each node will be required.
|
||||
let n_required_map = self.build_n_required_map(&node_tree, stop_nodes);
|
||||
|
||||
// Then we checkpoint the nodes with the corresponding n_required value
|
||||
self.insert_checkpoints(
|
||||
&mut backward_states_map,
|
||||
&mut retro_forwards_map,
|
||||
n_required_map,
|
||||
);
|
||||
|
||||
Checkpointer::new(
|
||||
BackwardStates::new(backward_states_map),
|
||||
RetroForwards::new(retro_forwards_map),
|
||||
node_tree,
|
||||
)
|
||||
}
|
||||
|
||||
fn find_stop_nodes(&self) -> Vec<NodeId> {
|
||||
let mut stop_nodes = Vec::default();
|
||||
for action in self
|
||||
.explicit_actions
|
||||
.iter()
|
||||
.chain(self.backup_actions.iter())
|
||||
{
|
||||
match action {
|
||||
CheckpointingAction::Computed {
|
||||
node_id: node_ref,
|
||||
state_content: _,
|
||||
} => stop_nodes.push(*node_ref),
|
||||
CheckpointingAction::Recompute {
|
||||
node_id: _,
|
||||
retro_forward: _,
|
||||
} => {}
|
||||
}
|
||||
}
|
||||
stop_nodes
|
||||
}
|
||||
|
||||
fn build_n_required_map(
|
||||
&self,
|
||||
node_tree: &NodeTree,
|
||||
stop_nodes: Vec<NodeId>,
|
||||
) -> HashMap<NodeId, usize> {
|
||||
let mut n_required_map = HashMap::<NodeId, usize>::default();
|
||||
|
||||
for action in self.explicit_actions.iter() {
|
||||
match action {
|
||||
CheckpointingAction::Computed {
|
||||
node_id: node_ref,
|
||||
state_content: _,
|
||||
} => {
|
||||
let id = *node_ref;
|
||||
match n_required_map.remove(&id) {
|
||||
Some(n) => {
|
||||
n_required_map.insert(id, n + 1);
|
||||
}
|
||||
None => {
|
||||
n_required_map.insert(id, 1);
|
||||
}
|
||||
};
|
||||
}
|
||||
CheckpointingAction::Recompute {
|
||||
node_id: node_ref,
|
||||
retro_forward: _,
|
||||
} => {
|
||||
let id = *node_ref;
|
||||
Self::update_n_required_of_parents(
|
||||
id,
|
||||
&mut n_required_map,
|
||||
node_tree,
|
||||
&stop_nodes,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
n_required_map
|
||||
}
|
||||
|
||||
fn insert_checkpoints(
|
||||
mut self,
|
||||
backward_states_map: &mut HashMap<NodeId, State>,
|
||||
retro_forward_map: &mut HashMap<NodeId, Arc<dyn RetroForward>>,
|
||||
n_required_map: HashMap<NodeId, usize>,
|
||||
) {
|
||||
// We do not loop over checkpointing actions anymore because they can contain
|
||||
// duplicates or miss some that are in backup. We loop over the n_required_map
|
||||
// from which we use the ids to find them again in the checkpointing actions
|
||||
for (node_id, n_required) in n_required_map {
|
||||
// We find the checkpointing action for node_id. It's likely in checkpointing_actions
|
||||
// so we check there first, otherwise it will be in backup.
|
||||
// Technically it can be there several times but can never be of both types, so we can assume the first we find is fine
|
||||
|
||||
let action = match self
|
||||
.explicit_actions
|
||||
.iter()
|
||||
.position(|action| action.id() == node_id)
|
||||
{
|
||||
Some(pos) => self.explicit_actions.remove(pos),
|
||||
None => {
|
||||
let pos = self
|
||||
.backup_actions
|
||||
.iter()
|
||||
.position(|action| action.id() == node_id);
|
||||
self.backup_actions.remove(pos.unwrap_or_else(|| {
|
||||
panic!("Node {:?} is needed but never checkpointed", &node_id)
|
||||
}))
|
||||
}
|
||||
};
|
||||
|
||||
match action {
|
||||
CheckpointingAction::Computed {
|
||||
node_id: _,
|
||||
state_content,
|
||||
} => {
|
||||
self.checkpoint_compute(backward_states_map, node_id, state_content, n_required)
|
||||
}
|
||||
CheckpointingAction::Recompute {
|
||||
node_id: _,
|
||||
retro_forward,
|
||||
} => self.checkpoint_lazy(
|
||||
backward_states_map,
|
||||
retro_forward_map,
|
||||
node_id,
|
||||
retro_forward,
|
||||
n_required,
|
||||
),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn update_n_required_of_parents(
|
||||
id: NodeId,
|
||||
n_required_map: &mut HashMap<NodeId, usize>,
|
||||
node_tree: &NodeTree,
|
||||
stop_nodes: &Vec<NodeId>,
|
||||
) {
|
||||
match n_required_map.remove(&id) {
|
||||
Some(n) => {
|
||||
n_required_map.insert(id, n + 1);
|
||||
}
|
||||
None => {
|
||||
n_required_map.insert(id, 1);
|
||||
if !stop_nodes.contains(&id)
|
||||
&& let Some(parents) = node_tree.parents(&id)
|
||||
{
|
||||
for p in parents {
|
||||
Self::update_n_required_of_parents(
|
||||
p,
|
||||
n_required_map,
|
||||
node_tree,
|
||||
stop_nodes,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn checkpoint_compute(
|
||||
&self,
|
||||
backward_states_map: &mut HashMap<NodeId, State>,
|
||||
node_id: NodeId,
|
||||
state_content: Box<dyn Any + Send>,
|
||||
n_required: usize,
|
||||
) {
|
||||
backward_states_map.insert(
|
||||
node_id,
|
||||
State::Computed {
|
||||
state_content,
|
||||
n_required,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn checkpoint_lazy(
|
||||
&self,
|
||||
backward_states_map: &mut HashMap<NodeId, State>,
|
||||
retro_forward_map: &mut HashMap<NodeId, Arc<dyn RetroForward>>,
|
||||
node_id: NodeId,
|
||||
retro_forward: Arc<dyn RetroForward>,
|
||||
n_required: usize,
|
||||
) {
|
||||
retro_forward_map.insert(node_id, retro_forward);
|
||||
backward_states_map.insert(node_id, State::Recompute { n_required });
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
/// Checkpointer module
|
||||
pub mod base;
|
||||
pub(crate) mod builder;
|
||||
/// RetroForward module
|
||||
pub mod retro_forward;
|
||||
/// BackwardStates module
|
||||
pub mod state;
|
||||
/// CheckpointStrategy module
|
||||
pub mod strategy;
|
||||
@@ -0,0 +1,116 @@
|
||||
use crate::collections::HashMap;
|
||||
use crate::graph::NodeId;
|
||||
|
||||
use alloc::sync::Arc;
|
||||
use core::fmt::Debug;
|
||||
|
||||
use super::state::{BackwardStates, State};
|
||||
|
||||
/// Definition of the forward function of a node, called during retropropagation only.
|
||||
/// This is different from the normal forward function because it reads and writes from
|
||||
/// the [BackwardStates] map instead of having a clear function signature.
|
||||
pub trait RetroForward: Debug + Send + 'static {
|
||||
/// Applies the forward pass for retropropagation.
|
||||
fn forward(&self, states: &mut BackwardStates, out_node: NodeId);
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
/// Links [NodeId]s to their corresponding [RetroForward]
|
||||
pub(crate) struct RetroForwards {
|
||||
map: HashMap<NodeId, Arc<dyn RetroForward>>,
|
||||
}
|
||||
|
||||
impl RetroForwards {
|
||||
/// Executes the [RetroForward] for a given [NodeId] if the node's
|
||||
/// [State] is [State::Recompute], otherwise does nothing.
|
||||
pub(crate) fn execute_retro_forward(
|
||||
&mut self,
|
||||
node_id: NodeId,
|
||||
backward_states: &mut BackwardStates,
|
||||
) {
|
||||
if let State::Recompute { n_required: _ } = backward_states
|
||||
.get_state_ref(&node_id)
|
||||
.unwrap_or_else(|| panic!("Should find node {node_id:?}"))
|
||||
{
|
||||
// Retro forwards are always used only once because afterwards their state is computed
|
||||
let retro_forward = self.map.remove(&node_id).unwrap();
|
||||
retro_forward.forward(backward_states, node_id);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_empty(&self) -> bool {
|
||||
self.map.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
/// Creates a RetroForward struct for unary scalar operations
|
||||
macro_rules! retro_unary_scalar {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug, Clone)]
|
||||
struct $name<B: Backend> {
|
||||
lhs_id: NodeId,
|
||||
rhs: Scalar,
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> RetroForward for $name<B> {
|
||||
fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {
|
||||
let lhs = states.get_state::<B::FloatTensorPrimitive>(&self.lhs_id);
|
||||
let out = $ops(lhs, self.rhs);
|
||||
states.save(out_node, out)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
/// Creates a RetroForward struct for unary scalar operations
|
||||
macro_rules! retro_unary {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug, Clone)]
|
||||
struct $name<B: Backend> {
|
||||
input_id: NodeId,
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> RetroForward for $name<B> {
|
||||
fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {
|
||||
let input = states.get_state::<B::FloatTensorPrimitive>(&self.input_id);
|
||||
let out = $ops(input);
|
||||
states.save(out_node, out)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
/// Creates a RetroForward struct for binary operations
|
||||
macro_rules! retro_binary {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug, Clone)]
|
||||
struct $name<B: Backend> {
|
||||
lhs_id: NodeId,
|
||||
rhs_id: NodeId,
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> RetroForward for $name<B> {
|
||||
fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {
|
||||
let lhs = states.get_state::<B::FloatTensorPrimitive>(&self.lhs_id);
|
||||
let rhs = states.get_state::<B::FloatTensorPrimitive>(&self.rhs_id);
|
||||
let out = $ops(lhs, rhs);
|
||||
states.save(out_node, out)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
use core::any::Any;
|
||||
|
||||
use crate::collections::HashMap;
|
||||
use crate::graph::NodeId;
|
||||
use alloc::boxed::Box;
|
||||
|
||||
/// In order to accept arbitrary node output in the same hashmap, we need to upcast them to any.
|
||||
pub(crate) type StateContent = Box<dyn Any + Send>;
|
||||
|
||||
#[derive(Debug)]
|
||||
/// The state contained at one node. Encapsulates the node output if precomputed,
|
||||
/// or clearly asks that it needs to be recomputed from the parents.
|
||||
/// Also keeps track of the number of times the state is required so it can be removed
|
||||
/// from the map of states on its last use.
|
||||
pub(crate) enum State {
|
||||
/// The state was not checkpointed, will need to recompute it from the node's parents
|
||||
Recompute { n_required: usize },
|
||||
/// The state was checkpointed or computed during retropropagation and can be directly accessed
|
||||
Computed {
|
||||
state_content: StateContent,
|
||||
n_required: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl State {
|
||||
/// Returns a reference to the (not yet) downcasted node output, if checkpointed
|
||||
pub(crate) fn to_state_content(&self) -> &StateContent {
|
||||
match self {
|
||||
State::Recompute { n_required: _ } => {
|
||||
unreachable!(
|
||||
"Can't get state content of recompute state. A child has likely been accessed before its parents."
|
||||
)
|
||||
}
|
||||
State::Computed {
|
||||
state_content,
|
||||
n_required: _,
|
||||
} => state_content,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a (not yet) downcasted node output, if checkpointed
|
||||
pub(crate) fn into_state_content(self) -> StateContent {
|
||||
match self {
|
||||
State::Recompute { n_required: _ } => {
|
||||
unreachable!(
|
||||
"Can't get state content of recompute state. A child has likely been accessed before its parents."
|
||||
)
|
||||
}
|
||||
State::Computed {
|
||||
state_content,
|
||||
n_required: _,
|
||||
} => state_content,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of time the state is required
|
||||
pub(crate) fn n_required(&self) -> usize {
|
||||
match self {
|
||||
State::Recompute { n_required } => *n_required,
|
||||
State::Computed {
|
||||
state_content: _,
|
||||
n_required,
|
||||
} => *n_required,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new, Default, Debug)]
|
||||
/// Links [NodeId]s to their current state
|
||||
pub struct BackwardStates {
|
||||
map: HashMap<NodeId, State>,
|
||||
}
|
||||
|
||||
impl BackwardStates {
|
||||
/// Returns the output in the state of the given [NodeId],
|
||||
/// and decrements the number of times this state is required.
|
||||
/// This function always gives ownership of the output, but will clone it if needed for further uses.
|
||||
pub fn get_state<T>(&mut self, node_id: &NodeId) -> T
|
||||
where
|
||||
T: Clone + Send + 'static,
|
||||
{
|
||||
// Fetch the state and decrement its number of required
|
||||
let state = self.map.remove(node_id).unwrap();
|
||||
let remaining_n_required = state.n_required() - 1;
|
||||
|
||||
// Downcast the state to whatever it is supposed to be
|
||||
// If still needed after giving ownership, we copy it back to the hashmap
|
||||
if remaining_n_required > 0 {
|
||||
let new_stored_state = match state {
|
||||
State::Recompute { n_required: _ } => unreachable!(),
|
||||
State::Computed {
|
||||
state_content,
|
||||
n_required: _,
|
||||
} => State::Computed {
|
||||
state_content,
|
||||
n_required: remaining_n_required,
|
||||
},
|
||||
};
|
||||
|
||||
let downcasted = new_stored_state
|
||||
.to_state_content()
|
||||
.downcast_ref::<T>()
|
||||
.unwrap()
|
||||
.clone();
|
||||
|
||||
self.insert_state(*node_id, new_stored_state);
|
||||
|
||||
downcasted
|
||||
} else {
|
||||
let downcasted = state.into_state_content().downcast::<T>().unwrap();
|
||||
*downcasted
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a reference to the [State] of the given node
|
||||
/// Useful when we need [State] information without needing the underlying tensor
|
||||
pub(crate) fn get_state_ref(&self, node_id: &NodeId) -> Option<&State> {
|
||||
self.map.get(node_id)
|
||||
}
|
||||
|
||||
/// Associates a [State] to its [NodeId]
|
||||
pub(crate) fn insert_state(&mut self, node_id: NodeId, state: State) {
|
||||
self.map.insert(node_id, state);
|
||||
}
|
||||
|
||||
/// Saves the output to the state of the given [NodeId].
|
||||
pub fn save<T>(&mut self, node_id: NodeId, saved_output: T)
|
||||
where
|
||||
T: Clone + Send + 'static,
|
||||
{
|
||||
let n_required = self.get_state_ref(&node_id).unwrap().n_required();
|
||||
self.insert_state(
|
||||
node_id,
|
||||
State::Computed {
|
||||
state_content: Box::new(saved_output),
|
||||
n_required,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) fn is_empty(&self) -> bool {
|
||||
self.map.is_empty()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
use core::fmt::Debug;
|
||||
|
||||
use burn_backend::Backend;
|
||||
|
||||
use crate::{graph::ComputingProperty, tensor::AutodiffTensor};
|
||||
use alloc::sync::Arc;
|
||||
|
||||
use super::{
|
||||
builder::{ActionType, CheckpointerBuilder},
|
||||
retro_forward::RetroForward,
|
||||
};
|
||||
|
||||
/// Strategy for the amount of checkpointing to do during autodiff
|
||||
pub trait CheckpointStrategy: Clone + Copy + Debug + Default + Send + Sync + 'static {
|
||||
/// May modify the compute property depending on the strategy
|
||||
fn compute_property<R: RetroForward>(retro_forward: R) -> ComputingProperty;
|
||||
|
||||
/// Checkpoints parents if necessary in the strategy
|
||||
fn checkpoint_parents<'a, B2, A>(
|
||||
parents: A,
|
||||
builder: &mut CheckpointerBuilder,
|
||||
) -> Result<(), CheckpointingError>
|
||||
where
|
||||
B2: Backend,
|
||||
A: IntoIterator<Item = &'a AutodiffTensor<B2>>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Error that can happen when trying to checkpoint a tensor.
|
||||
pub enum CheckpointingError {
|
||||
/// When a parent is untracked, we can't easily checkpoint its state, since we don't know the
|
||||
/// requirements in advanced.
|
||||
UntrackedParent,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
/// All operations are considered compute bound, notwithstanding how they are marked
|
||||
pub struct NoCheckpointing {}
|
||||
|
||||
impl CheckpointStrategy for NoCheckpointing {
|
||||
/// An operation marked as memory bound is actually compute bound.
|
||||
fn compute_property<R: RetroForward>(_retro_forward: R) -> ComputingProperty {
|
||||
ComputingProperty::ComputeBound
|
||||
}
|
||||
|
||||
/// An operation marked as memory bound is actually compute bound.
|
||||
/// It's therefore useless to checkpoint the parents
|
||||
fn checkpoint_parents<'a, B2, A>(
|
||||
_parents: A,
|
||||
_builder: &mut CheckpointerBuilder,
|
||||
) -> Result<(), CheckpointingError>
|
||||
where
|
||||
B2: Backend,
|
||||
A: IntoIterator<Item = &'a AutodiffTensor<B2>>,
|
||||
{
|
||||
// Nothing to do here
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
/// Operation properties are as they are marked (compute or memory bound)
|
||||
pub struct BalancedCheckpointing {}
|
||||
|
||||
impl CheckpointStrategy for BalancedCheckpointing {
|
||||
/// An operation marked as memory bound is memory bound.
|
||||
/// When memory bound, an operation needs to save its RetroForward
|
||||
fn compute_property<R: RetroForward>(retro_forward: R) -> ComputingProperty {
|
||||
ComputingProperty::MemoryBound {
|
||||
retro_forward: Arc::new(retro_forward),
|
||||
}
|
||||
}
|
||||
|
||||
/// An operation marked as memory bound is really memory bound.
|
||||
/// Since the operation may not checkpoint its parents but may need them indirectly
|
||||
/// if asked to recompute itself, the method needs to know the parent tensors to maybe checkpoint them
|
||||
fn checkpoint_parents<'a, B2, A>(
|
||||
parents: A,
|
||||
builder: &mut CheckpointerBuilder,
|
||||
) -> Result<(), CheckpointingError>
|
||||
where
|
||||
B2: Backend,
|
||||
A: IntoIterator<Item = &'a AutodiffTensor<B2>>,
|
||||
{
|
||||
let mut can_checkpoint = true;
|
||||
|
||||
for tensor in parents.into_iter() {
|
||||
if let crate::graph::Requirement::None = tensor.node.requirement {
|
||||
can_checkpoint = false;
|
||||
} else {
|
||||
builder.checkpoint(tensor, ActionType::Backup);
|
||||
}
|
||||
}
|
||||
|
||||
if !can_checkpoint {
|
||||
*builder = CheckpointerBuilder::default();
|
||||
return Err(CheckpointingError::UntrackedParent);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
use burn_backend::{
|
||||
Backend, TensorMetadata, TensorPrimitive,
|
||||
tensor::{FloatTensor, TensorContainer},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
NodeId,
|
||||
graph::{NodeRef, Requirement},
|
||||
tensor::AutodiffTensor,
|
||||
};
|
||||
|
||||
/// Gradient identifier.
|
||||
pub type GradID = u64;
|
||||
|
||||
/// Gradients container used during the backward pass.
|
||||
pub struct Gradients {
|
||||
container: TensorContainer<GradID>,
|
||||
}
|
||||
|
||||
impl Gradients {
|
||||
/// Creates a new gradients container.
|
||||
pub fn new<B: Backend>(root_node: NodeRef, root_tensor: FloatTensor<B>) -> Self {
|
||||
let mut gradients = Self {
|
||||
container: TensorContainer::new(),
|
||||
};
|
||||
gradients.register::<B>(
|
||||
root_node.id,
|
||||
B::float_ones(
|
||||
root_tensor.shape(),
|
||||
&B::float_device(&root_tensor),
|
||||
root_tensor.dtype().into(),
|
||||
),
|
||||
);
|
||||
gradients
|
||||
}
|
||||
|
||||
/// Consumes the gradients for a given tensor.
|
||||
///
|
||||
/// Each tensor should be consumed exactly 1 time if its gradients are only required during the
|
||||
/// backward pass, otherwise, it may be consume multiple times.
|
||||
pub fn consume<B: Backend>(&mut self, node: &NodeRef) -> FloatTensor<B> {
|
||||
match node.requirement {
|
||||
Requirement::Grad => self
|
||||
.container
|
||||
.get::<B>(&node.id.value)
|
||||
.map(|tensor| tensor.tensor())
|
||||
.expect("Can't consume the gradients before they are registered at least once."),
|
||||
Requirement::GradInBackward => self
|
||||
.container
|
||||
.remove::<B>(&node.id.value)
|
||||
.map(|tensor| tensor.tensor())
|
||||
.expect("Can't consume the gradients before they are registered at least once."),
|
||||
Requirement::None => panic!("Trying to consume the gradients for an untracked tensor"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes a grad tensor from the container.
|
||||
pub fn remove<B: Backend>(&mut self, tensor: &AutodiffTensor<B>) -> Option<FloatTensor<B>> {
|
||||
self.container
|
||||
.remove::<B>(&tensor.node.id.value)
|
||||
.map(|tensor| tensor.tensor())
|
||||
}
|
||||
|
||||
/// Gets a grad tensor from the container.
|
||||
pub fn get<B: Backend>(&self, tensor: &AutodiffTensor<B>) -> Option<FloatTensor<B>> {
|
||||
self.container
|
||||
.get::<B>(&tensor.node.id.value)
|
||||
.map(|tensor| tensor.tensor())
|
||||
}
|
||||
|
||||
/// Register a grad tensor in the container.
|
||||
///
|
||||
/// If the tensor already exists, add both tensors together before saving the result.
|
||||
pub fn register<B: Backend>(&mut self, node_id: NodeId, value: FloatTensor<B>) {
|
||||
if let Some(tensor_old) = self.container.remove::<B>(&node_id.value) {
|
||||
self.container.register::<B>(
|
||||
node_id.value,
|
||||
TensorPrimitive::Float(B::float_add(value, tensor_old.tensor())),
|
||||
);
|
||||
} else {
|
||||
self.container
|
||||
.register::<B>(node_id.value, TensorPrimitive::Float(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
use super::NodeId;
|
||||
use crate::{checkpoint::base::Checkpointer, grads::Gradients, graph::Parent};
|
||||
use alloc::boxed::Box;
|
||||
|
||||
/// Backward step for reverse mode autodiff.
|
||||
pub trait Step: Send + core::fmt::Debug {
|
||||
/// Executes the step and consumes it.
|
||||
fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer);
|
||||
/// Depth of the operation relative to the first node added to a graph.
|
||||
fn depth(&self) -> usize;
|
||||
/// The node associated to the step.
|
||||
fn node(&self) -> NodeId;
|
||||
/// The parents of the node associated to the step.
|
||||
fn parents(&self) -> &[Parent];
|
||||
}
|
||||
|
||||
pub type StepBoxed = Box<dyn Step>;
|
||||
@@ -0,0 +1,9 @@
|
||||
mod base;
|
||||
mod node;
|
||||
mod requirement;
|
||||
|
||||
pub mod traversal;
|
||||
|
||||
pub use base::*;
|
||||
pub use node::*;
|
||||
pub use requirement::*;
|
||||
@@ -0,0 +1,87 @@
|
||||
use alloc::{sync::Arc, vec::Vec};
|
||||
|
||||
#[cfg(target_has_atomic = "64")]
|
||||
use core::sync::atomic::{AtomicU64, Ordering};
|
||||
#[cfg(not(target_has_atomic = "64"))]
|
||||
use portable_atomic::{AtomicU64, Ordering};
|
||||
|
||||
use crate::checkpoint::retro_forward::RetroForward;
|
||||
use crate::runtime::AutodiffClientImpl;
|
||||
|
||||
use super::Requirement;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ComputingProperty {
|
||||
ComputeBound,
|
||||
MemoryBound {
|
||||
retro_forward: Arc<dyn RetroForward>,
|
||||
},
|
||||
Ambiguous, // Maybe autotune someday
|
||||
}
|
||||
|
||||
/// This is safe only because we only call RetroForward on the autodiff server.
|
||||
/// Therefore, the trait will never be used by multiple threads at the same time.
|
||||
///
|
||||
/// TODO: Find a way to avoid cloning the compute property, which will remove the need to add the
|
||||
/// Arc, which will make (dyn RetroForward) safely implement Send.
|
||||
unsafe impl Send for ComputingProperty {}
|
||||
/// unsafe Sync is required because Send is only implemented for Arc<Sync>, not Arc<Send>.
|
||||
unsafe impl Sync for ComputingProperty {}
|
||||
|
||||
/// A node contains graph metadata and should be used wrapped in an Arc for cheap cloning.
|
||||
#[derive(new, Debug)]
|
||||
pub struct Node {
|
||||
pub parents: Vec<Parent>,
|
||||
pub order: usize,
|
||||
pub id: NodeId,
|
||||
pub requirement: Requirement,
|
||||
pub properties: ComputingProperty,
|
||||
pub client: AutodiffClientImpl,
|
||||
}
|
||||
pub type NodeRef = Arc<Node>;
|
||||
|
||||
#[derive(new, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Parent {
|
||||
pub id: NodeId,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
/// Returns the [node](Node) only if gradients are required.
|
||||
pub fn clone_if_require_grad(self: &Arc<Self>) -> Option<NodeRef> {
|
||||
match self.requirement.is_none() {
|
||||
true => None,
|
||||
false => Some(self.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Unique identifier generated for each node.
|
||||
#[derive(Clone, Hash, PartialEq, Eq, Debug, Copy)]
|
||||
pub struct NodeId {
|
||||
/// The integer representation of the id
|
||||
pub value: u64,
|
||||
}
|
||||
|
||||
impl core::fmt::Display for NodeId {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_fmt(format_args!("NodeId({})", self.value))
|
||||
}
|
||||
}
|
||||
|
||||
impl NodeId {
|
||||
/// Create a unique [node id](NodeId).
|
||||
pub fn new() -> Self {
|
||||
static COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
let value = COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
if value == u64::MAX {
|
||||
panic!("NodeId overflowed");
|
||||
}
|
||||
Self { value }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NodeId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
use super::NodeRef;
|
||||
|
||||
/// Requirement for each tensor in the graph.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Requirement {
|
||||
/// Operations that require gradients.
|
||||
Grad,
|
||||
/// Operations that require gradients only for backprop.
|
||||
GradInBackward,
|
||||
/// Operations that don't need gradients, therefore not to be included in the graph.
|
||||
None,
|
||||
}
|
||||
|
||||
impl Requirement {
|
||||
/// Returns true if gradients are not required.
|
||||
pub fn is_none(&self) -> bool {
|
||||
matches!(self, Self::None)
|
||||
}
|
||||
/// Returns the right requirement from a list of nodes.
|
||||
pub fn from_nodes(nodes: &[NodeRef]) -> Self {
|
||||
if nodes.len() == 1 {
|
||||
return nodes[0].requirement.infer(&Requirement::None);
|
||||
}
|
||||
|
||||
nodes
|
||||
.iter()
|
||||
.map(|node| node.requirement)
|
||||
.reduce(|acc, requirement| requirement.infer(&acc))
|
||||
.unwrap_or(Requirement::None)
|
||||
}
|
||||
|
||||
fn infer(&self, other: &Self) -> Self {
|
||||
match self.is_none() && other.is_none() {
|
||||
true => Self::None,
|
||||
false => Self::GradInBackward,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
use super::{Step, StepBoxed};
|
||||
use crate::{
|
||||
NodeId,
|
||||
collections::{HashMap, HashSet},
|
||||
graph::Parent,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
/// Breadth for search algorithm.
|
||||
pub struct BreadthFirstSearch;
|
||||
|
||||
pub trait TraversalItem {
|
||||
fn id(&self) -> NodeId;
|
||||
fn parents(&self) -> &[Parent];
|
||||
fn parent_nodes(&self) -> Vec<NodeId> {
|
||||
self.parents().iter().map(|p| p.id).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl BreadthFirstSearch {
|
||||
/// Traverse the graph of backward steps from a root node.
|
||||
pub fn traverse<F, I>(
|
||||
&self,
|
||||
root_id: NodeId,
|
||||
root_step: I,
|
||||
steps: &mut HashMap<NodeId, I>,
|
||||
mut callback: F,
|
||||
) where
|
||||
F: FnMut(NodeId, I),
|
||||
I: TraversalItem,
|
||||
{
|
||||
let mut visited = HashSet::new();
|
||||
let mut parents = Vec::new();
|
||||
|
||||
visited.insert(root_id);
|
||||
parents.append(&mut root_step.parent_nodes());
|
||||
|
||||
callback(root_id, root_step);
|
||||
|
||||
while let Some(id) = parents.pop() {
|
||||
let step = match steps.remove(&id) {
|
||||
Some(step) => step,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let step_node = step.id();
|
||||
let step_parents = step.parent_nodes();
|
||||
|
||||
if visited.contains(&step_node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
visited.insert(step_node);
|
||||
|
||||
for id in step_parents.iter() {
|
||||
if !visited.contains(id) {
|
||||
parents.push(*id);
|
||||
}
|
||||
}
|
||||
|
||||
callback(step_node, step);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TraversalItem for StepBoxed {
|
||||
fn id(&self) -> NodeId {
|
||||
Step::node(self.as_ref())
|
||||
}
|
||||
|
||||
fn parents(&self) -> &[Parent] {
|
||||
Step::parents(self.as_ref())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
//! # Burn Autodiff
|
||||
//!
|
||||
//! This autodiff library is a part of the Burn project. It is a standalone crate
|
||||
//! that can be used to perform automatic differentiation on tensors. It is
|
||||
//! designed to be used with the Burn Tensor crate, but it can be used with any
|
||||
//! tensor library that implements the `Backend` trait.
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
/// Checkpoint module.
|
||||
pub mod checkpoint;
|
||||
/// Gradients module.
|
||||
pub mod grads;
|
||||
/// Operation module.
|
||||
pub mod ops;
|
||||
|
||||
pub(crate) mod graph;
|
||||
// Exported for backend extension
|
||||
pub use graph::NodeId;
|
||||
pub(crate) mod tensor;
|
||||
pub(crate) mod utils;
|
||||
|
||||
mod backend;
|
||||
|
||||
pub(crate) mod runtime;
|
||||
|
||||
pub use backend::*;
|
||||
|
||||
/// A facade around for HashMap and HashSet.
|
||||
/// This avoids elaborate import wrangling having to happen in every module.
|
||||
mod collections {
|
||||
#[cfg(not(feature = "std"))]
|
||||
pub use hashbrown::{HashMap, HashSet};
|
||||
#[cfg(feature = "std")]
|
||||
pub use std::collections::{HashMap, HashSet};
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
Autodiff,
|
||||
checkpoint::{
|
||||
base::Checkpointer, retro_forward::RetroForward, state::BackwardStates,
|
||||
strategy::CheckpointStrategy,
|
||||
},
|
||||
grads::Gradients,
|
||||
graph::NodeId,
|
||||
ops::{Backward, Ops, OpsKind, unary},
|
||||
retro_unary,
|
||||
};
|
||||
use burn_backend::{Backend, ops::ActivationOps, tensor::FloatTensor};
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> ActivationOps<Autodiff<B, C>> for Autodiff<B, C> {
|
||||
fn gelu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
#[derive(Debug)]
|
||||
struct Gelu;
|
||||
|
||||
retro_unary!(RetroGelu, B::gelu);
|
||||
|
||||
impl<B: Backend> Backward<B, 1> for Gelu {
|
||||
type State = NodeId;
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 1>,
|
||||
grads: &mut Gradients,
|
||||
checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
let input = checkpointer.retrieve_node_output(ops.state);
|
||||
|
||||
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
|
||||
B::gelu_backward(input, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match Gelu
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroGelu::<B>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(mut prep) => {
|
||||
let state = prep.checkpoint(&tensor);
|
||||
prep.finish(state, B::gelu(tensor.primitive.clone()))
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::gelu(tensor.primitive)),
|
||||
}
|
||||
}
|
||||
|
||||
fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
#[derive(Debug)]
|
||||
struct Relu;
|
||||
|
||||
retro_unary!(RetroRelu, B::relu);
|
||||
|
||||
impl<B: Backend> Backward<B, 1> for Relu {
|
||||
type State = NodeId;
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 1>,
|
||||
grads: &mut Gradients,
|
||||
checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
let state = checkpointer.retrieve_node_output(ops.state);
|
||||
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
|
||||
B::relu_backward(state, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match Relu
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroRelu::<B>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(mut prep) => {
|
||||
let state = prep.checkpoint(&tensor);
|
||||
prep.finish(state, B::relu(tensor.primitive))
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::relu(tensor.primitive)),
|
||||
}
|
||||
}
|
||||
|
||||
fn sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
#[derive(Debug)]
|
||||
struct Sigmoid;
|
||||
|
||||
retro_unary!(RetroSigmoid, B::sigmoid);
|
||||
|
||||
impl<B: Backend> Backward<B, 1> for Sigmoid {
|
||||
type State = NodeId;
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 1>,
|
||||
grads: &mut Gradients,
|
||||
checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
let input = checkpointer.retrieve_node_output(ops.state);
|
||||
let output = B::sigmoid(input);
|
||||
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
|
||||
B::sigmoid_backward(output, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match Sigmoid
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroSigmoid::<B>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(mut prep) => {
|
||||
let state = prep.checkpoint(&tensor);
|
||||
prep.finish(state, B::sigmoid(tensor.primitive))
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::sigmoid(tensor.primitive)),
|
||||
}
|
||||
}
|
||||
|
||||
fn log_sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
#[derive(Debug)]
|
||||
struct LogSigmoid;
|
||||
|
||||
retro_unary!(RetroLogSigmoid, B::log_sigmoid);
|
||||
|
||||
impl<B: Backend> Backward<B, 1> for LogSigmoid {
|
||||
type State = NodeId;
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 1>,
|
||||
grads: &mut Gradients,
|
||||
checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
let input = checkpointer.retrieve_node_output(ops.state);
|
||||
|
||||
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
|
||||
B::log_sigmoid_backward(input, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match LogSigmoid
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroLogSigmoid::<B>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(mut prep) => {
|
||||
let state = prep.checkpoint(&tensor);
|
||||
prep.finish(state, B::log_sigmoid(tensor.primitive.clone()))
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::log_sigmoid(tensor.primitive)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
use super::{Ops, OpsPrep};
|
||||
use crate::{
|
||||
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder, strategy::CheckpointStrategy},
|
||||
grads::Gradients,
|
||||
graph::{ComputingProperty, NodeRef, Requirement},
|
||||
utils::duplicate,
|
||||
};
|
||||
use burn_backend::Backend;
|
||||
|
||||
/// Trait for all operations.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Concrete types implementing this trait should not have any state.
|
||||
/// If a state is necessary during the backward pass,
|
||||
/// they should be declared with the associated type 'State'.
|
||||
pub trait Backward<B, const N: usize>: Send + core::fmt::Debug
|
||||
where
|
||||
Self: Sized + 'static,
|
||||
B: Backend,
|
||||
{
|
||||
/// Associated type to compute the backward pass.
|
||||
type State: Clone + Send + core::fmt::Debug + 'static;
|
||||
|
||||
/// The backward pass.
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, N>,
|
||||
grads: &mut Gradients,
|
||||
checkpointer: &mut Checkpointer,
|
||||
);
|
||||
|
||||
/// Prepare the backward ops.
|
||||
fn prepare<C: CheckpointStrategy>(
|
||||
self,
|
||||
nodes: [NodeRef; N],
|
||||
) -> OpsPrep<Self, B, Self::State, C, N> {
|
||||
let requirement = Requirement::from_nodes(&nodes);
|
||||
OpsPrep::new(
|
||||
nodes,
|
||||
requirement,
|
||||
self,
|
||||
ComputingProperty::Ambiguous, // If not specified we start with ambiguous
|
||||
CheckpointerBuilder::default(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a binary operation during the backward step.
|
||||
pub fn binary<B, FLhs, FRhs>(
|
||||
parents: [Option<NodeRef>; 2],
|
||||
node: NodeRef,
|
||||
grads: &mut Gradients,
|
||||
func_lhs: FLhs,
|
||||
func_rhs: FRhs,
|
||||
) where
|
||||
B: Backend,
|
||||
FLhs: FnOnce(B::FloatTensorPrimitive) -> B::FloatTensorPrimitive,
|
||||
FRhs: FnOnce(B::FloatTensorPrimitive) -> B::FloatTensorPrimitive,
|
||||
{
|
||||
let [grad_4lhs, grad_4rhs] = duplicate(&parents, Some(grads.consume::<B>(&node)));
|
||||
let [node_lhs, node_rhs] = parents;
|
||||
|
||||
if let Some(node) = node_lhs {
|
||||
let grad = func_lhs(grad_4lhs.unwrap());
|
||||
grads.register::<B>(node.id, grad)
|
||||
}
|
||||
|
||||
if let Some(node) = node_rhs {
|
||||
let grad = func_rhs(grad_4rhs.unwrap());
|
||||
grads.register::<B>(node.id, grad)
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a unary operation during the backward step.
|
||||
pub fn unary<B, F>(parents: [Option<NodeRef>; 1], node: NodeRef, grads: &mut Gradients, func: F)
|
||||
where
|
||||
B: Backend,
|
||||
F: FnOnce(B::FloatTensorPrimitive) -> B::FloatTensorPrimitive,
|
||||
{
|
||||
let [parent_node] = parents;
|
||||
let grad = grads.consume::<B>(&node);
|
||||
|
||||
if let Some(node) = parent_node {
|
||||
let grad = func(grad);
|
||||
grads.register::<B>(node.id, grad)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,317 @@
|
||||
use super::Backward;
|
||||
use crate::{
|
||||
checkpoint::{
|
||||
base::Checkpointer,
|
||||
builder::{ActionType, CheckpointerBuilder},
|
||||
retro_forward::RetroForward,
|
||||
strategy::CheckpointStrategy,
|
||||
},
|
||||
grads::Gradients,
|
||||
graph::{ComputingProperty, NodeId, NodeRef, Parent, Requirement, Step},
|
||||
tensor::AutodiffTensor,
|
||||
};
|
||||
use alloc::boxed::Box;
|
||||
use burn_backend::{Backend, TensorMetadata, tensor::FloatTensor};
|
||||
use burn_std::Shape;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
/// Operation in preparation.
|
||||
///
|
||||
/// Each mode has its own set of functions to minimize cloning for unused backward states.
|
||||
#[derive(new)]
|
||||
pub struct OpsPrep<Backward, B, S, C, const N: usize, Mode = Init> {
|
||||
nodes: [NodeRef; N],
|
||||
requirement: Requirement,
|
||||
backward: Backward,
|
||||
compute_property: ComputingProperty,
|
||||
checkpointer_builder: CheckpointerBuilder,
|
||||
checkpoint_strategy: PhantomData<C>,
|
||||
phantom_backend: PhantomData<B>,
|
||||
phantom_state: PhantomData<S>,
|
||||
marker: PhantomData<Mode>,
|
||||
}
|
||||
|
||||
/// Operation is initialized
|
||||
pub struct Init;
|
||||
/// Operation has been tagged as memory bound
|
||||
pub struct MemoryBound;
|
||||
/// Memory bound operation has received its RetroForward
|
||||
pub struct MemoryBoundRetroForward;
|
||||
/// Operation's compute property is fixed
|
||||
pub struct ComputePropertyDone;
|
||||
/// Tracked operation tag.
|
||||
pub struct Tracked;
|
||||
/// Untracked operation tag.
|
||||
pub struct UnTracked;
|
||||
|
||||
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, Init>
|
||||
where
|
||||
B: Backend,
|
||||
BO: Backward<B, N, State = S>,
|
||||
{
|
||||
/// Indicates that the operation is compute bound, meaning its computation
|
||||
/// is heavy and should not be recomputed
|
||||
pub fn compute_bound(self) -> OpsPrep<BO, B, S, C, N, ComputePropertyDone> {
|
||||
OpsPrep::new(
|
||||
self.nodes,
|
||||
self.requirement,
|
||||
self.backward,
|
||||
ComputingProperty::ComputeBound,
|
||||
self.checkpointer_builder,
|
||||
)
|
||||
}
|
||||
|
||||
/// Indicates that the operation is memory bound, meaning its computation
|
||||
/// is light and can be recomputed
|
||||
pub fn memory_bound(self) -> OpsPrep<BO, B, S, C, N, MemoryBound> {
|
||||
OpsPrep::new(
|
||||
self.nodes,
|
||||
self.requirement,
|
||||
self.backward,
|
||||
self.compute_property,
|
||||
self.checkpointer_builder,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, MemoryBound>
|
||||
where
|
||||
B: Backend,
|
||||
BO: Backward<B, N, State = S>,
|
||||
C: CheckpointStrategy,
|
||||
{
|
||||
/// Registers the retro forward, if needed
|
||||
pub fn retro_forward<R: RetroForward>(
|
||||
self,
|
||||
retro_forward: R,
|
||||
) -> OpsPrep<BO, B, S, C, N, MemoryBoundRetroForward> {
|
||||
OpsPrep::new(
|
||||
self.nodes,
|
||||
self.requirement,
|
||||
self.backward,
|
||||
C::compute_property(retro_forward),
|
||||
self.checkpointer_builder,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, MemoryBoundRetroForward>
|
||||
where
|
||||
B: Backend,
|
||||
BO: Backward<B, N, State = S>,
|
||||
C: CheckpointStrategy,
|
||||
{
|
||||
/// Checkpoints the parents, if needed
|
||||
pub fn parents<'a, B2, A>(mut self, parents: A) -> OpsPrep<BO, B, S, C, N, ComputePropertyDone>
|
||||
where
|
||||
B2: Backend,
|
||||
A: IntoIterator<Item = &'a AutodiffTensor<B2>>,
|
||||
{
|
||||
let compute_property = match C::checkpoint_parents(parents, &mut self.checkpointer_builder)
|
||||
{
|
||||
Ok(..) => self.compute_property,
|
||||
Err(..) => ComputingProperty::ComputeBound,
|
||||
};
|
||||
|
||||
OpsPrep::new(
|
||||
self.nodes,
|
||||
self.requirement,
|
||||
self.backward,
|
||||
compute_property,
|
||||
self.checkpointer_builder,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<BO, B, C, const N: usize> OpsPrep<BO, B, (), C, N, ComputePropertyDone>
|
||||
where
|
||||
B: Backend,
|
||||
BO: Backward<B, N, State = ()>,
|
||||
{
|
||||
/// Prepare a stateless operation.
|
||||
pub fn stateless(self, output: FloatTensor<B>) -> AutodiffTensor<B> {
|
||||
match self.stateful() {
|
||||
OpsKind::Tracked(prep) => prep.finish((), output),
|
||||
OpsKind::UnTracked(prep) => prep.finish(output),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, ComputePropertyDone>
|
||||
where
|
||||
B: Backend,
|
||||
S: Clone + Send + core::fmt::Debug + 'static,
|
||||
BO: Backward<B, N, State = S>,
|
||||
{
|
||||
/// Prepare an operation that requires a state during the backward pass.
|
||||
pub fn stateful(self) -> OpsKind<BO, B, S, C, N> {
|
||||
match self.requirement.is_none() {
|
||||
false => OpsKind::Tracked(OpsPrep::new(
|
||||
self.nodes,
|
||||
self.requirement,
|
||||
self.backward,
|
||||
self.compute_property,
|
||||
self.checkpointer_builder,
|
||||
)),
|
||||
true => OpsKind::UnTracked(OpsPrep::new(
|
||||
self.nodes,
|
||||
self.requirement,
|
||||
self.backward,
|
||||
self.compute_property,
|
||||
self.checkpointer_builder,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, UnTracked>
|
||||
where
|
||||
B: Backend,
|
||||
S: Clone + Send + core::fmt::Debug + 'static,
|
||||
BO: Backward<B, N, State = S>,
|
||||
{
|
||||
/// Finish the preparation of an untracked operation and returns the output tensor.
|
||||
pub fn finish(self, output: FloatTensor<B>) -> AutodiffTensor<B> {
|
||||
let output = AutodiffTensor::from_parents(
|
||||
output,
|
||||
&self.nodes,
|
||||
self.requirement,
|
||||
self.compute_property,
|
||||
);
|
||||
let parents = self.nodes.map(|node| node.clone_if_require_grad());
|
||||
let ops = Ops::new(parents, output.node.clone(), ());
|
||||
|
||||
// We register the ops in the graph even if untracked, otherwise memory bound operations
|
||||
// that have an untracked parent would not be able to retrieve it
|
||||
output.register_step(UntrackedOpsStep::new(ops), self.checkpointer_builder)
|
||||
}
|
||||
}
|
||||
|
||||
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, Tracked>
|
||||
where
|
||||
B: Backend,
|
||||
S: Clone + Send + core::fmt::Debug + 'static,
|
||||
BO: Backward<B, N, State = S>,
|
||||
{
|
||||
/// Finish the preparation of a tracked operation and returns the output tensor.
|
||||
pub fn finish(self, state: S, output: FloatTensor<B>) -> AutodiffTensor<B> {
|
||||
let output = AutodiffTensor::from_parents(
|
||||
output,
|
||||
&self.nodes,
|
||||
self.requirement,
|
||||
self.compute_property,
|
||||
);
|
||||
let parents = self.nodes.map(|node| node.clone_if_require_grad());
|
||||
let ops = Ops::new(parents, output.node.clone(), state);
|
||||
|
||||
output.register_step(OpsStep::new(ops, self.backward), self.checkpointer_builder)
|
||||
}
|
||||
|
||||
/// Checkpoints the tensor
|
||||
pub fn checkpoint(&mut self, tensor: &AutodiffTensor<B>) -> NodeId {
|
||||
self.checkpointer_builder
|
||||
.checkpoint(tensor, ActionType::Explicit);
|
||||
|
||||
tensor.node.id
|
||||
}
|
||||
}
|
||||
|
||||
/// Enum used before finishing tracked and untracked operations.
|
||||
pub enum OpsKind<BO, B, S, C, const N: usize> {
|
||||
/// Tracked operation preparation.
|
||||
Tracked(OpsPrep<BO, B, S, C, N, Tracked>),
|
||||
/// Untracked operation preparation.
|
||||
UnTracked(OpsPrep<BO, B, S, C, N, UnTracked>),
|
||||
}
|
||||
|
||||
/// Operation containing its parent nodes, its own node and the backward step state.
|
||||
#[derive(new, Debug)]
|
||||
pub struct Ops<S, const N: usize> {
|
||||
/// Parents nodes.
|
||||
pub parents: [Option<NodeRef>; N],
|
||||
/// The node.
|
||||
pub node: NodeRef,
|
||||
/// The state.
|
||||
pub state: S,
|
||||
}
|
||||
|
||||
/// Operation implementing backward [step](Step) with type erasing.
|
||||
#[derive(new, Debug)]
|
||||
struct OpsStep<B, T, SB, const N: usize>
|
||||
where
|
||||
B: Backend,
|
||||
T: Backward<B, N, State = SB>,
|
||||
SB: Clone + Send + core::fmt::Debug + 'static,
|
||||
{
|
||||
ops: Ops<SB, N>,
|
||||
backward: T,
|
||||
phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B, T, SB, const N: usize> Step for OpsStep<B, T, SB, N>
|
||||
where
|
||||
B: Backend,
|
||||
T: Backward<B, N, State = SB>,
|
||||
SB: Clone + Send + core::fmt::Debug + 'static,
|
||||
{
|
||||
fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer) {
|
||||
self.backward.backward(self.ops, grads, checkpointer);
|
||||
}
|
||||
|
||||
fn node(&self) -> NodeId {
|
||||
self.ops.node.id
|
||||
}
|
||||
|
||||
fn parents(&self) -> &[Parent] {
|
||||
&self.ops.node.parents
|
||||
}
|
||||
|
||||
fn depth(&self) -> usize {
|
||||
self.ops.node.order
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct UntrackedOpsStep<const N: usize> {
|
||||
ops: Ops<(), N>,
|
||||
}
|
||||
|
||||
impl<const N: usize> Step for UntrackedOpsStep<N> {
|
||||
fn step(self: Box<Self>, _grads: &mut Gradients, _checkpointer: &mut Checkpointer) {
|
||||
// Nothing to do
|
||||
}
|
||||
|
||||
fn node(&self) -> NodeId {
|
||||
self.ops.node.id
|
||||
}
|
||||
|
||||
fn parents(&self) -> &[Parent] {
|
||||
&self.ops.node.parents
|
||||
}
|
||||
fn depth(&self) -> usize {
|
||||
self.ops.node.order
|
||||
}
|
||||
}
|
||||
|
||||
/// Make sure the grad tensor has the given shape.
|
||||
///
|
||||
/// If broadcasting happened during the forward pass, the gradients will be sum along the
|
||||
/// broadcasted dimension.
|
||||
pub fn broadcast_shape<B: Backend>(mut grad: FloatTensor<B>, shape: &Shape) -> FloatTensor<B> {
|
||||
let shape_grad = grad.shape();
|
||||
let ndims = shape_grad.num_dims();
|
||||
|
||||
for i in 0..ndims {
|
||||
if shape_grad[i] != shape[i] {
|
||||
if shape[i] != 1 {
|
||||
panic!(
|
||||
"Invalid broadcast shapes: Next grad shape {:?}, Previous grad shape {:?}. {}",
|
||||
shape, shape_grad, "Expected the shape of the next grad to be 1."
|
||||
);
|
||||
}
|
||||
grad = B::float_sum_dim(grad, i);
|
||||
}
|
||||
}
|
||||
|
||||
grad
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use burn_backend::{
|
||||
Backend, ExecutionError, Scalar, TensorData,
|
||||
ops::BoolTensorOps,
|
||||
tensor::{BoolTensor, Device, IntTensor},
|
||||
};
|
||||
use burn_std::Shape;
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
|
||||
fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B> {
|
||||
B::bool_from_data(data, device)
|
||||
}
|
||||
|
||||
async fn bool_into_data(tensor: BoolTensor<B>) -> Result<TensorData, ExecutionError> {
|
||||
B::bool_into_data(tensor).await
|
||||
}
|
||||
|
||||
fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B> {
|
||||
B::bool_into_int(tensor)
|
||||
}
|
||||
|
||||
fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B> {
|
||||
B::bool_to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn bool_device(tensor: &BoolTensor<B>) -> Device<B> {
|
||||
B::bool_device(tensor)
|
||||
}
|
||||
|
||||
fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B> {
|
||||
B::bool_reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn bool_slice(tensor: BoolTensor<B>, slices: &[burn_std::Slice]) -> BoolTensor<B> {
|
||||
B::bool_slice(tensor, slices)
|
||||
}
|
||||
|
||||
fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B> {
|
||||
B::bool_empty(shape, device)
|
||||
}
|
||||
|
||||
fn bool_zeros(shape: Shape, device: &Device<B>) -> BoolTensor<B> {
|
||||
B::bool_zeros(shape, device)
|
||||
}
|
||||
|
||||
fn bool_ones(shape: Shape, device: &Device<B>) -> BoolTensor<B> {
|
||||
B::bool_ones(shape, device)
|
||||
}
|
||||
|
||||
fn bool_slice_assign(
|
||||
tensor: BoolTensor<Self>,
|
||||
slices: &[burn_std::Slice],
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
B::bool_slice_assign(tensor, slices, value)
|
||||
}
|
||||
|
||||
fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
|
||||
B::bool_cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
|
||||
B::bool_equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B> {
|
||||
B::bool_not(tensor)
|
||||
}
|
||||
|
||||
fn bool_and(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
|
||||
B::bool_and(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bool_or(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
|
||||
B::bool_or(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
|
||||
B::bool_xor(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bool_into_float(tensor: BoolTensor<B>) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {
|
||||
AutodiffTensor::new(B::bool_into_float(tensor))
|
||||
}
|
||||
|
||||
fn bool_swap_dims(
|
||||
tensor: <Autodiff<B> as Backend>::BoolTensorPrimitive,
|
||||
dim1: usize,
|
||||
dim2: usize,
|
||||
) -> <Autodiff<B> as Backend>::BoolTensorPrimitive {
|
||||
B::bool_swap_dims(tensor, dim1, dim2)
|
||||
}
|
||||
|
||||
fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
|
||||
B::bool_permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B> {
|
||||
B::bool_flip(tensor, axes)
|
||||
}
|
||||
|
||||
async fn bool_argwhere(tensor: BoolTensor<B>) -> IntTensor<B> {
|
||||
B::bool_argwhere(tensor).await
|
||||
}
|
||||
|
||||
fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B> {
|
||||
B::bool_expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
|
||||
B::bool_repeat_dim(tensor, dim, times)
|
||||
}
|
||||
|
||||
fn bool_unfold(
|
||||
tensor: BoolTensor<Self>,
|
||||
dim: usize,
|
||||
size: usize,
|
||||
step: usize,
|
||||
) -> BoolTensor<Self> {
|
||||
B::bool_unfold(tensor, dim, size, step)
|
||||
}
|
||||
|
||||
fn bool_mask_where(
|
||||
tensor: BoolTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
source: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
B::bool_mask_where(tensor, mask, source)
|
||||
}
|
||||
|
||||
fn bool_mask_fill(
|
||||
tensor: BoolTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: Scalar,
|
||||
) -> BoolTensor<Self> {
|
||||
B::bool_mask_fill(tensor, mask, value)
|
||||
}
|
||||
|
||||
fn bool_gather(
|
||||
dim: usize,
|
||||
tensor: BoolTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
B::bool_gather(dim, tensor, indices)
|
||||
}
|
||||
|
||||
fn bool_scatter_or(
|
||||
dim: usize,
|
||||
tensor: BoolTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
B::bool_scatter_or(dim, tensor, indices, value)
|
||||
}
|
||||
|
||||
fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
B::bool_equal_elem(lhs, rhs)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,406 @@
|
||||
use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use burn_backend::{
|
||||
Backend, Distribution, ExecutionError, Scalar, TensorData,
|
||||
ops::IntTensorOps,
|
||||
tensor::{BoolTensor, Device, IntTensor},
|
||||
};
|
||||
use burn_std::{IntDType, Shape};
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
|
||||
fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<B> {
|
||||
B::int_from_data(data, device)
|
||||
}
|
||||
|
||||
async fn int_into_data(tensor: IntTensor<B>) -> Result<TensorData, ExecutionError> {
|
||||
B::int_into_data(tensor).await
|
||||
}
|
||||
|
||||
fn int_to_device(tensor: IntTensor<B>, device: &Device<Self>) -> IntTensor<B> {
|
||||
B::int_to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn int_device(tensor: &IntTensor<B>) -> Device<Self> {
|
||||
B::int_device(tensor)
|
||||
}
|
||||
|
||||
fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B> {
|
||||
B::int_reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn int_slice(tensor: IntTensor<B>, slices: &[burn_std::Slice]) -> IntTensor<B> {
|
||||
B::int_slice(tensor, slices)
|
||||
}
|
||||
|
||||
fn int_empty(
|
||||
shape: Shape,
|
||||
device: &<Autodiff<B> as Backend>::Device,
|
||||
dtype: IntDType,
|
||||
) -> IntTensor<B> {
|
||||
B::int_empty(shape, device, dtype)
|
||||
}
|
||||
|
||||
fn int_slice_assign(
|
||||
tensor: IntTensor<B>,
|
||||
slices: &[burn_std::Slice],
|
||||
value: IntTensor<B>,
|
||||
) -> IntTensor<B> {
|
||||
B::int_slice_assign(tensor, slices, value)
|
||||
}
|
||||
|
||||
fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
|
||||
B::int_cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
|
||||
B::int_equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
|
||||
B::int_equal_elem(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_add(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_add_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
|
||||
B::int_add_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_clamp_min(tensor: IntTensor<B>, min: Scalar) -> IntTensor<B> {
|
||||
B::int_clamp_min(tensor, min)
|
||||
}
|
||||
|
||||
fn int_clamp_max(tensor: IntTensor<B>, max: Scalar) -> IntTensor<B> {
|
||||
B::int_clamp_max(tensor, max)
|
||||
}
|
||||
|
||||
fn int_clamp(tensor: IntTensor<B>, min: Scalar, max: Scalar) -> IntTensor<B> {
|
||||
B::int_clamp(tensor, min, max)
|
||||
}
|
||||
|
||||
fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_sub(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_sub_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
|
||||
B::int_sub_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_mul(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_mul_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
|
||||
B::int_mul_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_div(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_div_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
|
||||
B::int_div_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_remainder(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_remainder_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
|
||||
B::int_remainder_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_matmul(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_neg(tensor)
|
||||
}
|
||||
|
||||
fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<B> {
|
||||
B::int_zeros(shape, device, dtype)
|
||||
}
|
||||
|
||||
fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<B> {
|
||||
B::int_ones(shape, device, dtype)
|
||||
}
|
||||
|
||||
fn int_full(
|
||||
shape: Shape,
|
||||
fill_value: Scalar,
|
||||
device: &Device<Self>,
|
||||
dtype: IntDType,
|
||||
) -> IntTensor<B> {
|
||||
B::int_full(shape, fill_value, device, dtype)
|
||||
}
|
||||
|
||||
fn int_sum(tensor: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_sum(tensor)
|
||||
}
|
||||
|
||||
fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
|
||||
B::int_sum_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_mean(tensor)
|
||||
}
|
||||
|
||||
fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
|
||||
B::int_mean_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
|
||||
B::int_cumsum(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
|
||||
B::int_cumprod(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
|
||||
B::int_cummin(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
|
||||
B::int_cummax(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
|
||||
B::int_repeat_dim(tensor, dim, times)
|
||||
}
|
||||
|
||||
fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
|
||||
B::int_greater(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_greater_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
|
||||
B::int_greater_elem(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_greater_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
|
||||
B::int_greater_equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_greater_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
|
||||
B::int_greater_equal_elem(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
|
||||
B::int_lower(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_lower_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
|
||||
B::int_lower_elem(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
|
||||
B::int_lower_equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
|
||||
B::int_lower_equal_elem(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_gather(dim, tensor, indices)
|
||||
}
|
||||
|
||||
fn int_scatter_add(
|
||||
dim: usize,
|
||||
tensor: IntTensor<B>,
|
||||
indices: IntTensor<B>,
|
||||
value: IntTensor<B>,
|
||||
) -> IntTensor<B> {
|
||||
B::int_scatter_add(dim, tensor, indices, value)
|
||||
}
|
||||
|
||||
fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B> {
|
||||
B::int_select(tensor, dim, indices)
|
||||
}
|
||||
|
||||
fn int_select_add(
|
||||
tensor: IntTensor<B>,
|
||||
dim: usize,
|
||||
indices: IntTensor<B>,
|
||||
value: IntTensor<B>,
|
||||
) -> IntTensor<B> {
|
||||
B::int_select_add(tensor, dim, indices, value)
|
||||
}
|
||||
|
||||
fn int_mask_where(
|
||||
tensor: IntTensor<B>,
|
||||
mask: BoolTensor<B>,
|
||||
value: IntTensor<B>,
|
||||
) -> <Autodiff<B> as Backend>::IntTensorPrimitive {
|
||||
B::int_mask_where(tensor, mask, value)
|
||||
}
|
||||
|
||||
fn int_mask_fill(
|
||||
tensor: IntTensor<B>,
|
||||
mask: BoolTensor<B>,
|
||||
value: Scalar,
|
||||
) -> <Autodiff<B> as Backend>::IntTensorPrimitive {
|
||||
B::int_mask_fill(tensor, mask, value)
|
||||
}
|
||||
|
||||
fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
|
||||
B::int_argmax(tensor, dim)
|
||||
}
|
||||
fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
|
||||
B::int_argmin(tensor, dim)
|
||||
}
|
||||
fn int_max(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
|
||||
B::int_max(tensor)
|
||||
}
|
||||
fn int_max_dim(tensor: B::IntTensorPrimitive, dim: usize) -> B::IntTensorPrimitive {
|
||||
B::int_max_dim(tensor, dim)
|
||||
}
|
||||
fn int_max_dim_with_indices(
|
||||
tensor: B::IntTensorPrimitive,
|
||||
dim: usize,
|
||||
) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) {
|
||||
B::int_max_dim_with_indices(tensor, dim)
|
||||
}
|
||||
fn int_min(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
|
||||
B::int_min(tensor)
|
||||
}
|
||||
fn int_min_dim(tensor: B::IntTensorPrimitive, dim: usize) -> B::IntTensorPrimitive {
|
||||
B::int_min_dim(tensor, dim)
|
||||
}
|
||||
fn int_min_dim_with_indices(
|
||||
tensor: B::IntTensorPrimitive,
|
||||
dim: usize,
|
||||
) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) {
|
||||
B::int_min_dim_with_indices(tensor, dim)
|
||||
}
|
||||
fn int_abs(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
|
||||
B::int_abs(tensor)
|
||||
}
|
||||
fn int_into_float(
|
||||
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive,
|
||||
) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {
|
||||
AutodiffTensor::new(B::int_into_float(tensor))
|
||||
}
|
||||
|
||||
fn int_swap_dims(
|
||||
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive,
|
||||
dim1: usize,
|
||||
dim2: usize,
|
||||
) -> <Autodiff<B> as Backend>::IntTensorPrimitive {
|
||||
B::int_swap_dims(tensor, dim1, dim2)
|
||||
}
|
||||
|
||||
fn int_random(
|
||||
shape: Shape,
|
||||
distribution: Distribution,
|
||||
device: &Device<Self>,
|
||||
) -> IntTensor<Self> {
|
||||
B::int_random(shape, distribution, device)
|
||||
}
|
||||
|
||||
fn int_arange(range: core::ops::Range<i64>, device: &Device<Self>) -> IntTensor<Self> {
|
||||
B::int_arange(range, device)
|
||||
}
|
||||
|
||||
fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
|
||||
B::int_permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
|
||||
B::int_flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
B::int_sign(tensor)
|
||||
}
|
||||
|
||||
fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
B::int_prod(tensor)
|
||||
}
|
||||
|
||||
fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
B::int_prod_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B> {
|
||||
B::int_expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
|
||||
B::int_sort(tensor, dim, descending)
|
||||
}
|
||||
|
||||
fn int_sort_with_indices(
|
||||
tensor: IntTensor<Self>,
|
||||
dim: usize,
|
||||
descending: bool,
|
||||
) -> (IntTensor<Self>, IntTensor<Self>) {
|
||||
B::int_sort_with_indices(tensor, dim, descending)
|
||||
}
|
||||
|
||||
fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
|
||||
B::int_argsort(tensor, dim, descending)
|
||||
}
|
||||
|
||||
fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
B::bitwise_and(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
B::bitwise_and_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
B::bitwise_or(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
B::bitwise_or_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
B::bitwise_xor(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
B::bitwise_xor_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
B::bitwise_not(tensor)
|
||||
}
|
||||
|
||||
fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
B::bitwise_left_shift(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
B::bitwise_left_shift_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
B::bitwise_right_shift(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
B::bitwise_right_shift_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
|
||||
B::int_cast(tensor, dtype)
|
||||
}
|
||||
|
||||
fn int_unfold(
|
||||
tensor: IntTensor<Self>,
|
||||
dim: usize,
|
||||
size: usize,
|
||||
step: usize,
|
||||
) -> IntTensor<Self> {
|
||||
B::int_unfold(tensor, dim, size, step)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
use super::{Backward, Ops, unary};
|
||||
use crate::{checkpoint::base::Checkpointer, grads::Gradients};
|
||||
use burn_backend::{Backend, TensorMetadata};
|
||||
use burn_std::Shape;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct MaxMinDim;
|
||||
|
||||
impl<B: Backend> Backward<B, 1> for MaxMinDim {
|
||||
type State = (B::IntTensorPrimitive, Shape, usize);
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 1>,
|
||||
grads: &mut Gradients,
|
||||
_checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
|
||||
let (indices, shape, dim) = ops.state;
|
||||
let device = B::float_device(&grad);
|
||||
let dtype = grad.dtype();
|
||||
let zeros = B::float_zeros(shape, &device, dtype.into());
|
||||
|
||||
B::float_scatter_add(dim, zeros, indices, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
mod activation;
|
||||
mod backward;
|
||||
mod base;
|
||||
mod bool_tensor;
|
||||
mod int_tensor;
|
||||
mod module;
|
||||
mod qtensor;
|
||||
mod tensor;
|
||||
mod transaction;
|
||||
|
||||
pub(crate) mod maxmin;
|
||||
pub(crate) mod sort;
|
||||
|
||||
pub use backward::*;
|
||||
pub use base::*;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,106 @@
|
||||
use burn_backend::{
|
||||
Backend, ExecutionError, TensorData,
|
||||
ops::QTensorOps,
|
||||
tensor::{
|
||||
Device, FloatTensor, IntTensor, QuantizedTensor,
|
||||
quantization::QuantizationParametersPrimitive,
|
||||
},
|
||||
};
|
||||
use burn_std::{QuantScheme, Shape};
|
||||
|
||||
use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy};
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {
|
||||
fn q_from_data(_data: TensorData, _device: &Device<Self>) -> QuantizedTensor<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn quantize(
|
||||
_tensor: FloatTensor<Self>,
|
||||
_scheme: &QuantScheme,
|
||||
_qparams: QuantizationParametersPrimitive<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
todo!() // required for QAT
|
||||
}
|
||||
|
||||
fn quantize_dynamic(
|
||||
_tensor: FloatTensor<Self>,
|
||||
_scheme: &QuantScheme,
|
||||
) -> QuantizedTensor<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
|
||||
B::q_device(tensor)
|
||||
}
|
||||
|
||||
fn q_to_device(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_device: &Device<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
|
||||
B::q_reshape(tensor, shape)
|
||||
}
|
||||
|
||||
async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
|
||||
B::q_into_data(tensor).await
|
||||
}
|
||||
|
||||
fn q_swap_dims(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_dim1: usize,
|
||||
_dim2: usize,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_gather(
|
||||
_dim: usize,
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_indices: IntTensor<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_select(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_dim: usize,
|
||||
_indices: IntTensor<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_slice(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_slices: &[burn_std::Slice],
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
B::q_argmax(tensor, dim)
|
||||
}
|
||||
|
||||
fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
B::q_argmin(tensor, dim)
|
||||
}
|
||||
|
||||
fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
use super::{Backward, Ops, unary};
|
||||
use crate::{checkpoint::base::Checkpointer, grads::Gradients};
|
||||
use burn_backend::{Backend, TensorMetadata};
|
||||
use burn_std::Shape;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct SortDim;
|
||||
|
||||
impl<B: Backend> Backward<B, 1> for SortDim {
|
||||
type State = (B::IntTensorPrimitive, Shape, usize);
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 1>,
|
||||
grads: &mut Gradients,
|
||||
_checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
|
||||
let (indices, shape, dim) = ops.state;
|
||||
let device = B::float_device(&grad);
|
||||
let dtype = grad.dtype();
|
||||
let zeros = B::float_zeros(shape, &device, dtype.into());
|
||||
|
||||
B::float_scatter_add(dim, zeros, indices, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,24 @@
|
||||
use burn_backend::{
|
||||
Backend, ExecutionError,
|
||||
ops::{TransactionOps, TransactionPrimitive},
|
||||
};
|
||||
|
||||
use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy};
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> TransactionOps<Self> for Autodiff<B, C> {
|
||||
async fn tr_execute(
|
||||
transaction: TransactionPrimitive<Self>,
|
||||
) -> Result<burn_backend::ops::TransactionPrimitiveData, ExecutionError> {
|
||||
B::tr_execute(TransactionPrimitive::new(
|
||||
transaction
|
||||
.read_floats
|
||||
.into_iter()
|
||||
.map(|t| t.primitive)
|
||||
.collect(),
|
||||
transaction.read_qfloats,
|
||||
transaction.read_ints,
|
||||
transaction.read_bools,
|
||||
))
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
use crate::{
|
||||
checkpoint::builder::CheckpointerBuilder,
|
||||
grads::Gradients,
|
||||
graph::StepBoxed,
|
||||
tensor::{AutodiffTensor, NodeRefCount},
|
||||
};
|
||||
use burn_backend::Backend;
|
||||
|
||||
/// Client used to communicate with the autodiff server.
|
||||
pub trait AutodiffClient: Send + Clone {
|
||||
/// Register a new step.
|
||||
fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder);
|
||||
/// Call backpropagation from the given tensor.
|
||||
fn backward<B: Backend>(&self, tensor: AutodiffTensor<B>) -> Gradients;
|
||||
}
|
||||
|
||||
/// Client implementation in used.
|
||||
pub type AutodiffClientImpl = super::graph::GraphMutexClient;
|
||||
@@ -0,0 +1,335 @@
|
||||
use super::{AutodiffClient, server::AutodiffServer};
|
||||
use crate::{
|
||||
NodeId,
|
||||
checkpoint::builder::CheckpointerBuilder,
|
||||
grads::Gradients,
|
||||
graph::{Parent, StepBoxed},
|
||||
runtime::server::NodeCleaner,
|
||||
tensor::{AutodiffTensor, NodeRefCount},
|
||||
};
|
||||
use alloc::sync::Arc;
|
||||
use alloc::vec::Vec;
|
||||
use burn_backend::Backend;
|
||||
use hashbrown::{HashMap, HashSet};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use parking_lot::{Mutex, MutexGuard};
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use spin::{Mutex, MutexGuard};
|
||||
|
||||
/// A client for managing multiple graphs using mutex-based synchronization.
|
||||
///
|
||||
/// The biggest benefit of using this client implementation is that each graph can modify its own
|
||||
/// data without blocking other graphs, which is essential for multi-device training.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The [AutodiffServer] fully supports multiple graphs with sharing nodes, however those type of
|
||||
/// graphs will be stored under a single mutex-protected graph by the client, limiting
|
||||
/// parallelisation.
|
||||
#[derive(Clone, new, Debug)]
|
||||
pub struct GraphMutexClient;
|
||||
|
||||
/// Manages a collection of graphs, mapping [node ids](NodeId) to their respective graph.
|
||||
///
|
||||
/// The `GraphLocator` is responsible for selecting and merging graphs based on their IDs and parent
|
||||
/// dependencies, ensuring proper synchronization and server allocation.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Multiple node ids can point to the same graph, where the autodiff graph is stored.
|
||||
#[derive(Default)]
|
||||
pub struct GraphLocator {
|
||||
graphs: HashMap<NodeId, Arc<Graph>>,
|
||||
/// We keep a mapping of each original node id (graph id) => all nodes that point to that graph.
|
||||
/// This is to ensure that when merging graphs, we correctly move all previous graphs to
|
||||
/// the new merged one.
|
||||
keys: HashMap<NodeId, HashSet<NodeId>>,
|
||||
}
|
||||
|
||||
/// Represents a single computation graph with a mutex-protected server.
|
||||
///
|
||||
/// Each `Graph` contains an [AutodiffServer] and the original [NodeId] where the server was
|
||||
/// first created.
|
||||
pub(crate) struct Graph {
|
||||
origin: NodeId,
|
||||
state: Mutex<GraphState>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct GraphState {
|
||||
server: AutodiffServer,
|
||||
}
|
||||
|
||||
impl core::fmt::Debug for Graph {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.debug_struct("Graph")
|
||||
.field("origin", &self.origin)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
static STATE: Mutex<Option<GraphLocator>> = Mutex::new(None);
|
||||
|
||||
impl GraphMutexClient {
|
||||
/// Retrieves or creates a graph for the given [NodeId] and parent dependencies.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `node`: The unique identifier for the stream.
|
||||
/// - `parents`: A slice of parent nodes that the stream depends on.
|
||||
///
|
||||
/// # Returns
|
||||
/// An `Arc<Graph>` representing the selected or newly created stream.
|
||||
fn graph(node: NodeId, parents: &[Parent]) -> Arc<Graph> {
|
||||
let mut state = STATE.lock();
|
||||
|
||||
match state.as_mut() {
|
||||
Some(locator) => locator.select(node, parents),
|
||||
None => {
|
||||
let mut locator = GraphLocator::default();
|
||||
let stream = locator.select(node, parents);
|
||||
*state = Some(locator);
|
||||
stream
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AutodiffClient for GraphMutexClient {
|
||||
fn register(&self, node_id_ref: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {
|
||||
let node_id = *node_id_ref;
|
||||
let graph = GraphMutexClient::graph(node_id, step.parents());
|
||||
let mut state = graph.state.lock();
|
||||
|
||||
state.server.register(node_id_ref, step, actions);
|
||||
}
|
||||
|
||||
fn backward<B: Backend>(&self, root: AutodiffTensor<B>) -> Gradients {
|
||||
let node_id = root.node.id;
|
||||
let graph = GraphMutexClient::graph(root.node.id, &[]);
|
||||
|
||||
let grads = Gradients::new::<B>(root.node, root.primitive);
|
||||
let grads = {
|
||||
let mut state = graph.state.lock();
|
||||
state.server.backward::<GraphCleaner>(grads, node_id)
|
||||
}; // lock released
|
||||
|
||||
GraphCleaner::cleanup_orphaned_entries();
|
||||
|
||||
grads
|
||||
}
|
||||
}
|
||||
|
||||
struct GraphCleaner<'a> {
|
||||
guard: MutexGuard<'a, Option<GraphLocator>>,
|
||||
}
|
||||
|
||||
impl<'a> GraphCleaner<'a> {
|
||||
fn cleanup_orphaned_entries() {
|
||||
let graphs = {
|
||||
// Get the available graphs and release the lock
|
||||
match STATE.lock().as_ref() {
|
||||
Some(state) => state.graphs.clone(),
|
||||
None => return,
|
||||
}
|
||||
};
|
||||
|
||||
let mut should_remove = Vec::new();
|
||||
for graph in graphs.values() {
|
||||
{
|
||||
let mut guard = graph.state.lock();
|
||||
// Double safety: in case it was marked as no longer useful, but other
|
||||
// nodes are still relevant, we only check which nodes can safely be removed.
|
||||
if !guard.server.maybe_useful() {
|
||||
guard
|
||||
.server
|
||||
.free_unused_roots(|node| should_remove.push(*node));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !should_remove.is_empty() {
|
||||
let mut state = STATE.lock();
|
||||
if let Some(state) = state.as_mut() {
|
||||
for node in should_remove {
|
||||
state.remove_entry(&node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> NodeCleaner for GraphCleaner<'a> {
|
||||
fn init() -> Self {
|
||||
let guard = STATE.lock();
|
||||
Self { guard }
|
||||
}
|
||||
|
||||
fn clean(&mut self, node: &NodeId) {
|
||||
if let Some(state) = self.guard.as_mut() {
|
||||
state.remove_entry(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphLocator {
|
||||
/// Selects a single graph for the given [NodeId], considering parent dependencies.
|
||||
///
|
||||
/// If multiple graphs are found, they are merged into a single one.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `node`: The node ID of the graph to select.
|
||||
/// - `parents`: A slice of parent nodes that the graph depends on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An `Arc<Graph>` representing the selected or merged graph.
|
||||
pub(crate) fn select(&mut self, node: NodeId, parents: &[Parent]) -> Arc<Graph> {
|
||||
match self.analyse(node, parents) {
|
||||
GraphAnalysis::NoCollision(graph) => {
|
||||
if graph.origin != node {
|
||||
self.graphs.insert(node, graph.clone());
|
||||
self.register_key(graph.origin, node);
|
||||
}
|
||||
|
||||
graph
|
||||
}
|
||||
GraphAnalysis::Collisions(graphs) => self.merge(node, graphs),
|
||||
}
|
||||
}
|
||||
|
||||
/// Analyses the graph for a given node and its parents, returning the associated `GraphAnalysis`.
|
||||
fn analyse(&mut self, node: NodeId, parents: &[Parent]) -> GraphAnalysis {
|
||||
// If no parents, there is no collision, therefore a single graph is ok.
|
||||
if parents.is_empty() {
|
||||
let graph = match self.graphs.get(&node) {
|
||||
Some(val) => val.clone(),
|
||||
None => self.new_graph(node),
|
||||
};
|
||||
return GraphAnalysis::NoCollision(graph);
|
||||
};
|
||||
|
||||
// We collect all graphs of parents and of the current node based on their origin node id.
|
||||
let mut graphs = HashMap::<NodeId, Arc<Graph>>::new();
|
||||
|
||||
if let Some(val) = self.graphs.get(&node) {
|
||||
graphs.insert(val.origin, val.clone());
|
||||
}
|
||||
|
||||
for parent in parents {
|
||||
match self.graphs.get(&parent.id) {
|
||||
Some(graph) => graphs.insert(graph.origin, graph.clone()),
|
||||
None => continue,
|
||||
};
|
||||
}
|
||||
|
||||
if graphs.is_empty() {
|
||||
return match self.graphs.get(&node) {
|
||||
Some(old) => GraphAnalysis::NoCollision(old.clone()),
|
||||
None => GraphAnalysis::NoCollision(self.new_graph(node)),
|
||||
};
|
||||
}
|
||||
|
||||
if graphs.len() == 1 {
|
||||
return GraphAnalysis::NoCollision(graphs.drain().next().unwrap().1);
|
||||
}
|
||||
|
||||
GraphAnalysis::Collisions(graphs)
|
||||
}
|
||||
|
||||
/// Merges multiple graphs associated with a node into a single graph.
|
||||
fn merge(&mut self, node: NodeId, mut graphs: HashMap<NodeId, Arc<Graph>>) -> Arc<Graph> {
|
||||
let mut graphs = graphs.drain().map(|g| g.1);
|
||||
|
||||
let main = graphs.next().expect("At least one graph");
|
||||
self.register_key(main.origin, node);
|
||||
|
||||
let mut state = main.state.lock();
|
||||
|
||||
for graph in graphs {
|
||||
self.merge_two(&mut state, &main, graph);
|
||||
}
|
||||
|
||||
self.graphs.insert(main.origin, main.clone());
|
||||
self.graphs.insert(node, main.clone());
|
||||
|
||||
core::mem::drop(state);
|
||||
|
||||
main
|
||||
}
|
||||
|
||||
/// Registers a key for a given origin node.
|
||||
fn register_key(&mut self, origin: NodeId, key: NodeId) {
|
||||
if !self.keys.contains_key(&origin) {
|
||||
// Ensure an entry exists for this origin
|
||||
self.keys.insert(origin, HashSet::new());
|
||||
}
|
||||
|
||||
if origin != key {
|
||||
// Register this node to point to the origin graph
|
||||
self.keys.get_mut(&origin).unwrap().insert(key);
|
||||
}
|
||||
}
|
||||
|
||||
/// Merges two graphs by combining their states and updating graph mappings.
|
||||
fn merge_two(&mut self, main_state: &mut GraphState, main: &Arc<Graph>, merged: Arc<Graph>) {
|
||||
let mut locked = merged.state.lock();
|
||||
let mut state_old = GraphState::default();
|
||||
core::mem::swap(&mut state_old, &mut locked);
|
||||
main_state.server.extend(state_old.server);
|
||||
|
||||
// Re-map merged origin to the main graph
|
||||
self.graphs.insert(merged.origin, main.clone());
|
||||
|
||||
// Move all keys (node IDs) from the merged graph to the main graph
|
||||
if let Some(locator_keys) = self.keys.remove(&merged.origin) {
|
||||
for k in locator_keys.iter() {
|
||||
self.graphs.insert(*k, main.clone());
|
||||
}
|
||||
|
||||
let locator_keys_main = self
|
||||
.keys
|
||||
.get_mut(&main.origin)
|
||||
.expect("Should be init before the merge.");
|
||||
locator_keys_main.extend(locator_keys);
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new graph for a given node.
|
||||
fn new_graph(&mut self, origin: NodeId) -> Arc<Graph> {
|
||||
let graph = Arc::new(Graph {
|
||||
origin,
|
||||
state: Mutex::new(GraphState::default()),
|
||||
});
|
||||
self.graphs.insert(origin, graph.clone());
|
||||
self.keys.insert(origin, HashSet::new());
|
||||
graph
|
||||
}
|
||||
|
||||
fn remove_entry(&mut self, node: &NodeId) {
|
||||
if let Some(graph) = self.graphs.remove(node) {
|
||||
let mut remove = false;
|
||||
|
||||
if let Some(entry) = self.keys.get_mut(&graph.origin) {
|
||||
entry.remove(node);
|
||||
if entry.is_empty() {
|
||||
remove = true;
|
||||
}
|
||||
}
|
||||
|
||||
if remove {
|
||||
self.keys.remove(&graph.origin);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the analysis result of graph operations for a given node and its parents.
|
||||
#[derive(Debug)]
|
||||
enum GraphAnalysis {
|
||||
/// No collision detected, contains the graph associated with the node.
|
||||
NoCollision(Arc<Graph>),
|
||||
/// Collision detected, contains a map of node IDs to their associated graphs.
|
||||
Collisions(HashMap<NodeId, Arc<Graph>>),
|
||||
}
|
||||
@@ -0,0 +1,294 @@
|
||||
use crate::{
|
||||
NodeId,
|
||||
collections::{HashMap, HashSet},
|
||||
graph::Parent,
|
||||
tensor::NodeRefCount,
|
||||
};
|
||||
use alloc::{borrow::ToOwned, sync::Arc, vec, vec::Vec};
|
||||
use core::mem;
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct GraphMemoryManagement {
|
||||
nodes: HashMap<NodeRefCount, Vec<NodeId>>,
|
||||
leaves: HashSet<NodeId>,
|
||||
statuses: HashMap<NodeId, NodeMemoryStatus>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum NodeMemoryStatus {
|
||||
Useful,
|
||||
Unavailable,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl GraphMemoryManagement {
|
||||
pub fn extend(&mut self, other: Self) {
|
||||
self.nodes.extend(other.nodes);
|
||||
self.leaves.extend(other.leaves);
|
||||
self.statuses.extend(other.statuses);
|
||||
}
|
||||
|
||||
/// Register a new node with its parent.
|
||||
pub fn register(&mut self, node: NodeRefCount, parents: &[Parent]) {
|
||||
let node_id = *node.as_ref();
|
||||
|
||||
for parent in parents.iter() {
|
||||
self.leaves.remove(&parent.id);
|
||||
}
|
||||
|
||||
self.leaves.insert(node_id);
|
||||
self.nodes
|
||||
.insert(node, parents.iter().map(|p| p.id).collect());
|
||||
}
|
||||
|
||||
/// Free the node from the state.
|
||||
pub fn consume_node(&mut self, node_id: NodeId) {
|
||||
if !self.is_referenced(node_id) {
|
||||
self.leaves.remove(&node_id);
|
||||
self.nodes.remove(&node_id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Free all nodes whose backward call has become impossible
|
||||
///
|
||||
/// This function goes into three steps, which must happen for all leaves
|
||||
/// before going into the next step. Then it deletes what can be safely deleted
|
||||
pub(crate) fn free_unavailable_nodes(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
|
||||
let leaves = self.leaves.clone();
|
||||
let mut new_leaves = HashSet::new();
|
||||
let mut deletables = Vec::new();
|
||||
|
||||
// When consuming nodes with a backward pass, some other backward passes become
|
||||
// unavailable because some of their parents have been consumed. They are
|
||||
// identified here.
|
||||
for leaf in leaves.clone() {
|
||||
self.unavailable_propagation(leaf);
|
||||
}
|
||||
|
||||
// Among the available nodes that remain, some may be useless if no
|
||||
// available node with a tensor reference exist in their descendance.
|
||||
// But some may seem useless from some leaf but be useful from another one,
|
||||
// hence the need to iterate on all leaves.
|
||||
self.useful_propagation(leaves.clone());
|
||||
|
||||
// New leaves are the roots of a useful backward sub-tree.
|
||||
// Deletables are everything not marked as useful.
|
||||
for leaf in leaves {
|
||||
self.identify_leaves_and_deletables(leaf, &mut new_leaves, &mut deletables);
|
||||
}
|
||||
|
||||
// Replace leaves by the new ones and delete everything not useful anymore
|
||||
mem::swap(&mut self.leaves, &mut new_leaves);
|
||||
|
||||
self.clear_unused_roots(&mut deletables);
|
||||
|
||||
self.statuses.clear();
|
||||
for node_to_delete in deletables {
|
||||
self.nodes.remove(&node_to_delete);
|
||||
on_free_graph(&node_to_delete)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn free_unused_roots(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
|
||||
let mut deletables = Vec::new();
|
||||
self.clear_unused_roots(&mut deletables);
|
||||
|
||||
for node_id in deletables {
|
||||
self.nodes.remove(&node_id);
|
||||
on_free_graph(&node_id);
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_unused_roots(&self, to_delete: &mut Vec<NodeId>) {
|
||||
for (id, parents) in self.nodes.iter() {
|
||||
let is_useful = matches!(
|
||||
self.statuses.get(id.as_ref()),
|
||||
Some(NodeMemoryStatus::Useful)
|
||||
);
|
||||
|
||||
// Check if parents are either empty or absent from self.nodes
|
||||
let parents_absent = parents.iter().all(|p| !self.nodes.contains_key(p));
|
||||
|
||||
if !is_useful && Arc::strong_count(id) == 1 && parents_absent {
|
||||
to_delete.push(*id.as_ref())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn unavailable_propagation(&mut self, node_id: NodeId) -> NodeMemoryStatus {
|
||||
// If already visited
|
||||
if let Some(status) = self.statuses.get(&node_id) {
|
||||
return status.clone();
|
||||
}
|
||||
|
||||
match self.nodes.get(&node_id).cloned() {
|
||||
// If node exists and any of its parents is unavailable, it is unavailable as well
|
||||
// If node exists but the parents vec is empty, it is a tensor that never had parents;
|
||||
// the status remains unknown
|
||||
Some(parents) => {
|
||||
let mut node_status = NodeMemoryStatus::Unknown;
|
||||
for parent in parents {
|
||||
let parent_status = self.unavailable_propagation(parent);
|
||||
if let NodeMemoryStatus::Unavailable = parent_status {
|
||||
node_status = NodeMemoryStatus::Unavailable;
|
||||
}
|
||||
}
|
||||
self.statuses.insert(node_id, node_status.clone());
|
||||
node_status
|
||||
}
|
||||
// If node does not exist, it was
|
||||
// deleted, so this and all its descendants are unavailable
|
||||
None => {
|
||||
self.statuses.insert(node_id, NodeMemoryStatus::Unavailable);
|
||||
NodeMemoryStatus::Unavailable
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn useful_propagation(&mut self, leaves: HashSet<NodeId>) {
|
||||
// Accumulate visited nodes
|
||||
let mut explored = HashSet::new();
|
||||
let mut tagged_useful = HashSet::new();
|
||||
|
||||
// Queue of nodes to visit
|
||||
let mut to_tag_useful = PopNodeSet::default();
|
||||
let mut to_explore = PopNodeSet::new(leaves);
|
||||
|
||||
// Utility function to iterate over a node's parents
|
||||
let parents = |node_id| {
|
||||
self.nodes
|
||||
.get(&node_id)
|
||||
.cloned()
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
};
|
||||
|
||||
loop {
|
||||
// Pop a node id, greedily looking at tag_useful ones first
|
||||
let (node_id, status) = match to_tag_useful.pop() {
|
||||
Some(node_id) => (node_id, NodeMemoryStatus::Useful),
|
||||
None => match to_explore.pop() {
|
||||
Some(node_id) => {
|
||||
let node_status = self
|
||||
.statuses
|
||||
.get(&node_id)
|
||||
.expect("All nodes should have received a status during unavailable_propagation")
|
||||
.to_owned();
|
||||
|
||||
if let NodeMemoryStatus::Unknown = node_status {
|
||||
match self.is_referenced(node_id) {
|
||||
true => (node_id, NodeMemoryStatus::Useful),
|
||||
false => (node_id, NodeMemoryStatus::Unknown),
|
||||
}
|
||||
} else {
|
||||
(node_id, node_status)
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// There are no nodes in the queues anymore
|
||||
break;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
match status {
|
||||
NodeMemoryStatus::Useful => {
|
||||
tagged_useful.insert(node_id);
|
||||
for parent in parents(node_id) {
|
||||
// The node can be explored, as long as it's not already tagged useful
|
||||
if !(tagged_useful.contains(&parent) || to_tag_useful.contains(&parent)) {
|
||||
to_tag_useful.insert(parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
explored.insert(node_id);
|
||||
for parent in parents(node_id) {
|
||||
if !(explored.contains(&parent) || to_explore.contains(&parent)) {
|
||||
to_explore.insert(parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.statuses.insert(node_id, status);
|
||||
}
|
||||
}
|
||||
|
||||
fn identify_leaves_and_deletables(
|
||||
&self,
|
||||
leaf_id: NodeId,
|
||||
new_leaves: &mut HashSet<NodeId>,
|
||||
to_delete: &mut Vec<NodeId>,
|
||||
) {
|
||||
let mut visited = HashSet::new();
|
||||
let mut to_visit = vec![leaf_id];
|
||||
|
||||
while let Some(node_id) = to_visit.pop() {
|
||||
visited.insert(node_id);
|
||||
|
||||
match self
|
||||
.statuses
|
||||
.get(&node_id)
|
||||
.expect("Node should have status")
|
||||
{
|
||||
NodeMemoryStatus::Useful => {
|
||||
new_leaves.insert(node_id);
|
||||
}
|
||||
_ => {
|
||||
to_delete.push(node_id);
|
||||
|
||||
for parent in self
|
||||
.nodes
|
||||
.get(&node_id)
|
||||
.cloned()
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
{
|
||||
if !visited.contains(&parent) {
|
||||
to_visit.push(parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn is_referenced(&self, node_id: NodeId) -> bool {
|
||||
match self.nodes.get_key_value(&node_id) {
|
||||
Some((key, _value)) => Arc::strong_count(key) > 1,
|
||||
None => panic!("Node should be in the nodes map"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn maybe_useful(&self) -> bool {
|
||||
self.nodes.keys().any(|node| Arc::strong_count(node) > 1)
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper over hash set for fast popping of any node
|
||||
#[derive(new, Default)]
|
||||
struct PopNodeSet {
|
||||
hash_set: HashSet<NodeId>,
|
||||
}
|
||||
|
||||
impl PopNodeSet {
|
||||
#[inline(always)]
|
||||
fn pop(&mut self) -> Option<NodeId> {
|
||||
self.hash_set
|
||||
.iter()
|
||||
.next()
|
||||
.copied()
|
||||
.and_then(|node_id| self.hash_set.take(&node_id))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn contains(&self, node_id: &NodeId) -> bool {
|
||||
self.hash_set.contains(node_id)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn insert(&mut self, node_id: NodeId) {
|
||||
self.hash_set.insert(node_id);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
mod client;
|
||||
mod memory_management;
|
||||
mod server;
|
||||
|
||||
pub mod graph;
|
||||
pub use client::*;
|
||||
@@ -0,0 +1,143 @@
|
||||
use super::memory_management::GraphMemoryManagement;
|
||||
use crate::{
|
||||
NodeId,
|
||||
checkpoint::{
|
||||
base::{Checkpointer, NodeTree},
|
||||
builder::CheckpointerBuilder,
|
||||
},
|
||||
collections::HashMap,
|
||||
grads::Gradients,
|
||||
graph::{StepBoxed, traversal::BreadthFirstSearch},
|
||||
tensor::NodeRefCount,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct AutodiffServer {
|
||||
steps: HashMap<NodeId, StepBoxed>,
|
||||
actions_builder: HashMap<NodeId, CheckpointerBuilder>,
|
||||
memory_management: GraphMemoryManagement,
|
||||
}
|
||||
|
||||
/// Defines how nodes are clean.
|
||||
pub trait NodeCleaner {
|
||||
/// Initialize a new cleaner.
|
||||
fn init() -> Self;
|
||||
/// Cleans a single [node](NodeId).
|
||||
fn clean(&mut self, node: &NodeId);
|
||||
}
|
||||
|
||||
impl AutodiffServer {
|
||||
pub fn extend(&mut self, other: AutodiffServer) {
|
||||
self.steps.extend(other.steps);
|
||||
self.actions_builder.extend(other.actions_builder);
|
||||
self.memory_management.extend(other.memory_management);
|
||||
}
|
||||
|
||||
pub fn register(&mut self, rc: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {
|
||||
let parents = step.parents();
|
||||
let node_id = *rc.as_ref();
|
||||
|
||||
self.memory_management.register(rc, parents);
|
||||
|
||||
self.steps.insert(node_id, step);
|
||||
self.actions_builder.insert(node_id, actions);
|
||||
}
|
||||
|
||||
pub fn backward<NC: NodeCleaner>(&mut self, grads: Gradients, node_id: NodeId) -> Gradients {
|
||||
let step = self.steps.remove(&node_id).expect(
|
||||
"Node should have a step registered, did you forget to call \
|
||||
`Tensor::register_grad` on the tensor where you need gradients?",
|
||||
);
|
||||
let builder = self.actions_builder.remove(&node_id).unwrap();
|
||||
|
||||
let mut consumed = Vec::new();
|
||||
let (tape, checkpointer) = self.build_tape(node_id, step, builder, &mut consumed);
|
||||
|
||||
let gradients = Self::execute_steps(tape, grads, checkpointer);
|
||||
|
||||
// Cleanup
|
||||
let mut cleaner = NC::init();
|
||||
self.memory_management
|
||||
.free_unavailable_nodes(|node_id: &NodeId| {
|
||||
self.steps.remove(node_id);
|
||||
self.actions_builder.remove(node_id);
|
||||
NC::clean(&mut cleaner, node_id);
|
||||
});
|
||||
for node_id in consumed {
|
||||
cleaner.clean(&node_id)
|
||||
}
|
||||
|
||||
gradients
|
||||
}
|
||||
|
||||
pub(crate) fn free_unused_roots(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
|
||||
self.memory_management.free_unused_roots(|node_id| {
|
||||
self.steps.remove(node_id);
|
||||
self.actions_builder.remove(node_id);
|
||||
on_free_graph(node_id);
|
||||
});
|
||||
}
|
||||
|
||||
fn build_tape(
|
||||
&mut self,
|
||||
node: NodeId,
|
||||
node_step: StepBoxed,
|
||||
mut builder: CheckpointerBuilder,
|
||||
consumed: &mut Vec<NodeId>,
|
||||
) -> (Vec<Vec<StepBoxed>>, Checkpointer) {
|
||||
let mut tape = (0..node_step.depth())
|
||||
.map(|_| Vec::with_capacity(1))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut tree = HashMap::default();
|
||||
|
||||
BreadthFirstSearch.traverse(node, node_step, &mut self.steps, |id, step| {
|
||||
self.memory_management.consume_node(id);
|
||||
// Clean up consumed node
|
||||
consumed.push(id);
|
||||
|
||||
let depth = step.depth();
|
||||
|
||||
if depth == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(steps) = tape.get_mut(depth - 1) {
|
||||
let parents = step.parents().iter().map(|p| p.id).filter(|s| *s != id);
|
||||
tree.insert(id, parents.collect());
|
||||
steps.push(step);
|
||||
}
|
||||
|
||||
if let Some(node_builder) = self.actions_builder.remove(&id) {
|
||||
builder.extend(node_builder);
|
||||
}
|
||||
});
|
||||
|
||||
let checkpointer = builder.build(NodeTree::new(tree));
|
||||
|
||||
(tape, checkpointer)
|
||||
}
|
||||
|
||||
fn execute_steps(
|
||||
tape: Vec<Vec<StepBoxed>>,
|
||||
mut grads: Gradients,
|
||||
mut checkpointer: Checkpointer,
|
||||
) -> Gradients {
|
||||
tape.into_iter().rev().for_each(|steps| {
|
||||
steps
|
||||
.into_iter()
|
||||
.for_each(|step| step.step(&mut grads, &mut checkpointer))
|
||||
});
|
||||
|
||||
// For checkpointing tests
|
||||
#[cfg(feature = "export_tests")]
|
||||
assert!(checkpointer.is_empty());
|
||||
|
||||
grads
|
||||
}
|
||||
|
||||
pub(crate) fn maybe_useful(&self) -> bool {
|
||||
self.memory_management.maybe_useful()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
use crate::{
|
||||
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder},
|
||||
grads::Gradients,
|
||||
graph::{ComputingProperty, Node, NodeId, NodeRef, Parent, Requirement, Step},
|
||||
runtime::{AutodiffClient, AutodiffClientImpl},
|
||||
};
|
||||
use alloc::{boxed::Box, sync::Arc, vec};
|
||||
use burn_backend::{Backend, TensorMetadata};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AutodiffTensor<B: Backend> {
|
||||
pub primitive: B::FloatTensorPrimitive,
|
||||
pub node: NodeRef,
|
||||
pub rc: NodeRefCount,
|
||||
}
|
||||
|
||||
impl<B: Backend> TensorMetadata for AutodiffTensor<B> {
|
||||
fn dtype(&self) -> burn_std::DType {
|
||||
self.primitive.dtype()
|
||||
}
|
||||
|
||||
fn shape(&self) -> burn_std::Shape {
|
||||
self.primitive.shape()
|
||||
}
|
||||
|
||||
fn rank(&self) -> usize {
|
||||
self.primitive.rank()
|
||||
}
|
||||
}
|
||||
|
||||
pub type NodeRefCount = Arc<NodeId>;
|
||||
|
||||
#[derive(new, Debug)]
|
||||
pub(crate) struct RootStep {
|
||||
node: NodeRef,
|
||||
}
|
||||
|
||||
impl Step for RootStep {
|
||||
fn step(self: Box<Self>, _grads: &mut Gradients, _checkpointer: &mut Checkpointer) {
|
||||
// Nothing to do
|
||||
}
|
||||
|
||||
fn node(&self) -> NodeId {
|
||||
self.node.id
|
||||
}
|
||||
|
||||
fn parents(&self) -> &[Parent] {
|
||||
&self.node.parents
|
||||
}
|
||||
|
||||
fn depth(&self) -> usize {
|
||||
self.node.order
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> AutodiffTensor<B> {
|
||||
/// Create a new leaf tensor.
|
||||
pub fn new(primitive: B::FloatTensorPrimitive) -> Self {
|
||||
let id = NodeId::new();
|
||||
let node: NodeRef = Node::new(
|
||||
vec![],
|
||||
0,
|
||||
id,
|
||||
Requirement::None,
|
||||
ComputingProperty::Ambiguous,
|
||||
AutodiffClientImpl::new(),
|
||||
)
|
||||
.into();
|
||||
|
||||
Self {
|
||||
rc: Arc::new(node.id),
|
||||
primitive,
|
||||
node: node.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_tracked(&self) -> bool {
|
||||
!self.node.requirement.is_none()
|
||||
}
|
||||
|
||||
/// Mark the tensor as requiring gradients.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// It panics if the tensor is not a leaf.
|
||||
pub fn require_grad(mut self) -> Self {
|
||||
match self.node.requirement {
|
||||
Requirement::Grad => self,
|
||||
Requirement::GradInBackward => {
|
||||
panic!("Can't convert a non leaf tensor into a tracked tensor")
|
||||
}
|
||||
Requirement::None => {
|
||||
self.node = Node::new(
|
||||
vec![],
|
||||
0,
|
||||
self.node.id,
|
||||
Requirement::Grad,
|
||||
self.node.properties.clone(),
|
||||
self.node.client.clone(),
|
||||
)
|
||||
.into();
|
||||
let step = RootStep::new(self.node.clone());
|
||||
|
||||
self.register_step(step, CheckpointerBuilder::default())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a tensor from parent infos.
|
||||
pub fn from_parents(
|
||||
primitive: B::FloatTensorPrimitive,
|
||||
parent_nodes: &[NodeRef],
|
||||
requirement: Requirement,
|
||||
computing_properties: ComputingProperty,
|
||||
) -> Self {
|
||||
let order = parent_nodes
|
||||
.iter()
|
||||
.map(|node| node.order)
|
||||
.reduce(usize::max)
|
||||
.unwrap_or(0)
|
||||
+ 1;
|
||||
|
||||
let client = parent_nodes
|
||||
.first()
|
||||
.map(|node| node.client.clone())
|
||||
.unwrap_or_else(AutodiffClientImpl::new);
|
||||
|
||||
let node: NodeRef = Node::new(
|
||||
parent_nodes
|
||||
.iter()
|
||||
.filter_map(|node| node.clone_if_require_grad())
|
||||
.map(|node| Parent::new(node.id))
|
||||
.collect(),
|
||||
order,
|
||||
NodeId::new(),
|
||||
requirement,
|
||||
computing_properties,
|
||||
client,
|
||||
)
|
||||
.into();
|
||||
|
||||
Self {
|
||||
rc: Arc::new(node.id),
|
||||
primitive,
|
||||
node,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a step into a graph for that tensor.
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// This should be called only once per tensor.
|
||||
pub fn register_step<S: Step + 'static>(
|
||||
self,
|
||||
step_that_created_the_tensor: S,
|
||||
actions: CheckpointerBuilder,
|
||||
) -> Self {
|
||||
self.node.client.register(
|
||||
self.rc.clone(),
|
||||
Box::new(step_that_created_the_tensor),
|
||||
actions,
|
||||
);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn into_primitive(self) -> B::FloatTensorPrimitive {
|
||||
self.primitive
|
||||
}
|
||||
|
||||
pub fn backward(self) -> Gradients {
|
||||
let client = self.node.client.clone();
|
||||
|
||||
AutodiffClient::backward::<B>(&client, self)
|
||||
}
|
||||
|
||||
pub fn grad(&self, grads: &Gradients) -> Option<B::FloatTensorPrimitive> {
|
||||
grads.get::<B>(self)
|
||||
}
|
||||
|
||||
pub fn grad_remove(&self, grads: &mut Gradients) -> Option<B::FloatTensorPrimitive> {
|
||||
grads.remove::<B>(self)
|
||||
}
|
||||
|
||||
pub fn grad_replace(&self, grads: &mut Gradients, grad: B::FloatTensorPrimitive) {
|
||||
grads.remove::<B>(self);
|
||||
grads.register::<B>(self.node.id, grad);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use crate::graph::NodeRef;
|
||||
/// Duplicate the given object for each node that requires gradients.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This is useful since you don't have to keep N cloned references alive event if just 1 node
|
||||
/// will be updated.
|
||||
///
|
||||
/// If the object is a tensor and if one reference exists, it can be updated inplace.
|
||||
pub fn duplicate<T: Clone + core::fmt::Debug, const N: usize>(
|
||||
nodes: &[Option<NodeRef>; N],
|
||||
obj: Option<T>,
|
||||
) -> [Option<T>; N] {
|
||||
nodes
|
||||
.iter()
|
||||
.map(|node| match node {
|
||||
Some(_) => obj.clone(),
|
||||
None => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap()
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
[alias]
|
||||
test-cpu = "test --release --no-default-features --features cpu,std"
|
||||
test-cuda = "test --release --no-default-features --features cuda,std"
|
||||
test-ndarray = "test --release --no-default-features --features ndarray,std"
|
||||
test-rocm = "test --release --no-default-features --features rocm,std"
|
||||
test-router = "test --release --no-default-features --features router,std"
|
||||
test-tch = "test --release --no-default-features --features tch,std"
|
||||
test-wgpu = "test --release --no-default-features --features wgpu,std"
|
||||
test-vulkan = "test --release --no-default-features --features vulkan,std"
|
||||
test-metal = "test --release --no-default-features --features metal,std"
|
||||
@@ -0,0 +1,120 @@
|
||||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science", "no-std", "embedded", "wasm"]
|
||||
description = "Tensor tests for Burn backends"
|
||||
documentation = "https://docs.rs/burn-backend-tests"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
|
||||
license.workspace = true
|
||||
name = "burn-backend-tests"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-backend-tests"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = [
|
||||
"burn-tensor/default",
|
||||
"burn-autodiff/default",
|
||||
# Backends (default not enabled for CubeCL backends as it activates fusion)
|
||||
"burn-cpu?/default",
|
||||
"burn-ndarray?/default",
|
||||
"burn-tch?/default",
|
||||
# Default
|
||||
"ndarray",
|
||||
"std",
|
||||
]
|
||||
std = [
|
||||
"burn-tensor/std",
|
||||
"burn-autodiff/std",
|
||||
# Backends
|
||||
"burn-cpu?/std",
|
||||
"burn-ndarray?/std",
|
||||
"burn-wgpu?/std",
|
||||
"burn-router?/std",
|
||||
"burn-cuda?/std",
|
||||
"burn-rocm?/std",
|
||||
]
|
||||
|
||||
tracing = [
|
||||
"cubecl?/tracing",
|
||||
"burn-tensor/tracing",
|
||||
"burn-autodiff/tracing",
|
||||
# Backends
|
||||
"burn-cpu?/tracing",
|
||||
"burn-ndarray?/tracing",
|
||||
"burn-wgpu?/tracing",
|
||||
"burn-router?/tracing",
|
||||
"burn-cuda?/tracing",
|
||||
"burn-rocm?/tracing",
|
||||
]
|
||||
|
||||
# Backends
|
||||
cuda = ["burn-cuda", "quantization", "cube"]
|
||||
rocm = ["burn-rocm", "quantization", "cube"]
|
||||
ndarray = ["burn-ndarray", "quantization"]
|
||||
tch = ["burn-tch"]
|
||||
vulkan = ["wgpu", "burn-wgpu/vulkan"]
|
||||
webgpu = ["wgpu", "burn-wgpu/webgpu"]
|
||||
metal = ["wgpu", "burn-wgpu/metal"]
|
||||
wgpu = ["burn-wgpu", "quantization", "cube"]
|
||||
cpu = ["burn-cpu", "cube"]
|
||||
router = ["burn-router", "ndarray", "burn-wgpu"]
|
||||
|
||||
autotune = [
|
||||
"burn-wgpu?/autotune",
|
||||
"burn-cuda?/autotune",
|
||||
"burn-rocm?/autotune",
|
||||
"burn-cpu?/autotune",
|
||||
]
|
||||
autotune-checks = [
|
||||
"burn-wgpu?/autotune-checks",
|
||||
"burn-cuda?/autotune-checks",
|
||||
"burn-rocm?/autotune-checks",
|
||||
"burn-cpu?/autotune-checks",
|
||||
]
|
||||
|
||||
# CubeCL backends
|
||||
cube = [
|
||||
"cubecl",
|
||||
"cubek",
|
||||
"autotune",
|
||||
"burn-fusion",
|
||||
"burn-cubecl",
|
||||
"burn-ndarray",
|
||||
]
|
||||
|
||||
# Test configs
|
||||
quantization = []
|
||||
|
||||
[dependencies]
|
||||
burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", default-features = false }
|
||||
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "=0.21.0-pre.2" }
|
||||
|
||||
# Backends
|
||||
burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2", default-features = false, features = [
|
||||
"export_tests",
|
||||
] }
|
||||
burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-cpu = { path = "../burn-cpu", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-rocm = { path = "../burn-rocm", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2", optional = true, default-features = false, features = [
|
||||
"export_tests",
|
||||
] }
|
||||
burn-router = { path = "../burn-router", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
|
||||
# To wrap `Fusion<CubeBackend>
|
||||
burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true }
|
||||
burn-cubecl = { path = "../burn-cubecl", version = "=0.21.0-pre.2", optional = true, features = [
|
||||
"fusion",
|
||||
] }
|
||||
|
||||
num-traits = { workspace = true }
|
||||
serial_test = { workspace = true }
|
||||
|
||||
cubecl = { workspace = true, optional = true }
|
||||
cubek = { workspace = true, features = ["random"], optional = true }
|
||||
@@ -0,0 +1,111 @@
|
||||
# Burn Backend Tests
|
||||
|
||||
This crate provides a comprehensive suite of tests for Burn backends, covering:
|
||||
|
||||
- Tensor operations: [tests/tensor/](./tests/tensor/)
|
||||
- Autodiff: [tests/autodiff/](./tests/autodiff/)
|
||||
- (Optional) CubeCL kernels correctness: [tests/cubecl/](./tests/cubecl/)
|
||||
|
||||
## Running Tests
|
||||
|
||||
The `TestBackend` is selected via feature flags. Use the provided shorthand commands for
|
||||
convenience:
|
||||
|
||||
```sh
|
||||
# Cpu
|
||||
cargo test-cpu
|
||||
# Cuda
|
||||
cargo test-cuda
|
||||
# Rocm
|
||||
cargo test-rocm
|
||||
# Wgpu / WebGpu
|
||||
cargo test-wgpu
|
||||
# Vulkan
|
||||
cargo test-vulkan
|
||||
# Metal
|
||||
cargo test-metal
|
||||
# Router
|
||||
cargo test-router
|
||||
|
||||
# NdArray
|
||||
cargo test-ndarray
|
||||
# LibTorch
|
||||
cargo test-tch
|
||||
```
|
||||
|
||||
By default, `cargo test` fail-fast across integration test binaries. When one integration test
|
||||
binary fails, Cargo does not run the remaining test binaries. If you want to run all test binaries
|
||||
regardless of failures, pass `--no-fail-fast`, for example:
|
||||
|
||||
```sh
|
||||
cargo test-cuda --no-fail-fast
|
||||
```
|
||||
|
||||
## Structure
|
||||
|
||||
- `tests/tensor.rs`: Tensor tests
|
||||
- `tests/autodiff.rs`: Autodiff tests
|
||||
- `tests/fusion.rs`: Fusion backend tests wrapping tensor and autodiff tests
|
||||
- `tests/cubecl.rs`: CubeCL kernel tests
|
||||
|
||||
Each test module assumes exactly one `FloatElemType`, `IntElemType`, and `TestBackend` in scope.
|
||||
|
||||
### Common Modules
|
||||
|
||||
- `common/backend.rs`: Backend type definitions
|
||||
- `common/tensor.rs`: Reusable tensor test suite, split across float, int and bool tensor kinds
|
||||
- `common/autodiff.rs`: Reusable autodiff test suite, with and without checkpointing
|
||||
|
||||
### Test Reusability
|
||||
|
||||
This crate uses a pattern of parameterized test modules to run the same tests with different
|
||||
configurations (backends, dtypes, etc.):
|
||||
|
||||
1. **Type aliases define the configuration**: Each test scope declares `FloatElemType`,
|
||||
`IntElemType`, and `TestBackend`
|
||||
1. **`#[path = "..."]` references shared modules**: Points to test files outside the normal module
|
||||
hierarchy, e.g. `"common/tensor.rs"`
|
||||
1. **`include!()` imports test code**: Test modules are included multiple times with different type
|
||||
configurations
|
||||
1. **`use super::*;`** propagates types down the module tree: Each level re-exports parent types so
|
||||
deeply nested tests have access to the configured types
|
||||
|
||||
For example, `common/tensor.rs` can be included with `FloatElemType = f32` for base tests, then
|
||||
included again with `FloatElemType = f16` for half-precision tests, running the same test suite
|
||||
twice with different dtypes.
|
||||
|
||||
## Adding New Tests
|
||||
|
||||
Add test modules under `tests/tensor/`, `tests/autodiff/`, or `tests/cubecl` respectively. They will
|
||||
automatically run for all required configurations.
|
||||
|
||||
For tensor tests, make sure to add the test to each relevant tensor kind:
|
||||
|
||||
- `tensor/bool`: boolean tensor tests
|
||||
- `tensor/float`: float tensor tests
|
||||
- `tensor/int`: integer tensor tests
|
||||
|
||||
**Guidelines:**
|
||||
|
||||
Import types with `use super::*;` at the top of each module and use the types defined in
|
||||
`common/backend.rs`:
|
||||
|
||||
```rust
|
||||
/// Collection of types used across tests
|
||||
pub use burn_autodiff::Autodiff;
|
||||
pub use burn_tensor::Tensor;
|
||||
pub type TestBackend = ...;
|
||||
|
||||
pub type TestTensor<const D: usize> = Tensor<TestBackend, D>;
|
||||
pub type TestTensorInt<const D: usize> = Tensor<TestBackend, D, burn_tensor::Int>;
|
||||
pub type TestTensorBool<const D: usize> = Tensor<TestBackend, D, burn_tensor::Bool>;
|
||||
|
||||
pub type FloatElem = burn_tensor::ops::FloatElem<TestBackend>;
|
||||
pub type IntElem = burn_tensor::ops::IntElem<TestBackend>;
|
||||
|
||||
pub type TestAutodiffBackend = Autodiff<TestBackend>;
|
||||
pub type TestAutodiffTensor<const D: usize> = Tensor<TestAutodiffBackend, D>;
|
||||
```
|
||||
|
||||
Tests will automatically run with default dtypes and any variants (f16, bf16, etc.) based on the
|
||||
backend configuration.
|
||||
@@ -0,0 +1,22 @@
|
||||
extern crate alloc;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
pub use burn_tensor_testgen::might_panic;
|
||||
|
||||
/// Generate a test module with custom floating element types.
|
||||
#[macro_export]
|
||||
macro_rules! test_float_elem_variant {
|
||||
($modname:ident, $float:ty, $module:literal, [$($feat:literal),* $(,)?]) => {
|
||||
#[cfg(all(test, any($(feature = $feat),*)))]
|
||||
mod $modname {
|
||||
pub type FloatElemType = $float;
|
||||
#[allow(unused)]
|
||||
pub use super::IntElemType;
|
||||
|
||||
mod ty {
|
||||
include!("backend.rs");
|
||||
include!($module);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
//! Burn autodiff tests.
|
||||
|
||||
#![allow(
|
||||
clippy::single_range_in_vec_init,
|
||||
clippy::duplicate_mod,
|
||||
reason = "false positive"
|
||||
)]
|
||||
extern crate alloc;
|
||||
|
||||
pub type FloatElemType = f32;
|
||||
#[allow(unused)]
|
||||
pub type IntElemType = i32;
|
||||
|
||||
#[path = "common/backend.rs"]
|
||||
mod backend;
|
||||
pub use backend::*;
|
||||
|
||||
#[allow(clippy::module_inception)]
|
||||
#[path = "common/autodiff.rs"]
|
||||
mod autodiff;
|
||||
@@ -0,0 +1,58 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance, cast::ToElement};
|
||||
|
||||
#[test]
|
||||
fn should_diff_abs() {
|
||||
let data_1 = TensorData::from([[0.0, -1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, -10.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[71.0, 107.0], [71.0, 107.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[84.0, 42.0], [90.0, 54.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_abs_no_nans() {
|
||||
let data_1 = TensorData::from([[6.0, 7.0], [9.0, -10.0]]);
|
||||
let data_2 = TensorData::from([[0.0, -1.0], [3.0, 4.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs());
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[1.0, 7.0], [1.0, 7.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[0.0, -15.0], [-3.0, -3.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let contains_nan = grad_2.contains_nan();
|
||||
assert!(!contains_nan.into_scalar().to_bool());
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
use super::*;
|
||||
use burn_tensor::module::adaptive_avg_pool1d;
|
||||
use burn_tensor::{Shape, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool1d_simple() {
|
||||
let test = AdaptiveAvgPool1dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
length: 5,
|
||||
output_size: 3,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[
|
||||
[0.5000, 0.83333, 0.33333, 0.83333, 0.5000],
|
||||
[0.5000, 0.83333, 0.33333, 0.83333, 0.5000],
|
||||
]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
struct AdaptiveAvgPool1dTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
length: usize,
|
||||
output_size: usize,
|
||||
}
|
||||
|
||||
impl AdaptiveAvgPool1dTestCase {
|
||||
fn assert_output(self, x_grad: TestTensor<3>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<3, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let output = adaptive_avg_pool1d(x.clone(), self.output_size);
|
||||
let grads = output.backward();
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
|
||||
x_grad.to_data().assert_approx_eq::<FloatElem>(
|
||||
&x_grad_actual.into_data(),
|
||||
Tolerance::default().set_half_precision_relative(1e-3),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
use super::*;
|
||||
use burn_tensor::module::adaptive_avg_pool2d;
|
||||
use burn_tensor::{Shape, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool2d_simple() {
|
||||
let test = AdaptiveAvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
height: 5,
|
||||
width: 3,
|
||||
output_size_1: 3,
|
||||
output_size_2: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[0.2500, 0.5000, 0.2500],
|
||||
[0.41667, 0.83333, 0.41667],
|
||||
[0.16667, 0.33333, 0.16667],
|
||||
[0.41667, 0.83333, 0.41667],
|
||||
[0.2500, 0.5000, 0.2500],
|
||||
],
|
||||
[
|
||||
[0.2500, 0.5000, 0.2500],
|
||||
[0.41667, 0.83333, 0.41667],
|
||||
[0.16667, 0.33333, 0.16667],
|
||||
[0.41667, 0.83333, 0.41667],
|
||||
[0.2500, 0.5000, 0.2500],
|
||||
],
|
||||
]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool2d_output_1() {
|
||||
let test = AdaptiveAvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
height: 4,
|
||||
width: 8,
|
||||
output_size_1: 1,
|
||||
output_size_2: 1,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[[
|
||||
[
|
||||
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
|
||||
],
|
||||
[
|
||||
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
|
||||
],
|
||||
[
|
||||
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
|
||||
],
|
||||
[
|
||||
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
|
||||
],
|
||||
]]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
struct AdaptiveAvgPool2dTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
output_size_1: usize,
|
||||
output_size_2: usize,
|
||||
}
|
||||
|
||||
impl AdaptiveAvgPool2dTestCase {
|
||||
fn assert_output(self, x_grad: TestTensor<4>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let output = adaptive_avg_pool2d(x.clone(), [self.output_size_1, self.output_size_2]);
|
||||
let grads = output.backward();
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
|
||||
x_grad.to_data().assert_approx_eq::<FloatElem>(
|
||||
&x_grad_actual.into_data(),
|
||||
Tolerance::default().set_half_precision_relative(1e-3),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_add() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<1>::from_floats([2.0, 5.0], &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_floats([4.0, 1.0], &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone() + tensor_2.clone();
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([1.0, 1.0]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([1.0, 1.0]), false);
|
||||
tensor_3
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([6.0, 6.0]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_add_scalar() {
|
||||
let data = TensorData::from([2.0, 10.0]);
|
||||
|
||||
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
|
||||
let tensor_out = tensor.clone().add_scalar(5.0);
|
||||
let grads = tensor_out.backward();
|
||||
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
grad.to_data()
|
||||
.assert_eq(&TensorData::from([1.0, 1.0]), false);
|
||||
tensor_out
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([7.0, 15.0]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_complex_1() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_1.clone().add(tensor_2.clone());
|
||||
let tensor_5 = tensor_4
|
||||
.add(tensor_3)
|
||||
.add_scalar(5.0)
|
||||
.add(tensor_1.clone())
|
||||
.add(tensor_2.clone());
|
||||
let tensor_6 = tensor_1.clone().add(tensor_5);
|
||||
|
||||
let grads = tensor_6.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[3.0, 3.0], [3.0, 3.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[2.0, 2.0], [2.0, 2.0]]), false);
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_diff_mean() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_1.clone().mul(tensor_3.mean().unsqueeze());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[3.5, 9.5], [3.5, 9.5]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[-0.75, -0.75], [3.0, 3.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_sum_1() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_1.clone().mul(tensor_3.sum().unsqueeze());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[14.0, 38.0], [14.0, 38.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[-3.0, -3.0], [12.0, 12.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_sum_2() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.clone().sum_dim(1);
|
||||
let tensor_5 = tensor_4.mul(tensor_3);
|
||||
|
||||
let grads = tensor_5.sum().backward();
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[494.0, 722.0], [2990.0, 4370.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[690.0, 690.0], [958.0, 958.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_mean_dim() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_1.clone().mul(tensor_3.mean_dim(1).unsqueeze());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[4.0, 36.0], [3.0, -17.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[9.0, 9.0], [35.5, 35.5]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_sum_dim() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_1.clone().mul(tensor_3.sum_dim(1).unsqueeze());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[8.0, 72.0], [6.0, -34.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[18.0, 18.0], [71.0, 71.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
use super::*;
|
||||
use burn_tensor::module::avg_pool1d;
|
||||
use burn_tensor::{Shape, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool1d_simple() {
|
||||
let test = AvgPool1dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
kernel_size: 3,
|
||||
padding: 0,
|
||||
stride: 1,
|
||||
length: 6,
|
||||
count_include_pad: true,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[[0.33333, 0.66667, 1.0000, 1.0000, 0.66667, 0.33333]]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool1d_complex() {
|
||||
let test = AvgPool1dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
length: 6,
|
||||
count_include_pad: true,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[
|
||||
[0.33333, 0.66667, 0.33333, 0.66667, 0.33333, 0.33333],
|
||||
[0.33333, 0.66667, 0.33333, 0.66667, 0.33333, 0.33333],
|
||||
]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool1d_complex_dont_count_pad() {
|
||||
let test = AvgPool1dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
length: 6,
|
||||
count_include_pad: false,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[
|
||||
[0.5000, 0.83333, 0.33333, 0.66667, 0.33333, 0.33333],
|
||||
[0.5000, 0.83333, 0.33333, 0.66667, 0.33333, 0.33333],
|
||||
]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
struct AvgPool1dTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
kernel_size: usize,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
length: usize,
|
||||
count_include_pad: bool,
|
||||
}
|
||||
|
||||
impl AvgPool1dTestCase {
|
||||
fn assert_output(self, x_grad: TestTensor<3>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<3, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let output = avg_pool1d(
|
||||
x.clone(),
|
||||
self.kernel_size,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.count_include_pad,
|
||||
false,
|
||||
);
|
||||
let grads = output.backward();
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
|
||||
x_grad
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.into_data(), tolerance);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
use super::*;
|
||||
use burn_tensor::module::avg_pool2d;
|
||||
use burn_tensor::{Shape, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool2d_simple() {
|
||||
let test = AvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
height: 6,
|
||||
width: 6,
|
||||
count_include_pad: true,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[[
|
||||
[0.11111, 0.22222, 0.33333, 0.33333, 0.22222, 0.11111],
|
||||
[0.22222, 0.44444, 0.66667, 0.66667, 0.44444, 0.22222],
|
||||
[0.33333, 0.66667, 1.00000, 1.00000, 0.66667, 0.33333],
|
||||
[0.33333, 0.66667, 1.00000, 1.00000, 0.66667, 0.33333],
|
||||
[0.22222, 0.44444, 0.66667, 0.66667, 0.44444, 0.22222],
|
||||
[0.11111, 0.22222, 0.33333, 0.33333, 0.22222, 0.11111],
|
||||
]]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool2d_complex() {
|
||||
let test = AvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 4,
|
||||
padding_1: 1,
|
||||
padding_2: 2,
|
||||
stride_1: 1,
|
||||
stride_2: 2,
|
||||
height: 4,
|
||||
width: 6,
|
||||
count_include_pad: true,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[[
|
||||
[0.33333, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333],
|
||||
[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
|
||||
[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
|
||||
[0.33333, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333],
|
||||
]]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool2d_complex_dont_include_pad() {
|
||||
let test = AvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 4,
|
||||
padding_1: 1,
|
||||
padding_2: 2,
|
||||
stride_1: 1,
|
||||
stride_2: 2,
|
||||
height: 4,
|
||||
width: 6,
|
||||
count_include_pad: false,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[[
|
||||
[0.6250, 0.6250, 0.41667, 0.41667, 0.6250, 0.6250],
|
||||
[0.8750, 0.8750, 0.58333, 0.58333, 0.8750, 0.8750],
|
||||
[0.8750, 0.8750, 0.58333, 0.58333, 0.8750, 0.8750],
|
||||
[0.6250, 0.6250, 0.41667, 0.41667, 0.6250, 0.6250],
|
||||
]]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
struct AvgPool2dTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
kernel_size_1: usize,
|
||||
kernel_size_2: usize,
|
||||
padding_1: usize,
|
||||
padding_2: usize,
|
||||
stride_1: usize,
|
||||
stride_2: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
count_include_pad: bool,
|
||||
}
|
||||
|
||||
impl AvgPool2dTestCase {
|
||||
fn assert_output(self, x_grad: TestTensor<4>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let output = avg_pool2d(
|
||||
x.clone(),
|
||||
[self.kernel_size_1, self.kernel_size_2],
|
||||
[self.stride_1, self.stride_2],
|
||||
[self.padding_1, self.padding_2],
|
||||
self.count_include_pad,
|
||||
false,
|
||||
);
|
||||
let grads = output.backward();
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
|
||||
x_grad.to_data().assert_approx_eq::<FloatElem>(
|
||||
&x_grad_actual.into_data(),
|
||||
Tolerance::default().set_half_precision_relative(1e-3),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Int, Tensor, TensorData, module::embedding};
|
||||
|
||||
#[test]
|
||||
fn test_embedding_backward() {
|
||||
let weights = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let indices = TensorData::from([[0, 1], [1, 1]]);
|
||||
let x = TensorData::from([
|
||||
[[1.0, 2.0], [4.0, 5.0], [3.0, 4.0]],
|
||||
[[4.0, 5.0], [8.0, 5.0], [1.0, 9.0]],
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weights = Tensor::<TestAutodiffBackend, 2>::from_data(weights, &device).require_grad();
|
||||
let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(indices, &device);
|
||||
let x = Tensor::<TestAutodiffBackend, 3>::from_data(x, &device).require_grad();
|
||||
|
||||
let output = embedding(weights.clone(), indices);
|
||||
let output = output.matmul(x);
|
||||
let grads = output.backward();
|
||||
|
||||
let grad = weights.grad(&grads).unwrap();
|
||||
grad.to_data()
|
||||
.assert_eq(&TensorData::from([[3., 9., 7.], [21., 35., 27.]]), false);
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
use super::*;
|
||||
use burn_tensor::{DType, Distribution, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_full_precision() {
|
||||
let device = Default::default();
|
||||
let x1 = Tensor::<TestAutodiffBackend, 2>::random([32, 32], Distribution::Default, &device)
|
||||
.require_grad();
|
||||
let x2 = Tensor::<TestAutodiffBackend, 2>::random([32, 32], Distribution::Default, &device)
|
||||
.require_grad();
|
||||
let dtype = x1.dtype();
|
||||
|
||||
let x3 = x1.clone().cast(DType::F32);
|
||||
let x4 = x2.clone().cast(DType::F32);
|
||||
|
||||
let x5 = x3.matmul(x4);
|
||||
let x6 = x5.cast(dtype);
|
||||
let x7 = x6 * x1.clone() / x2.clone();
|
||||
|
||||
let grads = x7.backward();
|
||||
|
||||
let x1_grad = x1.grad(&grads);
|
||||
let x2_grad = x2.grad(&grads);
|
||||
|
||||
assert!(x1_grad.is_some());
|
||||
assert!(x2_grad.is_some());
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn mul_broadcast() {
|
||||
test_ops_broadcast_backward(|x, y| x * y);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn div_broadcast() {
|
||||
test_ops_broadcast_backward(|x, y| x / y);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sub_broadcast() {
|
||||
test_ops_broadcast_backward(|x, y| x - y);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_broadcast() {
|
||||
test_ops_broadcast_backward(|x, y| x + y);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matmul_broadcast() {
|
||||
test_ops_broadcast_backward(|x, y| x.matmul(y));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_where_broadcast() {
|
||||
test_ops_broadcast_backward(|x, y| {
|
||||
let cond = y.clone().equal_elem(4);
|
||||
x.mask_where(cond, y)
|
||||
});
|
||||
}
|
||||
|
||||
fn test_ops_broadcast_backward<F>(func: F)
|
||||
where
|
||||
F: Fn(TestAutodiffTensor<3>, TestAutodiffTensor<3>) -> TestAutodiffTensor<3>,
|
||||
{
|
||||
let device = Default::default();
|
||||
let w = TestAutodiffTensor::zeros([16, 5, 5], &device).require_grad();
|
||||
let x = TestAutodiffTensor::zeros([4, 5, 5], &device).require_grad();
|
||||
|
||||
// Slice isn't a broadcastable operation, so it will fail when the previous backward pass
|
||||
// of an operation that support broadcast doesn't support it during the backward pass.
|
||||
let y = func(w.clone().slice([0..1]), x.clone());
|
||||
|
||||
// Will panic if broadcast isn't supported!
|
||||
let grads = y.backward();
|
||||
|
||||
let w_grad = w.grad(&grads).unwrap();
|
||||
let x_grad = x.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(w_grad.shape(), w.shape());
|
||||
assert_eq!(x_grad.shape(), x.shape());
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
// Skip on metal - F64 not supported
|
||||
#![cfg(all(feature = "std", not(feature = "metal")))]
|
||||
|
||||
use super::*;
|
||||
use burn_backend_tests::might_panic;
|
||||
use burn_tensor::{DType, Tensor, TensorData};
|
||||
|
||||
#[might_panic(reason = "Unsupported precision for fusion")]
|
||||
#[test]
|
||||
fn cast_keeps_gradient_flow() {
|
||||
let device = Default::default();
|
||||
|
||||
let x = Tensor::<TestAutodiffBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let y = x.clone().cast(DType::F64);
|
||||
let z = y.sum();
|
||||
|
||||
let grads = z.backward();
|
||||
let grad_x = x.grad(&grads).unwrap();
|
||||
|
||||
grad_x
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[1., 1.], [1., 1.]]), false);
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
use super::*;
|
||||
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_diff_cat() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::<2>::from_data([[2.0, -1.0], [5.0, 2.0]], &device).require_grad();
|
||||
let tensor_2 =
|
||||
TestAutodiffTensor::<2>::from_data([[5.0, 4.0], [-1.0, 4.0]], &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let mut tensor_1_list = Vec::new();
|
||||
let mut tensor_2_list = Vec::new();
|
||||
|
||||
for i in 0..2 {
|
||||
tensor_1_list.push(tensor_1.clone().slice([i..i + 1]));
|
||||
tensor_2_list.push(tensor_2.clone().slice([i..i + 1]));
|
||||
}
|
||||
|
||||
let tensor_1_cat = TestAutodiffTensor::cat(tensor_1_list.clone(), 0);
|
||||
let tensor_2_cat = TestAutodiffTensor::cat(tensor_2_list.clone(), 0);
|
||||
|
||||
let tensor_3_cat = tensor_1_cat.clone().matmul(tensor_2_cat.clone());
|
||||
let grads = tensor_3_cat.backward();
|
||||
|
||||
let grad_1_slice_1 = tensor_1.grad(&grads).unwrap().slice([0..1]);
|
||||
let grad_1_slice_2 = tensor_1.grad(&grads).unwrap().slice([1..2]);
|
||||
|
||||
let grad_2_slice_1 = tensor_2.grad(&grads).unwrap().slice([0..1]);
|
||||
let grad_2_slice_2 = tensor_2.grad(&grads).unwrap().slice([1..2]);
|
||||
|
||||
grad_1
|
||||
.clone()
|
||||
.slice([0..1])
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&grad_1_slice_1.to_data(), Tolerance::default());
|
||||
grad_1
|
||||
.slice([1..2])
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&grad_1_slice_2.to_data(), Tolerance::default());
|
||||
|
||||
grad_2
|
||||
.clone()
|
||||
.slice([0..1])
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&grad_2_slice_1.to_data(), Tolerance::default());
|
||||
grad_2
|
||||
.slice([1..2])
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&grad_2_slice_2.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cat_more_than_1_dim() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::<2>::from_data([[2.0, -1.0], [5.0, 2.0]], &device).require_grad();
|
||||
let tensor_2 =
|
||||
TestAutodiffTensor::<2>::from_data([[5.0, 4.0], [-1.0, 4.0], [4.0, 1.0]], &device)
|
||||
.require_grad();
|
||||
|
||||
// Concat a tensor [2, 2] with another tensor [3, 2] along dim 0.
|
||||
// The resulting tensor should be [5, 2]
|
||||
let tensor_3 = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 0);
|
||||
assert_eq!(tensor_3.dims(), [5, 2]);
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(tensor_1.dims(), grad_1.dims());
|
||||
assert_eq!(tensor_2.dims(), grad_2.dims());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_slice_grads_correctly_when_some_inputs_not_tracked() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data([[1.0]], &device).require_grad(); // tracked
|
||||
let tensor_2 = TestAutodiffTensor::<2>::from_data([[10.0, 20.0]], &device); // not tracked
|
||||
let tensor_3 =
|
||||
TestAutodiffTensor::<2>::from_data([[100.0, 200.0, 300.0]], &device).require_grad(); // tracked
|
||||
|
||||
let cat = TestAutodiffTensor::cat(
|
||||
vec![tensor_1.clone(), tensor_2.clone(), tensor_3.clone()],
|
||||
1,
|
||||
);
|
||||
|
||||
// Make gradient per column unique so wrong slicing shows up.
|
||||
let weights = TestAutodiffTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]], &device);
|
||||
let loss = (cat * weights).sum();
|
||||
|
||||
let grads = loss.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_3 = tensor_3.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&burn_tensor::TensorData::from([[1.0]]), false);
|
||||
grad_3
|
||||
.to_data()
|
||||
.assert_eq(&burn_tensor::TensorData::from([[4.0, 5.0, 6.0]]), false);
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_ceil() {
|
||||
let data = TensorData::from([
|
||||
[-1.9751, 0.0714, 0.0643, 0.2406],
|
||||
[-1.3172, 0.1252, -0.1119, -0.0127],
|
||||
]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
|
||||
let tensor_2 = tensor_1.clone().ceil();
|
||||
let grads = tensor_2.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_eq(
|
||||
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,215 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Bool, Tensor, TensorData};
|
||||
|
||||
#[test]
|
||||
fn test_autodiff_checkpoint_complicated_computation() {
|
||||
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
|
||||
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
|
||||
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
|
||||
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
|
||||
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
|
||||
|
||||
let tensor_5 = compute_bound_eager(tensor_0, tensor_1);
|
||||
let tensor_6 = compute_bound_lazy(tensor_2, tensor_3.clone());
|
||||
let tensor_7 = memory_bound_eager(tensor_3, tensor_4);
|
||||
let tensor_8 = compute_bound_lazy(tensor_6, tensor_7.clone());
|
||||
let tensor_9 = memory_bound_eager_scalar(tensor_7, 11.);
|
||||
let tensor_10 = memory_bound_lazy(tensor_5, tensor_8.clone());
|
||||
let tensor_11 = memory_bound_lazy(tensor_8, tensor_9);
|
||||
let tensor_12 = compute_bound_lazy(tensor_10, tensor_11);
|
||||
|
||||
assert_checkpoint(tensor_12);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autodiff_checkpoint_with_missing_requirement() {
|
||||
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
|
||||
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device); // does not require_grad
|
||||
|
||||
let tensor_2 = memory_bound_eager(tensor_0, tensor_1);
|
||||
let tensor_3 = memory_bound_eager_scalar(tensor_2.clone(), 11.);
|
||||
let tensor_4 = memory_bound_eager_scalar(tensor_2.clone(), 11.);
|
||||
let tensor_5 = compute_bound_lazy(tensor_3, tensor_4);
|
||||
let tensor_6 = compute_bound_eager_scalar(tensor_5.clone(), 11.);
|
||||
let tensor_7 = memory_bound_eager(tensor_5, tensor_2);
|
||||
let tensor_8 = memory_bound_eager(tensor_6, tensor_7);
|
||||
|
||||
assert_checkpoint(tensor_8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autodiff_checkpoint_with_many_duplicates() {
|
||||
let data_0 = TensorData::from([[4.0, 7.0], [7.0, 7.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
|
||||
|
||||
let tensor_1 = memory_bound_eager(tensor_0.clone(), tensor_0.clone());
|
||||
let tensor_2 = compute_bound_eager(tensor_0.clone(), tensor_0.clone());
|
||||
let tensor_3 = memory_bound_lazy(tensor_0.clone(), tensor_0.clone());
|
||||
let tensor_4 = compute_bound_lazy(tensor_0.clone(), tensor_0.clone());
|
||||
|
||||
let tensor_5 = memory_bound_eager(tensor_1.clone(), tensor_0.clone());
|
||||
let tensor_6 = memory_bound_eager(tensor_0.clone(), tensor_5.clone());
|
||||
let tensor_7 = compute_bound_lazy(tensor_3.clone(), tensor_5.clone());
|
||||
let tensor_8 = compute_bound_eager(tensor_4.clone(), tensor_2.clone());
|
||||
let tensor_9 = memory_bound_lazy(tensor_6, tensor_7);
|
||||
let tensor_10 = memory_bound_eager(tensor_0, tensor_9);
|
||||
let tensor_11 = memory_bound_eager_scalar(tensor_10, 9.);
|
||||
let tensor_12 = compute_bound_lazy(tensor_8, tensor_11);
|
||||
|
||||
assert_checkpoint(tensor_12);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autodiff_checkpoint_with_long_chain_of_eager_memory_bound() {
|
||||
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
|
||||
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
|
||||
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
|
||||
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
|
||||
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device);
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
|
||||
|
||||
let tensor_5 = memory_bound_eager(tensor_0, tensor_1.clone());
|
||||
let tensor_6 = memory_bound_eager(tensor_5, tensor_2);
|
||||
let tensor_7 = memory_bound_eager(tensor_6, tensor_3);
|
||||
let tensor_8 = memory_bound_eager(tensor_7, tensor_4);
|
||||
let tensor_9 = memory_bound_lazy(tensor_8, tensor_1);
|
||||
|
||||
assert_checkpoint(tensor_9)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autodiff_checkpoint_half_sub_graph_not_tracked() {
|
||||
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
|
||||
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
|
||||
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
|
||||
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
|
||||
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
|
||||
let data_5 = TensorData::from([[0.5, 7.0], [7.0, 7.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device);
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device);
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device);
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
|
||||
let tensor_5 = TestAutodiffTensor::from_data(data_5, &device).require_grad();
|
||||
|
||||
let tensor_6 = memory_bound_lazy(tensor_0, tensor_1);
|
||||
let tensor_7 = compute_bound_eager(tensor_6, tensor_2);
|
||||
|
||||
let tensor_8 = memory_bound_eager(tensor_3, tensor_4);
|
||||
let tensor_9 = compute_bound_lazy(tensor_8, tensor_5);
|
||||
|
||||
let tensor_10 = compute_bound_lazy(tensor_7, tensor_9);
|
||||
|
||||
assert_checkpoint(tensor_10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autodiff_checkpoint_very_complex() {
|
||||
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
|
||||
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
|
||||
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
|
||||
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
|
||||
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device);
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
|
||||
|
||||
let tensor_5 = memory_bound_eager_scalar(tensor_0, 8.);
|
||||
let tensor_6 = memory_bound_lazy(tensor_5.clone(), tensor_1.clone());
|
||||
let tensor_7 = compute_bound_lazy(tensor_6.clone(), tensor_6);
|
||||
let tensor_8 = memory_bound_lazy(tensor_1.clone(), tensor_5.clone());
|
||||
let tensor_9 = memory_bound_eager_scalar(tensor_7.clone(), 7.);
|
||||
let tensor_10 = compute_bound_eager(tensor_5, tensor_8);
|
||||
let tensor_11 = memory_bound_eager(tensor_2.clone(), tensor_9);
|
||||
let tensor_12 = memory_bound_lazy(tensor_2.clone(), tensor_2);
|
||||
let tensor_13 = compute_bound_eager(tensor_10.clone(), tensor_11);
|
||||
let tensor_14 = compute_bound_eager_scalar(tensor_3, 8.);
|
||||
let tensor_15 = compute_bound_lazy(tensor_4, tensor_12);
|
||||
let tensor_16 = memory_bound_lazy(tensor_10, tensor_7);
|
||||
let tensor_17 = compute_bound_lazy(tensor_13, tensor_1);
|
||||
let tensor_18 = memory_bound_eager(tensor_15, tensor_16);
|
||||
let tensor_19 = compute_bound_eager(tensor_14, tensor_17);
|
||||
let tensor_20 = memory_bound_lazy(tensor_18, tensor_19);
|
||||
let tensor_21 = memory_bound_eager_scalar(tensor_20, 8.);
|
||||
|
||||
assert_checkpoint(tensor_21)
|
||||
}
|
||||
|
||||
fn assert_checkpoint<const D: usize>(tensor: TestAutodiffTensor<D>) {
|
||||
// Assert is not explicit here, but the test can fail
|
||||
// - when a tensor is actually required more than n_required, it won't be found and will panic
|
||||
// - when a tensor is actually required less than n_required, the backward states map won't be
|
||||
// empty and will fail the assertion within the backward code, same for retro_forwards
|
||||
tensor.backward();
|
||||
}
|
||||
|
||||
// Does not save its state and does not need its parents
|
||||
fn memory_bound_eager<const D: usize>(
|
||||
tensor_a: TestAutodiffTensor<D>,
|
||||
tensor_b: TestAutodiffTensor<D>,
|
||||
) -> TestAutodiffTensor<D> {
|
||||
tensor_a.add(tensor_b)
|
||||
}
|
||||
fn memory_bound_eager_scalar<const D: usize>(
|
||||
tensor_a: TestAutodiffTensor<D>,
|
||||
b: f32,
|
||||
) -> TestAutodiffTensor<D> {
|
||||
tensor_a.add_scalar(b)
|
||||
}
|
||||
|
||||
// Saves its own state and does not need its parents
|
||||
fn compute_bound_eager<const D: usize>(
|
||||
tensor_a: TestAutodiffTensor<D>,
|
||||
tensor_b: TestAutodiffTensor<D>,
|
||||
) -> TestAutodiffTensor<D> {
|
||||
let mask = Tensor::<TestAutodiffBackend, D, Bool>::empty(tensor_a.shape(), &tensor_a.device());
|
||||
tensor_a.mask_where(mask, tensor_b)
|
||||
}
|
||||
fn compute_bound_eager_scalar<const D: usize>(
|
||||
tensor_a: TestAutodiffTensor<D>,
|
||||
b: f32,
|
||||
) -> TestAutodiffTensor<D> {
|
||||
let mask = Tensor::<TestAutodiffBackend, D, Bool>::empty(tensor_a.shape(), &tensor_a.device());
|
||||
tensor_a.mask_fill(mask, b)
|
||||
}
|
||||
|
||||
// Does not save its state and needs its parents
|
||||
fn memory_bound_lazy<const D: usize>(
|
||||
tensor_a: TestAutodiffTensor<D>,
|
||||
tensor_b: TestAutodiffTensor<D>,
|
||||
) -> TestAutodiffTensor<D> {
|
||||
tensor_a.mul(tensor_b)
|
||||
}
|
||||
|
||||
// Saves its own state and needs its parents
|
||||
fn compute_bound_lazy<const D: usize>(
|
||||
tensor_a: TestAutodiffTensor<D>,
|
||||
tensor_b: TestAutodiffTensor<D>,
|
||||
) -> TestAutodiffTensor<D> {
|
||||
tensor_a.matmul(tensor_b)
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_full_complex_1() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.matmul(tensor_1.clone());
|
||||
let tensor_5 = tensor_4.mul(tensor_2.clone());
|
||||
|
||||
let grads = tensor_5.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[593., 463.0], [487.0, 539.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[734.0, 294.0], [1414.0, 242.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_full_complex_2() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.matmul(tensor_1.clone());
|
||||
let tensor_5 = tensor_4.add_scalar(17.0).add(tensor_2.clone());
|
||||
|
||||
let grads = tensor_5.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[166.0, 110.0], [212.0, 156.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[113.0, 141.0], [33.0, 41.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_full_complex_3() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.matmul(tensor_1.clone());
|
||||
let tensor_5 = tensor_4.clone().sub(tensor_2.clone());
|
||||
let tensor_6 = tensor_5.add(tensor_4);
|
||||
|
||||
let grads = tensor_6.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[332.0, 220.0], [424.0, 312.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[223.0, 279.0], [63.0, 79.0]]), false);
|
||||
}
|
||||
@@ -0,0 +1,277 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Shape, Tolerance, module::conv1d, ops::ConvOptions};
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_basic() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[14., 24., 24., 18.], [26., 42., 42., 30.]],
|
||||
[[14., 24., 24., 18.], [26., 42., 42., 30.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[30., 44., 36.], [54., 76., 60.]],
|
||||
[[30., 44., 36.], [54., 76., 60.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([8., 8.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_different_channels() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[39., 63., 63., 45.], [57., 90., 90., 63.]],
|
||||
[[39., 63., 63., 45.], [57., 90., 90., 63.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[30., 44., 36.], [54., 76., 60.]],
|
||||
[[30., 44., 36.], [54., 76., 60.]],
|
||||
[[30., 44., 36.], [54., 76., 60.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([8., 8., 8.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_with_padding() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 2,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[24., 24., 24., 24.], [42., 42., 42., 42.]],
|
||||
[[24., 24., 24., 24.], [42., 42., 42., 42.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[44., 44., 44.], [76., 76., 76.]],
|
||||
[[44., 44., 44.], [76., 76., 76.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([12., 12.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_with_stride() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[8., 16., 8., 10.], [14., 28., 14., 16.]],
|
||||
[[8., 16., 8., 10.], [14., 28., 14., 16.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[10., 20., 24.], [18., 36., 40.]],
|
||||
[[10., 20., 24.], [18., 36., 40.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([4., 4.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_dilation() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 2,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[6., 8., 8., 10.], [12., 14., 14., 16.]],
|
||||
[[6., 8., 8., 10.], [12., 14., 14., 16.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[8., 22., 14.], [16., 38., 22.]],
|
||||
[[8., 22., 14.], [16., 38., 22.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([4., 4.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_groups() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 2,
|
||||
length: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[1., 3., 3., 3.], [7., 12., 12., 9.]],
|
||||
[[1., 3., 3., 3.], [7., 12., 12., 9.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats([[[30., 44., 36.]], [[54., 76., 60.]]], &device),
|
||||
bias: TestTensor::from_floats([8., 8.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
struct Conv1dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
channels_out: usize,
|
||||
kernel_size: usize,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
length: usize,
|
||||
}
|
||||
|
||||
struct Grads {
|
||||
x: TestTensor<3>,
|
||||
weight: TestTensor<3>,
|
||||
bias: TestTensor<1>,
|
||||
}
|
||||
|
||||
impl Conv1dTestCase {
|
||||
fn assert_grads(self, expected_grads: Grads) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels_out,
|
||||
self.channels_in / self.groups,
|
||||
self.kernel_size,
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape::<3, _>(shape_weight)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let bias = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<3, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let output = conv1d(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
Some(bias.clone()),
|
||||
ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups),
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Assert
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
let weight_grad_actual = weight.grad(&grads).unwrap();
|
||||
let bias_grad_actual = bias.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default();
|
||||
expected_grads
|
||||
.bias
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.x
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,962 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Shape, Tolerance, module::conv2d, ops::ConvOptions};
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_basic() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[88., 138., 138., 96.],
|
||||
[150., 234., 234., 162.],
|
||||
[150., 234., 234., 162.],
|
||||
[112., 174., 174., 120.],
|
||||
],
|
||||
[
|
||||
[160., 246., 246., 168.],
|
||||
[258., 396., 396., 270.],
|
||||
[258., 396., 396., 270.],
|
||||
[184., 282., 282., 192.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[88., 138., 138., 96.],
|
||||
[150., 234., 234., 162.],
|
||||
[150., 234., 234., 162.],
|
||||
[112., 174., 174., 120.],
|
||||
],
|
||||
[
|
||||
[160., 246., 246., 168.],
|
||||
[258., 396., 396., 270.],
|
||||
[258., 396., 396., 270.],
|
||||
[184., 282., 282., 192.],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
|
||||
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
|
||||
],
|
||||
[
|
||||
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
|
||||
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([32., 32.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_different_channels() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[240., 369., 369., 252.],
|
||||
[387., 594., 594., 405.],
|
||||
[387., 594., 594., 405.],
|
||||
[276., 423., 423., 288.],
|
||||
],
|
||||
[
|
||||
[348., 531., 531., 360.],
|
||||
[549., 837., 837., 567.],
|
||||
[549., 837., 837., 567.],
|
||||
[384., 585., 585., 396.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[240., 369., 369., 252.],
|
||||
[387., 594., 594., 405.],
|
||||
[387., 594., 594., 405.],
|
||||
[276., 423., 423., 288.],
|
||||
],
|
||||
[
|
||||
[348., 531., 531., 360.],
|
||||
[549., 837., 837., 567.],
|
||||
[549., 837., 837., 567.],
|
||||
[384., 585., 585., 396.],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
|
||||
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
|
||||
],
|
||||
[
|
||||
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
|
||||
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
|
||||
],
|
||||
[
|
||||
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
|
||||
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([32., 32., 32.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_different_kernel_size() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 4,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[116., 180., 192., 132.],
|
||||
[198., 306., 324., 222.],
|
||||
[198., 306., 324., 222.],
|
||||
[148., 228., 240., 164.],
|
||||
],
|
||||
[
|
||||
[212., 324., 336., 228.],
|
||||
[342., 522., 540., 366.],
|
||||
[342., 522., 540., 366.],
|
||||
[244., 372., 384., 260.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[27., 45., 54., 39.],
|
||||
[52., 84., 96., 68.],
|
||||
[51., 81., 90., 63.],
|
||||
],
|
||||
[
|
||||
[123., 189., 198., 135.],
|
||||
[180., 276., 288., 196.],
|
||||
[147., 225., 234., 159.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[27., 45., 54., 39.],
|
||||
[52., 84., 96., 68.],
|
||||
[51., 81., 90., 63.],
|
||||
],
|
||||
[
|
||||
[123., 189., 198., 135.],
|
||||
[180., 276., 288., 196.],
|
||||
[147., 225., 234., 159.],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([12., 12.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_different_padding() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 2,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[138., 138., 138., 138.],
|
||||
[234., 234., 234., 234.],
|
||||
[234., 234., 234., 234.],
|
||||
[174., 174., 174., 174.],
|
||||
],
|
||||
[
|
||||
[246., 246., 246., 246.],
|
||||
[396., 396., 396., 396.],
|
||||
[396., 396., 396., 396.],
|
||||
[282., 282., 282., 282.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]],
|
||||
[[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]],
|
||||
],
|
||||
[
|
||||
[[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]],
|
||||
[[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([24., 24.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_different_width() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 5,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[88., 138., 138., 138., 96.],
|
||||
[150., 234., 234., 234., 162.],
|
||||
[150., 234., 234., 234., 162.],
|
||||
[112., 174., 174., 174., 120.],
|
||||
],
|
||||
[
|
||||
[160., 246., 246., 246., 168.],
|
||||
[258., 396., 396., 396., 270.],
|
||||
[258., 396., 396., 396., 270.],
|
||||
[184., 282., 282., 282., 192.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]],
|
||||
[[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]],
|
||||
],
|
||||
[
|
||||
[[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]],
|
||||
[[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([20., 20.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_stride_2() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 2,
|
||||
stride_2: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 6,
|
||||
width: 6,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[26., 52., 26., 52., 26., 28.],
|
||||
[52., 104., 52., 104., 52., 56.],
|
||||
[26., 52., 26., 52., 26., 28.],
|
||||
[52., 104., 52., 104., 52., 56.],
|
||||
[26., 52., 26., 52., 26., 28.],
|
||||
[32., 64., 32., 64., 32., 34.],
|
||||
],
|
||||
[
|
||||
[44., 88., 44., 88., 44., 46.],
|
||||
[88., 176., 88., 176., 88., 92.],
|
||||
[44., 88., 44., 88., 44., 46.],
|
||||
[88., 176., 88., 176., 88., 92.],
|
||||
[44., 88., 44., 88., 44., 46.],
|
||||
[50., 100., 50., 100., 50., 52.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]],
|
||||
[[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]],
|
||||
],
|
||||
[
|
||||
[[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]],
|
||||
[[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([9., 9.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_different_stride() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 3,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 8,
|
||||
width: 8,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[50., 78., 78., 78., 78., 78., 78., 54.],
|
||||
[62., 96., 96., 96., 96., 96., 96., 66.],
|
||||
[38., 60., 60., 60., 60., 60., 60., 42.],
|
||||
[50., 78., 78., 78., 78., 78., 78., 54.],
|
||||
[62., 96., 96., 96., 96., 96., 96., 66.],
|
||||
[38., 60., 60., 60., 60., 60., 60., 42.],
|
||||
[50., 78., 78., 78., 78., 78., 78., 54.],
|
||||
[62., 96., 96., 96., 96., 96., 96., 66.],
|
||||
],
|
||||
[
|
||||
[86., 132., 132., 132., 132., 132., 132., 90.],
|
||||
[98., 150., 150., 150., 150., 150., 150., 102.],
|
||||
[74., 114., 114., 114., 114., 114., 114., 78.],
|
||||
[86., 132., 132., 132., 132., 132., 132., 90.],
|
||||
[98., 150., 150., 150., 150., 150., 150., 102.],
|
||||
[74., 114., 114., 114., 114., 114., 114., 78.],
|
||||
[86., 132., 132., 132., 132., 132., 132., 90.],
|
||||
[98., 150., 150., 150., 150., 150., 150., 102.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]],
|
||||
[
|
||||
[1330., 1528., 1344.],
|
||||
[1911., 2196., 1932.],
|
||||
[2079., 2388., 2100.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]],
|
||||
[
|
||||
[1330., 1528., 1344.],
|
||||
[1911., 2196., 1932.],
|
||||
[2079., 2388., 2100.],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([24., 24.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_dilation_2() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 2,
|
||||
dilation_2: 2,
|
||||
groups: 1,
|
||||
height: 6,
|
||||
width: 6,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[18., 38., 38., 42., 42., 22.],
|
||||
[42., 88., 88., 96., 96., 50.],
|
||||
[42., 88., 88., 96., 96., 50.],
|
||||
[54., 112., 112., 120., 120., 62.],
|
||||
[54., 112., 112., 120., 120., 62.],
|
||||
[30., 62., 62., 66., 66., 34.],
|
||||
],
|
||||
[
|
||||
[36., 74., 74., 78., 78., 40.],
|
||||
[78., 160., 160., 168., 168., 86.],
|
||||
[78., 160., 160., 168., 168., 86.],
|
||||
[90., 184., 184., 192., 192., 98.],
|
||||
[90., 184., 184., 192., 192., 98.],
|
||||
[48., 98., 98., 102., 102., 52.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]],
|
||||
[[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]],
|
||||
],
|
||||
[
|
||||
[[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]],
|
||||
[[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([16., 16.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_different_dilation() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 2,
|
||||
dilation_2: 3,
|
||||
groups: 1,
|
||||
height: 6,
|
||||
width: 6,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[18., 0., 20., 20., 0., 22.],
|
||||
[42., 0., 46., 46., 0., 50.],
|
||||
[42., 0., 46., 46., 0., 50.],
|
||||
[54., 0., 58., 58., 0., 62.],
|
||||
[54., 0., 58., 58., 0., 62.],
|
||||
[30., 0., 32., 32., 0., 34.],
|
||||
],
|
||||
[
|
||||
[36., 0., 38., 38., 0., 40.],
|
||||
[78., 0., 82., 82., 0., 86.],
|
||||
[78., 0., 82., 82., 0., 86.],
|
||||
[90., 0., 94., 94., 0., 98.],
|
||||
[90., 0., 94., 94., 0., 98.],
|
||||
[48., 0., 50., 50., 0., 52.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]],
|
||||
[[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]],
|
||||
],
|
||||
[
|
||||
[[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]],
|
||||
[[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([8., 8.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_groups() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 2,
|
||||
height: 5,
|
||||
width: 5,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[0., 1., 3., 3., 2.],
|
||||
[3., 8., 15., 12., 7.],
|
||||
[9., 21., 36., 27., 15.],
|
||||
[9., 20., 33., 24., 13.],
|
||||
[6., 13., 21., 15., 8.],
|
||||
],
|
||||
[
|
||||
[9., 19., 30., 21., 11.],
|
||||
[21., 44., 69., 48., 25.],
|
||||
[36., 75., 117., 81., 42.],
|
||||
[27., 56., 87., 60., 31.],
|
||||
[15., 31., 48., 33., 17.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[[54., 63., 72.], [99., 108., 117.], [144., 153., 162.]]],
|
||||
[[[279., 288., 297.], [324., 333., 342.], [369., 378., 387.]]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([9., 9.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_groups_stride_2() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 4,
|
||||
channels_out: 4,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 2,
|
||||
stride_2: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 4,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[4., 8., 4., 5.],
|
||||
[8., 16., 8., 10.],
|
||||
[4., 8., 4., 5.],
|
||||
[7., 14., 7., 8.],
|
||||
],
|
||||
[
|
||||
[13., 26., 13., 14.],
|
||||
[26., 52., 26., 28.],
|
||||
[13., 26., 13., 14.],
|
||||
[16., 32., 16., 17.],
|
||||
],
|
||||
[
|
||||
[22., 44., 22., 23.],
|
||||
[44., 88., 44., 46.],
|
||||
[22., 44., 22., 23.],
|
||||
[25., 50., 25., 26.],
|
||||
],
|
||||
[
|
||||
[31., 62., 31., 32.],
|
||||
[62., 124., 62., 64.],
|
||||
[31., 62., 31., 32.],
|
||||
[34., 68., 34., 35.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[[5., 10., 12.], [10., 20., 24.], [18., 36., 40.]]],
|
||||
[[[21., 42., 44.], [42., 84., 88.], [50., 100., 104.]]],
|
||||
[[[37., 74., 76.], [74., 148., 152.], [82., 164., 168.]]],
|
||||
[[[53., 106., 108.], [106., 212., 216.], [114., 228., 232.]]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([4., 4., 4., 4.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_groups_different_channels() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 3,
|
||||
channels_out: 6,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 3,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[9., 20., 24., 13.],
|
||||
[24., 52., 60., 32.],
|
||||
[36., 76., 84., 44.],
|
||||
[21., 44., 48., 25.],
|
||||
],
|
||||
[
|
||||
[45., 92., 96., 49.],
|
||||
[96., 196., 204., 104.],
|
||||
[108., 220., 228., 116.],
|
||||
[57., 116., 120., 61.],
|
||||
],
|
||||
[
|
||||
[81., 164., 168., 85.],
|
||||
[168., 340., 348., 176.],
|
||||
[180., 364., 372., 188.],
|
||||
[93., 188., 192., 97.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]],
|
||||
[[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]],
|
||||
[[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]],
|
||||
[[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]],
|
||||
[[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]],
|
||||
[[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([4., 4., 4., 4., 4., 4.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_complex() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 2,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 2,
|
||||
stride_1: 1,
|
||||
stride_2: 2,
|
||||
dilation_1: 2,
|
||||
dilation_2: 3,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 5,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[36., 39., 0., 39., 42.],
|
||||
[81., 87., 0., 87., 93.],
|
||||
[81., 87., 0., 87., 93.],
|
||||
[45., 48., 0., 48., 51.],
|
||||
],
|
||||
[
|
||||
[54., 57., 0., 57., 60.],
|
||||
[117., 123., 0., 123., 129.],
|
||||
[117., 123., 0., 123., 129.],
|
||||
[63., 66., 0., 66., 69.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[15., 42., 27.], [30., 72., 42.]],
|
||||
[[75., 162., 87.], [90., 192., 102.]],
|
||||
],
|
||||
[
|
||||
[[15., 42., 27.], [30., 72., 42.]],
|
||||
[[75., 162., 87.], [90., 192., 102.]],
|
||||
],
|
||||
[
|
||||
[[15., 42., 27.], [30., 72., 42.]],
|
||||
[[75., 162., 87.], [90., 192., 102.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([8., 8., 8.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_groups_stride_2_no_pad() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 4,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 2,
|
||||
stride_2: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 2,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[0., 1., 2., 0.],
|
||||
[3., 4., 5., 0.],
|
||||
[6., 7., 8., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[9., 10., 11., 0.],
|
||||
[12., 13., 14., 0.],
|
||||
[15., 16., 17., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[18., 19., 20., 0.],
|
||||
[21., 22., 23., 0.],
|
||||
[24., 25., 26., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[27., 28., 29., 0.],
|
||||
[30., 31., 32., 0.],
|
||||
[33., 34., 35., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[0., 1., 2.], [4., 5., 6.], [8., 9., 10.]],
|
||||
[[16., 17., 18.], [20., 21., 22.], [24., 25., 26.]],
|
||||
],
|
||||
[
|
||||
[[32., 33., 34.], [36., 37., 38.], [40., 41., 42.]],
|
||||
[[48., 49., 50.], [52., 53., 54.], [56., 57., 58.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([1., 1.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
struct Conv2dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
channels_out: usize,
|
||||
kernel_size_1: usize,
|
||||
kernel_size_2: usize,
|
||||
padding_1: usize,
|
||||
padding_2: usize,
|
||||
stride_1: usize,
|
||||
stride_2: usize,
|
||||
dilation_1: usize,
|
||||
dilation_2: usize,
|
||||
groups: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
}
|
||||
|
||||
struct Grads {
|
||||
x: TestTensor<4>,
|
||||
weight: TestTensor<4>,
|
||||
bias: TestTensor<1>,
|
||||
}
|
||||
|
||||
impl Conv2dTestCase {
|
||||
fn assert_grads(self, expected_grads: Grads) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels_out,
|
||||
self.channels_in / self.groups,
|
||||
self.kernel_size_1,
|
||||
self.kernel_size_2,
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_weight)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let bias = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let output = conv2d(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
Some(bias.clone()),
|
||||
ConvOptions::new(
|
||||
[self.stride_1, self.stride_2],
|
||||
[self.padding_1, self.padding_2],
|
||||
[self.dilation_1, self.dilation_2],
|
||||
self.groups,
|
||||
),
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Assert
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
let weight_grad_actual = weight.grad(&grads).unwrap();
|
||||
let bias_grad_actual = bias.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::rel_abs(0.01, 0.01);
|
||||
expected_grads
|
||||
.bias
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.x
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,690 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Shape, Tolerance, module::conv3d, ops::ConvOptions};
|
||||
|
||||
#[test]
|
||||
fn test_conv3d_basic() {
|
||||
let test = Conv3dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
kernel_size_3: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
padding_3: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
stride_3: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
dilation_3: 1,
|
||||
groups: 1,
|
||||
depth: 4,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[
|
||||
[536., 816., 816., 552.],
|
||||
[840., 1278., 1278., 864.],
|
||||
[840., 1278., 1278., 864.],
|
||||
[584., 888., 888., 600.],
|
||||
],
|
||||
[
|
||||
[912., 1386., 1386., 936.],
|
||||
[1422., 2160., 2160., 1458.],
|
||||
[1422., 2160., 2160., 1458.],
|
||||
[984., 1494., 1494., 1008.],
|
||||
],
|
||||
[
|
||||
[912., 1386., 1386., 936.],
|
||||
[1422., 2160., 2160., 1458.],
|
||||
[1422., 2160., 2160., 1458.],
|
||||
[984., 1494., 1494., 1008.],
|
||||
],
|
||||
[
|
||||
[680., 1032., 1032., 696.],
|
||||
[1056., 1602., 1602., 1080.],
|
||||
[1056., 1602., 1602., 1080.],
|
||||
[728., 1104., 1104., 744.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[968., 1464., 1464., 984.],
|
||||
[1488., 2250., 2250., 1512.],
|
||||
[1488., 2250., 2250., 1512.],
|
||||
[1016., 1536., 1536., 1032.],
|
||||
],
|
||||
[
|
||||
[1560., 2358., 2358., 1584.],
|
||||
[2394., 3618., 3618., 2430.],
|
||||
[2394., 3618., 3618., 2430.],
|
||||
[1632., 2466., 2466., 1656.],
|
||||
],
|
||||
[
|
||||
[1560., 2358., 2358., 1584.],
|
||||
[2394., 3618., 3618., 2430.],
|
||||
[2394., 3618., 3618., 2430.],
|
||||
[1632., 2466., 2466., 1656.],
|
||||
],
|
||||
[
|
||||
[1112., 1680., 1680., 1128.],
|
||||
[1704., 2574., 2574., 1728.],
|
||||
[1704., 2574., 2574., 1728.],
|
||||
[1160., 1752., 1752., 1176.],
|
||||
],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[
|
||||
[536., 816., 816., 552.],
|
||||
[840., 1278., 1278., 864.],
|
||||
[840., 1278., 1278., 864.],
|
||||
[584., 888., 888., 600.],
|
||||
],
|
||||
[
|
||||
[912., 1386., 1386., 936.],
|
||||
[1422., 2160., 2160., 1458.],
|
||||
[1422., 2160., 2160., 1458.],
|
||||
[984., 1494., 1494., 1008.],
|
||||
],
|
||||
[
|
||||
[912., 1386., 1386., 936.],
|
||||
[1422., 2160., 2160., 1458.],
|
||||
[1422., 2160., 2160., 1458.],
|
||||
[984., 1494., 1494., 1008.],
|
||||
],
|
||||
[
|
||||
[680., 1032., 1032., 696.],
|
||||
[1056., 1602., 1602., 1080.],
|
||||
[1056., 1602., 1602., 1080.],
|
||||
[728., 1104., 1104., 744.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[968., 1464., 1464., 984.],
|
||||
[1488., 2250., 2250., 1512.],
|
||||
[1488., 2250., 2250., 1512.],
|
||||
[1016., 1536., 1536., 1032.],
|
||||
],
|
||||
[
|
||||
[1560., 2358., 2358., 1584.],
|
||||
[2394., 3618., 3618., 2430.],
|
||||
[2394., 3618., 3618., 2430.],
|
||||
[1632., 2466., 2466., 1656.],
|
||||
],
|
||||
[
|
||||
[1560., 2358., 2358., 1584.],
|
||||
[2394., 3618., 3618., 2430.],
|
||||
[2394., 3618., 3618., 2430.],
|
||||
[1632., 2466., 2466., 1656.],
|
||||
],
|
||||
[
|
||||
[1112., 1680., 1680., 1128.],
|
||||
[1704., 2574., 2574., 1728.],
|
||||
[1704., 2574., 2574., 1728.],
|
||||
[1160., 1752., 1752., 1176.],
|
||||
],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[
|
||||
[4590., 6156., 4644.],
|
||||
[6264., 8400., 6336.],
|
||||
[4806., 6444., 4860.],
|
||||
],
|
||||
[
|
||||
[6696., 8976., 6768.],
|
||||
[9120., 12224., 9216.],
|
||||
[6984., 9360., 7056.],
|
||||
],
|
||||
[
|
||||
[5454., 7308., 5508.],
|
||||
[7416., 9936., 7488.],
|
||||
[5670., 7596., 5724.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[8046., 10764., 8100.],
|
||||
[10872., 14544., 10944.],
|
||||
[8262., 11052., 8316.],
|
||||
],
|
||||
[
|
||||
[11304., 15120., 11376.],
|
||||
[15264., 20416., 15360.],
|
||||
[11592., 15504., 11664.],
|
||||
],
|
||||
[
|
||||
[8910., 11916., 8964.],
|
||||
[12024., 16080., 12096.],
|
||||
[9126., 12204., 9180.],
|
||||
],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[
|
||||
[4590., 6156., 4644.],
|
||||
[6264., 8400., 6336.],
|
||||
[4806., 6444., 4860.],
|
||||
],
|
||||
[
|
||||
[6696., 8976., 6768.],
|
||||
[9120., 12224., 9216.],
|
||||
[6984., 9360., 7056.],
|
||||
],
|
||||
[
|
||||
[5454., 7308., 5508.],
|
||||
[7416., 9936., 7488.],
|
||||
[5670., 7596., 5724.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[8046., 10764., 8100.],
|
||||
[10872., 14544., 10944.],
|
||||
[8262., 11052., 8316.],
|
||||
],
|
||||
[
|
||||
[11304., 15120., 11376.],
|
||||
[15264., 20416., 15360.],
|
||||
[11592., 15504., 11664.],
|
||||
],
|
||||
[
|
||||
[8910., 11916., 8964.],
|
||||
[12024., 16080., 12096.],
|
||||
[9126., 12204., 9180.],
|
||||
],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([128., 128.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv3d_complex() {
|
||||
let test = Conv3dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 2,
|
||||
kernel_size_2: 3,
|
||||
kernel_size_3: 4,
|
||||
padding_1: 1,
|
||||
padding_2: 2,
|
||||
padding_3: 3,
|
||||
stride_1: 1,
|
||||
stride_2: 2,
|
||||
stride_3: 3,
|
||||
dilation_1: 2,
|
||||
dilation_2: 3,
|
||||
dilation_3: 4,
|
||||
groups: 1,
|
||||
depth: 5,
|
||||
height: 6,
|
||||
width: 7,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[
|
||||
[0., 147., 0., 0., 0., 150., 0.],
|
||||
[0., 159., 0., 0., 0., 162., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 159., 0., 0., 0., 162., 0.],
|
||||
[0., 171., 0., 0., 0., 174., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 330., 0., 0., 0., 336., 0.],
|
||||
[0., 354., 0., 0., 0., 360., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 354., 0., 0., 0., 360., 0.],
|
||||
[0., 378., 0., 0., 0., 384., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 330., 0., 0., 0., 336., 0.],
|
||||
[0., 354., 0., 0., 0., 360., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 354., 0., 0., 0., 360., 0.],
|
||||
[0., 378., 0., 0., 0., 384., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 330., 0., 0., 0., 336., 0.],
|
||||
[0., 354., 0., 0., 0., 360., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 354., 0., 0., 0., 360., 0.],
|
||||
[0., 378., 0., 0., 0., 384., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 183., 0., 0., 0., 186., 0.],
|
||||
[0., 195., 0., 0., 0., 198., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 195., 0., 0., 0., 198., 0.],
|
||||
[0., 207., 0., 0., 0., 210., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[0., 219., 0., 0., 0., 222., 0.],
|
||||
[0., 231., 0., 0., 0., 234., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 231., 0., 0., 0., 234., 0.],
|
||||
[0., 243., 0., 0., 0., 246., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 474., 0., 0., 0., 480., 0.],
|
||||
[0., 498., 0., 0., 0., 504., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 498., 0., 0., 0., 504., 0.],
|
||||
[0., 522., 0., 0., 0., 528., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 474., 0., 0., 0., 480., 0.],
|
||||
[0., 498., 0., 0., 0., 504., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 498., 0., 0., 0., 504., 0.],
|
||||
[0., 522., 0., 0., 0., 528., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 474., 0., 0., 0., 480., 0.],
|
||||
[0., 498., 0., 0., 0., 504., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 498., 0., 0., 0., 504., 0.],
|
||||
[0., 522., 0., 0., 0., 528., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 255., 0., 0., 0., 258., 0.],
|
||||
[0., 267., 0., 0., 0., 270., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 267., 0., 0., 0., 270., 0.],
|
||||
[0., 279., 0., 0., 0., 282., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[
|
||||
[0., 256., 272., 0.],
|
||||
[0., 624., 656., 0.],
|
||||
[0., 368., 384., 0.],
|
||||
],
|
||||
[
|
||||
[0., 424., 440., 0.],
|
||||
[0., 960., 992., 0.],
|
||||
[0., 536., 552., 0.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[0., 1096., 1112., 0.],
|
||||
[0., 2304., 2336., 0.],
|
||||
[0., 1208., 1224., 0.],
|
||||
],
|
||||
[
|
||||
[0., 1264., 1280., 0.],
|
||||
[0., 2640., 2672., 0.],
|
||||
[0., 1376., 1392., 0.],
|
||||
],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[
|
||||
[0., 256., 272., 0.],
|
||||
[0., 624., 656., 0.],
|
||||
[0., 368., 384., 0.],
|
||||
],
|
||||
[
|
||||
[0., 424., 440., 0.],
|
||||
[0., 960., 992., 0.],
|
||||
[0., 536., 552., 0.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[0., 1096., 1112., 0.],
|
||||
[0., 2304., 2336., 0.],
|
||||
[0., 1208., 1224., 0.],
|
||||
],
|
||||
[
|
||||
[0., 1264., 1280., 0.],
|
||||
[0., 2640., 2672., 0.],
|
||||
[0., 1376., 1392., 0.],
|
||||
],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[
|
||||
[0., 256., 272., 0.],
|
||||
[0., 624., 656., 0.],
|
||||
[0., 368., 384., 0.],
|
||||
],
|
||||
[
|
||||
[0., 424., 440., 0.],
|
||||
[0., 960., 992., 0.],
|
||||
[0., 536., 552., 0.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[0., 1096., 1112., 0.],
|
||||
[0., 2304., 2336., 0.],
|
||||
[0., 1208., 1224., 0.],
|
||||
],
|
||||
[
|
||||
[0., 1264., 1280., 0.],
|
||||
[0., 2640., 2672., 0.],
|
||||
[0., 1376., 1392., 0.],
|
||||
],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([10., 10., 10.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv3d_groups_stride_2_no_pad() {
|
||||
let test = Conv3dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 4,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
kernel_size_3: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
padding_3: 0,
|
||||
stride_1: 2,
|
||||
stride_2: 2,
|
||||
stride_3: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
dilation_3: 1,
|
||||
groups: 2,
|
||||
depth: 4,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[
|
||||
[0., 1., 2., 0.],
|
||||
[3., 4., 5., 0.],
|
||||
[6., 7., 8., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[9., 10., 11., 0.],
|
||||
[12., 13., 14., 0.],
|
||||
[15., 16., 17., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[18., 19., 20., 0.],
|
||||
[21., 22., 23., 0.],
|
||||
[24., 25., 26., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[27., 28., 29., 0.],
|
||||
[30., 31., 32., 0.],
|
||||
[33., 34., 35., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[36., 37., 38., 0.],
|
||||
[39., 40., 41., 0.],
|
||||
[42., 43., 44., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[45., 46., 47., 0.],
|
||||
[48., 49., 50., 0.],
|
||||
[51., 52., 53., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[54., 55., 56., 0.],
|
||||
[57., 58., 59., 0.],
|
||||
[60., 61., 62., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[63., 64., 65., 0.],
|
||||
[66., 67., 68., 0.],
|
||||
[69., 70., 71., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[72., 73., 74., 0.],
|
||||
[75., 76., 77., 0.],
|
||||
[78., 79., 80., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[81., 82., 83., 0.],
|
||||
[84., 85., 86., 0.],
|
||||
[87., 88., 89., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[90., 91., 92., 0.],
|
||||
[93., 94., 95., 0.],
|
||||
[96., 97., 98., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[99., 100., 101., 0.],
|
||||
[102., 103., 104., 0.],
|
||||
[105., 106., 107., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[[0., 1., 2.], [4., 5., 6.], [8., 9., 10.]],
|
||||
[[16., 17., 18.], [20., 21., 22.], [24., 25., 26.]],
|
||||
[[32., 33., 34.], [36., 37., 38.], [40., 41., 42.]],
|
||||
],
|
||||
[
|
||||
[[64., 65., 66.], [68., 69., 70.], [72., 73., 74.]],
|
||||
[[80., 81., 82.], [84., 85., 86.], [88., 89., 90.]],
|
||||
[[96., 97., 98.], [100., 101., 102.], [104., 105., 106.]],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[[128., 129., 130.], [132., 133., 134.], [136., 137., 138.]],
|
||||
[[144., 145., 146.], [148., 149., 150.], [152., 153., 154.]],
|
||||
[[160., 161., 162.], [164., 165., 166.], [168., 169., 170.]],
|
||||
],
|
||||
[
|
||||
[[192., 193., 194.], [196., 197., 198.], [200., 201., 202.]],
|
||||
[[208., 209., 210.], [212., 213., 214.], [216., 217., 218.]],
|
||||
[[224., 225., 226.], [228., 229., 230.], [232., 233., 234.]],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([1., 1.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
struct Conv3dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
channels_out: usize,
|
||||
kernel_size_1: usize,
|
||||
kernel_size_2: usize,
|
||||
kernel_size_3: usize,
|
||||
padding_1: usize,
|
||||
padding_2: usize,
|
||||
padding_3: usize,
|
||||
stride_1: usize,
|
||||
stride_2: usize,
|
||||
stride_3: usize,
|
||||
dilation_1: usize,
|
||||
dilation_2: usize,
|
||||
dilation_3: usize,
|
||||
groups: usize,
|
||||
depth: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
}
|
||||
|
||||
struct Grads {
|
||||
x: TestTensor<5>,
|
||||
weight: TestTensor<5>,
|
||||
bias: TestTensor<1>,
|
||||
}
|
||||
|
||||
impl Conv3dTestCase {
|
||||
fn assert_grads(self, expected_grads: Grads) {
|
||||
let shape_x = Shape::new([
|
||||
self.batch_size,
|
||||
self.channels_in,
|
||||
self.depth,
|
||||
self.height,
|
||||
self.width,
|
||||
]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels_out,
|
||||
self.channels_in / self.groups,
|
||||
self.kernel_size_1,
|
||||
self.kernel_size_2,
|
||||
self.kernel_size_3,
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape::<5, _>(shape_weight)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let bias = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<5, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let output = conv3d(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
Some(bias.clone()),
|
||||
ConvOptions::new(
|
||||
[self.stride_1, self.stride_2, self.stride_3],
|
||||
[self.padding_1, self.padding_2, self.padding_3],
|
||||
[self.dilation_1, self.dilation_2, self.dilation_3],
|
||||
self.groups,
|
||||
),
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Assert
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
let weight_grad_actual = weight.grad(&grads).unwrap();
|
||||
let bias_grad_actual = bias.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default();
|
||||
expected_grads
|
||||
.bias
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.x
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,292 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Shape, Tolerance, module::conv_transpose1d, ops::ConvTransposeOptions};
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose1d_basic() {
|
||||
let test = ConvTranspose1dTestCase {
|
||||
batch_size: 2,
|
||||
channels: [2, 2],
|
||||
kernel_size: 3,
|
||||
padding: 0,
|
||||
padding_out: 0,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
size: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]],
|
||||
[[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[44.0, 44.0, 44.0], [44.0, 44.0, 44.0]],
|
||||
[[76.0, 76.0, 76.0], [76.0, 76.0, 76.0]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([12., 12.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose1d_padding() {
|
||||
let test = ConvTranspose1dTestCase {
|
||||
batch_size: 2,
|
||||
channels: [2, 2],
|
||||
kernel_size: 3,
|
||||
padding: 2,
|
||||
padding_out: 0,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
size: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[7., 12., 8., 3.], [19., 36., 32., 15.]],
|
||||
[[7., 12., 8., 3.], [19., 36., 32., 15.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[26., 22., 18.], [26., 22., 18.]],
|
||||
[[42., 38., 34.], [42., 38., 34.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([4., 4.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose1d_stride() {
|
||||
let test = ConvTranspose1dTestCase {
|
||||
batch_size: 2,
|
||||
channels: [2, 2],
|
||||
kernel_size: 3,
|
||||
padding: 0,
|
||||
padding_out: 0,
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
size: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
|
||||
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[44., 44., 44.], [44., 44., 44.]],
|
||||
[[76., 76., 76.], [76., 76., 76.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([18., 18.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose1d_stride_padding_out() {
|
||||
let test = ConvTranspose1dTestCase {
|
||||
batch_size: 2,
|
||||
channels: [2, 2],
|
||||
kernel_size: 3,
|
||||
padding: 0,
|
||||
padding_out: 1,
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
size: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
|
||||
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[44., 44., 44.], [44., 44., 44.]],
|
||||
[[76., 76., 76.], [76., 76., 76.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([20., 20.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose1d_dilation() {
|
||||
let test = ConvTranspose1dTestCase {
|
||||
batch_size: 2,
|
||||
channels: [2, 2],
|
||||
kernel_size: 3,
|
||||
padding: 0,
|
||||
padding_out: 0,
|
||||
stride: 1,
|
||||
dilation: 2,
|
||||
groups: 1,
|
||||
size: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
|
||||
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[44., 44., 44.], [44., 44., 44.]],
|
||||
[[76., 76., 76.], [76., 76., 76.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([16., 16.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose1d_complex() {
|
||||
let test = ConvTranspose1dTestCase {
|
||||
batch_size: 2,
|
||||
channels: [2, 4],
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
padding_out: 1,
|
||||
stride: 2,
|
||||
dilation: 2,
|
||||
groups: 2,
|
||||
size: 8,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0],
|
||||
[36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0],
|
||||
],
|
||||
[
|
||||
[12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0],
|
||||
[36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[168.0, 184.0, 184.0], [168.0, 184.0, 184.0]],
|
||||
[[280.0, 312.0, 312.0], [280.0, 312.0, 312.0]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([36.0, 36.0, 36.0, 36.0], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
struct ConvTranspose1dTestCase {
|
||||
batch_size: usize,
|
||||
channels: [usize; 2],
|
||||
kernel_size: usize,
|
||||
padding: usize,
|
||||
padding_out: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
struct Grads {
|
||||
x: TestTensor<3>,
|
||||
weight: TestTensor<3>,
|
||||
bias: TestTensor<1>,
|
||||
}
|
||||
|
||||
impl ConvTranspose1dTestCase {
|
||||
fn assert_grads(self, expected_grads: Grads) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels[0], self.size]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels[0],
|
||||
self.channels[1] / self.groups,
|
||||
self.kernel_size,
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape::<3, _>(shape_weight)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let bias = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<3, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let output = conv_transpose1d(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
Some(bias.clone()),
|
||||
ConvTransposeOptions::new(
|
||||
[self.stride],
|
||||
[self.padding],
|
||||
[self.padding_out],
|
||||
[self.dilation],
|
||||
self.groups,
|
||||
),
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Assert
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
let weight_grad_actual = weight.grad(&grads).unwrap();
|
||||
let bias_grad_actual = bias.grad(&grads).unwrap();
|
||||
|
||||
expected_grads
|
||||
.bias
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), Tolerance::default());
|
||||
expected_grads
|
||||
.x
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,706 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Shape, Tolerance, module::conv_transpose2d, ops::ConvTransposeOptions};
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_basic() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 2,
|
||||
channels: [2, 2],
|
||||
kernel_size: [3, 3],
|
||||
padding: [0, 0],
|
||||
padding_out: [0, 0],
|
||||
stride: [1, 1],
|
||||
dilation: [1, 1],
|
||||
groups: 1,
|
||||
size: [4, 4],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[153., 153., 153., 153.],
|
||||
[153., 153., 153., 153.],
|
||||
[153., 153., 153., 153.],
|
||||
[153., 153., 153., 153.],
|
||||
],
|
||||
[
|
||||
[477., 477., 477., 477.],
|
||||
[477., 477., 477., 477.],
|
||||
[477., 477., 477., 477.],
|
||||
[477., 477., 477., 477.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[153., 153., 153., 153.],
|
||||
[153., 153., 153., 153.],
|
||||
[153., 153., 153., 153.],
|
||||
[153., 153., 153., 153.],
|
||||
],
|
||||
[
|
||||
[477., 477., 477., 477.],
|
||||
[477., 477., 477., 477.],
|
||||
[477., 477., 477., 477.],
|
||||
[477., 477., 477., 477.],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]],
|
||||
[[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]],
|
||||
],
|
||||
[
|
||||
[
|
||||
[1264., 1264., 1264.],
|
||||
[1264., 1264., 1264.],
|
||||
[1264., 1264., 1264.],
|
||||
],
|
||||
[
|
||||
[1264., 1264., 1264.],
|
||||
[1264., 1264., 1264.],
|
||||
[1264., 1264., 1264.],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([72., 72.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_padding() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [1, 1],
|
||||
kernel_size: [3, 3],
|
||||
padding: [1, 2],
|
||||
padding_out: [0, 0],
|
||||
stride: [1, 1],
|
||||
dilation: [1, 1],
|
||||
groups: 1,
|
||||
size: [4, 4],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[[
|
||||
[13., 24., 20., 9.],
|
||||
[15., 27., 21., 9.],
|
||||
[15., 27., 21., 9.],
|
||||
[7., 12., 8., 3.],
|
||||
]]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[[[[63., 57., 51.], [68., 60., 52.], [39., 33., 27.]]]],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([8.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_stride() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [1, 1],
|
||||
kernel_size: [3, 3],
|
||||
padding: [0, 0],
|
||||
padding_out: [0, 0],
|
||||
stride: [2, 3],
|
||||
dilation: [1, 1],
|
||||
groups: 1,
|
||||
size: [4, 4],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[[
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
]]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([108.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_stride_padding_out() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [1, 1],
|
||||
kernel_size: [3, 3],
|
||||
padding: [0, 0],
|
||||
padding_out: [1, 2],
|
||||
stride: [2, 3],
|
||||
dilation: [1, 1],
|
||||
groups: 1,
|
||||
size: [4, 4],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[[
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
]]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([140.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_dilation() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [1, 1],
|
||||
kernel_size: [3, 3],
|
||||
padding: [0, 0],
|
||||
padding_out: [0, 0],
|
||||
stride: [1, 1],
|
||||
dilation: [2, 3],
|
||||
groups: 1,
|
||||
size: [4, 4],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[[
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
]]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([80.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_channels() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [2, 3],
|
||||
kernel_size: [3, 3],
|
||||
padding: [0, 0],
|
||||
padding_out: [0, 0],
|
||||
stride: [1, 1],
|
||||
dilation: [1, 1],
|
||||
groups: 1,
|
||||
size: [4, 4],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[351., 351., 351., 351.],
|
||||
[351., 351., 351., 351.],
|
||||
[351., 351., 351., 351.],
|
||||
[351., 351., 351., 351.],
|
||||
],
|
||||
[
|
||||
[1080., 1080., 1080., 1080.],
|
||||
[1080., 1080., 1080., 1080.],
|
||||
[1080., 1080., 1080., 1080.],
|
||||
[1080., 1080., 1080., 1080.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],
|
||||
[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],
|
||||
[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],
|
||||
],
|
||||
[
|
||||
[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],
|
||||
[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],
|
||||
[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([36., 36., 36.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_kernel_size() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [1, 1],
|
||||
kernel_size: [3, 5],
|
||||
padding: [0, 0],
|
||||
padding_out: [0, 0],
|
||||
stride: [1, 1],
|
||||
dilation: [1, 1],
|
||||
groups: 1,
|
||||
size: [6, 6],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[[
|
||||
[105., 105., 105., 105., 105., 105.],
|
||||
[105., 105., 105., 105., 105., 105.],
|
||||
[105., 105., 105., 105., 105., 105.],
|
||||
[105., 105., 105., 105., 105., 105.],
|
||||
[105., 105., 105., 105., 105., 105.],
|
||||
[105., 105., 105., 105., 105., 105.],
|
||||
]]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[[[
|
||||
[630., 630., 630., 630., 630.],
|
||||
[630., 630., 630., 630., 630.],
|
||||
[630., 630., 630., 630., 630.],
|
||||
]]],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([80.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_groups() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [2, 2],
|
||||
kernel_size: [3, 3],
|
||||
padding: [0, 0],
|
||||
padding_out: [0, 0],
|
||||
stride: [1, 1],
|
||||
dilation: [1, 1],
|
||||
groups: 2,
|
||||
size: [4, 4],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
],
|
||||
[
|
||||
[117., 117., 117., 117.],
|
||||
[117., 117., 117., 117.],
|
||||
[117., 117., 117., 117.],
|
||||
[117., 117., 117., 117.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]],
|
||||
[[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([36., 36.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_complex_no_groups() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 2,
|
||||
channels: [2, 3],
|
||||
kernel_size: [3, 5],
|
||||
padding: [1, 2],
|
||||
padding_out: [1, 2],
|
||||
stride: [2, 3],
|
||||
dilation: [2, 3],
|
||||
groups: 1,
|
||||
size: [6, 8],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[600., 735., 735., 735., 735., 735., 735., 735.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
],
|
||||
[
|
||||
[1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[600., 735., 735., 735., 735., 735., 735., 735.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
],
|
||||
[
|
||||
[1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[5320., 6040., 6040., 6040., 6040.],
|
||||
[6048., 6864., 6864., 6864., 6864.],
|
||||
[6048., 6864., 6864., 6864., 6864.],
|
||||
],
|
||||
[
|
||||
[5320., 6040., 6040., 6040., 6040.],
|
||||
[6048., 6864., 6864., 6864., 6864.],
|
||||
[6048., 6864., 6864., 6864., 6864.],
|
||||
],
|
||||
[
|
||||
[5320., 6040., 6040., 6040., 6040.],
|
||||
[6048., 6864., 6864., 6864., 6864.],
|
||||
[6048., 6864., 6864., 6864., 6864.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[8680., 9880., 9880., 9880., 9880.],
|
||||
[10080., 11472., 11472., 11472., 11472.],
|
||||
[10080., 11472., 11472., 11472., 11472.],
|
||||
],
|
||||
[
|
||||
[8680., 9880., 9880., 9880., 9880.],
|
||||
[10080., 11472., 11472., 11472., 11472.],
|
||||
[10080., 11472., 11472., 11472., 11472.],
|
||||
],
|
||||
[
|
||||
[8680., 9880., 9880., 9880., 9880.],
|
||||
[10080., 11472., 11472., 11472., 11472.],
|
||||
[10080., 11472., 11472., 11472., 11472.],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([896., 896., 896.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_complex_no_groups_2() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [4, 2],
|
||||
kernel_size: [2, 3],
|
||||
padding: [1, 2],
|
||||
padding_out: [1, 2],
|
||||
stride: [2, 3],
|
||||
dilation: [1, 2],
|
||||
groups: 1,
|
||||
size: [10, 10],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[30., 42., 42., 42., 42., 42., 42., 42., 42., 42.],
|
||||
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
],
|
||||
[
|
||||
[78., 114., 114., 114., 114., 114., 114., 114., 114., 114.],
|
||||
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
|
||||
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
|
||||
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
|
||||
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
|
||||
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
|
||||
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
|
||||
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
|
||||
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
|
||||
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
|
||||
],
|
||||
[
|
||||
[126., 186., 186., 186., 186., 186., 186., 186., 186., 186.],
|
||||
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
|
||||
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
|
||||
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
|
||||
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
|
||||
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
|
||||
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
|
||||
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
|
||||
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
|
||||
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
|
||||
],
|
||||
[
|
||||
[174., 258., 258., 258., 258., 258., 258., 258., 258., 258.],
|
||||
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
|
||||
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
|
||||
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
|
||||
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
|
||||
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
|
||||
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
|
||||
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
|
||||
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
|
||||
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[4455., 4905., 4905.], [4500., 4950., 4950.]],
|
||||
[[4455., 4905., 4905.], [4500., 4950., 4950.]],
|
||||
],
|
||||
[
|
||||
[[12555., 13905., 13905.], [13500., 14950., 14950.]],
|
||||
[[12555., 13905., 13905.], [13500., 14950., 14950.]],
|
||||
],
|
||||
[
|
||||
[[20655., 22905., 22905.], [22500., 24950., 24950.]],
|
||||
[[20655., 22905., 22905.], [22500., 24950., 24950.]],
|
||||
],
|
||||
[
|
||||
[[28755., 31905., 31905.], [31500., 34950., 34950.]],
|
||||
[[28755., 31905., 31905.], [31500., 34950., 34950.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([570., 570.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_complex_groups() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [4, 2],
|
||||
kernel_size: [2, 3],
|
||||
padding: [1, 2],
|
||||
padding_out: [1, 2],
|
||||
stride: [2, 3],
|
||||
dilation: [1, 2],
|
||||
groups: 2,
|
||||
size: [10, 10],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[9., 12., 12., 12., 12., 12., 12., 12., 12., 12.],
|
||||
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
|
||||
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
|
||||
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
|
||||
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
|
||||
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
|
||||
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
|
||||
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
|
||||
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
|
||||
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
|
||||
],
|
||||
[
|
||||
[21., 30., 30., 30., 30., 30., 30., 30., 30., 30.],
|
||||
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
|
||||
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
|
||||
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
|
||||
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
|
||||
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
|
||||
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
|
||||
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
|
||||
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
|
||||
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
|
||||
],
|
||||
[
|
||||
[33., 48., 48., 48., 48., 48., 48., 48., 48., 48.],
|
||||
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
|
||||
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
|
||||
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
|
||||
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
|
||||
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
|
||||
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
|
||||
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
|
||||
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
|
||||
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
|
||||
],
|
||||
[
|
||||
[45., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
|
||||
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
|
||||
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
|
||||
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
|
||||
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
|
||||
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
|
||||
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
|
||||
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
|
||||
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[[4455., 4905., 4905.], [4500., 4950., 4950.]]],
|
||||
[[[12555., 13905., 13905.], [13500., 14950., 14950.]]],
|
||||
[[[20655., 22905., 22905.], [22500., 24950., 24950.]]],
|
||||
[[[28755., 31905., 31905.], [31500., 34950., 34950.]]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([570., 570.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
struct ConvTranspose2dTestCase {
|
||||
batch_size: usize,
|
||||
channels: [usize; 2],
|
||||
kernel_size: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
padding_out: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
groups: usize,
|
||||
size: [usize; 2],
|
||||
}
|
||||
|
||||
struct Grads {
|
||||
x: TestTensor<4>,
|
||||
weight: TestTensor<4>,
|
||||
bias: TestTensor<1>,
|
||||
}
|
||||
|
||||
impl ConvTranspose2dTestCase {
|
||||
fn assert_grads(self, expected_grads: Grads) {
|
||||
let shape_x = Shape::new([
|
||||
self.batch_size,
|
||||
self.channels[0],
|
||||
self.size[0],
|
||||
self.size[1],
|
||||
]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels[0],
|
||||
self.channels[1] / self.groups,
|
||||
self.kernel_size[0],
|
||||
self.kernel_size[1],
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_weight)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let bias = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let output = conv_transpose2d(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
Some(bias.clone()),
|
||||
ConvTransposeOptions::new(
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.padding_out,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
),
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Assert
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
let weight_grad_actual = weight.grad(&grads).unwrap();
|
||||
let bias_grad_actual = bias.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::permissive();
|
||||
expected_grads
|
||||
.bias
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.x
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,711 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Shape, Tolerance, module::conv_transpose3d, ops::ConvTransposeOptions};
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose3d_basic() {
|
||||
let test = ConvTranspose3dTestCase {
|
||||
batch_size: 2,
|
||||
channels: [2, 2],
|
||||
kernel_size: [3, 3, 3],
|
||||
padding: [0, 0, 0],
|
||||
padding_out: [0, 0, 0],
|
||||
stride: [1, 1, 1],
|
||||
dilation: [1, 1, 1],
|
||||
groups: 1,
|
||||
size: [4, 4, 4],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
],
|
||||
[
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
],
|
||||
[
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
],
|
||||
[
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
],
|
||||
[
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
],
|
||||
[
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
],
|
||||
[
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
],
|
||||
[
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
],
|
||||
[
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
],
|
||||
[
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
],
|
||||
[
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
],
|
||||
[
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
],
|
||||
[
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
],
|
||||
[
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
],
|
||||
[
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
],
|
||||
[
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
],
|
||||
[
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
],
|
||||
[
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
],
|
||||
[
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
],
|
||||
[
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
],
|
||||
[
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([432., 432.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose3d_complex_groups() {
|
||||
let test = ConvTranspose3dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [4, 2],
|
||||
kernel_size: [2, 3, 4],
|
||||
padding: [1, 2, 3],
|
||||
padding_out: [1, 2, 3],
|
||||
stride: [2, 3, 4],
|
||||
dilation: [1, 2, 3],
|
||||
groups: 2,
|
||||
size: [6, 6, 6],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[
|
||||
[1.250000, 1.625000, 1.625000, 1.625000, 1.625000, 1.625000],
|
||||
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
|
||||
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
|
||||
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
|
||||
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
|
||||
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
|
||||
],
|
||||
[
|
||||
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
],
|
||||
[
|
||||
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
],
|
||||
[
|
||||
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
],
|
||||
[
|
||||
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
],
|
||||
[
|
||||
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[2.750000, 3.625000, 3.625000, 3.625000, 3.625000, 3.625000],
|
||||
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
|
||||
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
|
||||
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
|
||||
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
|
||||
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
|
||||
],
|
||||
[
|
||||
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
],
|
||||
[
|
||||
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
],
|
||||
[
|
||||
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
],
|
||||
[
|
||||
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
],
|
||||
[
|
||||
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[4.250000, 5.625000, 5.625000, 5.625000, 5.625000, 5.625000],
|
||||
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
|
||||
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
|
||||
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
|
||||
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
|
||||
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
|
||||
],
|
||||
[
|
||||
[
|
||||
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[5.750000, 7.625000, 7.625000, 7.625000, 7.625000, 7.625000],
|
||||
[
|
||||
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
|
||||
],
|
||||
[
|
||||
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
|
||||
],
|
||||
[
|
||||
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
|
||||
],
|
||||
[
|
||||
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
|
||||
],
|
||||
[
|
||||
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[
|
||||
[
|
||||
[18.663193, 22.309027, 22.309027, 22.309027],
|
||||
[21.875000, 26.145834, 26.145834, 26.145834],
|
||||
[21.875000, 26.145834, 26.145834, 26.145834],
|
||||
],
|
||||
[
|
||||
[19.270832, 23.020834, 23.020834, 23.020834],
|
||||
[22.500000, 26.875002, 26.875002, 26.875002],
|
||||
[22.500000, 26.875002, 26.875002, 26.875002],
|
||||
],
|
||||
]],
|
||||
[[
|
||||
[
|
||||
[49.913193, 59.809029, 59.809029, 59.809029],
|
||||
[59.375000, 71.145836, 71.145836, 71.145836],
|
||||
[59.375000, 71.145836, 71.145836, 71.145836],
|
||||
],
|
||||
[
|
||||
[56.770836, 68.020836, 68.020836, 68.020836],
|
||||
[67.500000, 80.875000, 80.875000, 80.875000],
|
||||
[67.500000, 80.875000, 80.875000, 80.875000],
|
||||
],
|
||||
]],
|
||||
[[
|
||||
[
|
||||
[81.163193, 97.309029, 97.309029, 97.309029],
|
||||
[96.875000, 116.145828, 116.145828, 116.145828],
|
||||
[96.875000, 116.145828, 116.145828, 116.145828],
|
||||
],
|
||||
[
|
||||
[94.270828, 113.020828, 113.020828, 113.020828],
|
||||
[112.500000, 134.875000, 134.875000, 134.875000],
|
||||
[112.500000, 134.875000, 134.875000, 134.875000],
|
||||
],
|
||||
]],
|
||||
[[
|
||||
[
|
||||
[112.413200, 134.809021, 134.809021, 134.809021],
|
||||
[134.375000, 161.145828, 161.145828, 161.145828],
|
||||
[134.375000, 161.145828, 161.145828, 161.145828],
|
||||
],
|
||||
[
|
||||
[131.770844, 158.020828, 158.020828, 158.020828],
|
||||
[157.500000, 188.875000, 188.875000, 188.875000],
|
||||
[157.500000, 188.875000, 188.875000, 188.875000],
|
||||
],
|
||||
]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([5346., 5346.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
struct ConvTranspose3dTestCase {
|
||||
batch_size: usize,
|
||||
channels: [usize; 2],
|
||||
kernel_size: [usize; 3],
|
||||
padding: [usize; 3],
|
||||
padding_out: [usize; 3],
|
||||
stride: [usize; 3],
|
||||
dilation: [usize; 3],
|
||||
groups: usize,
|
||||
size: [usize; 3],
|
||||
}
|
||||
|
||||
struct Grads {
|
||||
x: TestTensor<5>,
|
||||
weight: TestTensor<5>,
|
||||
bias: TestTensor<1>,
|
||||
}
|
||||
|
||||
impl ConvTranspose3dTestCase {
|
||||
fn assert_grads(self, expected_grads: Grads) {
|
||||
let shape_x = Shape::new([
|
||||
self.batch_size,
|
||||
self.channels[0],
|
||||
self.size[0],
|
||||
self.size[1],
|
||||
self.size[2],
|
||||
]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels[0],
|
||||
self.channels[1] / self.groups,
|
||||
self.kernel_size[0],
|
||||
self.kernel_size[1],
|
||||
self.kernel_size[2],
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape::<5, _>(shape_weight.clone())
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.div_scalar(shape_weight.num_elements() as f32)
|
||||
.require_grad();
|
||||
let bias = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<5, _>(shape_x.clone())
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.div_scalar(shape_x.num_elements() as f32)
|
||||
.require_grad();
|
||||
let output = conv_transpose3d(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
Some(bias.clone()),
|
||||
ConvTransposeOptions::new(
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.padding_out,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
),
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Assert
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
let weight_grad_actual = weight.grad(&grads).unwrap();
|
||||
let bias_grad_actual = bias.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::permissive();
|
||||
expected_grads
|
||||
.bias
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.x
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use burn_backend_tests::might_panic;
|
||||
|
||||
#[test]
|
||||
fn backward_basic() {
|
||||
let device = Default::default();
|
||||
let a = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let b = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
// Simple cross product; grad is a vector of ones.
|
||||
let c = a.clone().cross(b.clone(), 1);
|
||||
let grads = c.backward();
|
||||
|
||||
let a_grad = a.grad(&grads).unwrap().to_data();
|
||||
let b_grad = b.grad(&grads).unwrap().to_data();
|
||||
|
||||
// For a: b×grad_out, where grad_out = [1,1,1]
|
||||
let expected_a = TensorData::from([[-1.0, 2.0, -1.0], [-1.0, 2.0, -1.0]]);
|
||||
// For b: grad_out×a
|
||||
let expected_b = TensorData::from([[1.0, -2.0, 1.0], [1.0, -2.0, 1.0]]);
|
||||
|
||||
a_grad.assert_approx_eq::<FloatElem>(&expected_a, Tolerance::default());
|
||||
b_grad.assert_approx_eq::<FloatElem>(&expected_b, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn backward_after_sum() {
|
||||
let device = Default::default();
|
||||
let a = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let b = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
// Sum reduces to scalar, but the gradient should be the same.
|
||||
let c = a.clone().cross(b.clone(), 1).sum();
|
||||
let grads = c.backward();
|
||||
|
||||
let a_grad = a.grad(&grads).unwrap().to_data();
|
||||
let b_grad = b.grad(&grads).unwrap().to_data();
|
||||
|
||||
let expected_a = TensorData::from([[-1.0, 2.0, -1.0], [-1.0, 2.0, -1.0]]);
|
||||
let expected_b = TensorData::from([[1.0, -2.0, 1.0], [1.0, -2.0, 1.0]]);
|
||||
|
||||
a_grad.assert_approx_eq::<FloatElem>(&expected_a, Tolerance::default());
|
||||
b_grad.assert_approx_eq::<FloatElem>(&expected_b, Tolerance::default());
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[might_panic(reason = "not implemented: Cross product on non-last dimension")]
|
||||
#[test]
|
||||
fn different_dim() {
|
||||
// Also check when the cross is along a different dimension (e.g. dim 0).
|
||||
let device = Default::default();
|
||||
let a_raw = [[1.0, 4.0, 7.0], [2.0, 5.0, 8.0], [3.0, 6.0, 9.0]];
|
||||
let b_raw = [[9.0, 6.0, 3.0], [8.0, 5.0, 2.0], [7.0, 4.0, 1.0]];
|
||||
|
||||
let a = TestTensor::<2>::from_data(TensorData::from(a_raw), &device);
|
||||
let b = TestTensor::<2>::from_data(TensorData::from(b_raw), &device);
|
||||
// Cross along dim 0. Some backends (for example CubeCL) may not support
|
||||
// cross on non-last dimensions and will intentionally panic with a
|
||||
// message like "Cross product on non-last dimension not yet implemented".
|
||||
// In that case we treat the panic as a skipped test for that backend.
|
||||
let out = a.cross(b.clone(), 0);
|
||||
|
||||
// Manually compute cross of each column vector using raw arrays
|
||||
let expected = [
|
||||
[
|
||||
a_raw[1][0] * b_raw[2][0] - a_raw[2][0] * b_raw[1][0],
|
||||
a_raw[1][1] * b_raw[2][1] - a_raw[2][1] * b_raw[1][1],
|
||||
a_raw[1][2] * b_raw[2][2] - a_raw[2][2] * b_raw[1][2],
|
||||
],
|
||||
[
|
||||
a_raw[2][0] * b_raw[0][0] - a_raw[0][0] * b_raw[2][0],
|
||||
a_raw[2][1] * b_raw[0][1] - a_raw[0][1] * b_raw[2][1],
|
||||
a_raw[2][2] * b_raw[0][2] - a_raw[0][2] * b_raw[2][2],
|
||||
],
|
||||
[
|
||||
a_raw[0][0] * b_raw[1][0] - a_raw[1][0] * b_raw[0][0],
|
||||
a_raw[0][1] * b_raw[1][1] - a_raw[1][1] * b_raw[0][1],
|
||||
a_raw[0][2] * b_raw[1][2] - a_raw[1][2] * b_raw[0][2],
|
||||
],
|
||||
];
|
||||
|
||||
out.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&TensorData::from(expected), Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Tensor, TensorData, Tolerance, loss};
|
||||
|
||||
#[test]
|
||||
fn test_cross_entropy_loss_grad() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
let data_targets = TensorData::from([[0.8, 0.2], [0.9, 0.1]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();
|
||||
let tensor_targets =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data_targets, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = loss::cross_entropy_with_logits(tensor_3, tensor_targets);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::permissive();
|
||||
let expected = TensorData::from([[0.26553, 0.26553], [0.44954, 0.44954]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[-1.34863, 1.34863], [-2.06371, 2.06371]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummax() {
|
||||
// Simple test to verify cummax gradients work
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 3.0, 2.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummax(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [1.0, 2.0, 0.0]
|
||||
let expected = TensorData::from([1.0, 2.0, 0.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummax_2d() {
|
||||
// Test 2D cummax gradients
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[1.0, 3.0, 2.0], [2.0, 5.0, 4.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummax(1);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]
|
||||
let expected = TensorData::from([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummax_duplicate_values() {
|
||||
// Test with duplicate maximum values - critical edge case
|
||||
let device = Default::default();
|
||||
let tensor =
|
||||
TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 3.0, 3.0, 2.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummax(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// input: [1.0, 3.0, 3.0, 2.0]
|
||||
// cummax: [1.0, 3.0, 3.0, 3.0]
|
||||
// PyTorch reference: [1.0, 1.0, 2.0, 0.0]
|
||||
// Position 2 gets grad from itself + position 3
|
||||
let expected = TensorData::from([1.0, 1.0, 2.0, 0.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummax_all_same() {
|
||||
// Test with all same values
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 2.0, 2.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummax(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [1.0, 1.0, 1.0]
|
||||
// Each position matches cummax, so each gets its own gradient
|
||||
let expected = TensorData::from([1.0, 1.0, 1.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummax_increasing() {
|
||||
// Test with increasing sequence
|
||||
let device = Default::default();
|
||||
let tensor =
|
||||
TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 2.0, 3.0, 4.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummax(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [1.0, 1.0, 1.0, 1.0]
|
||||
// Each position is a new maximum
|
||||
let expected = TensorData::from([1.0, 1.0, 1.0, 1.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummax_2d_duplicates() {
|
||||
// Test 2D with duplicate values
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[1.0, 3.0, 3.0, 2.0], [2.0, 5.0, 5.0, 4.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummax(1);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]
|
||||
let expected = TensorData::from([[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummin() {
|
||||
// Simple test to verify cummin gradients work
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([3.0, 2.0, 4.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummin(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [1.0, 2.0, 0.0]
|
||||
let expected = TensorData::from([1.0, 2.0, 0.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummin_2d() {
|
||||
// Test 2D cummin gradients
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[3.0, 2.0, 4.0], [5.0, 1.0, 3.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummin(1);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]
|
||||
let expected = TensorData::from([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummin_duplicate_values() {
|
||||
// Test with duplicate minimum values - critical edge case
|
||||
let device = Default::default();
|
||||
let tensor =
|
||||
TestAutodiffTensor::<1>::from_data(TensorData::from([3.0, 2.0, 2.0, 4.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummin(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// input: [3.0, 2.0, 2.0, 4.0]
|
||||
// cummin: [3.0, 2.0, 2.0, 2.0]
|
||||
// PyTorch reference: [1.0, 1.0, 2.0, 0.0]
|
||||
// Position 2 gets grad from itself + position 3
|
||||
let expected = TensorData::from([1.0, 1.0, 2.0, 0.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummin_all_same() {
|
||||
// Test with all same values
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 2.0, 2.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummin(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [1.0, 1.0, 1.0]
|
||||
// Each position matches cummin, so each gets its own gradient
|
||||
let expected = TensorData::from([1.0, 1.0, 1.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummin_decreasing() {
|
||||
// Test with decreasing sequence
|
||||
let device = Default::default();
|
||||
let tensor =
|
||||
TestAutodiffTensor::<1>::from_data(TensorData::from([5.0, 4.0, 3.0, 2.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummin(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [1.0, 1.0, 1.0, 1.0]
|
||||
// Each position is a new minimum
|
||||
let expected = TensorData::from([1.0, 1.0, 1.0, 1.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummin_2d_duplicates() {
|
||||
// Test 2D with duplicate values
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[3.0, 2.0, 2.0, 4.0], [5.0, 1.0, 1.0, 3.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummin(1);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]
|
||||
let expected = TensorData::from([[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_diff_cumprod() {
|
||||
// Simple test to verify cumprod gradients work
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 3.0, 4.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cumprod(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [16.0, 10.0, 6.0]
|
||||
let expected = TensorData::from([16.0, 10.0, 6.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cumprod_2d() {
|
||||
// Test 2D cumprod gradients
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cumprod(1);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [[9.0, 4.0, 2.0], [36.0, 28.0, 20.0]]
|
||||
let expected = TensorData::from([[9.0, 4.0, 2.0], [36.0, 28.0, 20.0]]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
// TODO: The following tests are currently ignored due to a known limitation
|
||||
// in the cumprod gradient implementation. The current implementation uses
|
||||
// division (grad / input), which produces NaN when the input contains zeros.
|
||||
//
|
||||
// A proper fix requires implementing a zero-safe algorithm using exclusive
|
||||
// cumulative products (similar to PyTorch's cumprod_backward or JAX's
|
||||
// associative_scan approach). This is a non-trivial implementation that
|
||||
// requires careful handling of cumulative products in both forward and
|
||||
// reverse directions.
|
||||
//
|
||||
// See: https://github.com/tracel-ai/burn/issues/3864
|
||||
//
|
||||
// References:
|
||||
// - PyTorch: https://github.com/pytorch/pytorch (cumprod_backward)
|
||||
// - JAX PR #2596: Parallel prefix scan implementation
|
||||
// - TensorFlow Issue #3862: tf.cumprod's gradient produces nans given zeros
|
||||
|
||||
#[test]
|
||||
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
|
||||
fn should_diff_cumprod_zero_in_middle() {
|
||||
// Test cumprod with zero in the middle - edge case for division
|
||||
let device = Default::default();
|
||||
let tensor =
|
||||
TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 0.0, 3.0, 4.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cumprod(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [1.0, 32.0, 0.0, 0.0]
|
||||
let expected = TensorData::from([1.0, 32.0, 0.0, 0.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
|
||||
fn should_diff_cumprod_zero_at_start() {
|
||||
// Test cumprod with zero at the beginning
|
||||
let device = Default::default();
|
||||
let tensor =
|
||||
TestAutodiffTensor::<1>::from_data(TensorData::from([0.0, 2.0, 3.0, 4.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cumprod(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [33.0, 0.0, 0.0, 0.0]
|
||||
let expected = TensorData::from([33.0, 0.0, 0.0, 0.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
|
||||
fn should_diff_cumprod_zero_at_end() {
|
||||
// Test cumprod with zero at the end
|
||||
let device = Default::default();
|
||||
let tensor =
|
||||
TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 3.0, 4.0, 0.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cumprod(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [16.0, 10.0, 6.0, 24.0]
|
||||
let expected = TensorData::from([16.0, 10.0, 6.0, 24.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
|
||||
fn should_diff_cumprod_multiple_zeros() {
|
||||
// Test cumprod with multiple zeros
|
||||
let device = Default::default();
|
||||
let tensor =
|
||||
TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 0.0, 3.0, 0.0, 5.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cumprod(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [1.0, 8.0, 0.0, 0.0, 0.0]
|
||||
let expected = TensorData::from([1.0, 8.0, 0.0, 0.0, 0.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_diff_cumsum_dim0() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.cumsum(0);
|
||||
let tensor_5 = tensor_1.clone().mul(tensor_4);
|
||||
let grads = tensor_5.sum().backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
// Expected gradients computed with PyTorch
|
||||
let expected = TensorData::from([[-14.0, 24.0], [17.0, 6.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[3.0, 10.0], [-1.0, 37.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cumsum_dim1() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.cumsum(1);
|
||||
let tensor_5 = tensor_1.clone().mul(tensor_4);
|
||||
let grads = tensor_5.sum().backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
// Expected gradients computed with PyTorch
|
||||
let expected = TensorData::from([[1.0, 69.0], [-13.0, -28.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[18.0, 13.0], [71.0, 58.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cumsum_complex() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.clone().cumsum(1);
|
||||
let tensor_5 = tensor_4.mul(tensor_3);
|
||||
|
||||
let grads = tensor_5.sum().backward();
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
// Expected gradients computed with PyTorch
|
||||
let expected = TensorData::from([[371.0, 542.0], [2246.0, 3281.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[507.0, 528.0], [704.0, 733.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,105 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_diff_div() {
|
||||
let data_1 = TensorData::from([1.0, 7.0]);
|
||||
let data_2 = TensorData::from([4.0, 7.0]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().div(tensor_2.clone());
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([0.25, 0.14285715]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([-0.0625, -0.14285715]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_div_scalar() {
|
||||
let data = TensorData::from([1.0, 7.0]);
|
||||
|
||||
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
|
||||
let tensor_out = tensor.clone().div_scalar(4.0);
|
||||
|
||||
let grads = tensor_out.backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
grad.to_data()
|
||||
.assert_eq(&TensorData::from([0.25, 0.25]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_div_complex_1() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_1.clone().div(tensor_2.clone());
|
||||
let tensor_5 = tensor_4.div(tensor_3.clone());
|
||||
|
||||
let grads = tensor_5.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
let grad_3 = tensor_3.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[0.1250, 0.07142857], [0.25, 0.16666667]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[-0.03125, -0.07142857], [-1.6250, 0.16666667]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
let expected = TensorData::from([[-0.0625, -0.25], [-1.6250, 0.25]]);
|
||||
grad_3
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_div_complex_2() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.div(tensor_2.clone());
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default().set_half_precision_absolute(2e-3);
|
||||
let expected = TensorData::from([[2.00, 2.92857146], [1.36666667, 2.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[0.08333334, 0.09591837], [-0.05555558, -0.06714284]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_diff_erf() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().erf());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[32.0, 32.0], [32.0, 32.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[8.0, 8.0], [8.0, 8.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_diff_exp() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().exp());
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default();
|
||||
let expected = TensorData::from([[54.5991, 27.4746], [54.5991, 27.4746]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[-5.4598e+01, -9.1188e-04], [2.9556e+01, 8.0342e+01]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_expand() {
|
||||
// Python code to generate the test case values
|
||||
// import torch
|
||||
// x1 = torch.tensor([4.0, 7.0, 2.0, 3.0], requires_grad=True)
|
||||
// x2 = torch.tensor([2.0, 4.5, 7.0, 3.0], requires_grad=True)
|
||||
// y = x1.expand(4, 4)
|
||||
// z = (x2 * y).sum()
|
||||
// z.backward()
|
||||
// print("x1", x1.grad)
|
||||
// print("x2", x2.grad)
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let data_1 = TensorData::from([4.0, 7.0, 2.0, 3.0]);
|
||||
let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad();
|
||||
|
||||
let data_2 = TensorData::from([2.0, 4.5, 7.0, 3.0]);
|
||||
let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().expand([4, 4]);
|
||||
|
||||
// Use unsqueeze to make tensor_2 have the same shape as tensor_3
|
||||
let tensor_4 = tensor_2.clone().unsqueeze().mul(tensor_3).sum();
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([8., 18., 28., 12.]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([16., 28., 8., 12.]), false);
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_diff_flip() {
|
||||
let data_1 = TensorData::from([[[1.0, 7.0], [2.0, 3.0]]]); // 1x2x2
|
||||
let data_2 = TensorData::from([[[3.0, 2.0, 7.0], [3.0, 3.2, 1.0]]]); // 1x2x3
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<3>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_2.clone().flip([1, 2]);
|
||||
let tensor_4 = tensor_1.clone().matmul(tensor_3);
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
|
||||
grad_1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), tolerance); // 1x2x2
|
||||
grad_2.into_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[[10.0, 10.0, 10.0], [3.0, 3.0, 3.0]]]),
|
||||
tolerance,
|
||||
); // 1x2x3
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_floor() {
|
||||
let data = TensorData::from([
|
||||
[-1.9751, 0.0714, 0.0643, 0.2406],
|
||||
[-1.3172, 0.1252, -0.1119, -0.0127],
|
||||
]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
|
||||
let tensor_2 = tensor_1.clone().floor();
|
||||
let grads = tensor_2.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_eq(
|
||||
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
use super::*;
|
||||
use burn_tensor::{IndexingUpdateOp, Int, Tensor, TensorData};
|
||||
|
||||
#[test]
|
||||
fn test_gather_grad() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(
|
||||
TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(
|
||||
TensorData::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
|
||||
let tensor_3 = tensor_1.clone().gather(1, indices);
|
||||
let tensor_4 = tensor_2.matmul(tensor_3);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_eq(
|
||||
&TensorData::from([[94., 150., 187.], [242., 305., 304.]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_grad() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(
|
||||
TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let values = TestAutodiffTensor::from_data(
|
||||
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(
|
||||
TensorData::from([[2, 1, 0], [2, 0, 1]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
|
||||
let tensor_3 = tensor_1
|
||||
.clone()
|
||||
.scatter(1, indices, values.clone(), IndexingUpdateOp::Add);
|
||||
let tensor_4 = tensor_2.matmul(tensor_3);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = values.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_eq(
|
||||
&TensorData::from([[127., 181., 235.], [226., 316., 406.]]),
|
||||
false,
|
||||
);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[19., 19., 19.], [64., 64., 64.]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_add_grad_partial_indices() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::from_data(TensorData::from([[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]]), &device)
|
||||
.require_grad();
|
||||
let tensor_2 =
|
||||
TestAutodiffTensor::from_data(TensorData::from([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]), &device)
|
||||
.require_grad();
|
||||
let values =
|
||||
TestAutodiffTensor::from_data(TensorData::from([[4.0, 5.0, 6.0]]), &device).require_grad();
|
||||
let indices =
|
||||
Tensor::<TestAutodiffBackend, 2, Int>::from_data(TensorData::from([[2, 1, 0]]), &device);
|
||||
|
||||
let tensor_3 = tensor_1.clone().mul(tensor_2);
|
||||
let tensor_4 = tensor_3
|
||||
.clone()
|
||||
.scatter(1, indices, values.clone(), IndexingUpdateOp::Add);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = values.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[1., 2., 3., 4., 5., 6.]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[1., 1., 1.]]), false);
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance, activation};
|
||||
|
||||
#[test]
|
||||
fn should_diff_gelu() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::<2>::from_floats([[0.0, 1.0], [-3.0, 4.0]], &device).require_grad();
|
||||
let tensor_2 =
|
||||
TestAutodiffTensor::from_floats([[6.0, -0.5], [9.0, 10.0]], &device).require_grad();
|
||||
|
||||
let x = tensor_1.clone().matmul(activation::gelu(tensor_2.clone()));
|
||||
let x = tensor_1.clone().matmul(x);
|
||||
let grads = x.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::permissive();
|
||||
let expected = TensorData::from([[1.46281, 1.46281], [48.22866, 153.46280]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[-15.0000, -1.98757], [17.0000, 17.0000]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Distribution, activation};
|
||||
|
||||
#[test]
|
||||
fn should_update_tensor_when_grad_replace() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::<2>::random([32, 32], Distribution::Default, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::random([32, 32], Distribution::Default, &device);
|
||||
|
||||
let x = tensor_1.clone().matmul(activation::gelu(tensor_2));
|
||||
let mut grads = x.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
let grad_1_updated =
|
||||
TestAutodiffTensor::random([32, 32], Distribution::Default, &device).require_grad();
|
||||
tensor_1.grad_replace(&mut grads, grad_1_updated.clone().inner());
|
||||
|
||||
let grad_1_new = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
assert_ne!(grad_1_new.to_data(), grad_1.into_data());
|
||||
assert_eq!(grad_1_new.into_data(), grad_1_updated.into_data());
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_diff_log() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
|
||||
let expected = TensorData::from([[60.2652, 72.3130], [60.2652, 72.3130]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[22.8614, 24.5043], [24.5729, 26.8507]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_diff_log1p() {
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from([[0.0, 1.0], [3.0, 4.0]]).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from([[6.0, 7.0], [9.0, 10.0]]).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log1p());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
|
||||
let expected = TensorData::from([[64.80622101, 75.49362183], [64.80622101, 75.49362183]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[22.92208481, 24.47565651], [24.72780228, 26.86416626]]);
|
||||
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn should_diff_log_sigmoid() {
|
||||
let data = TensorData::from([[0.8762, -0.1423], [-300., 200.]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
|
||||
let tensor_2 = activation::log_sigmoid(tensor_1.clone());
|
||||
let grads = tensor_2.backward();
|
||||
|
||||
let grad = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[0.293966, 0.535515], [1.000000, 0.000000]]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{Bool, Tensor, TensorData};
|
||||
|
||||
#[test]
|
||||
fn should_diff_mask_fill() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let mask = TensorData::from([[true, false], [false, true]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let mask = Tensor::<TestAutodiffBackend, 2, Bool>::from_bool(mask, &device);
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.mask_fill(mask, 2.0);
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[7.0, 3.0], [4.0, 2.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[2.0, 1.0], [3.0, 7.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_mask_where() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::from_data([[1.0, 7.0], [2.0, 3.0]], &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data([[4.0, 7.0], [2.0, 3.0]], &device).require_grad();
|
||||
let tensor_3 =
|
||||
TestAutodiffTensor::from_data([[8.8, 9.8], [10.8, 11.8]], &device).require_grad();
|
||||
let mask =
|
||||
Tensor::<TestAutodiffBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
|
||||
|
||||
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_5 = tensor_4.clone().matmul(tensor_3.clone());
|
||||
let tensor_6 = tensor_5.mask_where(mask, tensor_3.clone());
|
||||
let grads = tensor_6.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
let grad_3 = tensor_3.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
|
||||
let expected = TensorData::from([[121.8, 55.0], [110.8, 50.0]]);
|
||||
grad_1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[27.4, 33.4], [95.0, 115.0]]);
|
||||
grad_2
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[15., 18.], [23., 29.]]);
|
||||
grad_3
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_matmul() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[3.0, 3.0], [10.0, 10.0]]), false);
|
||||
tensor_3
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[18.0, 28.0], [14.0, 23.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_complex_1() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_5 = tensor_4.matmul(tensor_3);
|
||||
|
||||
let grads = tensor_5.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[44.0, 20.0], [44.0, 20.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[56.0, 56.0], [16.0, 16.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_complex_2() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_5 = tensor_4.matmul(tensor_3.clone());
|
||||
let tensor_6 = tensor_1.clone().matmul(tensor_5);
|
||||
|
||||
let grads = tensor_6.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[800.0, 792.0], [360.0, 592.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[264., 264.0], [344.0, 344.0]]), false);
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_diff_max_dim() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();
|
||||
let tensor_2 =
|
||||
TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_1.clone().mul(tensor_3.max_dim(1).unsqueeze());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[50.0, 34.0], [40.0, -10.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[8.0, 10.0], [56.0, 15.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_min_dim() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();
|
||||
let tensor_2 =
|
||||
TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_1.clone().mul(tensor_3.min_dim(1).unsqueeze());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[-42.0, 38.0], [-34.0, -24.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[10.0, 8.0], [15.0, 56.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_min_dim_3d_dim1() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::<3>::from_floats([[[1.0, 7.0], [-2.0, -3.0]]], &device).require_grad();
|
||||
let tensor_2 =
|
||||
TestAutodiffTensor::<3>::from_floats([[[4., -7.], [2., 3.]]], &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().mul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.min_dim(1);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[[0., -7.], [2., 0.]]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[[0., 7.], [-2., -0.]]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::max_pool1d;
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_simple() {
|
||||
let kernel_size = 4;
|
||||
let padding = 0;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_floats(
|
||||
[[[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_grad_expected =
|
||||
TestAutodiffTensor::<3>::from_floats([[[1., 1., 0., 0., 0., 1.]]], &device);
|
||||
|
||||
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_with_dilation() {
|
||||
let kernel_size = 4;
|
||||
let padding = 0;
|
||||
let stride = 1;
|
||||
let dilation = 2;
|
||||
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_floats(
|
||||
[[[
|
||||
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
|
||||
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
|
||||
0.4610, 0.5365, 0.6880,
|
||||
]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_grad_expected = TestAutodiffTensor::<3>::from_floats(
|
||||
[[[
|
||||
0., 0., 1., 0., 0., 3., 0., 1., 2., 1., 0., 0., 2., 0., 0., 0., 4., 4., 0., 0., 0., 0.,
|
||||
0., 0., 1.,
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_complex() {
|
||||
let kernel_size = 4;
|
||||
let padding = 0;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_floats(
|
||||
[[[
|
||||
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
|
||||
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
|
||||
0.4610, 0.5365, 0.6880,
|
||||
]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_grad_expected = TestAutodiffTensor::<3>::from_floats(
|
||||
[[[
|
||||
0., 0., 0., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0.,
|
||||
1., 1., 1.,
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_complex_with_padding() {
|
||||
let kernel_size = 4;
|
||||
let padding = 2;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_floats(
|
||||
[[[
|
||||
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
|
||||
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
|
||||
0.4610, 0.5365, 0.6880,
|
||||
]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_grad_expected = TestAutodiffTensor::<3>::from_floats(
|
||||
[[[
|
||||
1., 0., 1., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0.,
|
||||
1., 1., 3.,
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::max_pool2d;
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_simple_1() {
|
||||
let kernel_size_1 = 3;
|
||||
let kernel_size_2 = 3;
|
||||
let padding_1 = 0;
|
||||
let padding_2 = 0;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_floats(
|
||||
[[[
|
||||
[0.2479, 0.6386, 0.3166, 0.5742],
|
||||
[0.7065, 0.1940, 0.6305, 0.8959],
|
||||
[0.5416, 0.8602, 0.8129, 0.1662],
|
||||
[0.3358, 0.3059, 0.8293, 0.0990],
|
||||
]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
|
||||
[[[
|
||||
[0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 2.0],
|
||||
[0.0, 2.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0],
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = max_pool2d(
|
||||
x.clone(),
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
false,
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_simple_2() {
|
||||
let kernel_size_1 = 2;
|
||||
let kernel_size_2 = 2;
|
||||
let padding_1 = 1;
|
||||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_floats(
|
||||
[[[
|
||||
[0.2479, 0.6386, 0.3166, 0.5742],
|
||||
[0.7065, 0.1940, 0.6305, 0.8959],
|
||||
[0.5416, 0.8602, 0.8129, 0.1662],
|
||||
[0.3358, 0.3059, 0.8293, 0.0990],
|
||||
]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
|
||||
[[[
|
||||
[1., 3., 0., 2.],
|
||||
[3., 0., 0., 4.],
|
||||
[1., 4., 0., 1.],
|
||||
[2., 0., 3., 1.],
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = max_pool2d(
|
||||
x.clone(),
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
false,
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_with_dilation() {
|
||||
let kernel_size_1 = 2;
|
||||
let kernel_size_2 = 2;
|
||||
let padding_1 = 1;
|
||||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
let dilation_1 = 2;
|
||||
let dilation_2 = 2;
|
||||
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_floats(
|
||||
[[[
|
||||
[0.2479, 0.6386, 0.3166, 0.5742],
|
||||
[0.7065, 0.1940, 0.6305, 0.8959],
|
||||
[0.5416, 0.8602, 0.8129, 0.1662],
|
||||
[0.3358, 0.3059, 0.8293, 0.0990],
|
||||
]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
|
||||
[[[
|
||||
[0., 0., 0., 0.],
|
||||
[1., 1., 1., 2.],
|
||||
[0., 4., 4., 0.],
|
||||
[0., 1., 2., 0.],
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = max_pool2d(
|
||||
x.clone(),
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
false,
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_complex() {
|
||||
let kernel_size_1 = 4;
|
||||
let kernel_size_2 = 2;
|
||||
let padding_1 = 2;
|
||||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 2;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_floats(
|
||||
[[[
|
||||
[0.5388, 0.0676, 0.7122, 0.8316, 0.0653],
|
||||
[0.9154, 0.1536, 0.9089, 0.8016, 0.7518],
|
||||
[0.2073, 0.0501, 0.8811, 0.5604, 0.5075],
|
||||
[0.4384, 0.9963, 0.9698, 0.4988, 0.2609],
|
||||
[0.3391, 0.2230, 0.4610, 0.5365, 0.6880],
|
||||
]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
|
||||
[[[
|
||||
[0., 0., 0., 3., 0.],
|
||||
[4., 0., 2., 1., 0.],
|
||||
[0., 0., 0., 0., 0.],
|
||||
[2., 4., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 2.],
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = max_pool2d(
|
||||
x.clone(),
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
false,
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_ceil_mode() {
|
||||
// Test ceil_mode=true with gradient computation
|
||||
// Using 1x1x6x6 input with kernel 3x3, stride 2x2, padding 0
|
||||
// Floor mode: output 2x2
|
||||
// Ceil mode: output 3x3
|
||||
let kernel_size_1 = 3;
|
||||
let kernel_size_2 = 3;
|
||||
let padding_1 = 0;
|
||||
let padding_2 = 0;
|
||||
let stride_1 = 2;
|
||||
let stride_2 = 2;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let device = Default::default();
|
||||
// Input (values 1-36):
|
||||
let x = TestAutodiffTensor::from_floats(
|
||||
[[[
|
||||
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
|
||||
[7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
|
||||
[13.0, 14.0, 15.0, 16.0, 17.0, 18.0],
|
||||
[19.0, 20.0, 21.0, 22.0, 23.0, 24.0],
|
||||
[25.0, 26.0, 27.0, 28.0, 29.0, 30.0],
|
||||
[31.0, 32.0, 33.0, 34.0, 35.0, 36.0],
|
||||
]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
// Expected gradients for ceil_mode output 3x3:
|
||||
// Output positions and their max value positions:
|
||||
// (0,0): max at (2,2)=15 -> grad[2,2] += 1
|
||||
// (0,1): max at (2,4)=17 -> grad[2,4] += 1
|
||||
// (0,2): max at (2,5)=18 -> grad[2,5] += 1
|
||||
// (1,0): max at (4,2)=27 -> grad[4,2] += 1
|
||||
// (1,1): max at (4,4)=29 -> grad[4,4] += 1
|
||||
// (1,2): max at (4,5)=30 -> grad[4,5] += 1
|
||||
// (2,0): max at (5,2)=33 -> grad[5,2] += 1
|
||||
// (2,1): max at (5,4)=35 -> grad[5,4] += 1
|
||||
// (2,2): max at (5,5)=36 -> grad[5,5] += 1
|
||||
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
|
||||
[[[
|
||||
[0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 1., 0., 1., 1.],
|
||||
[0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 1., 0., 1., 1.],
|
||||
[0., 0., 1., 0., 1., 1.],
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = max_pool2d(
|
||||
x.clone(),
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
true,
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,290 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Tensor, TensorData};
|
||||
|
||||
#[test]
|
||||
fn test_mm_independent_trees() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
// First tree
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_0 * tensor_1;
|
||||
let tensor_5 = tensor_2 * tensor_3;
|
||||
let tensor_6 = tensor_4 * tensor_5;
|
||||
|
||||
// Second tree
|
||||
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_11 = tensor_7.clone() * tensor_8.clone();
|
||||
let tensor_12 = tensor_9.clone() * tensor_10.clone();
|
||||
let tensor_13 = tensor_11 * tensor_12;
|
||||
|
||||
let _grads = tensor_6.backward();
|
||||
let grads = tensor_13.backward();
|
||||
|
||||
assert!(tensor_7.grad(&grads).is_some());
|
||||
assert!(tensor_8.grad(&grads).is_some());
|
||||
assert!(tensor_9.grad(&grads).is_some());
|
||||
assert!(tensor_10.grad(&grads).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mm_crossover_trees_root_unavailable() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
// First tree
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_0 * tensor_1;
|
||||
let tensor_5 = tensor_2 * tensor_3;
|
||||
let tensor_6 = tensor_4.clone() * tensor_5;
|
||||
|
||||
// Second tree
|
||||
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_9 = tensor_7.clone() * tensor_8.clone();
|
||||
let tensor_10 = tensor_4 * tensor_9;
|
||||
|
||||
let _grads = tensor_6.backward();
|
||||
let _grads = tensor_10.backward();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mm_crossover_trees_with_referred_subtree() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
// First tree
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_0 * tensor_1;
|
||||
let tensor_5 = tensor_2 * tensor_3;
|
||||
let tensor_6 = tensor_4.clone() * tensor_5;
|
||||
|
||||
// Second tree
|
||||
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_9 = tensor_7.clone() * tensor_8.clone();
|
||||
let _tensor_10 = tensor_4 * tensor_9.clone();
|
||||
|
||||
let _grads = tensor_6.backward();
|
||||
let _grads = tensor_9.backward();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mm_three_crossover_trees_last_still_usable() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
// First tree
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_0 * tensor_1;
|
||||
let tensor_5 = tensor_2 * tensor_3;
|
||||
let tensor_6 = tensor_4 * tensor_5.clone();
|
||||
|
||||
// Third tree
|
||||
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_11 = tensor_7 * tensor_8;
|
||||
let tensor_12 = tensor_9 * tensor_10;
|
||||
let tensor_13 = tensor_11 * tensor_12.clone();
|
||||
|
||||
// Second tree (in between)
|
||||
let _tensor_14 = tensor_5 * tensor_12;
|
||||
|
||||
let _grads = tensor_6.backward();
|
||||
let _grads = tensor_13.backward();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mm_three_crossover_trees_middle_one_unavailable() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
// First tree
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_0 * tensor_1;
|
||||
let tensor_5 = tensor_2 * tensor_3;
|
||||
let tensor_6 = tensor_4 * tensor_5.clone();
|
||||
|
||||
// Third tree
|
||||
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_11 = tensor_7 * tensor_8;
|
||||
let tensor_12 = tensor_9 * tensor_10;
|
||||
let _tensor_13 = tensor_11 * tensor_12.clone();
|
||||
|
||||
// Second tree (in between)
|
||||
let tensor_14 = tensor_5 * tensor_12;
|
||||
|
||||
let _grads = tensor_6.backward();
|
||||
let _grads = tensor_14.backward();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mm_self_referencing_tree() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
// First tree
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_0 * tensor_1;
|
||||
let tensor_5 = tensor_2 * tensor_3.clone();
|
||||
let tensor_6 = tensor_3 * tensor_5;
|
||||
|
||||
let _grads = tensor_6.backward();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mm_with_non_impacting_detach() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_3 = Tensor::<TestAutodiffBackend, 2>::from_data(data, &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_1.clone() * tensor_2.clone();
|
||||
let tensor_5 = tensor_4.detach() * tensor_3.clone();
|
||||
|
||||
let grads = tensor_5.backward();
|
||||
assert!(tensor_3.grad(&grads).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mm_with_missing_require_grad_after_cleanup() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
let tensor_1 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device);
|
||||
let tensor_3 = Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device);
|
||||
|
||||
let tensor_4 = tensor_1.clone() * tensor_2.clone();
|
||||
let tensor_5 = tensor_4 * tensor_3.clone();
|
||||
|
||||
// Trivial backward, just to trigger cleanup
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data, &device)
|
||||
.require_grad()
|
||||
.backward();
|
||||
|
||||
let grads = tensor_5.backward();
|
||||
assert!(tensor_1.grad(&grads).is_some());
|
||||
assert!(tensor_2.grad(&grads).is_none());
|
||||
assert!(tensor_3.grad(&grads).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mm_with_detach_after_cleanup() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
let tensor_1 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_3 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_1.clone() * tensor_2.clone();
|
||||
let tensor_5 = tensor_4 * tensor_3.clone().detach();
|
||||
|
||||
// Trivial backward, just to trigger cleanup
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data, &device)
|
||||
.require_grad()
|
||||
.backward();
|
||||
|
||||
let grads = tensor_5.backward();
|
||||
assert!(tensor_1.grad(&grads).is_some());
|
||||
assert!(tensor_2.grad(&grads).is_some());
|
||||
assert!(tensor_3.grad(&grads).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mm_deletables_propagate_well() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
let tensor_0 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_1 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_2 = tensor_0 * tensor_1;
|
||||
let tensor_3 = tensor_2.clone().exp();
|
||||
let _tensor_4 = tensor_3.clone().log();
|
||||
|
||||
let _grads = tensor_2.backward();
|
||||
|
||||
// We are testing that after backward on tensor_2, not only the leaf tensor_4 is deleted, but
|
||||
// the intermediate tensor_3 as well
|
||||
let _grads = tensor_3.backward();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mm_node_explored_once_can_still_be_tagged_as_useful_when_found_again_deeper() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
// The test has 50% chance of starting with leaf tensor_8 instead of tensor_4, which is not informative
|
||||
// By repeating it many times it becomes almost impossible that it passes if it shouldn't
|
||||
for _ in 0..12 {
|
||||
let tensor_0 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_1 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_2 = tensor_1.clone().exp();
|
||||
let tensor_3 = tensor_0.exp();
|
||||
let _tensor_4 = tensor_3.clone() * tensor_2.clone();
|
||||
let tensor_5 = tensor_2.exp();
|
||||
let tensor_6 = tensor_5.exp();
|
||||
let tensor_7 = tensor_6.exp();
|
||||
let tensor_8 = tensor_7.exp();
|
||||
|
||||
// tensor_2 should be tagged unknown through the leaf tensor_4, then useful through the leaf tensor_8
|
||||
// which should happen after because tensor_2 is deeper from tensor_8 point of view and we're in breadth first search
|
||||
tensor_3.backward();
|
||||
let grads = tensor_8.backward();
|
||||
|
||||
assert!(tensor_1.grad(&grads).is_some());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
#[allow(unused_imports)] // required for re-included modules
|
||||
pub use super::*;
|
||||
|
||||
mod abs;
|
||||
mod adaptive_avgpool1d;
|
||||
mod adaptive_avgpool2d;
|
||||
mod add;
|
||||
mod aggregation;
|
||||
mod avgpool1d;
|
||||
mod avgpool2d;
|
||||
mod backward;
|
||||
mod bridge;
|
||||
mod broadcast;
|
||||
mod cast;
|
||||
mod cat;
|
||||
mod ceil;
|
||||
mod checkpoint;
|
||||
mod complex;
|
||||
mod conv1d;
|
||||
mod conv2d;
|
||||
mod conv3d;
|
||||
mod conv_transpose1d;
|
||||
mod conv_transpose2d;
|
||||
mod conv_transpose3d;
|
||||
mod cross;
|
||||
mod cross_entropy;
|
||||
mod cummax;
|
||||
mod cummin;
|
||||
mod cumprod;
|
||||
mod cumsum;
|
||||
mod deform_conv2d;
|
||||
mod div;
|
||||
mod erf;
|
||||
mod exp;
|
||||
mod expand;
|
||||
mod flip;
|
||||
mod floor;
|
||||
mod gather_scatter;
|
||||
mod gelu;
|
||||
mod gradients;
|
||||
mod log;
|
||||
mod log1p;
|
||||
mod log_sigmoid;
|
||||
mod mask;
|
||||
mod matmul;
|
||||
mod maxmin;
|
||||
mod maxpool1d;
|
||||
mod maxpool2d;
|
||||
mod memory_management;
|
||||
mod mul;
|
||||
mod multithread;
|
||||
mod nearest_interpolate;
|
||||
mod neg;
|
||||
mod nonzero;
|
||||
mod permute;
|
||||
mod pow;
|
||||
mod recip;
|
||||
mod relu;
|
||||
mod remainder;
|
||||
mod repeat_dim;
|
||||
mod reshape;
|
||||
mod round;
|
||||
mod select;
|
||||
mod sigmoid;
|
||||
mod sign;
|
||||
mod slice;
|
||||
mod slice_assign;
|
||||
mod softmax;
|
||||
mod sort;
|
||||
mod sqrt;
|
||||
mod sub;
|
||||
mod transpose;
|
||||
mod trig;
|
||||
mod unfold;
|
||||
@@ -0,0 +1,68 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_mul() {
|
||||
let data_1 = TensorData::from([1.0, 7.0]);
|
||||
let data_2 = TensorData::from([4.0, 7.0]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1.clone(), &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().mul(tensor_2.clone());
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let _grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_eq(&data_2, false);
|
||||
tensor_3
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([4.0, 49.0]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_mul_scalar() {
|
||||
let data = TensorData::from([2.0, 5.0]);
|
||||
|
||||
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
|
||||
let tensor_out = tensor.clone().mul_scalar(4.0);
|
||||
|
||||
let grads = tensor_out.backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
tensor_out
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([8.0, 20.0]), false);
|
||||
grad.to_data()
|
||||
.assert_eq(&TensorData::from([4.0, 4.0]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mul_complex_1() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_1.clone().mul(tensor_2.clone());
|
||||
let tensor_5 = tensor_4.mul(tensor_3);
|
||||
let tensor_6 = tensor_1.clone().mul(tensor_5);
|
||||
|
||||
let grads = tensor_6.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[16.0, 196.0], [104.0, -36.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[2.0, 98.0], [338.0, 18.0]]), false);
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_behave_the_same_with_multithread() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
|
||||
let with_move = || {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1.clone(), &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.clone().matmul(tensor_2.clone());
|
||||
let tensor_5 = tensor_4.matmul(tensor_3);
|
||||
|
||||
// Task 1
|
||||
let tensor_1_cloned = tensor_1.clone();
|
||||
let tensor_2_cloned = tensor_2.clone();
|
||||
let tensor_5_cloned = tensor_5.clone();
|
||||
|
||||
let first_call = move || {
|
||||
let tensor_6_1 = tensor_5_cloned.matmul(tensor_2_cloned);
|
||||
tensor_6_1.matmul(tensor_1_cloned)
|
||||
};
|
||||
|
||||
// Task 2
|
||||
let tensor_1_cloned = tensor_1.clone();
|
||||
let tensor_2_cloned = tensor_2.clone();
|
||||
let tensor_5_cloned = tensor_5;
|
||||
|
||||
let second_call = move || {
|
||||
let tensor_6_2 = tensor_5_cloned.matmul(tensor_1_cloned);
|
||||
tensor_6_2.matmul(tensor_2_cloned)
|
||||
};
|
||||
|
||||
let tensor_7_1_handle = std::thread::spawn(first_call);
|
||||
let tensor_7_2_handle = std::thread::spawn(second_call);
|
||||
|
||||
let tensor_7_1 = tensor_7_1_handle.join().unwrap();
|
||||
let tensor_7_2 = tensor_7_2_handle.join().unwrap();
|
||||
let tensor_8 = tensor_7_1.matmul(tensor_7_2);
|
||||
|
||||
let grads = tensor_8.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
(grad_1, grad_2)
|
||||
};
|
||||
let without_move = || {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1.clone(), &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.clone().matmul(tensor_2.clone());
|
||||
let tensor_5 = tensor_4.matmul(tensor_3);
|
||||
|
||||
// Task 1
|
||||
let tensor_6_1 = tensor_5.clone().matmul(tensor_2.clone());
|
||||
let tensor_7_1 = tensor_6_1.matmul(tensor_1.clone());
|
||||
|
||||
// Task 2
|
||||
let tensor_6_2 = tensor_5.matmul(tensor_1.clone());
|
||||
let tensor_7_2 = tensor_6_2.matmul(tensor_2.clone());
|
||||
|
||||
let tensor_8 = tensor_7_1.matmul(tensor_7_2);
|
||||
|
||||
let grads = tensor_8.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
(grad_1, grad_2)
|
||||
};
|
||||
|
||||
let (grad_1, grad_2) = without_move();
|
||||
let (grad_1_moved, grad_2_moved) = with_move();
|
||||
|
||||
grad_1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&grad_1_moved.into_data(), Tolerance::default());
|
||||
grad_2
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&grad_2_moved.into_data(), Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
use super::*;
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::interpolate;
|
||||
use burn_tensor::ops::{InterpolateMode, InterpolateOptions};
|
||||
|
||||
#[test]
|
||||
fn test_upsample_interpolation() {
|
||||
let test = InterpolateTestCase {
|
||||
batch_size: 2,
|
||||
channels: 1,
|
||||
height: 7,
|
||||
width: 5,
|
||||
height_out: 8,
|
||||
width_out: 7,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([
|
||||
[[
|
||||
[4., 2., 4., 2., 2.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
]],
|
||||
[[
|
||||
[4., 2., 4., 2., 2.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
]],
|
||||
]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_downsample_interpolation() {
|
||||
let test = InterpolateTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
height: 8,
|
||||
width: 8,
|
||||
height_out: 4,
|
||||
width_out: 6,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[[
|
||||
[1., 1., 1., 0., 1., 1., 1., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 1., 1., 1., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 1., 1., 1., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 1., 1., 1., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0., 0.],
|
||||
]]]));
|
||||
}
|
||||
|
||||
struct InterpolateTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
height_out: usize,
|
||||
width_out: usize,
|
||||
}
|
||||
|
||||
impl InterpolateTestCase {
|
||||
fn assert_output(self, x_grad: TestTensor<4>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &x_grad.device())
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let output = interpolate(
|
||||
x.clone(),
|
||||
[self.height_out, self.width_out],
|
||||
InterpolateOptions::new(InterpolateMode::Nearest),
|
||||
);
|
||||
|
||||
let grads = output.backward();
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
|
||||
x_grad
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.into_data(), Tolerance::permissive());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_neg() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().neg());
|
||||
let tensor_4 = tensor_3.neg();
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[3.0, 3.0], [10.0, 10.0]]), false);
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Bool, Tensor, TensorData};
|
||||
|
||||
#[test]
|
||||
fn should_diff_nonzero() {
|
||||
let data_1 = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([-1.0, 1.0]);
|
||||
let mask = TensorData::from([[false, true], [true, false]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad();
|
||||
|
||||
// Multi-dimensional tensor indexing isn't really supported yet so the easiest way to do
|
||||
// this is to flatten the mask and tensor to get proper indexing. Anyway the returned tensor would
|
||||
// have dimensions different from the input, so this is somewhat equivalent.
|
||||
let mask = Tensor::<TestAutodiffBackend, 2, Bool>::from_bool(mask, &device).flatten::<1>(0, 1);
|
||||
let indices = mask.nonzero();
|
||||
let tensor_3 = tensor_1
|
||||
.clone()
|
||||
.flatten::<1>(0, 1)
|
||||
.select(0, indices[0].clone());
|
||||
|
||||
// Vector dot product not supported (only 2D matmuls) so unsqueeze for test purposes
|
||||
let tensor_4 = tensor_2
|
||||
.clone()
|
||||
.unsqueeze_dim::<2>(0)
|
||||
.matmul(tensor_3.unsqueeze_dim(1));
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[0.0, -1.0], [1.0, 0.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([2.0, 3.0]), false);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user