feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,10 @@
[alias]
test-cpu = "test --release --no-default-features --features cpu,std"
test-cuda = "test --release --no-default-features --features cuda,std"
test-ndarray = "test --release --no-default-features --features ndarray,std"
test-rocm = "test --release --no-default-features --features rocm,std"
test-router = "test --release --no-default-features --features router,std"
test-tch = "test --release --no-default-features --features tch,std"
test-wgpu = "test --release --no-default-features --features wgpu,std"
test-vulkan = "test --release --no-default-features --features vulkan,std"
test-metal = "test --release --no-default-features --features metal,std"

View File

@@ -0,0 +1,120 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science", "no-std", "embedded", "wasm"]
description = "Tensor tests for Burn backends"
documentation = "https://docs.rs/burn-backend-tests"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
license.workspace = true
name = "burn-backend-tests"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-backend-tests"
version.workspace = true
[lints]
workspace = true
[features]
default = [
"burn-tensor/default",
"burn-autodiff/default",
# Backends (default not enabled for CubeCL backends as it activates fusion)
"burn-cpu?/default",
"burn-ndarray?/default",
"burn-tch?/default",
# Default
"ndarray",
"std",
]
std = [
"burn-tensor/std",
"burn-autodiff/std",
# Backends
"burn-cpu?/std",
"burn-ndarray?/std",
"burn-wgpu?/std",
"burn-router?/std",
"burn-cuda?/std",
"burn-rocm?/std",
]
tracing = [
"cubecl?/tracing",
"burn-tensor/tracing",
"burn-autodiff/tracing",
# Backends
"burn-cpu?/tracing",
"burn-ndarray?/tracing",
"burn-wgpu?/tracing",
"burn-router?/tracing",
"burn-cuda?/tracing",
"burn-rocm?/tracing",
]
# Backends
cuda = ["burn-cuda", "quantization", "cube"]
rocm = ["burn-rocm", "quantization", "cube"]
ndarray = ["burn-ndarray", "quantization"]
tch = ["burn-tch"]
vulkan = ["wgpu", "burn-wgpu/vulkan"]
webgpu = ["wgpu", "burn-wgpu/webgpu"]
metal = ["wgpu", "burn-wgpu/metal"]
wgpu = ["burn-wgpu", "quantization", "cube"]
cpu = ["burn-cpu", "cube"]
router = ["burn-router", "ndarray", "burn-wgpu"]
autotune = [
"burn-wgpu?/autotune",
"burn-cuda?/autotune",
"burn-rocm?/autotune",
"burn-cpu?/autotune",
]
autotune-checks = [
"burn-wgpu?/autotune-checks",
"burn-cuda?/autotune-checks",
"burn-rocm?/autotune-checks",
"burn-cpu?/autotune-checks",
]
# CubeCL backends
cube = [
"cubecl",
"cubek",
"autotune",
"burn-fusion",
"burn-cubecl",
"burn-ndarray",
]
# Test configs
quantization = []
[dependencies]
burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", default-features = false }
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "=0.21.0-pre.2" }
# Backends
burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2", default-features = false, features = [
"export_tests",
] }
burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-cpu = { path = "../burn-cpu", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-rocm = { path = "../burn-rocm", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2", optional = true, default-features = false, features = [
"export_tests",
] }
burn-router = { path = "../burn-router", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true, default-features = false }
# To wrap `Fusion<CubeBackend>
burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true }
burn-cubecl = { path = "../burn-cubecl", version = "=0.21.0-pre.2", optional = true, features = [
"fusion",
] }
num-traits = { workspace = true }
serial_test = { workspace = true }
cubecl = { workspace = true, optional = true }
cubek = { workspace = true, features = ["random"], optional = true }

View File

@@ -0,0 +1,111 @@
# Burn Backend Tests
This crate provides a comprehensive suite of tests for Burn backends, covering:
- Tensor operations: [tests/tensor/](./tests/tensor/)
- Autodiff: [tests/autodiff/](./tests/autodiff/)
- (Optional) CubeCL kernels correctness: [tests/cubecl/](./tests/cubecl/)
## Running Tests
The `TestBackend` is selected via feature flags. Use the provided shorthand commands for
convenience:
```sh
# Cpu
cargo test-cpu
# Cuda
cargo test-cuda
# Rocm
cargo test-rocm
# Wgpu / WebGpu
cargo test-wgpu
# Vulkan
cargo test-vulkan
# Metal
cargo test-metal
# Router
cargo test-router
# NdArray
cargo test-ndarray
# LibTorch
cargo test-tch
```
By default, `cargo test` fail-fast across integration test binaries. When one integration test
binary fails, Cargo does not run the remaining test binaries. If you want to run all test binaries
regardless of failures, pass `--no-fail-fast`, for example:
```sh
cargo test-cuda --no-fail-fast
```
## Structure
- `tests/tensor.rs`: Tensor tests
- `tests/autodiff.rs`: Autodiff tests
- `tests/fusion.rs`: Fusion backend tests wrapping tensor and autodiff tests
- `tests/cubecl.rs`: CubeCL kernel tests
Each test module assumes exactly one `FloatElemType`, `IntElemType`, and `TestBackend` in scope.
### Common Modules
- `common/backend.rs`: Backend type definitions
- `common/tensor.rs`: Reusable tensor test suite, split across float, int and bool tensor kinds
- `common/autodiff.rs`: Reusable autodiff test suite, with and without checkpointing
### Test Reusability
This crate uses a pattern of parameterized test modules to run the same tests with different
configurations (backends, dtypes, etc.):
1. **Type aliases define the configuration**: Each test scope declares `FloatElemType`,
`IntElemType`, and `TestBackend`
1. **`#[path = "..."]` references shared modules**: Points to test files outside the normal module
hierarchy, e.g. `"common/tensor.rs"`
1. **`include!()` imports test code**: Test modules are included multiple times with different type
configurations
1. **`use super::*;`** propagates types down the module tree: Each level re-exports parent types so
deeply nested tests have access to the configured types
For example, `common/tensor.rs` can be included with `FloatElemType = f32` for base tests, then
included again with `FloatElemType = f16` for half-precision tests, running the same test suite
twice with different dtypes.
## Adding New Tests
Add test modules under `tests/tensor/`, `tests/autodiff/`, or `tests/cubecl` respectively. They will
automatically run for all required configurations.
For tensor tests, make sure to add the test to each relevant tensor kind:
- `tensor/bool`: boolean tensor tests
- `tensor/float`: float tensor tests
- `tensor/int`: integer tensor tests
**Guidelines:**
Import types with `use super::*;` at the top of each module and use the types defined in
`common/backend.rs`:
```rust
/// Collection of types used across tests
pub use burn_autodiff::Autodiff;
pub use burn_tensor::Tensor;
pub type TestBackend = ...;
pub type TestTensor<const D: usize> = Tensor<TestBackend, D>;
pub type TestTensorInt<const D: usize> = Tensor<TestBackend, D, burn_tensor::Int>;
pub type TestTensorBool<const D: usize> = Tensor<TestBackend, D, burn_tensor::Bool>;
pub type FloatElem = burn_tensor::ops::FloatElem<TestBackend>;
pub type IntElem = burn_tensor::ops::IntElem<TestBackend>;
pub type TestAutodiffBackend = Autodiff<TestBackend>;
pub type TestAutodiffTensor<const D: usize> = Tensor<TestAutodiffBackend, D>;
```
Tests will automatically run with default dtypes and any variants (f16, bf16, etc.) based on the
backend configuration.

View File

@@ -0,0 +1,22 @@
extern crate alloc;
#[cfg(feature = "std")]
pub use burn_tensor_testgen::might_panic;
/// Generate a test module with custom floating element types.
#[macro_export]
macro_rules! test_float_elem_variant {
($modname:ident, $float:ty, $module:literal, [$($feat:literal),* $(,)?]) => {
#[cfg(all(test, any($(feature = $feat),*)))]
mod $modname {
pub type FloatElemType = $float;
#[allow(unused)]
pub use super::IntElemType;
mod ty {
include!("backend.rs");
include!($module);
}
}
};
}

View File

@@ -0,0 +1,20 @@
//! Burn autodiff tests.
#![allow(
clippy::single_range_in_vec_init,
clippy::duplicate_mod,
reason = "false positive"
)]
extern crate alloc;
pub type FloatElemType = f32;
#[allow(unused)]
pub type IntElemType = i32;
#[path = "common/backend.rs"]
mod backend;
pub use backend::*;
#[allow(clippy::module_inception)]
#[path = "common/autodiff.rs"]
mod autodiff;

View File

@@ -0,0 +1,58 @@
use super::*;
use burn_tensor::{TensorData, Tolerance, cast::ToElement};
#[test]
fn should_diff_abs() {
let data_1 = TensorData::from([[0.0, -1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, -10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[71.0, 107.0], [71.0, 107.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[84.0, 42.0], [90.0, 54.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_abs_no_nans() {
let data_1 = TensorData::from([[6.0, 7.0], [9.0, -10.0]]);
let data_2 = TensorData::from([[0.0, -1.0], [3.0, 4.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[1.0, 7.0], [1.0, 7.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[0.0, -15.0], [-3.0, -3.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let contains_nan = grad_2.contains_nan();
assert!(!contains_nan.into_scalar().to_bool());
}

View File

@@ -0,0 +1,50 @@
use super::*;
use burn_tensor::module::adaptive_avg_pool1d;
use burn_tensor::{Shape, Tolerance};
#[test]
fn test_avg_pool1d_simple() {
let test = AdaptiveAvgPool1dTestCase {
batch_size: 1,
channels: 2,
length: 5,
output_size: 3,
};
test.assert_output(TestTensor::from_floats(
[[
[0.5000, 0.83333, 0.33333, 0.83333, 0.5000],
[0.5000, 0.83333, 0.33333, 0.83333, 0.5000],
]],
&Default::default(),
));
}
struct AdaptiveAvgPool1dTestCase {
batch_size: usize,
channels: usize,
length: usize,
output_size: usize,
}
impl AdaptiveAvgPool1dTestCase {
fn assert_output(self, x_grad: TestTensor<3>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<3, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = adaptive_avg_pool1d(x.clone(), self.output_size);
let grads = output.backward();
let x_grad_actual = x.grad(&grads).unwrap();
x_grad.to_data().assert_approx_eq::<FloatElem>(
&x_grad_actual.into_data(),
Tolerance::default().set_half_precision_relative(1e-3),
);
}
}

View File

@@ -0,0 +1,96 @@
use super::*;
use burn_tensor::module::adaptive_avg_pool2d;
use burn_tensor::{Shape, Tolerance};
#[test]
fn test_avg_pool2d_simple() {
let test = AdaptiveAvgPool2dTestCase {
batch_size: 1,
channels: 2,
height: 5,
width: 3,
output_size_1: 3,
output_size_2: 2,
};
test.assert_output(TestTensor::from_floats(
[[
[
[0.2500, 0.5000, 0.2500],
[0.41667, 0.83333, 0.41667],
[0.16667, 0.33333, 0.16667],
[0.41667, 0.83333, 0.41667],
[0.2500, 0.5000, 0.2500],
],
[
[0.2500, 0.5000, 0.2500],
[0.41667, 0.83333, 0.41667],
[0.16667, 0.33333, 0.16667],
[0.41667, 0.83333, 0.41667],
[0.2500, 0.5000, 0.2500],
],
]],
&Default::default(),
));
}
#[test]
fn test_avg_pool2d_output_1() {
let test = AdaptiveAvgPool2dTestCase {
batch_size: 1,
channels: 1,
height: 4,
width: 8,
output_size_1: 1,
output_size_2: 1,
};
test.assert_output(TestTensor::from_floats(
[[[
[
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
],
[
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
],
[
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
],
[
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
],
]]],
&Default::default(),
));
}
struct AdaptiveAvgPool2dTestCase {
batch_size: usize,
channels: usize,
height: usize,
width: usize,
output_size_1: usize,
output_size_2: usize,
}
impl AdaptiveAvgPool2dTestCase {
fn assert_output(self, x_grad: TestTensor<4>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<4, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = adaptive_avg_pool2d(x.clone(), [self.output_size_1, self.output_size_2]);
let grads = output.backward();
let x_grad_actual = x.grad(&grads).unwrap();
x_grad.to_data().assert_approx_eq::<FloatElem>(
&x_grad_actual.into_data(),
Tolerance::default().set_half_precision_relative(1e-3),
);
}
}

View File

@@ -0,0 +1,74 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_add() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_floats([2.0, 5.0], &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_floats([4.0, 1.0], &device).require_grad();
let tensor_3 = tensor_1.clone() + tensor_2.clone();
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([1.0, 1.0]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([1.0, 1.0]), false);
tensor_3
.to_data()
.assert_eq(&TensorData::from([6.0, 6.0]), false);
}
#[test]
fn should_diff_add_scalar() {
let data = TensorData::from([2.0, 10.0]);
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
let tensor_out = tensor.clone().add_scalar(5.0);
let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();
grad.to_data()
.assert_eq(&TensorData::from([1.0, 1.0]), false);
tensor_out
.into_data()
.assert_eq(&TensorData::from([7.0, 15.0]), false);
}
#[test]
fn test_add_complex_1() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().add(tensor_2.clone());
let tensor_5 = tensor_4
.add(tensor_3)
.add_scalar(5.0)
.add(tensor_1.clone())
.add(tensor_2.clone());
let tensor_6 = tensor_1.clone().add(tensor_5);
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[3.0, 3.0], [3.0, 3.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[2.0, 2.0], [2.0, 2.0]]), false);
}

View File

@@ -0,0 +1,138 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_mean() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.mean().unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[3.5, 9.5], [3.5, 9.5]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[-0.75, -0.75], [3.0, 3.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_sum_1() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.sum().unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[14.0, 38.0], [14.0, 38.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[-3.0, -3.0], [12.0, 12.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_sum_2() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.clone().sum_dim(1);
let tensor_5 = tensor_4.mul(tensor_3);
let grads = tensor_5.sum().backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[494.0, 722.0], [2990.0, 4370.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[690.0, 690.0], [958.0, 958.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_mean_dim() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.mean_dim(1).unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[4.0, 36.0], [3.0, -17.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[9.0, 9.0], [35.5, 35.5]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_sum_dim() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.sum_dim(1).unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[8.0, 72.0], [6.0, -34.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[18.0, 18.0], [71.0, 71.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,102 @@
use super::*;
use burn_tensor::module::avg_pool1d;
use burn_tensor::{Shape, Tolerance};
#[test]
fn test_avg_pool1d_simple() {
let test = AvgPool1dTestCase {
batch_size: 1,
channels: 1,
kernel_size: 3,
padding: 0,
stride: 1,
length: 6,
count_include_pad: true,
};
test.assert_output(TestTensor::from_floats(
[[[0.33333, 0.66667, 1.0000, 1.0000, 0.66667, 0.33333]]],
&Default::default(),
));
}
#[test]
fn test_avg_pool1d_complex() {
let test = AvgPool1dTestCase {
batch_size: 1,
channels: 2,
kernel_size: 3,
padding: 1,
stride: 2,
length: 6,
count_include_pad: true,
};
test.assert_output(TestTensor::from_floats(
[[
[0.33333, 0.66667, 0.33333, 0.66667, 0.33333, 0.33333],
[0.33333, 0.66667, 0.33333, 0.66667, 0.33333, 0.33333],
]],
&Default::default(),
));
}
#[test]
fn test_avg_pool1d_complex_dont_count_pad() {
let test = AvgPool1dTestCase {
batch_size: 1,
channels: 2,
kernel_size: 3,
padding: 1,
stride: 2,
length: 6,
count_include_pad: false,
};
test.assert_output(TestTensor::from_floats(
[[
[0.5000, 0.83333, 0.33333, 0.66667, 0.33333, 0.33333],
[0.5000, 0.83333, 0.33333, 0.66667, 0.33333, 0.33333],
]],
&Default::default(),
));
}
struct AvgPool1dTestCase {
batch_size: usize,
channels: usize,
kernel_size: usize,
padding: usize,
stride: usize,
length: usize,
count_include_pad: bool,
}
impl AvgPool1dTestCase {
fn assert_output(self, x_grad: TestTensor<3>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<3, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = avg_pool1d(
x.clone(),
self.kernel_size,
self.stride,
self.padding,
self.count_include_pad,
false,
);
let grads = output.backward();
let x_grad_actual = x.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
x_grad
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.into_data(), tolerance);
}
}

View File

@@ -0,0 +1,129 @@
use super::*;
use burn_tensor::module::avg_pool2d;
use burn_tensor::{Shape, Tolerance};
#[test]
fn test_avg_pool2d_simple() {
let test = AvgPool2dTestCase {
batch_size: 1,
channels: 1,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
height: 6,
width: 6,
count_include_pad: true,
};
test.assert_output(TestTensor::from_floats(
[[[
[0.11111, 0.22222, 0.33333, 0.33333, 0.22222, 0.11111],
[0.22222, 0.44444, 0.66667, 0.66667, 0.44444, 0.22222],
[0.33333, 0.66667, 1.00000, 1.00000, 0.66667, 0.33333],
[0.33333, 0.66667, 1.00000, 1.00000, 0.66667, 0.33333],
[0.22222, 0.44444, 0.66667, 0.66667, 0.44444, 0.22222],
[0.11111, 0.22222, 0.33333, 0.33333, 0.22222, 0.11111],
]]],
&Default::default(),
));
}
#[test]
fn test_avg_pool2d_complex() {
let test = AvgPool2dTestCase {
batch_size: 1,
channels: 1,
kernel_size_1: 3,
kernel_size_2: 4,
padding_1: 1,
padding_2: 2,
stride_1: 1,
stride_2: 2,
height: 4,
width: 6,
count_include_pad: true,
};
test.assert_output(TestTensor::from_floats(
[[[
[0.33333, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333],
[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
[0.33333, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333],
]]],
&Default::default(),
));
}
#[test]
fn test_avg_pool2d_complex_dont_include_pad() {
let test = AvgPool2dTestCase {
batch_size: 1,
channels: 1,
kernel_size_1: 3,
kernel_size_2: 4,
padding_1: 1,
padding_2: 2,
stride_1: 1,
stride_2: 2,
height: 4,
width: 6,
count_include_pad: false,
};
test.assert_output(TestTensor::from_floats(
[[[
[0.6250, 0.6250, 0.41667, 0.41667, 0.6250, 0.6250],
[0.8750, 0.8750, 0.58333, 0.58333, 0.8750, 0.8750],
[0.8750, 0.8750, 0.58333, 0.58333, 0.8750, 0.8750],
[0.6250, 0.6250, 0.41667, 0.41667, 0.6250, 0.6250],
]]],
&Default::default(),
));
}
struct AvgPool2dTestCase {
batch_size: usize,
channels: usize,
kernel_size_1: usize,
kernel_size_2: usize,
padding_1: usize,
padding_2: usize,
stride_1: usize,
stride_2: usize,
height: usize,
width: usize,
count_include_pad: bool,
}
impl AvgPool2dTestCase {
fn assert_output(self, x_grad: TestTensor<4>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<4, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = avg_pool2d(
x.clone(),
[self.kernel_size_1, self.kernel_size_2],
[self.stride_1, self.stride_2],
[self.padding_1, self.padding_2],
self.count_include_pad,
false,
);
let grads = output.backward();
let x_grad_actual = x.grad(&grads).unwrap();
x_grad.to_data().assert_approx_eq::<FloatElem>(
&x_grad_actual.into_data(),
Tolerance::default().set_half_precision_relative(1e-3),
);
}
}

View File

@@ -0,0 +1,24 @@
use super::*;
use burn_tensor::{Int, Tensor, TensorData, module::embedding};
#[test]
fn test_embedding_backward() {
let weights = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let indices = TensorData::from([[0, 1], [1, 1]]);
let x = TensorData::from([
[[1.0, 2.0], [4.0, 5.0], [3.0, 4.0]],
[[4.0, 5.0], [8.0, 5.0], [1.0, 9.0]],
]);
let device = Default::default();
let weights = Tensor::<TestAutodiffBackend, 2>::from_data(weights, &device).require_grad();
let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(indices, &device);
let x = Tensor::<TestAutodiffBackend, 3>::from_data(x, &device).require_grad();
let output = embedding(weights.clone(), indices);
let output = output.matmul(x);
let grads = output.backward();
let grad = weights.grad(&grads).unwrap();
grad.to_data()
.assert_eq(&TensorData::from([[3., 9., 7.], [21., 35., 27.]]), false);
}

View File

@@ -0,0 +1,27 @@
use super::*;
use burn_tensor::{DType, Distribution, Tensor};
#[test]
fn test_full_precision() {
let device = Default::default();
let x1 = Tensor::<TestAutodiffBackend, 2>::random([32, 32], Distribution::Default, &device)
.require_grad();
let x2 = Tensor::<TestAutodiffBackend, 2>::random([32, 32], Distribution::Default, &device)
.require_grad();
let dtype = x1.dtype();
let x3 = x1.clone().cast(DType::F32);
let x4 = x2.clone().cast(DType::F32);
let x5 = x3.matmul(x4);
let x6 = x5.cast(dtype);
let x7 = x6 * x1.clone() / x2.clone();
let grads = x7.backward();
let x1_grad = x1.grad(&grads);
let x2_grad = x2.grad(&grads);
assert!(x1_grad.is_some());
assert!(x2_grad.is_some());
}

View File

@@ -0,0 +1,56 @@
use super::*;
#[test]
fn mul_broadcast() {
test_ops_broadcast_backward(|x, y| x * y);
}
#[test]
fn div_broadcast() {
test_ops_broadcast_backward(|x, y| x / y);
}
#[test]
fn sub_broadcast() {
test_ops_broadcast_backward(|x, y| x - y);
}
#[test]
fn add_broadcast() {
test_ops_broadcast_backward(|x, y| x + y);
}
#[test]
fn matmul_broadcast() {
test_ops_broadcast_backward(|x, y| x.matmul(y));
}
#[test]
fn mask_where_broadcast() {
test_ops_broadcast_backward(|x, y| {
let cond = y.clone().equal_elem(4);
x.mask_where(cond, y)
});
}
fn test_ops_broadcast_backward<F>(func: F)
where
F: Fn(TestAutodiffTensor<3>, TestAutodiffTensor<3>) -> TestAutodiffTensor<3>,
{
let device = Default::default();
let w = TestAutodiffTensor::zeros([16, 5, 5], &device).require_grad();
let x = TestAutodiffTensor::zeros([4, 5, 5], &device).require_grad();
// Slice isn't a broadcastable operation, so it will fail when the previous backward pass
// of an operation that support broadcast doesn't support it during the backward pass.
let y = func(w.clone().slice([0..1]), x.clone());
// Will panic if broadcast isn't supported!
let grads = y.backward();
let w_grad = w.grad(&grads).unwrap();
let x_grad = x.grad(&grads).unwrap();
assert_eq!(w_grad.shape(), w.shape());
assert_eq!(x_grad.shape(), x.shape());
}

View File

@@ -0,0 +1,28 @@
// Skip on metal - F64 not supported
#![cfg(all(feature = "std", not(feature = "metal")))]
use super::*;
use burn_backend_tests::might_panic;
use burn_tensor::{DType, Tensor, TensorData};
#[might_panic(reason = "Unsupported precision for fusion")]
#[test]
fn cast_keeps_gradient_flow() {
let device = Default::default();
let x = Tensor::<TestAutodiffBackend, 2>::from_data(
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
&device,
)
.require_grad();
let y = x.clone().cast(DType::F64);
let z = y.sum();
let grads = z.backward();
let grad_x = x.grad(&grads).unwrap();
grad_x
.to_data()
.assert_eq(&TensorData::from([[1., 1.], [1., 1.]]), false);
}

View File

@@ -0,0 +1,110 @@
use super::*;
use burn_tensor::Tolerance;
#[test]
fn should_diff_cat() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_data([[2.0, -1.0], [5.0, 2.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::<2>::from_data([[5.0, 4.0], [-1.0, 4.0]], &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let mut tensor_1_list = Vec::new();
let mut tensor_2_list = Vec::new();
for i in 0..2 {
tensor_1_list.push(tensor_1.clone().slice([i..i + 1]));
tensor_2_list.push(tensor_2.clone().slice([i..i + 1]));
}
let tensor_1_cat = TestAutodiffTensor::cat(tensor_1_list.clone(), 0);
let tensor_2_cat = TestAutodiffTensor::cat(tensor_2_list.clone(), 0);
let tensor_3_cat = tensor_1_cat.clone().matmul(tensor_2_cat.clone());
let grads = tensor_3_cat.backward();
let grad_1_slice_1 = tensor_1.grad(&grads).unwrap().slice([0..1]);
let grad_1_slice_2 = tensor_1.grad(&grads).unwrap().slice([1..2]);
let grad_2_slice_1 = tensor_2.grad(&grads).unwrap().slice([0..1]);
let grad_2_slice_2 = tensor_2.grad(&grads).unwrap().slice([1..2]);
grad_1
.clone()
.slice([0..1])
.to_data()
.assert_approx_eq::<FloatElem>(&grad_1_slice_1.to_data(), Tolerance::default());
grad_1
.slice([1..2])
.to_data()
.assert_approx_eq::<FloatElem>(&grad_1_slice_2.to_data(), Tolerance::default());
grad_2
.clone()
.slice([0..1])
.to_data()
.assert_approx_eq::<FloatElem>(&grad_2_slice_1.to_data(), Tolerance::default());
grad_2
.slice([1..2])
.to_data()
.assert_approx_eq::<FloatElem>(&grad_2_slice_2.to_data(), Tolerance::default());
}
#[test]
fn should_diff_cat_more_than_1_dim() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_data([[2.0, -1.0], [5.0, 2.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::<2>::from_data([[5.0, 4.0], [-1.0, 4.0], [4.0, 1.0]], &device)
.require_grad();
// Concat a tensor [2, 2] with another tensor [3, 2] along dim 0.
// The resulting tensor should be [5, 2]
let tensor_3 = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 0);
assert_eq!(tensor_3.dims(), [5, 2]);
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
assert_eq!(tensor_1.dims(), grad_1.dims());
assert_eq!(tensor_2.dims(), grad_2.dims());
}
#[test]
fn should_slice_grads_correctly_when_some_inputs_not_tracked() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data([[1.0]], &device).require_grad(); // tracked
let tensor_2 = TestAutodiffTensor::<2>::from_data([[10.0, 20.0]], &device); // not tracked
let tensor_3 =
TestAutodiffTensor::<2>::from_data([[100.0, 200.0, 300.0]], &device).require_grad(); // tracked
let cat = TestAutodiffTensor::cat(
vec![tensor_1.clone(), tensor_2.clone(), tensor_3.clone()],
1,
);
// Make gradient per column unique so wrong slicing shows up.
let weights = TestAutodiffTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]], &device);
let loss = (cat * weights).sum();
let grads = loss.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_3 = tensor_3.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&burn_tensor::TensorData::from([[1.0]]), false);
grad_3
.to_data()
.assert_eq(&burn_tensor::TensorData::from([[4.0, 5.0, 6.0]]), false);
}

View File

@@ -0,0 +1,21 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_ceil() {
let data = TensorData::from([
[-1.9751, 0.0714, 0.0643, 0.2406],
[-1.3172, 0.1252, -0.1119, -0.0127],
]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
let tensor_2 = tensor_1.clone().ceil();
let grads = tensor_2.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
grad_1.to_data().assert_eq(
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
false,
);
}

View File

@@ -0,0 +1,215 @@
use super::*;
use burn_tensor::{Bool, Tensor, TensorData};
#[test]
fn test_autodiff_checkpoint_complicated_computation() {
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
let device = Default::default();
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
let tensor_5 = compute_bound_eager(tensor_0, tensor_1);
let tensor_6 = compute_bound_lazy(tensor_2, tensor_3.clone());
let tensor_7 = memory_bound_eager(tensor_3, tensor_4);
let tensor_8 = compute_bound_lazy(tensor_6, tensor_7.clone());
let tensor_9 = memory_bound_eager_scalar(tensor_7, 11.);
let tensor_10 = memory_bound_lazy(tensor_5, tensor_8.clone());
let tensor_11 = memory_bound_lazy(tensor_8, tensor_9);
let tensor_12 = compute_bound_lazy(tensor_10, tensor_11);
assert_checkpoint(tensor_12);
}
#[test]
fn test_autodiff_checkpoint_with_missing_requirement() {
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
let device = Default::default();
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device); // does not require_grad
let tensor_2 = memory_bound_eager(tensor_0, tensor_1);
let tensor_3 = memory_bound_eager_scalar(tensor_2.clone(), 11.);
let tensor_4 = memory_bound_eager_scalar(tensor_2.clone(), 11.);
let tensor_5 = compute_bound_lazy(tensor_3, tensor_4);
let tensor_6 = compute_bound_eager_scalar(tensor_5.clone(), 11.);
let tensor_7 = memory_bound_eager(tensor_5, tensor_2);
let tensor_8 = memory_bound_eager(tensor_6, tensor_7);
assert_checkpoint(tensor_8);
}
#[test]
fn test_autodiff_checkpoint_with_many_duplicates() {
let data_0 = TensorData::from([[4.0, 7.0], [7.0, 7.0]]);
let device = Default::default();
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
let tensor_1 = memory_bound_eager(tensor_0.clone(), tensor_0.clone());
let tensor_2 = compute_bound_eager(tensor_0.clone(), tensor_0.clone());
let tensor_3 = memory_bound_lazy(tensor_0.clone(), tensor_0.clone());
let tensor_4 = compute_bound_lazy(tensor_0.clone(), tensor_0.clone());
let tensor_5 = memory_bound_eager(tensor_1.clone(), tensor_0.clone());
let tensor_6 = memory_bound_eager(tensor_0.clone(), tensor_5.clone());
let tensor_7 = compute_bound_lazy(tensor_3.clone(), tensor_5.clone());
let tensor_8 = compute_bound_eager(tensor_4.clone(), tensor_2.clone());
let tensor_9 = memory_bound_lazy(tensor_6, tensor_7);
let tensor_10 = memory_bound_eager(tensor_0, tensor_9);
let tensor_11 = memory_bound_eager_scalar(tensor_10, 9.);
let tensor_12 = compute_bound_lazy(tensor_8, tensor_11);
assert_checkpoint(tensor_12);
}
#[test]
fn test_autodiff_checkpoint_with_long_chain_of_eager_memory_bound() {
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
let device = Default::default();
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device);
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
let tensor_5 = memory_bound_eager(tensor_0, tensor_1.clone());
let tensor_6 = memory_bound_eager(tensor_5, tensor_2);
let tensor_7 = memory_bound_eager(tensor_6, tensor_3);
let tensor_8 = memory_bound_eager(tensor_7, tensor_4);
let tensor_9 = memory_bound_lazy(tensor_8, tensor_1);
assert_checkpoint(tensor_9)
}
#[test]
fn test_autodiff_checkpoint_half_sub_graph_not_tracked() {
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
let data_5 = TensorData::from([[0.5, 7.0], [7.0, 7.0]]);
let device = Default::default();
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device);
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device);
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device);
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
let tensor_5 = TestAutodiffTensor::from_data(data_5, &device).require_grad();
let tensor_6 = memory_bound_lazy(tensor_0, tensor_1);
let tensor_7 = compute_bound_eager(tensor_6, tensor_2);
let tensor_8 = memory_bound_eager(tensor_3, tensor_4);
let tensor_9 = compute_bound_lazy(tensor_8, tensor_5);
let tensor_10 = compute_bound_lazy(tensor_7, tensor_9);
assert_checkpoint(tensor_10);
}
#[test]
fn test_autodiff_checkpoint_very_complex() {
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
let device = Default::default();
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device);
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
let tensor_5 = memory_bound_eager_scalar(tensor_0, 8.);
let tensor_6 = memory_bound_lazy(tensor_5.clone(), tensor_1.clone());
let tensor_7 = compute_bound_lazy(tensor_6.clone(), tensor_6);
let tensor_8 = memory_bound_lazy(tensor_1.clone(), tensor_5.clone());
let tensor_9 = memory_bound_eager_scalar(tensor_7.clone(), 7.);
let tensor_10 = compute_bound_eager(tensor_5, tensor_8);
let tensor_11 = memory_bound_eager(tensor_2.clone(), tensor_9);
let tensor_12 = memory_bound_lazy(tensor_2.clone(), tensor_2);
let tensor_13 = compute_bound_eager(tensor_10.clone(), tensor_11);
let tensor_14 = compute_bound_eager_scalar(tensor_3, 8.);
let tensor_15 = compute_bound_lazy(tensor_4, tensor_12);
let tensor_16 = memory_bound_lazy(tensor_10, tensor_7);
let tensor_17 = compute_bound_lazy(tensor_13, tensor_1);
let tensor_18 = memory_bound_eager(tensor_15, tensor_16);
let tensor_19 = compute_bound_eager(tensor_14, tensor_17);
let tensor_20 = memory_bound_lazy(tensor_18, tensor_19);
let tensor_21 = memory_bound_eager_scalar(tensor_20, 8.);
assert_checkpoint(tensor_21)
}
fn assert_checkpoint<const D: usize>(tensor: TestAutodiffTensor<D>) {
// Assert is not explicit here, but the test can fail
// - when a tensor is actually required more than n_required, it won't be found and will panic
// - when a tensor is actually required less than n_required, the backward states map won't be
// empty and will fail the assertion within the backward code, same for retro_forwards
tensor.backward();
}
// Does not save its state and does not need its parents
fn memory_bound_eager<const D: usize>(
tensor_a: TestAutodiffTensor<D>,
tensor_b: TestAutodiffTensor<D>,
) -> TestAutodiffTensor<D> {
tensor_a.add(tensor_b)
}
fn memory_bound_eager_scalar<const D: usize>(
tensor_a: TestAutodiffTensor<D>,
b: f32,
) -> TestAutodiffTensor<D> {
tensor_a.add_scalar(b)
}
// Saves its own state and does not need its parents
fn compute_bound_eager<const D: usize>(
tensor_a: TestAutodiffTensor<D>,
tensor_b: TestAutodiffTensor<D>,
) -> TestAutodiffTensor<D> {
let mask = Tensor::<TestAutodiffBackend, D, Bool>::empty(tensor_a.shape(), &tensor_a.device());
tensor_a.mask_where(mask, tensor_b)
}
fn compute_bound_eager_scalar<const D: usize>(
tensor_a: TestAutodiffTensor<D>,
b: f32,
) -> TestAutodiffTensor<D> {
let mask = Tensor::<TestAutodiffBackend, D, Bool>::empty(tensor_a.shape(), &tensor_a.device());
tensor_a.mask_fill(mask, b)
}
// Does not save its state and needs its parents
fn memory_bound_lazy<const D: usize>(
tensor_a: TestAutodiffTensor<D>,
tensor_b: TestAutodiffTensor<D>,
) -> TestAutodiffTensor<D> {
tensor_a.mul(tensor_b)
}
// Saves its own state and needs its parents
fn compute_bound_lazy<const D: usize>(
tensor_a: TestAutodiffTensor<D>,
tensor_b: TestAutodiffTensor<D>,
) -> TestAutodiffTensor<D> {
tensor_a.matmul(tensor_b)
}

View File

@@ -0,0 +1,81 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_full_complex_1() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.matmul(tensor_1.clone());
let tensor_5 = tensor_4.mul(tensor_2.clone());
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[593., 463.0], [487.0, 539.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[734.0, 294.0], [1414.0, 242.0]]), false);
}
#[test]
fn should_diff_full_complex_2() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.matmul(tensor_1.clone());
let tensor_5 = tensor_4.add_scalar(17.0).add(tensor_2.clone());
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[166.0, 110.0], [212.0, 156.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[113.0, 141.0], [33.0, 41.0]]), false);
}
#[test]
fn should_diff_full_complex_3() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.matmul(tensor_1.clone());
let tensor_5 = tensor_4.clone().sub(tensor_2.clone());
let tensor_6 = tensor_5.add(tensor_4);
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[332.0, 220.0], [424.0, 312.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[223.0, 279.0], [63.0, 79.0]]), false);
}

View File

@@ -0,0 +1,277 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv1d, ops::ConvOptions};
#[test]
fn test_conv1d_basic() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 1,
groups: 1,
length: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[14., 24., 24., 18.], [26., 42., 42., 30.]],
[[14., 24., 24., 18.], [26., 42., 42., 30.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[30., 44., 36.], [54., 76., 60.]],
[[30., 44., 36.], [54., 76., 60.]],
],
&device,
),
bias: TestTensor::from_floats([8., 8.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv1d_different_channels() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 3,
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 1,
groups: 1,
length: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[39., 63., 63., 45.], [57., 90., 90., 63.]],
[[39., 63., 63., 45.], [57., 90., 90., 63.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[30., 44., 36.], [54., 76., 60.]],
[[30., 44., 36.], [54., 76., 60.]],
[[30., 44., 36.], [54., 76., 60.]],
],
&device,
),
bias: TestTensor::from_floats([8., 8., 8.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv1d_with_padding() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 2,
stride: 1,
dilation: 1,
groups: 1,
length: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[24., 24., 24., 24.], [42., 42., 42., 42.]],
[[24., 24., 24., 24.], [42., 42., 42., 42.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[44., 44., 44.], [76., 76., 76.]],
[[44., 44., 44.], [76., 76., 76.]],
],
&device,
),
bias: TestTensor::from_floats([12., 12.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv1d_with_stride() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
stride: 2,
dilation: 1,
groups: 1,
length: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[8., 16., 8., 10.], [14., 28., 14., 16.]],
[[8., 16., 8., 10.], [14., 28., 14., 16.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[10., 20., 24.], [18., 36., 40.]],
[[10., 20., 24.], [18., 36., 40.]],
],
&device,
),
bias: TestTensor::from_floats([4., 4.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv1d_dilation() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 2,
groups: 1,
length: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[6., 8., 8., 10.], [12., 14., 14., 16.]],
[[6., 8., 8., 10.], [12., 14., 14., 16.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[8., 22., 14.], [16., 38., 22.]],
[[8., 22., 14.], [16., 38., 22.]],
],
&device,
),
bias: TestTensor::from_floats([4., 4.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv1d_groups() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 1,
groups: 2,
length: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[1., 3., 3., 3.], [7., 12., 12., 9.]],
[[1., 3., 3., 3.], [7., 12., 12., 9.]],
],
&device,
),
weight: TestTensor::from_floats([[[30., 44., 36.]], [[54., 76., 60.]]], &device),
bias: TestTensor::from_floats([8., 8.], &device),
};
test.assert_grads(grads);
}
struct Conv1dTestCase {
batch_size: usize,
channels_in: usize,
channels_out: usize,
kernel_size: usize,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
length: usize,
}
struct Grads {
x: TestTensor<3>,
weight: TestTensor<3>,
bias: TestTensor<1>,
}
impl Conv1dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]);
let shape_weight = Shape::new([
self.channels_out,
self.channels_in / self.groups,
self.kernel_size,
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<3, _>(shape_weight)
.into_data(),
&device,
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<3, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = conv1d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
let tolerance = Tolerance::default();
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
}
}

View File

@@ -0,0 +1,962 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv2d, ops::ConvOptions};
#[test]
fn test_conv2d_basic() {
let test = Conv2dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[88., 138., 138., 96.],
[150., 234., 234., 162.],
[150., 234., 234., 162.],
[112., 174., 174., 120.],
],
[
[160., 246., 246., 168.],
[258., 396., 396., 270.],
[258., 396., 396., 270.],
[184., 282., 282., 192.],
],
],
[
[
[88., 138., 138., 96.],
[150., 234., 234., 162.],
[150., 234., 234., 162.],
[112., 174., 174., 120.],
],
[
[160., 246., 246., 168.],
[258., 396., 396., 270.],
[258., 396., 396., 270.],
[184., 282., 282., 192.],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
[
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
],
&device,
),
bias: TestTensor::from_floats([32., 32.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_channels() {
let test = Conv2dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 3,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[240., 369., 369., 252.],
[387., 594., 594., 405.],
[387., 594., 594., 405.],
[276., 423., 423., 288.],
],
[
[348., 531., 531., 360.],
[549., 837., 837., 567.],
[549., 837., 837., 567.],
[384., 585., 585., 396.],
],
],
[
[
[240., 369., 369., 252.],
[387., 594., 594., 405.],
[387., 594., 594., 405.],
[276., 423., 423., 288.],
],
[
[348., 531., 531., 360.],
[549., 837., 837., 567.],
[549., 837., 837., 567.],
[384., 585., 585., 396.],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
[
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
[
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
],
&device,
),
bias: TestTensor::from_floats([32., 32., 32.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_kernel_size() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 4,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[116., 180., 192., 132.],
[198., 306., 324., 222.],
[198., 306., 324., 222.],
[148., 228., 240., 164.],
],
[
[212., 324., 336., 228.],
[342., 522., 540., 366.],
[342., 522., 540., 366.],
[244., 372., 384., 260.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[27., 45., 54., 39.],
[52., 84., 96., 68.],
[51., 81., 90., 63.],
],
[
[123., 189., 198., 135.],
[180., 276., 288., 196.],
[147., 225., 234., 159.],
],
],
[
[
[27., 45., 54., 39.],
[52., 84., 96., 68.],
[51., 81., 90., 63.],
],
[
[123., 189., 198., 135.],
[180., 276., 288., 196.],
[147., 225., 234., 159.],
],
],
],
&device,
),
bias: TestTensor::from_floats([12., 12.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_padding() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 2,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[138., 138., 138., 138.],
[234., 234., 234., 234.],
[234., 234., 234., 234.],
[174., 174., 174., 174.],
],
[
[246., 246., 246., 246.],
[396., 396., 396., 396.],
[396., 396., 396., 396.],
[282., 282., 282., 282.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]],
[[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]],
],
[
[[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]],
[[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]],
],
],
&device,
),
bias: TestTensor::from_floats([24., 24.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_width() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 4,
width: 5,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[88., 138., 138., 138., 96.],
[150., 234., 234., 234., 162.],
[150., 234., 234., 234., 162.],
[112., 174., 174., 174., 120.],
],
[
[160., 246., 246., 246., 168.],
[258., 396., 396., 396., 270.],
[258., 396., 396., 396., 270.],
[184., 282., 282., 282., 192.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]],
[[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]],
],
[
[[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]],
[[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]],
],
],
&device,
),
bias: TestTensor::from_floats([20., 20.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_stride_2() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 2,
stride_2: 2,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 6,
width: 6,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[26., 52., 26., 52., 26., 28.],
[52., 104., 52., 104., 52., 56.],
[26., 52., 26., 52., 26., 28.],
[52., 104., 52., 104., 52., 56.],
[26., 52., 26., 52., 26., 28.],
[32., 64., 32., 64., 32., 34.],
],
[
[44., 88., 44., 88., 44., 46.],
[88., 176., 88., 176., 88., 92.],
[44., 88., 44., 88., 44., 46.],
[88., 176., 88., 176., 88., 92.],
[44., 88., 44., 88., 44., 46.],
[50., 100., 50., 100., 50., 52.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]],
[[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]],
],
[
[[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]],
[[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]],
],
],
&device,
),
bias: TestTensor::from_floats([9., 9.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_stride() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 3,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 8,
width: 8,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[50., 78., 78., 78., 78., 78., 78., 54.],
[62., 96., 96., 96., 96., 96., 96., 66.],
[38., 60., 60., 60., 60., 60., 60., 42.],
[50., 78., 78., 78., 78., 78., 78., 54.],
[62., 96., 96., 96., 96., 96., 96., 66.],
[38., 60., 60., 60., 60., 60., 60., 42.],
[50., 78., 78., 78., 78., 78., 78., 54.],
[62., 96., 96., 96., 96., 96., 96., 66.],
],
[
[86., 132., 132., 132., 132., 132., 132., 90.],
[98., 150., 150., 150., 150., 150., 150., 102.],
[74., 114., 114., 114., 114., 114., 114., 78.],
[86., 132., 132., 132., 132., 132., 132., 90.],
[98., 150., 150., 150., 150., 150., 150., 102.],
[74., 114., 114., 114., 114., 114., 114., 78.],
[86., 132., 132., 132., 132., 132., 132., 90.],
[98., 150., 150., 150., 150., 150., 150., 102.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]],
[
[1330., 1528., 1344.],
[1911., 2196., 1932.],
[2079., 2388., 2100.],
],
],
[
[[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]],
[
[1330., 1528., 1344.],
[1911., 2196., 1932.],
[2079., 2388., 2100.],
],
],
],
&device,
),
bias: TestTensor::from_floats([24., 24.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_dilation_2() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 2,
dilation_2: 2,
groups: 1,
height: 6,
width: 6,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[18., 38., 38., 42., 42., 22.],
[42., 88., 88., 96., 96., 50.],
[42., 88., 88., 96., 96., 50.],
[54., 112., 112., 120., 120., 62.],
[54., 112., 112., 120., 120., 62.],
[30., 62., 62., 66., 66., 34.],
],
[
[36., 74., 74., 78., 78., 40.],
[78., 160., 160., 168., 168., 86.],
[78., 160., 160., 168., 168., 86.],
[90., 184., 184., 192., 192., 98.],
[90., 184., 184., 192., 192., 98.],
[48., 98., 98., 102., 102., 52.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]],
[[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]],
],
[
[[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]],
[[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]],
],
],
&device,
),
bias: TestTensor::from_floats([16., 16.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_dilation() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 2,
dilation_2: 3,
groups: 1,
height: 6,
width: 6,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[18., 0., 20., 20., 0., 22.],
[42., 0., 46., 46., 0., 50.],
[42., 0., 46., 46., 0., 50.],
[54., 0., 58., 58., 0., 62.],
[54., 0., 58., 58., 0., 62.],
[30., 0., 32., 32., 0., 34.],
],
[
[36., 0., 38., 38., 0., 40.],
[78., 0., 82., 82., 0., 86.],
[78., 0., 82., 82., 0., 86.],
[90., 0., 94., 94., 0., 98.],
[90., 0., 94., 94., 0., 98.],
[48., 0., 50., 50., 0., 52.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]],
[[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]],
],
[
[[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]],
[[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]],
],
],
&device,
),
bias: TestTensor::from_floats([8., 8.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_groups() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 2,
height: 5,
width: 5,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[0., 1., 3., 3., 2.],
[3., 8., 15., 12., 7.],
[9., 21., 36., 27., 15.],
[9., 20., 33., 24., 13.],
[6., 13., 21., 15., 8.],
],
[
[9., 19., 30., 21., 11.],
[21., 44., 69., 48., 25.],
[36., 75., 117., 81., 42.],
[27., 56., 87., 60., 31.],
[15., 31., 48., 33., 17.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[[54., 63., 72.], [99., 108., 117.], [144., 153., 162.]]],
[[[279., 288., 297.], [324., 333., 342.], [369., 378., 387.]]],
],
&device,
),
bias: TestTensor::from_floats([9., 9.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_groups_stride_2() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 4,
channels_out: 4,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 2,
stride_2: 2,
dilation_1: 1,
dilation_2: 1,
groups: 4,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[4., 8., 4., 5.],
[8., 16., 8., 10.],
[4., 8., 4., 5.],
[7., 14., 7., 8.],
],
[
[13., 26., 13., 14.],
[26., 52., 26., 28.],
[13., 26., 13., 14.],
[16., 32., 16., 17.],
],
[
[22., 44., 22., 23.],
[44., 88., 44., 46.],
[22., 44., 22., 23.],
[25., 50., 25., 26.],
],
[
[31., 62., 31., 32.],
[62., 124., 62., 64.],
[31., 62., 31., 32.],
[34., 68., 34., 35.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[[5., 10., 12.], [10., 20., 24.], [18., 36., 40.]]],
[[[21., 42., 44.], [42., 84., 88.], [50., 100., 104.]]],
[[[37., 74., 76.], [74., 148., 152.], [82., 164., 168.]]],
[[[53., 106., 108.], [106., 212., 216.], [114., 228., 232.]]],
],
&device,
),
bias: TestTensor::from_floats([4., 4., 4., 4.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_groups_different_channels() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 3,
channels_out: 6,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 3,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[9., 20., 24., 13.],
[24., 52., 60., 32.],
[36., 76., 84., 44.],
[21., 44., 48., 25.],
],
[
[45., 92., 96., 49.],
[96., 196., 204., 104.],
[108., 220., 228., 116.],
[57., 116., 120., 61.],
],
[
[81., 164., 168., 85.],
[168., 340., 348., 176.],
[180., 364., 372., 188.],
[93., 188., 192., 97.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]],
[[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]],
[[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]],
[[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]],
[[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]],
[[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]],
],
&device,
),
bias: TestTensor::from_floats([4., 4., 4., 4., 4., 4.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_complex() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 3,
kernel_size_1: 2,
kernel_size_2: 3,
padding_1: 1,
padding_2: 2,
stride_1: 1,
stride_2: 2,
dilation_1: 2,
dilation_2: 3,
groups: 1,
height: 4,
width: 5,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[36., 39., 0., 39., 42.],
[81., 87., 0., 87., 93.],
[81., 87., 0., 87., 93.],
[45., 48., 0., 48., 51.],
],
[
[54., 57., 0., 57., 60.],
[117., 123., 0., 123., 129.],
[117., 123., 0., 123., 129.],
[63., 66., 0., 66., 69.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[15., 42., 27.], [30., 72., 42.]],
[[75., 162., 87.], [90., 192., 102.]],
],
[
[[15., 42., 27.], [30., 72., 42.]],
[[75., 162., 87.], [90., 192., 102.]],
],
[
[[15., 42., 27.], [30., 72., 42.]],
[[75., 162., 87.], [90., 192., 102.]],
],
],
&device,
),
bias: TestTensor::from_floats([8., 8., 8.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_groups_stride_2_no_pad() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 4,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 2,
stride_2: 2,
dilation_1: 1,
dilation_2: 1,
groups: 2,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[0., 1., 2., 0.],
[3., 4., 5., 0.],
[6., 7., 8., 0.],
[0., 0., 0., 0.],
],
[
[9., 10., 11., 0.],
[12., 13., 14., 0.],
[15., 16., 17., 0.],
[0., 0., 0., 0.],
],
[
[18., 19., 20., 0.],
[21., 22., 23., 0.],
[24., 25., 26., 0.],
[0., 0., 0., 0.],
],
[
[27., 28., 29., 0.],
[30., 31., 32., 0.],
[33., 34., 35., 0.],
[0., 0., 0., 0.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[0., 1., 2.], [4., 5., 6.], [8., 9., 10.]],
[[16., 17., 18.], [20., 21., 22.], [24., 25., 26.]],
],
[
[[32., 33., 34.], [36., 37., 38.], [40., 41., 42.]],
[[48., 49., 50.], [52., 53., 54.], [56., 57., 58.]],
],
],
&device,
),
bias: TestTensor::from_floats([1., 1.], &device),
};
test.assert_grads(grads);
}
struct Conv2dTestCase {
batch_size: usize,
channels_in: usize,
channels_out: usize,
kernel_size_1: usize,
kernel_size_2: usize,
padding_1: usize,
padding_2: usize,
stride_1: usize,
stride_2: usize,
dilation_1: usize,
dilation_2: usize,
groups: usize,
height: usize,
width: usize,
}
struct Grads {
x: TestTensor<4>,
weight: TestTensor<4>,
bias: TestTensor<1>,
}
impl Conv2dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
let shape_weight = Shape::new([
self.channels_out,
self.channels_in / self.groups,
self.kernel_size_1,
self.kernel_size_2,
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<4, _>(shape_weight)
.into_data(),
&device,
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<4, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = conv2d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvOptions::new(
[self.stride_1, self.stride_2],
[self.padding_1, self.padding_2],
[self.dilation_1, self.dilation_2],
self.groups,
),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
let tolerance = Tolerance::rel_abs(0.01, 0.01);
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
}
}

View File

@@ -0,0 +1,690 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv3d, ops::ConvOptions};
#[test]
fn test_conv3d_basic() {
let test = Conv3dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
kernel_size_3: 3,
padding_1: 1,
padding_2: 1,
padding_3: 1,
stride_1: 1,
stride_2: 1,
stride_3: 1,
dilation_1: 1,
dilation_2: 1,
dilation_3: 1,
groups: 1,
depth: 4,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[
[536., 816., 816., 552.],
[840., 1278., 1278., 864.],
[840., 1278., 1278., 864.],
[584., 888., 888., 600.],
],
[
[912., 1386., 1386., 936.],
[1422., 2160., 2160., 1458.],
[1422., 2160., 2160., 1458.],
[984., 1494., 1494., 1008.],
],
[
[912., 1386., 1386., 936.],
[1422., 2160., 2160., 1458.],
[1422., 2160., 2160., 1458.],
[984., 1494., 1494., 1008.],
],
[
[680., 1032., 1032., 696.],
[1056., 1602., 1602., 1080.],
[1056., 1602., 1602., 1080.],
[728., 1104., 1104., 744.],
],
],
[
[
[968., 1464., 1464., 984.],
[1488., 2250., 2250., 1512.],
[1488., 2250., 2250., 1512.],
[1016., 1536., 1536., 1032.],
],
[
[1560., 2358., 2358., 1584.],
[2394., 3618., 3618., 2430.],
[2394., 3618., 3618., 2430.],
[1632., 2466., 2466., 1656.],
],
[
[1560., 2358., 2358., 1584.],
[2394., 3618., 3618., 2430.],
[2394., 3618., 3618., 2430.],
[1632., 2466., 2466., 1656.],
],
[
[1112., 1680., 1680., 1128.],
[1704., 2574., 2574., 1728.],
[1704., 2574., 2574., 1728.],
[1160., 1752., 1752., 1176.],
],
],
],
[
[
[
[536., 816., 816., 552.],
[840., 1278., 1278., 864.],
[840., 1278., 1278., 864.],
[584., 888., 888., 600.],
],
[
[912., 1386., 1386., 936.],
[1422., 2160., 2160., 1458.],
[1422., 2160., 2160., 1458.],
[984., 1494., 1494., 1008.],
],
[
[912., 1386., 1386., 936.],
[1422., 2160., 2160., 1458.],
[1422., 2160., 2160., 1458.],
[984., 1494., 1494., 1008.],
],
[
[680., 1032., 1032., 696.],
[1056., 1602., 1602., 1080.],
[1056., 1602., 1602., 1080.],
[728., 1104., 1104., 744.],
],
],
[
[
[968., 1464., 1464., 984.],
[1488., 2250., 2250., 1512.],
[1488., 2250., 2250., 1512.],
[1016., 1536., 1536., 1032.],
],
[
[1560., 2358., 2358., 1584.],
[2394., 3618., 3618., 2430.],
[2394., 3618., 3618., 2430.],
[1632., 2466., 2466., 1656.],
],
[
[1560., 2358., 2358., 1584.],
[2394., 3618., 3618., 2430.],
[2394., 3618., 3618., 2430.],
[1632., 2466., 2466., 1656.],
],
[
[1112., 1680., 1680., 1128.],
[1704., 2574., 2574., 1728.],
[1704., 2574., 2574., 1728.],
[1160., 1752., 1752., 1176.],
],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[
[4590., 6156., 4644.],
[6264., 8400., 6336.],
[4806., 6444., 4860.],
],
[
[6696., 8976., 6768.],
[9120., 12224., 9216.],
[6984., 9360., 7056.],
],
[
[5454., 7308., 5508.],
[7416., 9936., 7488.],
[5670., 7596., 5724.],
],
],
[
[
[8046., 10764., 8100.],
[10872., 14544., 10944.],
[8262., 11052., 8316.],
],
[
[11304., 15120., 11376.],
[15264., 20416., 15360.],
[11592., 15504., 11664.],
],
[
[8910., 11916., 8964.],
[12024., 16080., 12096.],
[9126., 12204., 9180.],
],
],
],
[
[
[
[4590., 6156., 4644.],
[6264., 8400., 6336.],
[4806., 6444., 4860.],
],
[
[6696., 8976., 6768.],
[9120., 12224., 9216.],
[6984., 9360., 7056.],
],
[
[5454., 7308., 5508.],
[7416., 9936., 7488.],
[5670., 7596., 5724.],
],
],
[
[
[8046., 10764., 8100.],
[10872., 14544., 10944.],
[8262., 11052., 8316.],
],
[
[11304., 15120., 11376.],
[15264., 20416., 15360.],
[11592., 15504., 11664.],
],
[
[8910., 11916., 8964.],
[12024., 16080., 12096.],
[9126., 12204., 9180.],
],
],
],
],
&device,
),
bias: TestTensor::from_floats([128., 128.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv3d_complex() {
let test = Conv3dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 3,
kernel_size_1: 2,
kernel_size_2: 3,
kernel_size_3: 4,
padding_1: 1,
padding_2: 2,
padding_3: 3,
stride_1: 1,
stride_2: 2,
stride_3: 3,
dilation_1: 2,
dilation_2: 3,
dilation_3: 4,
groups: 1,
depth: 5,
height: 6,
width: 7,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[
[0., 147., 0., 0., 0., 150., 0.],
[0., 159., 0., 0., 0., 162., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 159., 0., 0., 0., 162., 0.],
[0., 171., 0., 0., 0., 174., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 330., 0., 0., 0., 336., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 378., 0., 0., 0., 384., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 330., 0., 0., 0., 336., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 378., 0., 0., 0., 384., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 330., 0., 0., 0., 336., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 378., 0., 0., 0., 384., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 183., 0., 0., 0., 186., 0.],
[0., 195., 0., 0., 0., 198., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 195., 0., 0., 0., 198., 0.],
[0., 207., 0., 0., 0., 210., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
],
[
[
[0., 219., 0., 0., 0., 222., 0.],
[0., 231., 0., 0., 0., 234., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 231., 0., 0., 0., 234., 0.],
[0., 243., 0., 0., 0., 246., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 474., 0., 0., 0., 480., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 522., 0., 0., 0., 528., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 474., 0., 0., 0., 480., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 522., 0., 0., 0., 528., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 474., 0., 0., 0., 480., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 522., 0., 0., 0., 528., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 255., 0., 0., 0., 258., 0.],
[0., 267., 0., 0., 0., 270., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 267., 0., 0., 0., 270., 0.],
[0., 279., 0., 0., 0., 282., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[
[0., 256., 272., 0.],
[0., 624., 656., 0.],
[0., 368., 384., 0.],
],
[
[0., 424., 440., 0.],
[0., 960., 992., 0.],
[0., 536., 552., 0.],
],
],
[
[
[0., 1096., 1112., 0.],
[0., 2304., 2336., 0.],
[0., 1208., 1224., 0.],
],
[
[0., 1264., 1280., 0.],
[0., 2640., 2672., 0.],
[0., 1376., 1392., 0.],
],
],
],
[
[
[
[0., 256., 272., 0.],
[0., 624., 656., 0.],
[0., 368., 384., 0.],
],
[
[0., 424., 440., 0.],
[0., 960., 992., 0.],
[0., 536., 552., 0.],
],
],
[
[
[0., 1096., 1112., 0.],
[0., 2304., 2336., 0.],
[0., 1208., 1224., 0.],
],
[
[0., 1264., 1280., 0.],
[0., 2640., 2672., 0.],
[0., 1376., 1392., 0.],
],
],
],
[
[
[
[0., 256., 272., 0.],
[0., 624., 656., 0.],
[0., 368., 384., 0.],
],
[
[0., 424., 440., 0.],
[0., 960., 992., 0.],
[0., 536., 552., 0.],
],
],
[
[
[0., 1096., 1112., 0.],
[0., 2304., 2336., 0.],
[0., 1208., 1224., 0.],
],
[
[0., 1264., 1280., 0.],
[0., 2640., 2672., 0.],
[0., 1376., 1392., 0.],
],
],
],
],
&device,
),
bias: TestTensor::from_floats([10., 10., 10.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv3d_groups_stride_2_no_pad() {
let test = Conv3dTestCase {
batch_size: 1,
channels_in: 4,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
kernel_size_3: 3,
padding_1: 0,
padding_2: 0,
padding_3: 0,
stride_1: 2,
stride_2: 2,
stride_3: 2,
dilation_1: 1,
dilation_2: 1,
dilation_3: 1,
groups: 2,
depth: 4,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[
[0., 1., 2., 0.],
[3., 4., 5., 0.],
[6., 7., 8., 0.],
[0., 0., 0., 0.],
],
[
[9., 10., 11., 0.],
[12., 13., 14., 0.],
[15., 16., 17., 0.],
[0., 0., 0., 0.],
],
[
[18., 19., 20., 0.],
[21., 22., 23., 0.],
[24., 25., 26., 0.],
[0., 0., 0., 0.],
],
[
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
],
],
[
[
[27., 28., 29., 0.],
[30., 31., 32., 0.],
[33., 34., 35., 0.],
[0., 0., 0., 0.],
],
[
[36., 37., 38., 0.],
[39., 40., 41., 0.],
[42., 43., 44., 0.],
[0., 0., 0., 0.],
],
[
[45., 46., 47., 0.],
[48., 49., 50., 0.],
[51., 52., 53., 0.],
[0., 0., 0., 0.],
],
[
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
],
],
[
[
[54., 55., 56., 0.],
[57., 58., 59., 0.],
[60., 61., 62., 0.],
[0., 0., 0., 0.],
],
[
[63., 64., 65., 0.],
[66., 67., 68., 0.],
[69., 70., 71., 0.],
[0., 0., 0., 0.],
],
[
[72., 73., 74., 0.],
[75., 76., 77., 0.],
[78., 79., 80., 0.],
[0., 0., 0., 0.],
],
[
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
],
],
[
[
[81., 82., 83., 0.],
[84., 85., 86., 0.],
[87., 88., 89., 0.],
[0., 0., 0., 0.],
],
[
[90., 91., 92., 0.],
[93., 94., 95., 0.],
[96., 97., 98., 0.],
[0., 0., 0., 0.],
],
[
[99., 100., 101., 0.],
[102., 103., 104., 0.],
[105., 106., 107., 0.],
[0., 0., 0., 0.],
],
[
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[[0., 1., 2.], [4., 5., 6.], [8., 9., 10.]],
[[16., 17., 18.], [20., 21., 22.], [24., 25., 26.]],
[[32., 33., 34.], [36., 37., 38.], [40., 41., 42.]],
],
[
[[64., 65., 66.], [68., 69., 70.], [72., 73., 74.]],
[[80., 81., 82.], [84., 85., 86.], [88., 89., 90.]],
[[96., 97., 98.], [100., 101., 102.], [104., 105., 106.]],
],
],
[
[
[[128., 129., 130.], [132., 133., 134.], [136., 137., 138.]],
[[144., 145., 146.], [148., 149., 150.], [152., 153., 154.]],
[[160., 161., 162.], [164., 165., 166.], [168., 169., 170.]],
],
[
[[192., 193., 194.], [196., 197., 198.], [200., 201., 202.]],
[[208., 209., 210.], [212., 213., 214.], [216., 217., 218.]],
[[224., 225., 226.], [228., 229., 230.], [232., 233., 234.]],
],
],
],
&device,
),
bias: TestTensor::from_floats([1., 1.], &device),
};
test.assert_grads(grads);
}
struct Conv3dTestCase {
batch_size: usize,
channels_in: usize,
channels_out: usize,
kernel_size_1: usize,
kernel_size_2: usize,
kernel_size_3: usize,
padding_1: usize,
padding_2: usize,
padding_3: usize,
stride_1: usize,
stride_2: usize,
stride_3: usize,
dilation_1: usize,
dilation_2: usize,
dilation_3: usize,
groups: usize,
depth: usize,
height: usize,
width: usize,
}
struct Grads {
x: TestTensor<5>,
weight: TestTensor<5>,
bias: TestTensor<1>,
}
impl Conv3dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([
self.batch_size,
self.channels_in,
self.depth,
self.height,
self.width,
]);
let shape_weight = Shape::new([
self.channels_out,
self.channels_in / self.groups,
self.kernel_size_1,
self.kernel_size_2,
self.kernel_size_3,
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<5, _>(shape_weight)
.into_data(),
&device,
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<5, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = conv3d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvOptions::new(
[self.stride_1, self.stride_2, self.stride_3],
[self.padding_1, self.padding_2, self.padding_3],
[self.dilation_1, self.dilation_2, self.dilation_3],
self.groups,
),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
let tolerance = Tolerance::default();
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
}
}

View File

@@ -0,0 +1,292 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv_transpose1d, ops::ConvTransposeOptions};
#[test]
fn test_conv_transpose1d_basic() {
let test = ConvTranspose1dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: 3,
padding: 0,
padding_out: 0,
stride: 1,
dilation: 1,
groups: 1,
size: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]],
[[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[44.0, 44.0, 44.0], [44.0, 44.0, 44.0]],
[[76.0, 76.0, 76.0], [76.0, 76.0, 76.0]],
],
&device,
),
bias: TestTensor::from_floats([12., 12.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose1d_padding() {
let test = ConvTranspose1dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: 3,
padding: 2,
padding_out: 0,
stride: 1,
dilation: 1,
groups: 1,
size: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[7., 12., 8., 3.], [19., 36., 32., 15.]],
[[7., 12., 8., 3.], [19., 36., 32., 15.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[26., 22., 18.], [26., 22., 18.]],
[[42., 38., 34.], [42., 38., 34.]],
],
&device,
),
bias: TestTensor::from_floats([4., 4.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose1d_stride() {
let test = ConvTranspose1dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: 3,
padding: 0,
padding_out: 0,
stride: 2,
dilation: 1,
groups: 1,
size: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[44., 44., 44.], [44., 44., 44.]],
[[76., 76., 76.], [76., 76., 76.]],
],
&device,
),
bias: TestTensor::from_floats([18., 18.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose1d_stride_padding_out() {
let test = ConvTranspose1dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: 3,
padding: 0,
padding_out: 1,
stride: 2,
dilation: 1,
groups: 1,
size: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[44., 44., 44.], [44., 44., 44.]],
[[76., 76., 76.], [76., 76., 76.]],
],
&device,
),
bias: TestTensor::from_floats([20., 20.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose1d_dilation() {
let test = ConvTranspose1dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: 3,
padding: 0,
padding_out: 0,
stride: 1,
dilation: 2,
groups: 1,
size: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[44., 44., 44.], [44., 44., 44.]],
[[76., 76., 76.], [76., 76., 76.]],
],
&device,
),
bias: TestTensor::from_floats([16., 16.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose1d_complex() {
let test = ConvTranspose1dTestCase {
batch_size: 2,
channels: [2, 4],
kernel_size: 3,
padding: 1,
padding_out: 1,
stride: 2,
dilation: 2,
groups: 2,
size: 8,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0],
[36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0],
],
[
[12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0],
[36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[[168.0, 184.0, 184.0], [168.0, 184.0, 184.0]],
[[280.0, 312.0, 312.0], [280.0, 312.0, 312.0]],
],
&device,
),
bias: TestTensor::from_floats([36.0, 36.0, 36.0, 36.0], &device),
};
test.assert_grads(grads);
}
struct ConvTranspose1dTestCase {
batch_size: usize,
channels: [usize; 2],
kernel_size: usize,
padding: usize,
padding_out: usize,
stride: usize,
dilation: usize,
groups: usize,
size: usize,
}
struct Grads {
x: TestTensor<3>,
weight: TestTensor<3>,
bias: TestTensor<1>,
}
impl ConvTranspose1dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([self.batch_size, self.channels[0], self.size]);
let shape_weight = Shape::new([
self.channels[0],
self.channels[1] / self.groups,
self.kernel_size,
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<3, _>(shape_weight)
.into_data(),
&device,
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<3, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = conv_transpose1d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvTransposeOptions::new(
[self.stride],
[self.padding],
[self.padding_out],
[self.dilation],
self.groups,
),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), Tolerance::default());
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), Tolerance::default());
}
}

View File

@@ -0,0 +1,706 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv_transpose2d, ops::ConvTransposeOptions};
#[test]
fn test_conv_transpose2d_basic() {
let test = ConvTranspose2dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: [3, 3],
padding: [0, 0],
padding_out: [0, 0],
stride: [1, 1],
dilation: [1, 1],
groups: 1,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[153., 153., 153., 153.],
[153., 153., 153., 153.],
[153., 153., 153., 153.],
[153., 153., 153., 153.],
],
[
[477., 477., 477., 477.],
[477., 477., 477., 477.],
[477., 477., 477., 477.],
[477., 477., 477., 477.],
],
],
[
[
[153., 153., 153., 153.],
[153., 153., 153., 153.],
[153., 153., 153., 153.],
[153., 153., 153., 153.],
],
[
[477., 477., 477., 477.],
[477., 477., 477., 477.],
[477., 477., 477., 477.],
[477., 477., 477., 477.],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]],
[[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]],
],
[
[
[1264., 1264., 1264.],
[1264., 1264., 1264.],
[1264., 1264., 1264.],
],
[
[1264., 1264., 1264.],
[1264., 1264., 1264.],
[1264., 1264., 1264.],
],
],
],
&device,
),
bias: TestTensor::from_floats([72., 72.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_padding() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [1, 1],
kernel_size: [3, 3],
padding: [1, 2],
padding_out: [0, 0],
stride: [1, 1],
dilation: [1, 1],
groups: 1,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[[
[13., 24., 20., 9.],
[15., 27., 21., 9.],
[15., 27., 21., 9.],
[7., 12., 8., 3.],
]]],
&device,
),
weight: TestTensor::from_floats(
[[[[63., 57., 51.], [68., 60., 52.], [39., 33., 27.]]]],
&device,
),
bias: TestTensor::from_floats([8.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_stride() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [1, 1],
kernel_size: [3, 3],
padding: [0, 0],
padding_out: [0, 0],
stride: [2, 3],
dilation: [1, 1],
groups: 1,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[[
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
]]],
&device,
),
weight: TestTensor::from_floats(
[[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]],
&device,
),
bias: TestTensor::from_floats([108.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_stride_padding_out() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [1, 1],
kernel_size: [3, 3],
padding: [0, 0],
padding_out: [1, 2],
stride: [2, 3],
dilation: [1, 1],
groups: 1,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[[
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
]]],
&device,
),
weight: TestTensor::from_floats(
[[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]],
&device,
),
bias: TestTensor::from_floats([140.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_dilation() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [1, 1],
kernel_size: [3, 3],
padding: [0, 0],
padding_out: [0, 0],
stride: [1, 1],
dilation: [2, 3],
groups: 1,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[[
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
]]],
&device,
),
weight: TestTensor::from_floats(
[[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]],
&device,
),
bias: TestTensor::from_floats([80.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_channels() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [2, 3],
kernel_size: [3, 3],
padding: [0, 0],
padding_out: [0, 0],
stride: [1, 1],
dilation: [1, 1],
groups: 1,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[351., 351., 351., 351.],
[351., 351., 351., 351.],
[351., 351., 351., 351.],
[351., 351., 351., 351.],
],
[
[1080., 1080., 1080., 1080.],
[1080., 1080., 1080., 1080.],
[1080., 1080., 1080., 1080.],
[1080., 1080., 1080., 1080.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],
[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],
[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],
],
[
[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],
[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],
[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],
],
],
&device,
),
bias: TestTensor::from_floats([36., 36., 36.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_kernel_size() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [1, 1],
kernel_size: [3, 5],
padding: [0, 0],
padding_out: [0, 0],
stride: [1, 1],
dilation: [1, 1],
groups: 1,
size: [6, 6],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[[
[105., 105., 105., 105., 105., 105.],
[105., 105., 105., 105., 105., 105.],
[105., 105., 105., 105., 105., 105.],
[105., 105., 105., 105., 105., 105.],
[105., 105., 105., 105., 105., 105.],
[105., 105., 105., 105., 105., 105.],
]]],
&device,
),
weight: TestTensor::from_floats(
[[[
[630., 630., 630., 630., 630.],
[630., 630., 630., 630., 630.],
[630., 630., 630., 630., 630.],
]]],
&device,
),
bias: TestTensor::from_floats([80.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_groups() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [2, 2],
kernel_size: [3, 3],
padding: [0, 0],
padding_out: [0, 0],
stride: [1, 1],
dilation: [1, 1],
groups: 2,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
],
[
[117., 117., 117., 117.],
[117., 117., 117., 117.],
[117., 117., 117., 117.],
[117., 117., 117., 117.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]],
[[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]]],
],
&device,
),
bias: TestTensor::from_floats([36., 36.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_complex_no_groups() {
let test = ConvTranspose2dTestCase {
batch_size: 2,
channels: [2, 3],
kernel_size: [3, 5],
padding: [1, 2],
padding_out: [1, 2],
stride: [2, 3],
dilation: [2, 3],
groups: 1,
size: [6, 8],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[600., 735., 735., 735., 735., 735., 735., 735.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
],
[
[1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
],
],
[
[
[600., 735., 735., 735., 735., 735., 735., 735.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
],
[
[1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[5320., 6040., 6040., 6040., 6040.],
[6048., 6864., 6864., 6864., 6864.],
[6048., 6864., 6864., 6864., 6864.],
],
[
[5320., 6040., 6040., 6040., 6040.],
[6048., 6864., 6864., 6864., 6864.],
[6048., 6864., 6864., 6864., 6864.],
],
[
[5320., 6040., 6040., 6040., 6040.],
[6048., 6864., 6864., 6864., 6864.],
[6048., 6864., 6864., 6864., 6864.],
],
],
[
[
[8680., 9880., 9880., 9880., 9880.],
[10080., 11472., 11472., 11472., 11472.],
[10080., 11472., 11472., 11472., 11472.],
],
[
[8680., 9880., 9880., 9880., 9880.],
[10080., 11472., 11472., 11472., 11472.],
[10080., 11472., 11472., 11472., 11472.],
],
[
[8680., 9880., 9880., 9880., 9880.],
[10080., 11472., 11472., 11472., 11472.],
[10080., 11472., 11472., 11472., 11472.],
],
],
],
&device,
),
bias: TestTensor::from_floats([896., 896., 896.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_complex_no_groups_2() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [4, 2],
kernel_size: [2, 3],
padding: [1, 2],
padding_out: [1, 2],
stride: [2, 3],
dilation: [1, 2],
groups: 1,
size: [10, 10],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[30., 42., 42., 42., 42., 42., 42., 42., 42., 42.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
],
[
[78., 114., 114., 114., 114., 114., 114., 114., 114., 114.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
],
[
[126., 186., 186., 186., 186., 186., 186., 186., 186., 186.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
],
[
[174., 258., 258., 258., 258., 258., 258., 258., 258., 258.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[4455., 4905., 4905.], [4500., 4950., 4950.]],
[[4455., 4905., 4905.], [4500., 4950., 4950.]],
],
[
[[12555., 13905., 13905.], [13500., 14950., 14950.]],
[[12555., 13905., 13905.], [13500., 14950., 14950.]],
],
[
[[20655., 22905., 22905.], [22500., 24950., 24950.]],
[[20655., 22905., 22905.], [22500., 24950., 24950.]],
],
[
[[28755., 31905., 31905.], [31500., 34950., 34950.]],
[[28755., 31905., 31905.], [31500., 34950., 34950.]],
],
],
&device,
),
bias: TestTensor::from_floats([570., 570.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_complex_groups() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [4, 2],
kernel_size: [2, 3],
padding: [1, 2],
padding_out: [1, 2],
stride: [2, 3],
dilation: [1, 2],
groups: 2,
size: [10, 10],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[9., 12., 12., 12., 12., 12., 12., 12., 12., 12.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
],
[
[21., 30., 30., 30., 30., 30., 30., 30., 30., 30.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
],
[
[33., 48., 48., 48., 48., 48., 48., 48., 48., 48.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
],
[
[45., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[[4455., 4905., 4905.], [4500., 4950., 4950.]]],
[[[12555., 13905., 13905.], [13500., 14950., 14950.]]],
[[[20655., 22905., 22905.], [22500., 24950., 24950.]]],
[[[28755., 31905., 31905.], [31500., 34950., 34950.]]],
],
&device,
),
bias: TestTensor::from_floats([570., 570.], &device),
};
test.assert_grads(grads);
}
struct ConvTranspose2dTestCase {
batch_size: usize,
channels: [usize; 2],
kernel_size: [usize; 2],
padding: [usize; 2],
padding_out: [usize; 2],
stride: [usize; 2],
dilation: [usize; 2],
groups: usize,
size: [usize; 2],
}
struct Grads {
x: TestTensor<4>,
weight: TestTensor<4>,
bias: TestTensor<1>,
}
impl ConvTranspose2dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([
self.batch_size,
self.channels[0],
self.size[0],
self.size[1],
]);
let shape_weight = Shape::new([
self.channels[0],
self.channels[1] / self.groups,
self.kernel_size[0],
self.kernel_size[1],
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<4, _>(shape_weight)
.into_data(),
&device,
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<4, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = conv_transpose2d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvTransposeOptions::new(
self.stride,
self.padding,
self.padding_out,
self.dilation,
self.groups,
),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
let tolerance = Tolerance::permissive();
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
}
}

View File

@@ -0,0 +1,711 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv_transpose3d, ops::ConvTransposeOptions};
#[test]
fn test_conv_transpose3d_basic() {
let test = ConvTranspose3dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: [3, 3, 3],
padding: [0, 0, 0],
padding_out: [0, 0, 0],
stride: [1, 1, 1],
dilation: [1, 1, 1],
groups: 1,
size: [4, 4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
],
[
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
],
],
[
[
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
],
[
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
],
[
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
],
[
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
],
],
[
[
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
],
[
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
],
[
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
],
],
],
[
[
[
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
],
[
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
],
[
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
],
],
[
[
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
],
[
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
],
[
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
],
],
],
],
&device,
),
bias: TestTensor::from_floats([432., 432.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose3d_complex_groups() {
let test = ConvTranspose3dTestCase {
batch_size: 1,
channels: [4, 2],
kernel_size: [2, 3, 4],
padding: [1, 2, 3],
padding_out: [1, 2, 3],
stride: [2, 3, 4],
dilation: [1, 2, 3],
groups: 2,
size: [6, 6, 6],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[
[1.250000, 1.625000, 1.625000, 1.625000, 1.625000, 1.625000],
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
],
[
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
],
[
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
],
[
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
],
[
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
],
[
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
],
],
[
[
[2.750000, 3.625000, 3.625000, 3.625000, 3.625000, 3.625000],
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
],
[
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
],
[
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
],
[
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
],
[
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
],
[
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
],
],
[
[
[4.250000, 5.625000, 5.625000, 5.625000, 5.625000, 5.625000],
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
],
[
[
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
],
[
[
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
],
[
[
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
],
[
[
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
],
[
[
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
],
],
[
[
[5.750000, 7.625000, 7.625000, 7.625000, 7.625000, 7.625000],
[
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
],
[
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
],
[
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
],
[
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
],
[
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
],
],
[
[
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
],
[
[
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
],
[
[
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
],
[
[
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
],
[
[
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[
[
[18.663193, 22.309027, 22.309027, 22.309027],
[21.875000, 26.145834, 26.145834, 26.145834],
[21.875000, 26.145834, 26.145834, 26.145834],
],
[
[19.270832, 23.020834, 23.020834, 23.020834],
[22.500000, 26.875002, 26.875002, 26.875002],
[22.500000, 26.875002, 26.875002, 26.875002],
],
]],
[[
[
[49.913193, 59.809029, 59.809029, 59.809029],
[59.375000, 71.145836, 71.145836, 71.145836],
[59.375000, 71.145836, 71.145836, 71.145836],
],
[
[56.770836, 68.020836, 68.020836, 68.020836],
[67.500000, 80.875000, 80.875000, 80.875000],
[67.500000, 80.875000, 80.875000, 80.875000],
],
]],
[[
[
[81.163193, 97.309029, 97.309029, 97.309029],
[96.875000, 116.145828, 116.145828, 116.145828],
[96.875000, 116.145828, 116.145828, 116.145828],
],
[
[94.270828, 113.020828, 113.020828, 113.020828],
[112.500000, 134.875000, 134.875000, 134.875000],
[112.500000, 134.875000, 134.875000, 134.875000],
],
]],
[[
[
[112.413200, 134.809021, 134.809021, 134.809021],
[134.375000, 161.145828, 161.145828, 161.145828],
[134.375000, 161.145828, 161.145828, 161.145828],
],
[
[131.770844, 158.020828, 158.020828, 158.020828],
[157.500000, 188.875000, 188.875000, 188.875000],
[157.500000, 188.875000, 188.875000, 188.875000],
],
]],
],
&device,
),
bias: TestTensor::from_floats([5346., 5346.], &device),
};
test.assert_grads(grads);
}
struct ConvTranspose3dTestCase {
batch_size: usize,
channels: [usize; 2],
kernel_size: [usize; 3],
padding: [usize; 3],
padding_out: [usize; 3],
stride: [usize; 3],
dilation: [usize; 3],
groups: usize,
size: [usize; 3],
}
struct Grads {
x: TestTensor<5>,
weight: TestTensor<5>,
bias: TestTensor<1>,
}
impl ConvTranspose3dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([
self.batch_size,
self.channels[0],
self.size[0],
self.size[1],
self.size[2],
]);
let shape_weight = Shape::new([
self.channels[0],
self.channels[1] / self.groups,
self.kernel_size[0],
self.kernel_size[1],
self.kernel_size[2],
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<5, _>(shape_weight.clone())
.into_data(),
&device,
)
.div_scalar(shape_weight.num_elements() as f32)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<5, _>(shape_x.clone())
.into_data(),
&device,
)
.div_scalar(shape_x.num_elements() as f32)
.require_grad();
let output = conv_transpose3d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvTransposeOptions::new(
self.stride,
self.padding,
self.padding_out,
self.dilation,
self.groups,
),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
let tolerance = Tolerance::permissive();
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
}
}

View File

@@ -0,0 +1,103 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[cfg(feature = "std")]
use burn_backend_tests::might_panic;
#[test]
fn backward_basic() {
let device = Default::default();
let a = TestAutodiffTensor::<2>::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
)
.require_grad();
let b = TestAutodiffTensor::<2>::from_data(
TensorData::from([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
&device,
)
.require_grad();
// Simple cross product; grad is a vector of ones.
let c = a.clone().cross(b.clone(), 1);
let grads = c.backward();
let a_grad = a.grad(&grads).unwrap().to_data();
let b_grad = b.grad(&grads).unwrap().to_data();
// For a: b×grad_out, where grad_out = [1,1,1]
let expected_a = TensorData::from([[-1.0, 2.0, -1.0], [-1.0, 2.0, -1.0]]);
// For b: grad_out×a
let expected_b = TensorData::from([[1.0, -2.0, 1.0], [1.0, -2.0, 1.0]]);
a_grad.assert_approx_eq::<FloatElem>(&expected_a, Tolerance::default());
b_grad.assert_approx_eq::<FloatElem>(&expected_b, Tolerance::default());
}
#[test]
fn backward_after_sum() {
let device = Default::default();
let a = TestAutodiffTensor::<2>::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
)
.require_grad();
let b = TestAutodiffTensor::<2>::from_data(
TensorData::from([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
&device,
)
.require_grad();
// Sum reduces to scalar, but the gradient should be the same.
let c = a.clone().cross(b.clone(), 1).sum();
let grads = c.backward();
let a_grad = a.grad(&grads).unwrap().to_data();
let b_grad = b.grad(&grads).unwrap().to_data();
let expected_a = TensorData::from([[-1.0, 2.0, -1.0], [-1.0, 2.0, -1.0]]);
let expected_b = TensorData::from([[1.0, -2.0, 1.0], [1.0, -2.0, 1.0]]);
a_grad.assert_approx_eq::<FloatElem>(&expected_a, Tolerance::default());
b_grad.assert_approx_eq::<FloatElem>(&expected_b, Tolerance::default());
}
#[cfg(feature = "std")]
#[might_panic(reason = "not implemented: Cross product on non-last dimension")]
#[test]
fn different_dim() {
// Also check when the cross is along a different dimension (e.g. dim 0).
let device = Default::default();
let a_raw = [[1.0, 4.0, 7.0], [2.0, 5.0, 8.0], [3.0, 6.0, 9.0]];
let b_raw = [[9.0, 6.0, 3.0], [8.0, 5.0, 2.0], [7.0, 4.0, 1.0]];
let a = TestTensor::<2>::from_data(TensorData::from(a_raw), &device);
let b = TestTensor::<2>::from_data(TensorData::from(b_raw), &device);
// Cross along dim 0. Some backends (for example CubeCL) may not support
// cross on non-last dimensions and will intentionally panic with a
// message like "Cross product on non-last dimension not yet implemented".
// In that case we treat the panic as a skipped test for that backend.
let out = a.cross(b.clone(), 0);
// Manually compute cross of each column vector using raw arrays
let expected = [
[
a_raw[1][0] * b_raw[2][0] - a_raw[2][0] * b_raw[1][0],
a_raw[1][1] * b_raw[2][1] - a_raw[2][1] * b_raw[1][1],
a_raw[1][2] * b_raw[2][2] - a_raw[2][2] * b_raw[1][2],
],
[
a_raw[2][0] * b_raw[0][0] - a_raw[0][0] * b_raw[2][0],
a_raw[2][1] * b_raw[0][1] - a_raw[0][1] * b_raw[2][1],
a_raw[2][2] * b_raw[0][2] - a_raw[0][2] * b_raw[2][2],
],
[
a_raw[0][0] * b_raw[1][0] - a_raw[1][0] * b_raw[0][0],
a_raw[0][1] * b_raw[1][1] - a_raw[1][1] * b_raw[0][1],
a_raw[0][2] * b_raw[1][2] - a_raw[1][2] * b_raw[0][2],
],
];
out.to_data()
.assert_approx_eq::<FloatElem>(&TensorData::from(expected), Tolerance::default());
}

View File

@@ -0,0 +1,33 @@
use super::*;
use burn_tensor::{Tensor, TensorData, Tolerance, loss};
#[test]
fn test_cross_entropy_loss_grad() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let data_targets = TensorData::from([[0.8, 0.2], [0.9, 0.1]]);
let device = Default::default();
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();
let tensor_targets =
Tensor::<TestAutodiffBackend, 2>::from_data(data_targets, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = loss::cross_entropy_with_logits(tensor_3, tensor_targets);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::permissive();
let expected = TensorData::from([[0.26553, 0.26553], [0.44954, 0.44954]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[-1.34863, 1.34863], [-2.06371, 2.06371]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,117 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_cummax() {
// Simple test to verify cummax gradients work
let device = Default::default();
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 3.0, 2.0]), &device)
.require_grad();
let output = tensor.clone().cummax(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 2.0, 0.0]
let expected = TensorData::from([1.0, 2.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummax_2d() {
// Test 2D cummax gradients
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(
TensorData::from([[1.0, 3.0, 2.0], [2.0, 5.0, 4.0]]),
&device,
)
.require_grad();
let output = tensor.clone().cummax(1);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]
let expected = TensorData::from([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummax_duplicate_values() {
// Test with duplicate maximum values - critical edge case
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 3.0, 3.0, 2.0]), &device)
.require_grad();
let output = tensor.clone().cummax(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// input: [1.0, 3.0, 3.0, 2.0]
// cummax: [1.0, 3.0, 3.0, 3.0]
// PyTorch reference: [1.0, 1.0, 2.0, 0.0]
// Position 2 gets grad from itself + position 3
let expected = TensorData::from([1.0, 1.0, 2.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummax_all_same() {
// Test with all same values
let device = Default::default();
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 2.0, 2.0]), &device)
.require_grad();
let output = tensor.clone().cummax(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 1.0, 1.0]
// Each position matches cummax, so each gets its own gradient
let expected = TensorData::from([1.0, 1.0, 1.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummax_increasing() {
// Test with increasing sequence
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 2.0, 3.0, 4.0]), &device)
.require_grad();
let output = tensor.clone().cummax(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 1.0, 1.0, 1.0]
// Each position is a new maximum
let expected = TensorData::from([1.0, 1.0, 1.0, 1.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummax_2d_duplicates() {
// Test 2D with duplicate values
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(
TensorData::from([[1.0, 3.0, 3.0, 2.0], [2.0, 5.0, 5.0, 4.0]]),
&device,
)
.require_grad();
let output = tensor.clone().cummax(1);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]
let expected = TensorData::from([[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,117 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_cummin() {
// Simple test to verify cummin gradients work
let device = Default::default();
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([3.0, 2.0, 4.0]), &device)
.require_grad();
let output = tensor.clone().cummin(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 2.0, 0.0]
let expected = TensorData::from([1.0, 2.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummin_2d() {
// Test 2D cummin gradients
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(
TensorData::from([[3.0, 2.0, 4.0], [5.0, 1.0, 3.0]]),
&device,
)
.require_grad();
let output = tensor.clone().cummin(1);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]
let expected = TensorData::from([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummin_duplicate_values() {
// Test with duplicate minimum values - critical edge case
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([3.0, 2.0, 2.0, 4.0]), &device)
.require_grad();
let output = tensor.clone().cummin(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// input: [3.0, 2.0, 2.0, 4.0]
// cummin: [3.0, 2.0, 2.0, 2.0]
// PyTorch reference: [1.0, 1.0, 2.0, 0.0]
// Position 2 gets grad from itself + position 3
let expected = TensorData::from([1.0, 1.0, 2.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummin_all_same() {
// Test with all same values
let device = Default::default();
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 2.0, 2.0]), &device)
.require_grad();
let output = tensor.clone().cummin(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 1.0, 1.0]
// Each position matches cummin, so each gets its own gradient
let expected = TensorData::from([1.0, 1.0, 1.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummin_decreasing() {
// Test with decreasing sequence
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([5.0, 4.0, 3.0, 2.0]), &device)
.require_grad();
let output = tensor.clone().cummin(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 1.0, 1.0, 1.0]
// Each position is a new minimum
let expected = TensorData::from([1.0, 1.0, 1.0, 1.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummin_2d_duplicates() {
// Test 2D with duplicate values
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(
TensorData::from([[3.0, 2.0, 2.0, 4.0], [5.0, 1.0, 1.0, 3.0]]),
&device,
)
.require_grad();
let output = tensor.clone().cummin(1);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]
let expected = TensorData::from([[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,132 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_cumprod() {
// Simple test to verify cumprod gradients work
let device = Default::default();
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 3.0, 4.0]), &device)
.require_grad();
let output = tensor.clone().cumprod(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [16.0, 10.0, 6.0]
let expected = TensorData::from([16.0, 10.0, 6.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cumprod_2d() {
// Test 2D cumprod gradients
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
)
.require_grad();
let output = tensor.clone().cumprod(1);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [[9.0, 4.0, 2.0], [36.0, 28.0, 20.0]]
let expected = TensorData::from([[9.0, 4.0, 2.0], [36.0, 28.0, 20.0]]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
// TODO: The following tests are currently ignored due to a known limitation
// in the cumprod gradient implementation. The current implementation uses
// division (grad / input), which produces NaN when the input contains zeros.
//
// A proper fix requires implementing a zero-safe algorithm using exclusive
// cumulative products (similar to PyTorch's cumprod_backward or JAX's
// associative_scan approach). This is a non-trivial implementation that
// requires careful handling of cumulative products in both forward and
// reverse directions.
//
// See: https://github.com/tracel-ai/burn/issues/3864
//
// References:
// - PyTorch: https://github.com/pytorch/pytorch (cumprod_backward)
// - JAX PR #2596: Parallel prefix scan implementation
// - TensorFlow Issue #3862: tf.cumprod's gradient produces nans given zeros
#[test]
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
fn should_diff_cumprod_zero_in_middle() {
// Test cumprod with zero in the middle - edge case for division
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 0.0, 3.0, 4.0]), &device)
.require_grad();
let output = tensor.clone().cumprod(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 32.0, 0.0, 0.0]
let expected = TensorData::from([1.0, 32.0, 0.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
fn should_diff_cumprod_zero_at_start() {
// Test cumprod with zero at the beginning
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([0.0, 2.0, 3.0, 4.0]), &device)
.require_grad();
let output = tensor.clone().cumprod(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [33.0, 0.0, 0.0, 0.0]
let expected = TensorData::from([33.0, 0.0, 0.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
fn should_diff_cumprod_zero_at_end() {
// Test cumprod with zero at the end
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 3.0, 4.0, 0.0]), &device)
.require_grad();
let output = tensor.clone().cumprod(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [16.0, 10.0, 6.0, 24.0]
let expected = TensorData::from([16.0, 10.0, 6.0, 24.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
fn should_diff_cumprod_multiple_zeros() {
// Test cumprod with multiple zeros
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 0.0, 3.0, 0.0, 5.0]), &device)
.require_grad();
let output = tensor.clone().cumprod(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 8.0, 0.0, 0.0, 0.0]
let expected = TensorData::from([1.0, 8.0, 0.0, 0.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,89 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_cumsum_dim0() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.cumsum(0);
let tensor_5 = tensor_1.clone().mul(tensor_4);
let grads = tensor_5.sum().backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
// Expected gradients computed with PyTorch
let expected = TensorData::from([[-14.0, 24.0], [17.0, 6.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[3.0, 10.0], [-1.0, 37.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cumsum_dim1() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.cumsum(1);
let tensor_5 = tensor_1.clone().mul(tensor_4);
let grads = tensor_5.sum().backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
// Expected gradients computed with PyTorch
let expected = TensorData::from([[1.0, 69.0], [-13.0, -28.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[18.0, 13.0], [71.0, 58.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cumsum_complex() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.clone().cumsum(1);
let tensor_5 = tensor_4.mul(tensor_3);
let grads = tensor_5.sum().backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
// Expected gradients computed with PyTorch
let expected = TensorData::from([[371.0, 542.0], [2246.0, 3281.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[507.0, 528.0], [704.0, 733.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,105 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_div() {
let data_1 = TensorData::from([1.0, 7.0]);
let data_2 = TensorData::from([4.0, 7.0]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().div(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([0.25, 0.14285715]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([-0.0625, -0.14285715]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_div_scalar() {
let data = TensorData::from([1.0, 7.0]);
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
let tensor_out = tensor.clone().div_scalar(4.0);
let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();
grad.to_data()
.assert_eq(&TensorData::from([0.25, 0.25]), false);
}
#[test]
fn test_div_complex_1() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().div(tensor_2.clone());
let tensor_5 = tensor_4.div(tensor_3.clone());
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let grad_3 = tensor_3.grad(&grads).unwrap();
let expected = TensorData::from([[0.1250, 0.07142857], [0.25, 0.16666667]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[-0.03125, -0.07142857], [-1.6250, 0.16666667]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[-0.0625, -0.25], [-1.6250, 0.25]]);
grad_3
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn test_div_complex_2() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.div(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_absolute(2e-3);
let expected = TensorData::from([[2.00, 2.92857146], [1.36666667, 2.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[0.08333334, 0.09591837], [-0.05555558, -0.06714284]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,29 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_erf() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().erf());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[32.0, 32.0], [32.0, 32.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[8.0, 8.0], [8.0, 8.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,29 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_exp() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().exp());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default();
let expected = TensorData::from([[54.5991, 27.4746], [54.5991, 27.4746]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[-5.4598e+01, -9.1188e-04], [2.9556e+01, 8.0342e+01]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,39 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_expand() {
// Python code to generate the test case values
// import torch
// x1 = torch.tensor([4.0, 7.0, 2.0, 3.0], requires_grad=True)
// x2 = torch.tensor([2.0, 4.5, 7.0, 3.0], requires_grad=True)
// y = x1.expand(4, 4)
// z = (x2 * y).sum()
// z.backward()
// print("x1", x1.grad)
// print("x2", x2.grad)
let device = Default::default();
let data_1 = TensorData::from([4.0, 7.0, 2.0, 3.0]);
let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad();
let data_2 = TensorData::from([2.0, 4.5, 7.0, 3.0]);
let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().expand([4, 4]);
// Use unsqueeze to make tensor_2 have the same shape as tensor_3
let tensor_4 = tensor_2.clone().unsqueeze().mul(tensor_3).sum();
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([8., 18., 28., 12.]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([16., 28., 8., 12.]), false);
}

View File

@@ -0,0 +1,29 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_flip() {
let data_1 = TensorData::from([[[1.0, 7.0], [2.0, 3.0]]]); // 1x2x2
let data_2 = TensorData::from([[[3.0, 2.0, 7.0], [3.0, 3.2, 1.0]]]); // 1x2x3
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<3>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_2.clone().flip([1, 2]);
let tensor_4 = tensor_1.clone().matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
grad_1
.into_data()
.assert_approx_eq::<FloatElem>(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), tolerance); // 1x2x2
grad_2.into_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[[10.0, 10.0, 10.0], [3.0, 3.0, 3.0]]]),
tolerance,
); // 1x2x3
}

View File

@@ -0,0 +1,21 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_floor() {
let data = TensorData::from([
[-1.9751, 0.0714, 0.0643, 0.2406],
[-1.3172, 0.1252, -0.1119, -0.0127],
]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
let tensor_2 = tensor_1.clone().floor();
let grads = tensor_2.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
grad_1.to_data().assert_eq(
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
false,
);
}

View File

@@ -0,0 +1,99 @@
use super::*;
use burn_tensor::{IndexingUpdateOp, Int, Tensor, TensorData};
#[test]
fn test_gather_grad() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(
TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),
&device,
)
.require_grad();
let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(
TensorData::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]]),
&device,
);
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1.clone().gather(1, indices);
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
grad_1.to_data().assert_eq(
&TensorData::from([[94., 150., 187.], [242., 305., 304.]]),
false,
);
}
#[test]
fn test_scatter_grad() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(
TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),
&device,
)
.require_grad();
let values = TestAutodiffTensor::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
)
.require_grad();
let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(
TensorData::from([[2, 1, 0], [2, 0, 1]]),
&device,
);
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1
.clone()
.scatter(1, indices, values.clone(), IndexingUpdateOp::Add);
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = values.grad(&grads).unwrap();
grad_1.to_data().assert_eq(
&TensorData::from([[127., 181., 235.], [226., 316., 406.]]),
false,
);
grad_2
.to_data()
.assert_eq(&TensorData::from([[19., 19., 19.], [64., 64., 64.]]), false);
}
#[test]
fn test_scatter_add_grad_partial_indices() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::from_data(TensorData::from([[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]]), &device)
.require_grad();
let tensor_2 =
TestAutodiffTensor::from_data(TensorData::from([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]), &device)
.require_grad();
let values =
TestAutodiffTensor::from_data(TensorData::from([[4.0, 5.0, 6.0]]), &device).require_grad();
let indices =
Tensor::<TestAutodiffBackend, 2, Int>::from_data(TensorData::from([[2, 1, 0]]), &device);
let tensor_3 = tensor_1.clone().mul(tensor_2);
let tensor_4 = tensor_3
.clone()
.scatter(1, indices, values.clone(), IndexingUpdateOp::Add);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = values.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[1., 2., 3., 4., 5., 6.]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[1., 1., 1.]]), false);
}

View File

@@ -0,0 +1,29 @@
use super::*;
use burn_tensor::{TensorData, Tolerance, activation};
#[test]
fn should_diff_gelu() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_floats([[0.0, 1.0], [-3.0, 4.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_floats([[6.0, -0.5], [9.0, 10.0]], &device).require_grad();
let x = tensor_1.clone().matmul(activation::gelu(tensor_2.clone()));
let x = tensor_1.clone().matmul(x);
let grads = x.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::permissive();
let expected = TensorData::from([[1.46281, 1.46281], [48.22866, 153.46280]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[-15.0000, -1.98757], [17.0000, 17.0000]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,24 @@
use super::*;
use burn_tensor::{Distribution, activation};
#[test]
fn should_update_tensor_when_grad_replace() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::random([32, 32], Distribution::Default, &device).require_grad();
let tensor_2 = TestAutodiffTensor::random([32, 32], Distribution::Default, &device);
let x = tensor_1.clone().matmul(activation::gelu(tensor_2));
let mut grads = x.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_1_updated =
TestAutodiffTensor::random([32, 32], Distribution::Default, &device).require_grad();
tensor_1.grad_replace(&mut grads, grad_1_updated.clone().inner());
let grad_1_new = tensor_1.grad(&grads).unwrap();
assert_ne!(grad_1_new.to_data(), grad_1.into_data());
assert_eq!(grad_1_new.into_data(), grad_1_updated.into_data());
}

View File

@@ -0,0 +1,30 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_log() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
let expected = TensorData::from([[60.2652, 72.3130], [60.2652, 72.3130]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[22.8614, 24.5043], [24.5729, 26.8507]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,28 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_log1p() {
let tensor_1 = TestAutodiffTensor::<2>::from([[0.0, 1.0], [3.0, 4.0]]).require_grad();
let tensor_2 = TestAutodiffTensor::from([[6.0, 7.0], [9.0, 10.0]]).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log1p());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
let expected = TensorData::from([[64.80622101, 75.49362183], [64.80622101, 75.49362183]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[22.92208481, 24.47565651], [24.72780228, 26.86416626]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,19 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{TensorData, activation};
#[test]
fn should_diff_log_sigmoid() {
let data = TensorData::from([[0.8762, -0.1423], [-300., 200.]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
let tensor_2 = activation::log_sigmoid(tensor_1.clone());
let grads = tensor_2.backward();
let grad = tensor_1.grad(&grads).unwrap();
let expected = TensorData::from([[0.293966, 0.535515], [1.000000, 0.000000]]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,65 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{Bool, Tensor, TensorData};
#[test]
fn should_diff_mask_fill() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let mask = TensorData::from([[true, false], [false, true]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let mask = Tensor::<TestAutodiffBackend, 2, Bool>::from_bool(mask, &device);
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.mask_fill(mask, 2.0);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[7.0, 3.0], [4.0, 2.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[2.0, 1.0], [3.0, 7.0]]), false);
}
#[test]
fn should_diff_mask_where() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data([[1.0, 7.0], [2.0, 3.0]], &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data([[4.0, 7.0], [2.0, 3.0]], &device).require_grad();
let tensor_3 =
TestAutodiffTensor::from_data([[8.8, 9.8], [10.8, 11.8]], &device).require_grad();
let mask =
Tensor::<TestAutodiffBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_5 = tensor_4.clone().matmul(tensor_3.clone());
let tensor_6 = tensor_5.mask_where(mask, tensor_3.clone());
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let grad_3 = tensor_3.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
let expected = TensorData::from([[121.8, 55.0], [110.8, 50.0]]);
grad_1
.into_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[27.4, 33.4], [95.0, 115.0]]);
grad_2
.into_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[15., 18.], [23., 29.]]);
grad_3
.into_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,83 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_matmul() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[3.0, 3.0], [10.0, 10.0]]), false);
tensor_3
.to_data()
.assert_eq(&TensorData::from([[18.0, 28.0], [14.0, 23.0]]), false);
}
#[test]
fn test_matmul_complex_1() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_5 = tensor_4.matmul(tensor_3);
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[44.0, 20.0], [44.0, 20.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[56.0, 56.0], [16.0, 16.0]]), false);
}
#[test]
fn test_matmul_complex_2() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_5 = tensor_4.matmul(tensor_3.clone());
let tensor_6 = tensor_1.clone().matmul(tensor_5);
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[800.0, 792.0], [360.0, 592.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[264., 264.0], [344.0, 344.0]]), false);
}

View File

@@ -0,0 +1,82 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_max_dim() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.max_dim(1).unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[50.0, 34.0], [40.0, -10.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[8.0, 10.0], [56.0, 15.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_min_dim() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.min_dim(1).unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[-42.0, 38.0], [-34.0, -24.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[10.0, 8.0], [15.0, 56.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_min_dim_3d_dim1() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<3>::from_floats([[[1.0, 7.0], [-2.0, -3.0]]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::<3>::from_floats([[[4., -7.], [2., 3.]]], &device).require_grad();
let tensor_3 = tensor_1.clone().mul(tensor_2.clone());
let tensor_4 = tensor_3.min_dim(1);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[[0., -7.], [2., 0.]]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[[0., 7.], [-2., -0.]]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,134 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::module::max_pool1d;
#[test]
fn test_max_pool1d_simple() {
let kernel_size = 4;
let padding = 0;
let stride = 1;
let dilation = 1;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221]]],
&device,
)
.require_grad();
let x_grad_expected =
TestAutodiffTensor::<3>::from_floats([[[1., 1., 0., 0., 0., 1.]]], &device);
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool1d_with_dilation() {
let kernel_size = 4;
let padding = 0;
let stride = 1;
let dilation = 2;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
0.4610, 0.5365, 0.6880,
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<3>::from_floats(
[[[
0., 0., 1., 0., 0., 3., 0., 1., 2., 1., 0., 0., 2., 0., 0., 0., 4., 4., 0., 0., 0., 0.,
0., 0., 1.,
]]],
&device,
);
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool1d_complex() {
let kernel_size = 4;
let padding = 0;
let stride = 1;
let dilation = 1;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
0.4610, 0.5365, 0.6880,
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<3>::from_floats(
[[[
0., 0., 0., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0.,
1., 1., 1.,
]]],
&device,
);
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool1d_complex_with_padding() {
let kernel_size = 4;
let padding = 2;
let stride = 1;
let dilation = 1;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
0.4610, 0.5365, 0.6880,
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<3>::from_floats(
[[[
1., 0., 1., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0.,
1., 1., 3.,
]]],
&device,
);
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}

View File

@@ -0,0 +1,271 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::module::max_pool2d;
#[test]
fn test_max_pool2d_simple_1() {
let kernel_size_1 = 3;
let kernel_size_2 = 3;
let padding_1 = 0;
let padding_2 = 0;
let stride_1 = 1;
let stride_2 = 1;
let dilation_1 = 1;
let dilation_2 = 1;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
[0.2479, 0.6386, 0.3166, 0.5742],
[0.7065, 0.1940, 0.6305, 0.8959],
[0.5416, 0.8602, 0.8129, 0.1662],
[0.3358, 0.3059, 0.8293, 0.0990],
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
[[[
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 2.0],
[0.0, 2.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
]]],
&device,
);
let output = max_pool2d(
x.clone(),
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
false,
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool2d_simple_2() {
let kernel_size_1 = 2;
let kernel_size_2 = 2;
let padding_1 = 1;
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 1;
let dilation_1 = 1;
let dilation_2 = 1;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
[0.2479, 0.6386, 0.3166, 0.5742],
[0.7065, 0.1940, 0.6305, 0.8959],
[0.5416, 0.8602, 0.8129, 0.1662],
[0.3358, 0.3059, 0.8293, 0.0990],
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
[[[
[1., 3., 0., 2.],
[3., 0., 0., 4.],
[1., 4., 0., 1.],
[2., 0., 3., 1.],
]]],
&device,
);
let output = max_pool2d(
x.clone(),
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
false,
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool2d_with_dilation() {
let kernel_size_1 = 2;
let kernel_size_2 = 2;
let padding_1 = 1;
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 1;
let dilation_1 = 2;
let dilation_2 = 2;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
[0.2479, 0.6386, 0.3166, 0.5742],
[0.7065, 0.1940, 0.6305, 0.8959],
[0.5416, 0.8602, 0.8129, 0.1662],
[0.3358, 0.3059, 0.8293, 0.0990],
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
[[[
[0., 0., 0., 0.],
[1., 1., 1., 2.],
[0., 4., 4., 0.],
[0., 1., 2., 0.],
]]],
&device,
);
let output = max_pool2d(
x.clone(),
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
false,
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool2d_complex() {
let kernel_size_1 = 4;
let kernel_size_2 = 2;
let padding_1 = 2;
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 2;
let dilation_1 = 1;
let dilation_2 = 1;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
[0.5388, 0.0676, 0.7122, 0.8316, 0.0653],
[0.9154, 0.1536, 0.9089, 0.8016, 0.7518],
[0.2073, 0.0501, 0.8811, 0.5604, 0.5075],
[0.4384, 0.9963, 0.9698, 0.4988, 0.2609],
[0.3391, 0.2230, 0.4610, 0.5365, 0.6880],
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
[[[
[0., 0., 0., 3., 0.],
[4., 0., 2., 1., 0.],
[0., 0., 0., 0., 0.],
[2., 4., 0., 0., 0.],
[0., 0., 0., 0., 2.],
]]],
&device,
);
let output = max_pool2d(
x.clone(),
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
false,
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool2d_ceil_mode() {
// Test ceil_mode=true with gradient computation
// Using 1x1x6x6 input with kernel 3x3, stride 2x2, padding 0
// Floor mode: output 2x2
// Ceil mode: output 3x3
let kernel_size_1 = 3;
let kernel_size_2 = 3;
let padding_1 = 0;
let padding_2 = 0;
let stride_1 = 2;
let stride_2 = 2;
let dilation_1 = 1;
let dilation_2 = 1;
let device = Default::default();
// Input (values 1-36):
let x = TestAutodiffTensor::from_floats(
[[[
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
[7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0, 17.0, 18.0],
[19.0, 20.0, 21.0, 22.0, 23.0, 24.0],
[25.0, 26.0, 27.0, 28.0, 29.0, 30.0],
[31.0, 32.0, 33.0, 34.0, 35.0, 36.0],
]]],
&device,
)
.require_grad();
// Expected gradients for ceil_mode output 3x3:
// Output positions and their max value positions:
// (0,0): max at (2,2)=15 -> grad[2,2] += 1
// (0,1): max at (2,4)=17 -> grad[2,4] += 1
// (0,2): max at (2,5)=18 -> grad[2,5] += 1
// (1,0): max at (4,2)=27 -> grad[4,2] += 1
// (1,1): max at (4,4)=29 -> grad[4,4] += 1
// (1,2): max at (4,5)=30 -> grad[4,5] += 1
// (2,0): max at (5,2)=33 -> grad[5,2] += 1
// (2,1): max at (5,4)=35 -> grad[5,4] += 1
// (2,2): max at (5,5)=36 -> grad[5,5] += 1
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
[[[
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 1., 1.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 1., 1.],
[0., 0., 1., 0., 1., 1.],
]]],
&device,
);
let output = max_pool2d(
x.clone(),
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
true,
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}

View File

@@ -0,0 +1,290 @@
use super::*;
use burn_tensor::{Tensor, TensorData};
#[test]
fn test_mm_independent_trees() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// First tree
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_4 = tensor_0 * tensor_1;
let tensor_5 = tensor_2 * tensor_3;
let tensor_6 = tensor_4 * tensor_5;
// Second tree
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_11 = tensor_7.clone() * tensor_8.clone();
let tensor_12 = tensor_9.clone() * tensor_10.clone();
let tensor_13 = tensor_11 * tensor_12;
let _grads = tensor_6.backward();
let grads = tensor_13.backward();
assert!(tensor_7.grad(&grads).is_some());
assert!(tensor_8.grad(&grads).is_some());
assert!(tensor_9.grad(&grads).is_some());
assert!(tensor_10.grad(&grads).is_some());
}
#[test]
#[should_panic]
fn test_mm_crossover_trees_root_unavailable() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// First tree
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_4 = tensor_0 * tensor_1;
let tensor_5 = tensor_2 * tensor_3;
let tensor_6 = tensor_4.clone() * tensor_5;
// Second tree
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_9 = tensor_7.clone() * tensor_8.clone();
let tensor_10 = tensor_4 * tensor_9;
let _grads = tensor_6.backward();
let _grads = tensor_10.backward();
}
#[test]
fn test_mm_crossover_trees_with_referred_subtree() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// First tree
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_4 = tensor_0 * tensor_1;
let tensor_5 = tensor_2 * tensor_3;
let tensor_6 = tensor_4.clone() * tensor_5;
// Second tree
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_9 = tensor_7.clone() * tensor_8.clone();
let _tensor_10 = tensor_4 * tensor_9.clone();
let _grads = tensor_6.backward();
let _grads = tensor_9.backward();
}
#[test]
fn test_mm_three_crossover_trees_last_still_usable() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// First tree
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_4 = tensor_0 * tensor_1;
let tensor_5 = tensor_2 * tensor_3;
let tensor_6 = tensor_4 * tensor_5.clone();
// Third tree
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_11 = tensor_7 * tensor_8;
let tensor_12 = tensor_9 * tensor_10;
let tensor_13 = tensor_11 * tensor_12.clone();
// Second tree (in between)
let _tensor_14 = tensor_5 * tensor_12;
let _grads = tensor_6.backward();
let _grads = tensor_13.backward();
}
#[test]
#[should_panic]
fn test_mm_three_crossover_trees_middle_one_unavailable() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// First tree
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_4 = tensor_0 * tensor_1;
let tensor_5 = tensor_2 * tensor_3;
let tensor_6 = tensor_4 * tensor_5.clone();
// Third tree
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_11 = tensor_7 * tensor_8;
let tensor_12 = tensor_9 * tensor_10;
let _tensor_13 = tensor_11 * tensor_12.clone();
// Second tree (in between)
let tensor_14 = tensor_5 * tensor_12;
let _grads = tensor_6.backward();
let _grads = tensor_14.backward();
}
#[test]
fn test_mm_self_referencing_tree() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// First tree
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_3 = tensor_0 * tensor_1;
let tensor_5 = tensor_2 * tensor_3.clone();
let tensor_6 = tensor_3 * tensor_5;
let _grads = tensor_6.backward();
}
#[test]
fn test_mm_with_non_impacting_detach() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_3 = Tensor::<TestAutodiffBackend, 2>::from_data(data, &device).require_grad();
let tensor_4 = tensor_1.clone() * tensor_2.clone();
let tensor_5 = tensor_4.detach() * tensor_3.clone();
let grads = tensor_5.backward();
assert!(tensor_3.grad(&grads).is_some());
}
#[test]
fn test_mm_with_missing_require_grad_after_cleanup() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device);
let tensor_3 = Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device);
let tensor_4 = tensor_1.clone() * tensor_2.clone();
let tensor_5 = tensor_4 * tensor_3.clone();
// Trivial backward, just to trigger cleanup
Tensor::<TestAutodiffBackend, 2>::from_data(data, &device)
.require_grad()
.backward();
let grads = tensor_5.backward();
assert!(tensor_1.grad(&grads).is_some());
assert!(tensor_2.grad(&grads).is_none());
assert!(tensor_3.grad(&grads).is_none());
}
#[test]
fn test_mm_with_detach_after_cleanup() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_3 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_4 = tensor_1.clone() * tensor_2.clone();
let tensor_5 = tensor_4 * tensor_3.clone().detach();
// Trivial backward, just to trigger cleanup
Tensor::<TestAutodiffBackend, 2>::from_data(data, &device)
.require_grad()
.backward();
let grads = tensor_5.backward();
assert!(tensor_1.grad(&grads).is_some());
assert!(tensor_2.grad(&grads).is_some());
assert!(tensor_3.grad(&grads).is_none());
}
#[test]
#[should_panic]
fn test_mm_deletables_propagate_well() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
let tensor_0 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 = tensor_0 * tensor_1;
let tensor_3 = tensor_2.clone().exp();
let _tensor_4 = tensor_3.clone().log();
let _grads = tensor_2.backward();
// We are testing that after backward on tensor_2, not only the leaf tensor_4 is deleted, but
// the intermediate tensor_3 as well
let _grads = tensor_3.backward();
}
#[test]
fn test_mm_node_explored_once_can_still_be_tagged_as_useful_when_found_again_deeper() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// The test has 50% chance of starting with leaf tensor_8 instead of tensor_4, which is not informative
// By repeating it many times it becomes almost impossible that it passes if it shouldn't
for _ in 0..12 {
let tensor_0 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 = tensor_1.clone().exp();
let tensor_3 = tensor_0.exp();
let _tensor_4 = tensor_3.clone() * tensor_2.clone();
let tensor_5 = tensor_2.exp();
let tensor_6 = tensor_5.exp();
let tensor_7 = tensor_6.exp();
let tensor_8 = tensor_7.exp();
// tensor_2 should be tagged unknown through the leaf tensor_4, then useful through the leaf tensor_8
// which should happen after because tensor_2 is deeper from tensor_8 point of view and we're in breadth first search
tensor_3.backward();
let grads = tensor_8.backward();
assert!(tensor_1.grad(&grads).is_some());
}
}

View File

@@ -0,0 +1,74 @@
#[allow(unused_imports)] // required for re-included modules
pub use super::*;
mod abs;
mod adaptive_avgpool1d;
mod adaptive_avgpool2d;
mod add;
mod aggregation;
mod avgpool1d;
mod avgpool2d;
mod backward;
mod bridge;
mod broadcast;
mod cast;
mod cat;
mod ceil;
mod checkpoint;
mod complex;
mod conv1d;
mod conv2d;
mod conv3d;
mod conv_transpose1d;
mod conv_transpose2d;
mod conv_transpose3d;
mod cross;
mod cross_entropy;
mod cummax;
mod cummin;
mod cumprod;
mod cumsum;
mod deform_conv2d;
mod div;
mod erf;
mod exp;
mod expand;
mod flip;
mod floor;
mod gather_scatter;
mod gelu;
mod gradients;
mod log;
mod log1p;
mod log_sigmoid;
mod mask;
mod matmul;
mod maxmin;
mod maxpool1d;
mod maxpool2d;
mod memory_management;
mod mul;
mod multithread;
mod nearest_interpolate;
mod neg;
mod nonzero;
mod permute;
mod pow;
mod recip;
mod relu;
mod remainder;
mod repeat_dim;
mod reshape;
mod round;
mod select;
mod sigmoid;
mod sign;
mod slice;
mod slice_assign;
mod softmax;
mod sort;
mod sqrt;
mod sub;
mod transpose;
mod trig;
mod unfold;

View File

@@ -0,0 +1,68 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_mul() {
let data_1 = TensorData::from([1.0, 7.0]);
let data_2 = TensorData::from([4.0, 7.0]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();
let tensor_3 = tensor_1.clone().mul(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let _grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_eq(&data_2, false);
tensor_3
.into_data()
.assert_eq(&TensorData::from([4.0, 49.0]), false);
}
#[test]
fn should_diff_mul_scalar() {
let data = TensorData::from([2.0, 5.0]);
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
let tensor_out = tensor.clone().mul_scalar(4.0);
let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();
tensor_out
.into_data()
.assert_eq(&TensorData::from([8.0, 20.0]), false);
grad.to_data()
.assert_eq(&TensorData::from([4.0, 4.0]), false);
}
#[test]
fn test_mul_complex_1() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().mul(tensor_2.clone());
let tensor_5 = tensor_4.mul(tensor_3);
let tensor_6 = tensor_1.clone().mul(tensor_5);
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[16.0, 196.0], [104.0, -36.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[2.0, 98.0], [338.0, 18.0]]), false);
}

View File

@@ -0,0 +1,88 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_behave_the_same_with_multithread() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let with_move = || {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.clone().matmul(tensor_2.clone());
let tensor_5 = tensor_4.matmul(tensor_3);
// Task 1
let tensor_1_cloned = tensor_1.clone();
let tensor_2_cloned = tensor_2.clone();
let tensor_5_cloned = tensor_5.clone();
let first_call = move || {
let tensor_6_1 = tensor_5_cloned.matmul(tensor_2_cloned);
tensor_6_1.matmul(tensor_1_cloned)
};
// Task 2
let tensor_1_cloned = tensor_1.clone();
let tensor_2_cloned = tensor_2.clone();
let tensor_5_cloned = tensor_5;
let second_call = move || {
let tensor_6_2 = tensor_5_cloned.matmul(tensor_1_cloned);
tensor_6_2.matmul(tensor_2_cloned)
};
let tensor_7_1_handle = std::thread::spawn(first_call);
let tensor_7_2_handle = std::thread::spawn(second_call);
let tensor_7_1 = tensor_7_1_handle.join().unwrap();
let tensor_7_2 = tensor_7_2_handle.join().unwrap();
let tensor_8 = tensor_7_1.matmul(tensor_7_2);
let grads = tensor_8.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
(grad_1, grad_2)
};
let without_move = || {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.clone().matmul(tensor_2.clone());
let tensor_5 = tensor_4.matmul(tensor_3);
// Task 1
let tensor_6_1 = tensor_5.clone().matmul(tensor_2.clone());
let tensor_7_1 = tensor_6_1.matmul(tensor_1.clone());
// Task 2
let tensor_6_2 = tensor_5.matmul(tensor_1.clone());
let tensor_7_2 = tensor_6_2.matmul(tensor_2.clone());
let tensor_8 = tensor_7_1.matmul(tensor_7_2);
let grads = tensor_8.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
(grad_1, grad_2)
};
let (grad_1, grad_2) = without_move();
let (grad_1_moved, grad_2_moved) = with_move();
grad_1
.into_data()
.assert_approx_eq::<FloatElem>(&grad_1_moved.into_data(), Tolerance::default());
grad_2
.into_data()
.assert_approx_eq::<FloatElem>(&grad_2_moved.into_data(), Tolerance::default());
}

View File

@@ -0,0 +1,97 @@
use super::*;
use burn_tensor::Shape;
use burn_tensor::Tolerance;
use burn_tensor::module::interpolate;
use burn_tensor::ops::{InterpolateMode, InterpolateOptions};
#[test]
fn test_upsample_interpolation() {
let test = InterpolateTestCase {
batch_size: 2,
channels: 1,
height: 7,
width: 5,
height_out: 8,
width_out: 7,
};
test.assert_output(TestTensor::from([
[[
[4., 2., 4., 2., 2.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
]],
[[
[4., 2., 4., 2., 2.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
]],
]));
}
#[test]
fn test_downsample_interpolation() {
let test = InterpolateTestCase {
batch_size: 1,
channels: 1,
height: 8,
width: 8,
height_out: 4,
width_out: 6,
};
test.assert_output(TestTensor::from([[[
[1., 1., 1., 0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
]]]));
}
struct InterpolateTestCase {
batch_size: usize,
channels: usize,
height: usize,
width: usize,
height_out: usize,
width_out: usize,
}
impl InterpolateTestCase {
fn assert_output(self, x_grad: TestTensor<4>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &x_grad.device())
.reshape::<4, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = interpolate(
x.clone(),
[self.height_out, self.width_out],
InterpolateOptions::new(InterpolateMode::Nearest),
);
let grads = output.backward();
let x_grad_actual = x.grad(&grads).unwrap();
x_grad
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.into_data(), Tolerance::permissive());
}
}

View File

@@ -0,0 +1,26 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_neg() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().neg());
let tensor_4 = tensor_3.neg();
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[3.0, 3.0], [10.0, 10.0]]), false);
}

View File

@@ -0,0 +1,41 @@
use super::*;
use burn_tensor::{Bool, Tensor, TensorData};
#[test]
fn should_diff_nonzero() {
let data_1 = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let data_2 = TensorData::from([-1.0, 1.0]);
let mask = TensorData::from([[false, true], [true, false]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad();
// Multi-dimensional tensor indexing isn't really supported yet so the easiest way to do
// this is to flatten the mask and tensor to get proper indexing. Anyway the returned tensor would
// have dimensions different from the input, so this is somewhat equivalent.
let mask = Tensor::<TestAutodiffBackend, 2, Bool>::from_bool(mask, &device).flatten::<1>(0, 1);
let indices = mask.nonzero();
let tensor_3 = tensor_1
.clone()
.flatten::<1>(0, 1)
.select(0, indices[0].clone());
// Vector dot product not supported (only 2D matmuls) so unsqueeze for test purposes
let tensor_4 = tensor_2
.clone()
.unsqueeze_dim::<2>(0)
.matmul(tensor_3.unsqueeze_dim(1));
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[0.0, -1.0], [1.0, 0.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([2.0, 3.0]), false);
}

View File

@@ -0,0 +1,29 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_permute() {
let data_1 = TensorData::from([[[1.0, 7.0], [2.0, 3.0]]]); // 1x2x2
let data_2 = TensorData::from([[[1.0, 7.0], [3.2, 2.0], [3.0, 3.0]]]); // 1x3x2
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_2.clone().permute([0, 2, 1]);
let tensor_4 = tensor_1.clone().matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
grad_1
.into_data()
.assert_approx_eq::<FloatElem>(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), tolerance); // 1x2x2
grad_2.into_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[[3.0, 10.0], [3.0, 10.0], [3.0, 10.0]]]),
tolerance,
); // 1x3x2
}

View File

@@ -0,0 +1,93 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_powf_scalar() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().powf_scalar(0.4));
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(2e-3);
let expected = TensorData::from([[68.0, 79.0328], [68.0, 79.0328]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[23.5081, 25.2779], [26.0502, 28.6383]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}
#[test]
fn should_diff_powf() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data([2.0, 7.0], &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data([4.0, 2.0], &device).require_grad();
let tensor_3 = tensor_1.clone().powf(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([32.0, 14.0]);
grad_1
.into_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([11.09035, 95.34960]);
grad_2
.into_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([16.0, 49.0]);
tensor_3
.into_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_powf_with_untracked_lhs() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data([2.0, 7.0], &device);
let tensor_2 = TestAutodiffTensor::from_data([4.0, 2.0], &device).require_grad();
let tensor_3 = tensor_1.clone().powf(tensor_2.clone());
let grads = tensor_3.backward();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([11.09035, 95.34960]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_powf_with_untracked_rhs() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data([2.0, 7.0], &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data([4.0, 2.0], &device);
let tensor_3 = tensor_1.clone().powf(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let expected = TensorData::from([32.0, 14.0]);
grad_1
.into_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,22 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_recip() {
let data = TensorData::from([2.0, 5.0, 0.4]);
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
let tensor_out = tensor.clone().recip();
let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();
tensor_out
.into_data()
.assert_eq(&TensorData::from([0.5, 0.2, 2.5]), false);
grad.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([-0.25, -0.04, -6.25]),
Tolerance::default(),
);
}

View File

@@ -0,0 +1,27 @@
use super::*;
use burn_tensor::{TensorData, activation};
#[test]
fn should_diff_relu() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = activation::relu(tensor_3);
let tensor_5 = tensor_4.matmul(tensor_2.clone());
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[-47.0, 9.0], [-35.0, 15.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[15.0, 13.0], [-2.0, 39.0]]), false);
}

View File

@@ -0,0 +1,41 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_remainder() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data(
TensorData::from([
0.9742, 0.3676, 0.0905, 0.8066, 0.7072, 0.7883, 0.6987, 0.1560, 0.7179, 0.7874, 0.9032,
0.1845,
]),
&device,
)
.require_grad();
let tensor_2 = TestAutodiffTensor::<1>::from_data(
TensorData::from([
0.3357, 0.0285, 0.4115, 0.5511, 0.8637, 0.3593, 0.3885, 0.2569, 0.0936, 0.7172, 0.4792,
0.4898,
]),
&device,
)
.require_grad();
let tensor_3 = tensor_1.clone().remainder(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([
-2.0, -12.0, -0.0, -1.0, -0.0, -2.0, -1.0, -0.0, -7.0, -1.0, -1.0, -0.0,
]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,44 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_repeat() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0], [2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_2.clone().repeat_dim(1, 3);
let tensor_3 = tensor_1.matmul(tensor_3);
let grads = tensor_3.backward();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_2
.to_data()
.assert_eq(&TensorData::from([[-3.0], [12.0]]), false);
}
#[test]
fn should_diff_repeat_multi_dim() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 2.0], [2.0, 4.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_2.clone().repeat_dim(1, 3);
let tensor_3 = tensor_1.matmul(tensor_3);
let grads = tensor_3.backward();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_2
.to_data()
.assert_eq(&TensorData::from([[-3.0, -3.0], [12.0, 12.0]]), false);
}

View File

@@ -0,0 +1,26 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_reshape() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([4.0, 7.0, 2.0, 3.0]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_2.clone().reshape([2, 2]);
let tensor_4 = tensor_1.clone().matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([3.0, 3.0, 10.0, 10.0]), false);
}

View File

@@ -0,0 +1,20 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_round() {
let data = TensorData::from([
[-1.9751, 0.0714, 0.0643, 0.2406],
[-1.3172, 0.1252, -0.1119, -0.0127],
]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_2 = tensor_1.clone().round();
let grads = tensor_2.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
grad_1.to_data().assert_eq(
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
false,
);
}

View File

@@ -0,0 +1,89 @@
use super::*;
use burn_tensor::{IndexingUpdateOp, Int, Tensor, TensorData};
#[test]
fn test_select_grad() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(
TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),
&device,
)
.require_grad();
let indices =
Tensor::<TestAutodiffBackend, 1, Int>::from_data(TensorData::from([1, 0]), &device);
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1.clone().select(0, indices);
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
grad_1.into_data().assert_eq(
&TensorData::from([[109., 148., 187.], [37., 58., 79.]]),
false,
);
}
#[test]
fn test_select_add_grad() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(
TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),
&device,
)
.require_grad();
let values = TestAutodiffTensor::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
)
.require_grad();
let indices =
Tensor::<TestAutodiffBackend, 1, Int>::from_data(TensorData::from([1, 0]), &device);
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 =
tensor_1
.clone()
.select_assign(0, indices, values.clone(), IndexingUpdateOp::Add);
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = values.grad(&grads).unwrap();
grad_1.into_data().assert_eq(
&TensorData::from([[127., 199., 271.], [172., 244., 316.]]),
false,
);
grad_2
.into_data()
.assert_eq(&TensorData::from([[64., 64., 64.], [19., 19., 19.]]), false);
}
#[test]
fn test_select_add_grad_different_shapes() {
let device = Default::default();
let indices: Tensor<TestAutodiffBackend, 1, Int> = Tensor::from_ints([1], &device);
let x: Tensor<TestAutodiffBackend, 2> = Tensor::ones([1, 1], &device).require_grad();
let y = Tensor::ones([2, 1], &device).require_grad();
let w = y
.clone()
.select_assign(0, indices, x.clone(), IndexingUpdateOp::Add);
let w = w.matmul(y.clone().transpose());
let grads = w.backward();
let x_grad = x.grad(&grads).unwrap();
let y_grad = y.grad(&grads).unwrap();
x_grad
.into_data()
.assert_eq(&TensorData::from([[2.0]]), false);
y_grad
.into_data()
.assert_eq(&TensorData::from([[5.0], [5.0]]), false);
}

View File

@@ -0,0 +1,35 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{TensorData, activation};
#[test]
fn should_diff_sigmoid() {
let data = TensorData::from([0.8762]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data(data, &device).require_grad();
let tensor_2 = activation::sigmoid(tensor_1.clone());
let grads = tensor_2.backward();
let grad = tensor_1.grad(&grads).unwrap();
let expected = TensorData::from([0.207549]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn small_neg_val_should_not_cause_grad_overflow() {
let data = TensorData::from([-90.0]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data(data, &device).require_grad();
let tensor_2 = activation::sigmoid(tensor_1.clone());
let grads = tensor_2.backward();
let grad = tensor_1.grad(&grads).unwrap();
let expected = TensorData::from([0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,42 @@
use super::*;
use burn_tensor::TensorData;
/// Example using the sign function with PyTorch:
// >>> import torch
// >>> # Create a tensor with requires_grad=True
// >>> x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
// >>> # Forward pass: Apply the sign function
// >>> y = torch.sign(x)
// >>> print("Forward pass:")
// Forward pass:
// >>> print("x:", x)
// x: tensor([-2., -1., 0., 1., 2.], requires_grad=True)
// >>> print("y:", y)
// y: tensor([-1., -1., 0., 1., 1.], grad_fn=<SignBackward0>)
// >>> # Compute the loss (just an example)
// >>> loss = y.sum()
// >>> # Backward pass: Compute the gradients
// >>> loss.backward()
// >>> print("\nBackward pass:")
// Backward pass:
// >>> print("x.grad:", x.grad)
// x.grad: tensor([0., 0., 0., 0., 0.])
#[test]
fn should_diff_sign() {
let data = TensorData::from([-2.0, -1.0, 0.0, 1.0, 2.0]);
let device = Default::default();
let x = TestAutodiffTensor::<1>::from_data(data, &device).require_grad();
let y = x.clone().sign();
let loss = y.clone().sum();
let grads = loss.backward();
let grad = x.grad(&grads).unwrap();
y.to_data()
.assert_eq(&TensorData::from([-1., -1., 0., 1., 1.]), false);
grad.to_data()
.assert_eq(&TensorData::from([0., 0., 0., 0., 0.]), false);
}

View File

@@ -0,0 +1,67 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_matmul_with_slice() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([[4.0, 7.0, 100.0], [2.0, 3.0, 15.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_2.clone().slice([0..2, 0..2]);
let tensor_4 = tensor_1.clone().matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);
grad_2.to_data().assert_eq(
&TensorData::from([[3.0, 3.0, 0.0], [10.0, 10.0, 0.0]]),
false,
);
}
#[test]
fn should_diff_matmul_with_slice_stepped() {
use burn_tensor::s;
let data_1 = TensorData::from([[1.0, 7.0], [100.0, 100.0], [2.0, 3.0], [100.0, 100.0]]);
let data_2 = TensorData::from([[4.0, 100.0, 7.0, 100.0], [2.0, 100.0, 3.0, 15.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().slice(s![0..;2, 0..2]); // [[1., 7.], [2., 3.]]
let tensor_4 = tensor_2.clone().slice(s![0..2, 0..;2]); // [[4., 7.], [2., 3.]]
let tensor_5 = tensor_3.clone().matmul(tensor_4);
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_eq(
&TensorData::from([[11., 5.], [0., 0.], [11., 5.], [0., 0.]]),
false,
);
grad_2.to_data().assert_eq(
&TensorData::from([[3., 0., 3., 0.], [10., 0., 10., 0.]]),
false,
);
}
#[test]
fn should_panic_on_slice_with_step() {
use burn_tensor::s;
let data = TensorData::from([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]);
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
// This should panic because step is 2
let _sliced = tensor.slice(s![.., 0..4; 2]);
}

View File

@@ -0,0 +1,163 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_matmul_with_slice_assign() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_assigned = TensorData::from([[9.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_assigned = TestAutodiffTensor::from_data(data_assigned, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.slice_assign([0..1, 0..1], tensor_assigned);
let tensor_5 = tensor_4.matmul(tensor_1.clone());
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[58.0, 38.0], [118.0, 82.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[16.0, 15.0], [24.0, 50.0]]), false);
}
#[test]
fn should_diff_matmul_with_slice_assign_complex() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[9.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_5 = tensor_2.clone().slice([0..1, 0..1]);
let tensor_6 = tensor_5.mul(tensor_3.clone());
let tensor_7 = tensor_4.slice_assign([0..1, 0..1], tensor_6);
let tensor_8 = tensor_7.matmul(tensor_1.clone());
let grads = tensor_8.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let grad_3 = tensor_3.grad(&grads).unwrap();
grad_3
.to_data()
.assert_eq(&TensorData::from([[32.0]]), false);
grad_1
.to_data()
.assert_eq(&TensorData::from([[85.0, 65.0], [118.0, 82.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[88.0, 15.0], [24.0, 50.0]]), false);
}
#[test]
fn slice_assign_diff_should_give_same_results_as_cat() {
let data_1 = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[5.0, 6.0], [7.0, 8.0]]);
let data_3 = TensorData::from([[14.0, 97.0, 100.0, 9.0], [2.0, 3.0, 15.0, 7.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device);
let slice_assign_output = TestAutodiffTensor::zeros([2, 4], &Default::default());
let slice_assign_output = slice_assign_output.slice_assign([0..2, 0..2], tensor_1.clone());
let slice_assign_output = slice_assign_output.slice_assign([0..2, 2..4], tensor_2.clone());
let slice_assign_output = slice_assign_output / tensor_3.clone();
let cat_output = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 1);
let cat_output = cat_output / tensor_3;
slice_assign_output
.to_data()
.assert_approx_eq::<FloatElem>(&cat_output.to_data(), Tolerance::default());
let slice_assign_grads = slice_assign_output.backward();
let cat_grads = cat_output.backward();
let slice_assign_grad_1 = tensor_1.grad(&slice_assign_grads).unwrap();
let slice_assign_grad_2 = tensor_2.grad(&slice_assign_grads).unwrap();
let cat_grad_1 = tensor_1.grad(&cat_grads).unwrap();
let cat_grad_2 = tensor_2.grad(&cat_grads).unwrap();
slice_assign_grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&cat_grad_1.to_data(), Tolerance::default());
slice_assign_grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&cat_grad_2.to_data(), Tolerance::default());
}
#[test]
fn should_diff_slice_assign_with_step() {
use burn_tensor::s;
let data = TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
let value_data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
let value = TestAutodiffTensor::<2>::from_data(value_data, &device).require_grad();
// Assign with step=2
let result = tensor.clone().slice_assign(s![.., 0..4; 2], value.clone());
let result = result * 2.0; // Scale to create gradients
let grads = result.backward();
let grad_tensor = tensor.grad(&grads).unwrap();
let grad_value = value.grad(&grads).unwrap();
// The gradient for tensor should be 2.0 everywhere except the assigned positions
grad_tensor.to_data().assert_eq(
&TensorData::from([[0.0, 2.0, 0.0, 2.0], [0.0, 2.0, 0.0, 2.0]]),
false,
);
// The gradient for value should be 2.0 at all positions
grad_value
.to_data()
.assert_eq(&TensorData::from([[2.0, 2.0], [2.0, 2.0]]), false);
}
#[test]
fn should_diff_slice_assign_with_negative_step() {
use burn_tensor::s;
let data = TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
let value_data = TensorData::from([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]);
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
let value = TestAutodiffTensor::<2>::from_data(value_data, &device).require_grad();
// Assign with step=-1 (reverse order, all elements)
let result = tensor.clone().slice_assign(s![.., ..;-1], value.clone());
let result = result * 2.0; // Scale to create gradients
let grads = result.backward();
let grad_tensor = tensor.grad(&grads).unwrap();
let grad_value = value.grad(&grads).unwrap();
// The gradient for tensor should be 0 since all values were replaced
grad_tensor.to_data().assert_eq(
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
false,
);
// The gradient for value should be 2.0 at all positions
grad_value.to_data().assert_eq(
&TensorData::from([[2.0, 2.0, 2.0, 2.0], [2.0, 2.0, 2.0, 2.0]]),
false,
);
}

View File

@@ -0,0 +1,90 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{Tensor, TensorData, activation};
#[test]
fn test_softmax_grad() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[1.179665, 1.179661], [0.005462, 0.005463]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(0.05, 0.5));
let expected = TensorData::from([[0.253469, 0.286237], [0.528630, 2.931664]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(0.05, 0.05));
}
#[test]
fn test_log_softmax_grad() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = activation::log_softmax(tensor_3, 1).matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[-4.3939, -4.3939], [-12.9709, -12.9709]]);
// f16 gradients from log-softmax + matmul amplify error, so we increase the tolerance
// to account for limited precision and large representable step sizes in this range.
let tolerance = Tolerance::permissive().set_half_precision_relative(6e-2);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[30.5984, -47.2267], [55.9631, -56.5914]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}
#[test]
fn test_quiet_softmax_grad() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[1.179665, 1.179661], [0.005462, 0.005463]]);
// Precision is quite bad yet on softmax grad especially with half precision.
let tolerance = Tolerance::rel_abs(0.5, 0.2);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[0.253469, 0.286237], [0.528630, 2.931664]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,82 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_sort() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.sort(1));
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[35.0, 35.0], [-1.0, -8.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[11.0, 7.0], [55.0, 16.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_sort_with_indices() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let (values, _indices) = tensor_3.sort_with_indices(1);
let tensor_4 = tensor_1.clone().mul(values);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[35.0, 35.0], [-1.0, -8.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[11.0, 7.0], [55.0, 16.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_sort_3d_dim1() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<3>::from_floats([[[1.0, 7.0], [-2.0, -3.0]]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_floats([[[4.0, -7.0], [2.0, 3.0]]], &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.sort(1));
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[[-1., -8.], [-27., 37.]]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[[-4., -17.], [-17., -42.]]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,31 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_sqrt() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sqrt());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
let expected = TensorData::from([[82.112640, 99.083275], [82.112640, 99.083275]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[30.309311, 33.120457], [34.581974, 38.769463]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,73 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_sub() {
let data_1 = TensorData::from([2.0, 5.0]);
let data_2 = TensorData::from([4.0, 1.0]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().sub(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([1.0, 1.0]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([-1.0, -1.0]), false);
tensor_3
.into_data()
.assert_eq(&TensorData::from([-2.0, 4.0]), false);
}
#[test]
fn should_diff_sub_scalar() {
let data = TensorData::from([2.0, 10.0]);
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
let tensor_out = tensor.clone().sub_scalar(5.0);
let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();
grad.to_data()
.assert_eq(&TensorData::from([1.0, 1.0]), false);
tensor_out
.into_data()
.assert_eq(&TensorData::from([-3.0, 5.0]), false);
}
#[test]
fn test_sub_complex_1() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().sub(tensor_2.clone());
let tensor_5 = tensor_4.sub(tensor_3).sub_scalar(5.0);
let tensor_6 = tensor_1.clone().sub(tensor_5);
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[0.0, 0.0], [0.0, 0.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[1.0, 1.0], [1.0, 1.0]]), false);
}

View File

@@ -0,0 +1,60 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_transpose() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().transpose());
let tensor_4 = tensor_3.transpose();
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[6.0, 10.0], [6.0, 10.0]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[3.0, 10.0], [3.0, 10.0]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_swap_dims() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<3>::from_floats(
[[[0.0, 1.0], [3.0, 4.0]], [[6.0, 7.0], [9.0, 10.0]]],
&device,
)
.require_grad();
let tensor_2 = TestAutodiffTensor::from_floats(
[[[1.0, 4.0], [2.0, 5.0]], [[7.0, 10.0], [8.0, 11.0]]],
&device,
)
.require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().swap_dims(0, 2));
let tensor_4 = tensor_3.matmul(tensor_2.clone().swap_dims(1, 2));
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[[66., 78.], [66., 78.]], [[270., 306.], [270., 306.]]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[[22., 286.], [28., 316.]], [[172., 652.], [190., 694.]]]),
Tolerance::default(),
);
}

View File

@@ -0,0 +1,371 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_cos() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().cos());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
// Metal has less precise trigonometric functions
let tolerance = Tolerance::default().set_half_precision_relative(1e-2);
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[26.8063, -27.7870], [26.8063, -27.7870]]),
tolerance,
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[9.222064, -39.123375], [-28.721354, 49.748356]]),
tolerance,
);
}
#[test]
fn should_diff_sin() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sin());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
// Metal has less precise trigonometric functions
let tolerance = Tolerance::default().set_half_precision_relative(1e-2);
let expected = TensorData::from([[8.8500, -4.9790], [8.8500, -4.9790]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[38.668987, 44.194775], [-59.97261, -80.46094]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}
#[test]
fn should_diff_tanh() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().tanh());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(8e-3);
let expected = TensorData::from([[32.0, 32.0], [32.0, 32.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[8.00092, 8.000153], [8.000003, 7.999995]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}
#[test]
fn should_diff_cosh() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().cosh());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[7.092221, 16.696301], [7.092221, 16.696301]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[17.489855, 27.484539], [39.409813, 86.910278]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_sinh() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sinh());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[4.894847, 15.887931], [4.894847, 15.887931]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[17.284000, 28.412029], [39.302979, 87.498329]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_tan() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[0.5, 1.0], [0.3, 0.8]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().tan());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[2.532602, 1.596607], [2.532602, 1.596607]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[9.028598, 14.489801], [18.038082, 21.151270]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_asin() {
let data_1 = TensorData::from([[0.0, 0.1], [0.3, 0.4]]);
let data_2 = TensorData::from([[0.2, 0.3], [0.5, 0.6]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().asin());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[0.435841, 0.969651], [0.435841, 0.969651]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[0.475300, 0.668141], [0.701834, 1.100658]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_acos() {
let data_1 = TensorData::from([[0.0, 0.1], [0.3, 0.4]]);
let data_2 = TensorData::from([[0.2, 0.3], [0.5, 0.6]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().acos());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[2.077433, 1.543624], [2.077433, 1.543624]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[0.781337, 0.588496], [0.554804, 0.155979]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_atan() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().atan());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[3.444365, 5.349211], [3.444365, 5.349211]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[9.904911, 11.554912], [10.199631, 11.391938]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_asinh() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().asinh());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[3.806625, 6.844869], [3.806625, 6.844869]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[11.442373, 14.842072], [14.022551, 17.688538]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_acosh() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[1.5, 2.0], [2.5, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().acosh());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[10.611752, 15.178907], [10.611752, 15.178907]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[20.112753, 20.247547], [20.402235, 22.487328]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_atanh() {
let data_1 = TensorData::from([[0.0, 0.1], [0.3, 0.4]]);
let data_2 = TensorData::from([[0.2, 0.3], [0.5, 0.6]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().atanh());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[0.441838, 1.037115], [0.441838, 1.037115]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[0.491723, 0.698110], [0.772763, 1.298805]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_atan2() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);
let data_3 = TensorData::from([[1.0, 0.5], [2.0, 1.5]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1
.clone()
.matmul(tensor_2.clone().atan2(tensor_3.clone()));
let tensor_5 = tensor_4.matmul(tensor_2.clone());
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let grad_3 = tensor_3.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[4.570492, 4.210785], [4.570492, 4.210785]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[8.208448, 8.808449], [10.357923, 12.157923]]),
Tolerance::default(),
);
grad_3.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[-1.8, -8.4], [-1.8, -5.6]]),
Tolerance::default(),
);
}

View File

@@ -0,0 +1,18 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn unfold_backward_accumulates_overlaps() {
let device = Default::default();
let x = TestAutodiffTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0]], &device).require_grad();
let y = x.clone().unfold::<3, _>(1, 2, 1);
let loss = y.sum();
let grads = loss.backward();
let grad_x = x.grad(&grads).unwrap();
grad_x
.to_data()
.assert_eq(&TensorData::from([[1., 2., 2., 1.]]), false);
}

View File

@@ -0,0 +1,35 @@
// Burn autodiff tests, reusable with element types.
pub use super::*;
#[path = "../autodiff/mod.rs"]
mod base;
mod checkpointing {
pub use super::*;
use burn_autodiff::checkpoint::strategy::BalancedCheckpointing;
// Override type def
pub type TestAutodiffBackend = Autodiff<TestBackend, BalancedCheckpointing>;
pub type TestAutodiffTensor<const D: usize> = Tensor<TestAutodiffBackend, D>;
include!("../autodiff/mod.rs");
}
use burn_backend_tests::test_float_elem_variant;
// NOTE: this currently doesn't test checkpointing with different dtypes
test_float_elem_variant!(
f16,
burn_tensor::f16,
"../autodiff/mod.rs",
["vulkan", "cuda", "rocm", "metal"]
);
// TODO: bf16 not yet supported on any backend for full test suite
// test_float_elem_variant!(
// bf16,
// burn_tensor::bf16,
// "../autodiff/mod.rs",
// [] // ["cuda", "rocm"] TODO, ["vulkan"] only supports bf16 for matmul, metal/wgpu doesn't support bf16
// );

View File

@@ -0,0 +1,47 @@
// Re-export
use super::FloatElemType;
// Default
#[cfg(feature = "ndarray")]
pub type TestBackend = burn_ndarray::NdArray<FloatElemType>;
#[cfg(feature = "tch")]
pub type TestBackend = burn_tch::LibTorch<FloatElemType>;
#[cfg(feature = "cuda")]
pub type TestBackend = burn_cuda::Cuda<FloatElemType, super::IntElemType>;
#[cfg(feature = "rocm")]
pub type TestBackend = burn_rocm::Rocm<FloatElemType, super::IntElemType>;
#[cfg(feature = "wgpu")]
pub type TestBackend = burn_wgpu::Wgpu<FloatElemType, super::IntElemType>;
#[cfg(feature = "cpu")]
pub type TestBackend = burn_cpu::Cpu<FloatElemType, super::IntElemType>;
#[cfg(feature = "router")]
pub type TestBackend = burn_router::BackendRouter<
burn_router::DirectByteChannel<(burn_ndarray::NdArray, burn_wgpu::Wgpu)>,
>;
/// Collection of types used across tests
#[allow(unused)]
pub mod prelude {
pub use burn_autodiff::Autodiff;
pub use burn_tensor::Tensor;
use super::*;
pub type TestTensor<const D: usize> = Tensor<TestBackend, D>;
pub type TestTensorInt<const D: usize> = Tensor<TestBackend, D, burn_tensor::Int>;
pub type TestTensorBool<const D: usize> = Tensor<TestBackend, D, burn_tensor::Bool>;
pub type FloatElem = burn_tensor::ops::FloatElem<TestBackend>;
pub type IntElem = burn_tensor::ops::IntElem<TestBackend>;
pub type TestAutodiffBackend = Autodiff<TestBackend>;
pub type TestAutodiffTensor<const D: usize> = Tensor<TestAutodiffBackend, D>;
}
#[allow(unused)]
pub use prelude::*;

View File

@@ -0,0 +1,39 @@
// Burn backend tensor tests, reusable with element types.
pub use super::*;
#[path = "../tensor/clone_invariance.rs"]
mod clone_invariance;
#[cfg(feature = "std")]
#[path = "../tensor/multi_threads.rs"]
mod multi_threads;
// Default float dtype
#[path = "../tensor/float/mod.rs"]
mod float;
// Default integer dtype
#[path = "../tensor/int/mod.rs"]
mod int;
// Default bool dtype
#[path = "../tensor/bool/mod.rs"]
mod bool;
use burn_backend_tests::test_float_elem_variant;
test_float_elem_variant!(
f16,
burn_tensor::f16,
"../tensor/float/mod.rs",
["vulkan", "cuda", "rocm", "metal"]
);
// TODO: bf16 not yet supported on any backend for full test suite
// test_float_elem_variant!(
// bf16,
// burn_tensor::bf16,
// "../tensor/float/mod.rs",
// [] // ["cuda", "rocm"] TODO, ["vulkan"] only supports bf16 for matmul, metal/wgpu doesn't support bf16
// );

View File

@@ -0,0 +1,17 @@
//! CubeCL kernel tests.
#[cfg(feature = "cube")]
#[path = "."]
mod cube {
type FloatElemType = f32;
type IntElemType = i32;
mod backend {
include!("common/backend.rs");
pub type ReferenceBackend = burn_ndarray::NdArray<FloatElemType>;
}
pub use backend::*;
#[path = "cubecl/mod.rs"]
mod kernel;
}

View File

@@ -0,0 +1,96 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{
Distribution, Tensor, TensorPrimitive, backend::Backend, module, ops::ModuleOps,
};
#[test]
fn avg_pool2d_should_match_reference_backend() {
let tensor = Tensor::<TestBackend, 4>::random(
[32, 32, 32, 32],
Distribution::Default,
&Default::default(),
);
let tensor_ref =
Tensor::<ReferenceBackend, 4>::from_data(tensor.to_data(), &Default::default());
let kernel_size = [3, 4];
let stride = [1, 2];
let padding = [1, 2];
let count_include_pad = true;
let pooled = module::avg_pool2d(
tensor,
kernel_size,
stride,
padding,
count_include_pad,
false,
);
let pooled_ref = module::avg_pool2d(
tensor_ref,
kernel_size,
stride,
padding,
count_include_pad,
false,
);
pooled
.into_data()
.assert_approx_eq::<FloatElem>(&pooled_ref.into_data(), Tolerance::default());
}
#[test]
fn avg_pool2d_backward_should_match_reference_backend() {
let device = Default::default();
TestBackend::seed(&device, 0);
ReferenceBackend::seed(&Default::default(), 0);
let tensor = Tensor::<TestBackend, 4>::random([32, 32, 32, 32], Distribution::Default, &device);
let tensor_ref =
Tensor::<ReferenceBackend, 4>::from_data(tensor.to_data(), &Default::default());
let kernel_size = [3, 3];
let stride = [1, 1];
let padding = [1, 1];
let count_include_pad = true;
let shape_out = module::avg_pool2d(
tensor.clone(),
kernel_size,
stride,
padding,
count_include_pad,
false,
)
.shape();
let grad_output =
Tensor::<TestBackend, 4>::random(shape_out, Distribution::Default, &Default::default());
let grad_output_ref =
Tensor::<ReferenceBackend, 4>::from_data(grad_output.to_data(), &Default::default());
let grad: Tensor<TestBackend, 4> =
Tensor::from_primitive(TensorPrimitive::Float(TestBackend::avg_pool2d_backward(
tensor.into_primitive().tensor(),
grad_output.into_primitive().tensor(),
kernel_size,
stride,
padding,
count_include_pad,
false,
)));
let grad_ref: Tensor<ReferenceBackend, 4> = Tensor::from_primitive(TensorPrimitive::Float(
ReferenceBackend::avg_pool2d_backward(
tensor_ref.into_primitive().tensor(),
grad_output_ref.into_primitive().tensor(),
kernel_size,
stride,
padding,
count_include_pad,
false,
),
));
grad.into_data()
.assert_approx_eq::<FloatElem>(&grad_ref.into_data(), Tolerance::default());
}

View File

@@ -0,0 +1,48 @@
use super::*;
use serial_test::serial;
use core::f32;
use burn_tensor::{Distribution, Shape, Tensor, backend::Backend};
use cubek::random::{assert_number_of_1_proportional_to_prob, assert_wald_wolfowitz_runs_test};
#[test]
#[serial]
fn number_of_1_proportional_to_prob() {
let device = Default::default();
TestBackend::seed(&device, 0);
let shape: Shape = [40, 40].into();
let prob = 0.7;
let tensor =
Tensor::<TestBackend, 2>::random(shape.clone(), Distribution::Bernoulli(prob), &device)
.into_data();
let numbers = tensor
.as_slice::<<TestBackend as Backend>::FloatElem>()
.unwrap();
assert_number_of_1_proportional_to_prob(numbers, prob as f32);
}
#[test]
#[serial]
fn wald_wolfowitz_runs_test() {
let device = Default::default();
TestBackend::seed(&device, 0);
let shape = Shape::new([512, 512]);
let device = Default::default();
let tensor = Tensor::<TestBackend, 2>::random(shape, Distribution::Bernoulli(0.5), &device);
let data = tensor.into_data();
let numbers = data
.as_slice::<<TestBackend as Backend>::FloatElem>()
.unwrap();
// High bound slightly over 1 so 1.0 is included in second bin
assert_wald_wolfowitz_runs_test(numbers, 0., 1.1);
}

View File

@@ -0,0 +1,45 @@
use super::*;
use burn_tensor::{Int, Tensor, TensorData};
#[test]
fn should_cast_int_to_float() {
const START: usize = 0;
const END: usize = 100;
let device = Default::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(START as i64..END as i64, &device);
let data_int = tensor.to_data();
let data_int = data_int.as_slice::<i32>().unwrap();
let data_float = tensor.float().into_data();
let data_float = data_float.as_slice::<f32>().unwrap();
for i in START..END {
assert_eq!(data_int[i], i as i32);
assert_eq!(data_float[i], i as f32);
}
}
#[test]
fn should_cast_bool_to_int() {
let device = Default::default();
let tensor_1 = Tensor::<TestBackend, 2>::from_floats([[1., 0., 3.], [0., 0., 900.]], &device);
let tensor_2: Tensor<TestBackend, 2, Int> = tensor_1.clone().greater_elem(0.0).int();
tensor_2
.to_data()
.assert_eq(&TensorData::from([[1, 0, 1], [0, 0, 1]]), false);
}
#[test]
fn should_cast_bool_to_float() {
let device = Default::default();
let tensor_1 = Tensor::<TestBackend, 2>::from_floats([[1., 0., 3.], [0., 0., 900.]], &device);
let tensor_2: Tensor<TestBackend, 2> = tensor_1.clone().greater_elem(0.0).float();
tensor_2
.to_data()
.assert_eq(&TensorData::from([[1., 0., 1.], [0., 0., 1.]]), false);
}

View File

@@ -0,0 +1,42 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{Distribution, Tensor, backend::Backend};
#[test]
fn cat_should_match_reference_backend_dim0() {
test_same_as_reference([6, 256], 2, 0);
}
#[test]
fn cat_should_match_reference_backend_dim1() {
test_same_as_reference([6, 256], 2, 1);
}
#[test]
fn cat_should_support_uneven_launch() {
test_same_as_reference([1, 137], 2, 0);
}
fn test_same_as_reference(shape: [usize; 2], num_tensors: usize, dim: usize) {
let device = Default::default();
TestBackend::seed(&device, 0);
let tensors = (0..num_tensors)
.map(|_| {
Tensor::<TestBackend, 2>::random(shape, Distribution::Default, &Default::default())
})
.collect::<Vec<_>>();
let tensors_ref = tensors
.iter()
.map(|tensor| {
Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data(), &Default::default())
})
.collect::<Vec<_>>();
let tensor = Tensor::<TestBackend, 2>::cat(tensors, dim);
let tensor_ref = Tensor::<ReferenceBackend, 2>::cat(tensors_ref, dim);
tensor
.into_data()
.assert_approx_eq::<FloatElem>(&tensor_ref.into_data(), Tolerance::default());
}

View File

@@ -0,0 +1,20 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{Distribution, Tensor};
#[test]
fn clamp_should_match_reference() {
let input = Tensor::<TestBackend, 4>::random(
[1, 5, 32, 32],
Distribution::Default,
&Default::default(),
);
let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data(), &Default::default());
let output = input.clamp(0.3, 0.7);
output.into_data().assert_approx_eq::<FloatElem>(
&input_ref.clamp(0.3, 0.7).into_data(),
Tolerance::default(),
);
}

View File

@@ -0,0 +1,40 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{Int, Tensor};
#[test]
pub fn into_contiguous_match_reference_backend_1() {
for shape in [
[4, 4, 4, 4],
[32, 42, 24, 48],
[8, 3, 7, 4],
[1, 4, 1, 1],
[1, 32, 256, 128],
] {
let num_elems = shape.iter().product::<usize>() as i64;
let tensor: Tensor<TestBackend, 4> =
Tensor::<TestBackend, 1, Int>::arange(0..num_elems, &Default::default())
.reshape(shape)
.float();
let tensor_ref =
Tensor::<ReferenceBackend, 4>::from_data(tensor.to_data(), &Default::default());
for (i, j) in get_combinations(shape.len()) {
let view = tensor.clone().swap_dims(i, j);
let view_ref = tensor_ref.clone().swap_dims(i, j);
let data = view.into_data();
let data_ref = view_ref.into_data();
data_ref.assert_approx_eq::<FloatElem>(&data, Tolerance::default());
}
}
}
fn get_combinations(n: usize) -> impl Iterator<Item = (usize, usize)> {
// Iterate from 0 up to n
(0..n).flat_map(move |i| {
// For each i, iterate from i + 1 up to n
// This ensures no repeats (i == j) and no duplicates (j, i)
(i + 1..n).map(move |j| (i, j))
})
}

View File

@@ -0,0 +1,78 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{Distribution, Tensor, module};
#[test]
fn conv2d_should_match_reference_backend() {
let test_device = Default::default();
let input =
Tensor::<TestBackend, 4>::random([6, 16, 32, 32], Distribution::Default, &test_device);
let weight =
Tensor::<TestBackend, 4>::random([12, 8, 3, 3], Distribution::Default, &test_device);
let bias = Tensor::<TestBackend, 1>::random([12], Distribution::Default, &test_device);
let ref_device = Default::default();
let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data(), &ref_device);
let weight_ref = Tensor::<ReferenceBackend, 4>::from_data(weight.to_data(), &ref_device);
let bias_ref = Tensor::<ReferenceBackend, 1>::from_data(bias.to_data(), &ref_device);
let options = burn_tensor::ops::ConvOptions::new([2, 3], [2, 3], [2, 3], 2);
let output = module::conv2d(input, weight, Some(bias), options.clone());
let output_ref = module::conv2d(input_ref, weight_ref, Some(bias_ref), options);
output
.into_data()
.assert_approx_eq::<FloatElem>(&output_ref.into_data(), Tolerance::default());
}
#[test]
fn conv2d_should_match_reference_backend_implicit() {
let test_device = Default::default();
let input =
Tensor::<TestBackend, 4>::random([4, 16, 6, 6], Distribution::Default, &test_device);
let weight =
Tensor::<TestBackend, 4>::random([16, 16, 3, 3], Distribution::Default, &test_device);
let bias = Tensor::<TestBackend, 1>::random([16], Distribution::Default, &test_device);
let ref_device = Default::default();
let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data(), &ref_device);
let weight_ref = Tensor::<ReferenceBackend, 4>::from_data(weight.to_data(), &ref_device);
let bias_ref = Tensor::<ReferenceBackend, 1>::from_data(bias.to_data(), &ref_device);
let options = burn_tensor::ops::ConvOptions::new([1, 1], [2, 2], [1, 1], 1);
let output = module::conv2d(input, weight, Some(bias), options.clone());
let output_ref = module::conv2d(input_ref, weight_ref, Some(bias_ref), options);
let tolerance = Tolerance::default();
output
.into_data()
.assert_approx_eq::<FloatElem>(&output_ref.into_data(), tolerance);
}
/// Regression test for bias loader in new implicit GEMM
#[test]
fn conv2d_should_match_reference_backend_bias_regression() {
let test_device = Default::default();
let input = Tensor::<TestBackend, 4>::random([1, 1, 1, 1], Distribution::Default, &test_device);
let weight =
Tensor::<TestBackend, 4>::random([32, 1, 3, 3], Distribution::Default, &test_device);
let bias = Tensor::<TestBackend, 1>::random([32], Distribution::Default, &test_device);
let ref_device = Default::default();
let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data(), &ref_device);
let weight_ref = Tensor::<ReferenceBackend, 4>::from_data(weight.to_data(), &ref_device);
let bias_ref = Tensor::<ReferenceBackend, 1>::from_data(bias.to_data(), &ref_device);
let options = burn_tensor::ops::ConvOptions::new([1, 1], [1, 1], [1, 1], 1);
let output = module::conv2d(input, weight, Some(bias), options.clone()).permute([0, 2, 3, 1]);
let output_ref =
module::conv2d(input_ref, weight_ref, Some(bias_ref), options).permute([0, 2, 3, 1]);
let tolerance = Tolerance::default();
output
.into_data()
.assert_approx_eq::<FloatElem>(&output_ref.into_data(), tolerance);
}

View File

@@ -0,0 +1,27 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{Distribution, Tensor, module};
#[test]
fn conv3d_should_match_reference_backend() {
let test_device = Default::default();
let input =
Tensor::<TestBackend, 5>::random([6, 16, 32, 32, 32], Distribution::Default, &test_device);
let weight =
Tensor::<TestBackend, 5>::random([12, 8, 3, 3, 3], Distribution::Default, &test_device);
let bias = Tensor::<TestBackend, 1>::random([12], Distribution::Default, &test_device);
let ref_device = Default::default();
let input_ref = Tensor::<ReferenceBackend, 5>::from_data(input.to_data(), &ref_device);
let weight_ref = Tensor::<ReferenceBackend, 5>::from_data(weight.to_data(), &ref_device);
let bias_ref = Tensor::<ReferenceBackend, 1>::from_data(bias.to_data(), &ref_device);
let options = burn_tensor::ops::ConvOptions::new([2, 3, 4], [2, 3, 4], [2, 3, 4], 2);
let output = module::conv3d(input, weight, Some(bias), options.clone());
let output_ref = module::conv3d(input_ref, weight_ref, Some(bias_ref), options);
output
.into_data()
.assert_approx_eq::<FloatElem>(&output_ref.into_data(), Tolerance::default());
}

View File

@@ -0,0 +1,48 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{Distribution, Tensor, backend::Backend, module};
#[test]
fn conv_transpose2d_should_match_reference_backend() {
let device = Default::default();
TestBackend::seed(&device, 0);
let height = 8;
let width = 8;
let in_channels = 8;
let out_channels = 8;
let batch_size = 32;
let kernel_size_0 = 3;
let kernel_size_1 = 3;
let options = burn_tensor::ops::ConvTransposeOptions::new([1, 1], [1, 1], [0, 0], [1, 1], 1);
let test_device = Default::default();
let input = Tensor::<TestBackend, 4>::random(
[batch_size, in_channels, height, width],
Distribution::Default,
&test_device,
);
let weight = Tensor::<TestBackend, 4>::random(
[
in_channels,
out_channels / options.groups,
kernel_size_0,
kernel_size_1,
],
Distribution::Default,
&test_device,
);
let bias =
Tensor::<TestBackend, 1>::random([out_channels], Distribution::Default, &test_device);
let ref_device = Default::default();
let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data(), &ref_device);
let weight_ref = Tensor::<ReferenceBackend, 4>::from_data(weight.to_data(), &ref_device);
let bias_ref = Tensor::<ReferenceBackend, 1>::from_data(bias.to_data(), &ref_device);
let output = module::conv_transpose2d(input, weight, Some(bias), options.clone());
let output_ref = module::conv_transpose2d(input_ref, weight_ref, Some(bias_ref), options);
output
.into_data()
.assert_approx_eq::<FloatElem>(&output_ref.into_data(), Tolerance::rel_abs(0.01, 0.02));
}

View File

@@ -0,0 +1,51 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{Distribution, Tensor, backend::Backend, module};
#[test]
fn conv_transpose3d_should_match_reference_backend() {
let test_device = Default::default();
TestBackend::seed(&test_device, 0);
let depth = 8;
let height = 8;
let width = 8;
let in_channels = 8;
let out_channels = 8;
let batch_size = 32;
let kernel_size_0 = 3;
let kernel_size_1 = 3;
let kernel_size_2 = 3;
let options =
burn_tensor::ops::ConvTransposeOptions::new([1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1], 1);
let input = Tensor::<TestBackend, 5>::random(
[batch_size, in_channels, depth, height, width],
Distribution::Default,
&test_device,
);
let weight = Tensor::<TestBackend, 5>::random(
[
in_channels,
out_channels / options.groups,
kernel_size_0,
kernel_size_1,
kernel_size_2,
],
Distribution::Default,
&test_device,
);
let bias =
Tensor::<TestBackend, 1>::random([out_channels], Distribution::Default, &test_device);
let ref_device = Default::default();
let input_ref = Tensor::<ReferenceBackend, 5>::from_data(input.to_data(), &ref_device);
let weight_ref = Tensor::<ReferenceBackend, 5>::from_data(weight.to_data(), &ref_device);
let bias_ref = Tensor::<ReferenceBackend, 1>::from_data(bias.to_data(), &ref_device);
let output = module::conv_transpose3d(input, weight, Some(bias), options.clone());
let output_ref = module::conv_transpose3d(input_ref, weight_ref, Some(bias_ref), options);
output
.into_data()
.assert_approx_eq::<FloatElem>(&output_ref.into_data(), Tolerance::default());
}

View File

@@ -0,0 +1,159 @@
use super::*;
use burn_tensor::Tensor;
use burn_tensor::Tolerance;
#[test]
fn test_cross_product() {
let device = Default::default();
// Test with well-known orthogonal vectors for clearer validation
let a = Tensor::<TestBackend, 2>::from_data([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], &device);
let b = Tensor::<TestBackend, 2>::from_data([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], &device);
let result = a.cross(b, 1);
// For orthogonal unit vectors:
// i × j = k
// j × k = i
let expected = Tensor::<TestBackend, 2>::from_data([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], &device);
// Use Tolerance for floating-point comparisons
let tolerance = Tolerance::<FloatElem>::default();
result
.to_data()
.assert_approx_eq(&expected.to_data(), tolerance);
}
#[test]
fn test_cross_product_zeros() {
let device = Default::default();
// Test cross product with zero vector - should always give zero vector
let a = Tensor::<TestBackend, 2>::from_data([[2.0, 3.0, 4.0]], &device);
let b = Tensor::<TestBackend, 2>::zeros([1, 3], &device);
let result = a.cross(b, 1);
let expected = Tensor::<TestBackend, 2>::zeros([1, 3], &device);
// For zeros, we can use exact equality or a very tight tolerance
let tolerance = Tolerance::<FloatElem>::default();
result
.to_data()
.assert_approx_eq(&expected.to_data(), tolerance);
}
#[test]
fn test_cross_product_batch() {
let device = Default::default();
// Test typical cross product computations in batch
let a = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
let b = Tensor::<TestBackend, 2>::from_data([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], &device);
let result = a.cross(b, 1);
// Cross products:
// [1,2,3] × [4,5,6] = [-3,6,-3]
// [4,5,6] × [7,8,9] = [-3,6,-3]
let expected =
Tensor::<TestBackend, 2>::from_data([[-3.0, 6.0, -3.0], [-3.0, 6.0, -3.0]], &device);
let tolerance = Tolerance::<FloatElem>::default();
result
.to_data()
.assert_approx_eq(&expected.to_data(), tolerance);
}
#[test]
#[should_panic]
fn test_cross_product_invalid_dimension() {
let device = Default::default();
let a = Tensor::<TestBackend, 2>::zeros([1, 4], &device);
let b = Tensor::<TestBackend, 2>::zeros([1, 4], &device);
let _ = a.cross(b, 1);
}
#[test]
fn test_cross_product_parallel_vectors() {
let device = Default::default();
// Test cross product of parallel vectors (should be zero)
let a = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0, 3.0]], &device);
let b = Tensor::<TestBackend, 2>::from_data([[2.0, 4.0, 6.0]], &device); // b = 2 * a
let result = a.cross(b, 1);
let expected = Tensor::<TestBackend, 2>::zeros([1, 3], &device);
let tolerance = Tolerance::<FloatElem>::default();
result
.to_data()
.assert_approx_eq(&expected.to_data(), tolerance);
}
#[test]
fn test_cross_product_3d_tensor() {
let device = Default::default();
// Test with 3D tensor (batch of matrices)
let a = Tensor::<TestBackend, 3>::from_data(
[
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
],
&device,
);
let b = Tensor::<TestBackend, 3>::from_data(
[
[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
[[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
],
&device,
);
let result = a.cross(b, 2); // Cross on last dimension
let expected = Tensor::<TestBackend, 3>::from_data(
[
[[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]],
[[-3.0, 6.0, -3.0], [-3.0, 6.0, -3.0]],
],
&device,
);
let tolerance = Tolerance::<FloatElem>::default();
result
.to_data()
.assert_approx_eq(&expected.to_data(), tolerance);
}
// Test to verify that padding doesn't affect results
#[test]
fn test_cross_product_with_padding_awareness() {
let device = Default::default();
// Create tensors that would span multiple 4-element blocks
// This tests that the padding doesn't corrupt adjacent data
let a = Tensor::<TestBackend, 2>::from_data(
[
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], // Two vectors: [1,2,3] and [4,5,6]
],
&device,
);
let b = Tensor::<TestBackend, 2>::from_data(
[
[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], // Two vectors: [7,8,9] and [10,11,12]
],
&device,
);
// Reshape to have proper 3-element vectors in last dimension
let a_reshaped = a.reshape([2, 3]);
let b_reshaped = b.reshape([2, 3]);
let result = a_reshaped.cross(b_reshaped, 1);
// Expected cross products:
// [1,2,3] × [7,8,9] = [-6,12,-6]
// [4,5,6] × [10,11,12] = [-6,12,-6]
let expected =
Tensor::<TestBackend, 2>::from_data([[-6.0, 12.0, -6.0], [-6.0, 12.0, -6.0]], &device);
let tolerance = Tolerance::<FloatElem>::default();
result
.to_data()
.assert_approx_eq(&expected.to_data(), tolerance);
}

View File

@@ -0,0 +1,44 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{Distribution, Int, Shape, Tensor, backend::Backend};
#[test]
fn gather_should_work_with_multiple_workgroups_dim0() {
test_same_as_ref([6, 256], 0);
}
#[test]
fn gather_should_work_with_multiple_workgroups_dim1() {
test_same_as_ref([6, 256], 1);
}
fn test_same_as_ref<const D: usize>(shape: [usize; D], dim: usize) {
let device = Default::default();
TestBackend::seed(&device, 0);
let max = shape[dim];
let shape = Shape::new(shape);
let tensor =
Tensor::<TestBackend, D>::random(shape.clone(), Distribution::Default, &Default::default());
let indices = Tensor::<TestBackend, 1, Int>::from_data(
Tensor::<TestBackend, 1>::random(
[shape.num_elements()],
Distribution::Uniform(0., max as f64),
&Default::default(),
)
.into_data(),
&Default::default(),
)
.reshape(shape);
let tensor_ref =
Tensor::<ReferenceBackend, D>::from_data(tensor.to_data(), &Default::default());
let indices_ref =
Tensor::<ReferenceBackend, D, Int>::from_data(indices.to_data(), &Default::default());
let actual = tensor.gather(dim, indices);
let expected = tensor_ref.gather(dim, indices_ref);
expected
.into_data()
.assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());
}

View File

@@ -0,0 +1,64 @@
use super::*;
use burn_cubecl::kernel::{MaskFillStrategy, mask_fill};
use burn_tensor::Tolerance;
use burn_tensor::{Bool, Distribution, Element, Tensor, TensorPrimitive, backend::Backend};
use cubecl::prelude::InputScalar;
#[test]
fn mask_fill_should_match_reference_backend() {
let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill();
let dtype_bool = <<TestBackend as Backend>::BoolElem as Element>::dtype();
let dtype_ft = <FloatElem as Element>::dtype();
let actual = Tensor::<TestBackend, 3>::from_primitive(TensorPrimitive::Float(mask_fill(
tensor.into_primitive().tensor(),
mask.into_primitive(),
InputScalar::new(4.0, dtype_ft),
MaskFillStrategy::Readonly,
dtype_bool,
)));
let expected = tensor_ref.mask_fill(mask_ref, 4.0);
expected
.into_data()
.assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());
}
#[test]
fn mask_fill_inplace_should_match_reference_backend() {
let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill();
let dtype_bool = <<TestBackend as Backend>::BoolElem as Element>::dtype();
let dtype_ft = <FloatElem as Element>::dtype();
let actual = Tensor::<TestBackend, 3>::from_primitive(TensorPrimitive::Float(mask_fill::<_>(
tensor.into_primitive().tensor(),
mask.into_primitive(),
InputScalar::new(4.0, dtype_ft),
MaskFillStrategy::Inplace,
dtype_bool,
)));
let expected = tensor_ref.mask_fill(mask_ref, 4.0);
expected
.into_data()
.assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());
}
#[allow(clippy::type_complexity)]
fn inputs_mask_fill() -> (
Tensor<TestBackend, 3>,
Tensor<TestBackend, 3, Bool>,
Tensor<ReferenceBackend, 3>,
Tensor<ReferenceBackend, 3, Bool>,
) {
let test_device = Default::default();
let tensor = Tensor::<TestBackend, 3>::random([2, 6, 256], Distribution::Default, &test_device);
let mask =
Tensor::<TestBackend, 3>::random([2, 6, 256], Distribution::Uniform(0., 1.), &test_device)
.lower_equal_elem(0.5);
let ref_device = Default::default();
let tensor_ref = Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &ref_device);
let mask_ref = Tensor::<ReferenceBackend, 3, Bool>::from_data(mask.to_data(), &ref_device);
(tensor, mask, tensor_ref, mask_ref)
}

View File

@@ -0,0 +1,80 @@
use super::*;
use burn_cubecl::kernel::{MaskWhereStrategy, mask_where};
use burn_tensor::Tolerance;
use burn_tensor::{Bool, Distribution, Element, Tensor, TensorPrimitive, backend::Backend};
#[test]
fn mask_where_should_match_reference_backend() {
let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where();
let actual = tensor.mask_where(mask, value);
let expected = tensor_ref.mask_where(mask_ref, value_ref);
expected
.into_data()
.assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());
}
#[test]
fn mask_where_inplace_lhs_should_match_reference_backend() {
let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where();
let dtype_bool = <<TestBackend as Backend>::BoolElem as Element>::dtype();
let actual = Tensor::<TestBackend, 3>::from_primitive(TensorPrimitive::Float(mask_where::<_>(
tensor.into_primitive().tensor(),
mask.into_primitive(),
value.into_primitive().tensor(),
MaskWhereStrategy::InplaceLhs,
dtype_bool,
)));
let expected = tensor_ref.mask_where(mask_ref, value_ref);
expected
.into_data()
.assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());
}
#[test]
fn mask_where_inplace_rhs_should_match_reference_backend() {
let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where();
let dtype_bool = <<TestBackend as Backend>::BoolElem as Element>::dtype();
let actual = Tensor::<TestBackend, 3>::from_primitive(TensorPrimitive::Float(mask_where::<_>(
tensor.into_primitive().tensor(),
mask.into_primitive(),
value.into_primitive().tensor(),
MaskWhereStrategy::InplaceRhs,
dtype_bool,
)));
let expected = tensor_ref.mask_where(mask_ref, value_ref);
expected
.into_data()
.assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());
}
#[allow(clippy::type_complexity)]
fn inputs_mask_where() -> (
Tensor<TestBackend, 3>,
Tensor<TestBackend, 3>,
Tensor<TestBackend, 3, Bool>,
Tensor<ReferenceBackend, 3>,
Tensor<ReferenceBackend, 3>,
Tensor<ReferenceBackend, 3, Bool>,
) {
let device = Default::default();
TestBackend::seed(&device, 0);
let tensor = Tensor::<TestBackend, 3>::random([2, 6, 256], Distribution::Default, &device);
let value = Tensor::<TestBackend, 3>::random([2, 6, 256], Distribution::Default, &device);
let mask =
Tensor::<TestBackend, 3>::random([2, 6, 256], Distribution::Uniform(0., 1.), &device)
.lower_equal_elem(0.5);
let device_ref = Default::default();
let tensor_ref = Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &device_ref);
let value_ref = Tensor::<ReferenceBackend, 3>::from_data(value.to_data(), &device_ref);
let mask_ref = Tensor::<ReferenceBackend, 3, Bool>::from_data(mask.to_data(), &device_ref);
mask.to_data().assert_eq(&mask_ref.to_data(), false);
(tensor, value, mask, tensor_ref, value_ref, mask_ref)
}

View File

@@ -0,0 +1,52 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{Distribution, Tensor, module};
#[test]
pub fn max_pool2d_should_match_reference_backends() {
let tensor = Tensor::<TestBackend, 4>::random(
[32, 32, 32, 32],
Distribution::Default,
&Default::default(),
);
let tensor_ref =
Tensor::<ReferenceBackend, 4>::from_data(tensor.to_data(), &Default::default());
let kernel_size = [3, 3];
let stride = [2, 2];
let padding = [1, 1];
let dilation = [1, 1];
let pooled = module::max_pool2d(tensor, kernel_size, stride, padding, dilation, false);
let pooled_ref = module::max_pool2d(tensor_ref, kernel_size, stride, padding, dilation, false);
pooled
.into_data()
.assert_approx_eq::<FloatElem>(&pooled_ref.into_data(), Tolerance::default());
}
#[test]
pub fn max_pool2d_with_indices_should_match_reference_backend() {
let tensor = Tensor::<TestBackend, 4>::random(
[32, 32, 32, 32],
Distribution::Default,
&Default::default(),
);
let tensor_ref =
Tensor::<ReferenceBackend, 4>::from_data(tensor.to_data(), &Default::default());
let kernel_size = [3, 3];
let stride = [2, 2];
let padding = [1, 1];
let dilation = [1, 1];
let (pooled, indices) =
module::max_pool2d_with_indices(tensor, kernel_size, stride, padding, dilation, false);
let (pooled_ref, indices_ref) =
module::max_pool2d_with_indices(tensor_ref, kernel_size, stride, padding, dilation, false);
pooled
.into_data()
.assert_approx_eq::<FloatElem>(&pooled_ref.into_data(), Tolerance::default());
indices
.into_data()
.assert_eq(&indices_ref.into_data(), false);
}

View File

@@ -0,0 +1,67 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{Distribution, Tensor, TensorPrimitive, module, ops::ModuleOps};
#[test]
pub fn max_pool2d_with_indices_backward_should_match_reference_backend() {
let test_device = Default::default();
let tensor =
Tensor::<TestBackend, 4>::random([32, 32, 32, 32], Distribution::Default, &test_device);
let grad_output =
Tensor::<TestBackend, 4>::random([32, 32, 16, 16], Distribution::Default, &test_device);
let ref_device = Default::default();
let tensor_ref = Tensor::<ReferenceBackend, 4>::from_data(tensor.to_data(), &ref_device);
let grad_output_ref =
Tensor::<ReferenceBackend, 4>::from_data(grad_output.to_data(), &ref_device);
let kernel_size = [3, 3];
let stride = [2, 2];
let padding = [1, 1];
let dilation = [1, 1];
let (_, indices) = module::max_pool2d_with_indices(
tensor.clone(),
kernel_size,
stride,
padding,
dilation,
false,
);
let (_, indices_ref) = module::max_pool2d_with_indices(
tensor_ref.clone(),
kernel_size,
stride,
padding,
dilation,
false,
);
let grad = TestBackend::max_pool2d_with_indices_backward(
tensor.into_primitive().tensor(),
kernel_size,
stride,
padding,
dilation,
false,
grad_output.into_primitive().tensor(),
indices.into_primitive(),
)
.x_grad;
let grad_ref = ReferenceBackend::max_pool2d_with_indices_backward(
tensor_ref.into_primitive().tensor(),
kernel_size,
stride,
padding,
dilation,
false,
grad_output_ref.into_primitive().tensor(),
indices_ref.into_primitive(),
)
.x_grad;
Tensor::<TestBackend, 4>::from_primitive(TensorPrimitive::Float(grad))
.into_data()
.assert_approx_eq::<FloatElem>(
&Tensor::<ReferenceBackend, 4>::from_primitive(TensorPrimitive::Float(grad_ref))
.into_data(),
Tolerance::default(),
);
}

View File

@@ -0,0 +1,30 @@
// #[allow(unused_imports)] // required for re-included modules
pub use super::*;
mod avg_pool2d;
mod bernoulli;
mod cast;
mod cat;
mod clamp;
mod contiguous;
mod conv2d;
mod conv3d;
mod conv_transpose2d;
mod conv_transpose3d;
mod cross;
mod gather;
mod mask_fill;
mod mask_where;
mod max_pool2d;
mod max_pool2d_backward;
mod normal;
mod quantization;
mod reduce;
mod repeat_dim;
mod scatter;
mod select;
mod select_assign;
mod slice;
mod slice_assign;
mod unary;
mod uniform;

View File

@@ -0,0 +1,36 @@
use super::*;
use burn_tensor::{Distribution, Shape, Tensor, backend::Backend};
use cubek::random::{assert_mean_approx_equal, assert_normal_respects_68_95_99_rule};
use serial_test::serial;
#[test]
#[serial]
fn empirical_mean_close_to_expectation() {
let device = Default::default();
TestBackend::seed(&device, 0);
let shape = [100, 100];
let mean = 10.;
let tensor = Tensor::<TestBackend, 2>::random(shape, Distribution::Normal(mean, 2.), &device)
.into_data();
let numbers = tensor.as_slice::<FloatElem>().unwrap();
assert_mean_approx_equal(numbers, mean as f32);
}
#[test]
#[serial]
fn normal_respects_68_95_99_rule() {
// https://en.wikipedia.org/wiki/68%E2%80%9395%E2%80%9399.7_rule
let shape: Shape = [1000, 1000].into();
let device = Default::default();
let mu = 0.;
let s = 1.;
let tensor =
Tensor::<TestBackend, 2>::random(shape.clone(), Distribution::Normal(mu, s), &device)
.into_data();
let numbers = tensor.as_slice::<FloatElem>().unwrap();
assert_normal_respects_68_95_99_rule(numbers, mu as f32, s as f32);
}

View File

@@ -0,0 +1,240 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{
Shape, Tensor,
backend::Backend,
quantization::{QuantLevel, QuantScheme, QuantStore, QuantValue},
};
fn should_quantize_dequantize_symmetric_arange<S: Into<Shape>>(
value: QuantValue,
store: QuantStore,
shape: S,
) {
let shape = shape.into();
assert_eq!(shape.rank(), 2); // 2D tests
let scheme = QuantScheme::default().with_value(value).with_store(store);
let scheme_ref = scheme.clone().with_store(QuantStore::Native);
let input: Tensor<TestBackend, 2> =
Tensor::arange(0..shape.num_elements() as i64, &Default::default())
.float()
.reshape(shape);
let input_ref = Tensor::<ReferenceBackend, 2>::from_data(input.to_data(), &Default::default());
let output = input.quantize_dynamic(&scheme);
let output_ref = input_ref.quantize_dynamic(&scheme_ref);
output.to_data().assert_eq(&output_ref.to_data(), false);
let output = output.dequantize();
let output_ref = output_ref.dequantize();
output
.into_data()
.assert_approx_eq::<FloatElem>(&output_ref.to_data(), Tolerance::default());
}
fn should_quantize_dequantize_symmetric_per_block_arange<S: Into<Shape>>(
value: QuantValue,
block_size: usize,
store: QuantStore,
shape: S,
) {
let scheme = QuantScheme::default()
.with_value(value)
.with_level(QuantLevel::block([block_size as u8]))
.with_store(store);
let scheme_ref = scheme.clone().with_store(QuantStore::Native);
let shape = shape.into();
let input: Tensor<TestBackend, 2> =
Tensor::arange(0..shape.num_elements() as i64, &Default::default())
.float()
.reshape(shape);
let input_ref = Tensor::<ReferenceBackend, 2>::from_data(input.to_data(), &Default::default());
let output = input.quantize_dynamic(&scheme);
let output_ref = input_ref.quantize_dynamic(&scheme_ref);
output.to_data().assert_eq(&output_ref.to_data(), false);
let output = output.dequantize();
let output_ref = output_ref.dequantize();
output
.into_data()
.assert_approx_eq::<FloatElem>(&output_ref.to_data(), Tolerance::default());
}
fn should_quantize_dequantize_symmetric_per_block(
value: QuantValue,
block_size: usize,
store: QuantStore,
) {
let scheme = QuantScheme::default()
.with_value(value)
.with_level(QuantLevel::block([block_size as u8]))
.with_store(store);
let scheme_ref = scheme.clone().with_store(QuantStore::Native);
let input = Tensor::<TestBackend, 2>::from_floats(
[
[
-1.8, -1.0, 0.0, 0.5, -1.8, -1.0, 0.0, 0.5, 0.01, 0.025, 0.03, 0.04, 0.01, 0.025,
0.03, 0.04,
],
[
1.8, 1.0, 0.0, -0.5, 1.8, 1.0, 0.0, -0.5, -0.01, -0.025, -0.03, -0.04, -0.01,
-0.025, -0.03, -0.04,
],
],
&Default::default(),
);
let input_ref = Tensor::<ReferenceBackend, 2>::from_data(input.to_data(), &Default::default());
let output = input.quantize_dynamic(&scheme);
let output_ref = input_ref.quantize_dynamic(&scheme_ref);
output.to_data().assert_eq(&output_ref.to_data(), false);
let output = output.dequantize();
let output_ref = output_ref.dequantize();
output
.into_data()
.assert_approx_eq::<FloatElem>(&output_ref.to_data(), Tolerance::default());
}
fn supports_native() -> bool {
let name = <TestBackend as Backend>::name(&Default::default());
// TODO: Proper checks for i8 support.
name.contains("cuda")
|| name.contains("rocm")
|| name.contains("hip")
|| name.contains("vulkan")
|| name.contains("spirv")
|| name.contains("metal")
|| name.contains("msl")
}
#[test]
fn should_quantize_dequantize_symmetric_arange_q8s_packed() {
should_quantize_dequantize_symmetric_arange(QuantValue::Q8S, QuantStore::PackedU32(0), [8, 16])
}
#[test]
fn should_quantize_dequantize_symmetric_arange_q8f_packed() {
should_quantize_dequantize_symmetric_arange(QuantValue::Q8F, QuantStore::PackedU32(0), [8, 16])
}
#[test]
fn should_quantize_dequantize_symmetric_arange_q4s_packed() {
should_quantize_dequantize_symmetric_arange(QuantValue::Q4S, QuantStore::PackedU32(0), [8, 16])
}
#[test]
fn should_quantize_dequantize_symmetric_arange_q4f_packed() {
should_quantize_dequantize_symmetric_arange(QuantValue::Q4F, QuantStore::PackedU32(0), [8, 16])
}
#[test]
fn should_quantize_dequantize_symmetric_arange_q2s_packed() {
should_quantize_dequantize_symmetric_arange(QuantValue::Q2S, QuantStore::PackedU32(0), [8, 16])
}
#[test]
fn should_quantize_dequantize_symmetric_arange_q2f_packed() {
should_quantize_dequantize_symmetric_arange(QuantValue::Q2F, QuantStore::PackedU32(0), [8, 16])
}
#[test]
fn should_quantize_dequantize_symmetric_per_block_q8s_packed() {
should_quantize_dequantize_symmetric_per_block(QuantValue::Q8S, 8, QuantStore::PackedU32(0))
}
#[test]
fn should_quantize_dequantize_symmetric_per_block_q4s_packed() {
should_quantize_dequantize_symmetric_per_block(QuantValue::Q4S, 8, QuantStore::PackedU32(0))
}
#[test]
#[should_panic = "Block size must be divisible by 16"]
fn should_panic_when_block_size_cannot_store_num_quants() {
// num_quants in u32 = 32 bits / 2 bits = 16
should_quantize_dequantize_symmetric_per_block(QuantValue::Q2S, 8, QuantStore::PackedU32(0))
}
#[test]
fn should_quantize_dequantize_symmetric_per_block_q2s_packed() {
should_quantize_dequantize_symmetric_per_block(QuantValue::Q2S, 16, QuantStore::PackedU32(0))
}
#[test]
fn should_quantize_dequantize_symmetric_arange_q8s_native() {
if supports_native() {
should_quantize_dequantize_symmetric_arange(QuantValue::Q8S, QuantStore::Native, [32, 32])
}
}
#[test]
fn should_quantize_dequantize_symmetric_per_block_q8s_native() {
if supports_native() {
should_quantize_dequantize_symmetric_per_block(QuantValue::Q8S, 8, QuantStore::Native)
}
}
#[test]
fn should_quantize_dequantize_symmetric_per_block_arange_q8s_packed() {
should_quantize_dequantize_symmetric_per_block_arange(
QuantValue::Q8S,
32,
QuantStore::PackedU32(0),
[32, 32],
)
}
#[test]
fn should_quantize_dequantize_symmetric_per_block_arange_q8s_native() {
if supports_native() {
should_quantize_dequantize_symmetric_per_block_arange(
QuantValue::Q8S,
32,
QuantStore::Native,
[32, 32],
)
}
}
#[test]
fn should_quantize_dequantize_symmetric_arange_128x256_q8s_native() {
if supports_native() {
should_quantize_dequantize_symmetric_per_block_arange(
QuantValue::Q8S,
32,
QuantStore::Native,
[128, 256],
)
}
}
#[test]
fn should_quantize_dequantize_symmetric_arange_128x256_q8s_packed() {
should_quantize_dequantize_symmetric_per_block_arange(
QuantValue::Q8S,
32,
QuantStore::PackedU32(0),
[128, 256],
)
}
#[test]
#[should_panic = "Can't store in u32"]
fn should_panic_when_shape_cannot_store_quants() {
let device = Default::default();
let scheme = QuantScheme::default();
let _tensor_1 =
Tensor::<TestBackend, 2>::from_floats([[1.0, 6.35], [2.0, 3.0], [1.0, 3.0]], &device)
.quantize_dynamic(&scheme);
}

Some files were not shown because too many files have changed in this diff Show More