feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,58 @@
use super::*;
use burn_tensor::{TensorData, Tolerance, cast::ToElement};
#[test]
fn should_diff_abs() {
let data_1 = TensorData::from([[0.0, -1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, -10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[71.0, 107.0], [71.0, 107.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[84.0, 42.0], [90.0, 54.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_abs_no_nans() {
let data_1 = TensorData::from([[6.0, 7.0], [9.0, -10.0]]);
let data_2 = TensorData::from([[0.0, -1.0], [3.0, 4.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[1.0, 7.0], [1.0, 7.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[0.0, -15.0], [-3.0, -3.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let contains_nan = grad_2.contains_nan();
assert!(!contains_nan.into_scalar().to_bool());
}

View File

@@ -0,0 +1,50 @@
use super::*;
use burn_tensor::module::adaptive_avg_pool1d;
use burn_tensor::{Shape, Tolerance};
#[test]
fn test_avg_pool1d_simple() {
let test = AdaptiveAvgPool1dTestCase {
batch_size: 1,
channels: 2,
length: 5,
output_size: 3,
};
test.assert_output(TestTensor::from_floats(
[[
[0.5000, 0.83333, 0.33333, 0.83333, 0.5000],
[0.5000, 0.83333, 0.33333, 0.83333, 0.5000],
]],
&Default::default(),
));
}
struct AdaptiveAvgPool1dTestCase {
batch_size: usize,
channels: usize,
length: usize,
output_size: usize,
}
impl AdaptiveAvgPool1dTestCase {
fn assert_output(self, x_grad: TestTensor<3>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<3, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = adaptive_avg_pool1d(x.clone(), self.output_size);
let grads = output.backward();
let x_grad_actual = x.grad(&grads).unwrap();
x_grad.to_data().assert_approx_eq::<FloatElem>(
&x_grad_actual.into_data(),
Tolerance::default().set_half_precision_relative(1e-3),
);
}
}

View File

@@ -0,0 +1,96 @@
use super::*;
use burn_tensor::module::adaptive_avg_pool2d;
use burn_tensor::{Shape, Tolerance};
#[test]
fn test_avg_pool2d_simple() {
let test = AdaptiveAvgPool2dTestCase {
batch_size: 1,
channels: 2,
height: 5,
width: 3,
output_size_1: 3,
output_size_2: 2,
};
test.assert_output(TestTensor::from_floats(
[[
[
[0.2500, 0.5000, 0.2500],
[0.41667, 0.83333, 0.41667],
[0.16667, 0.33333, 0.16667],
[0.41667, 0.83333, 0.41667],
[0.2500, 0.5000, 0.2500],
],
[
[0.2500, 0.5000, 0.2500],
[0.41667, 0.83333, 0.41667],
[0.16667, 0.33333, 0.16667],
[0.41667, 0.83333, 0.41667],
[0.2500, 0.5000, 0.2500],
],
]],
&Default::default(),
));
}
#[test]
fn test_avg_pool2d_output_1() {
let test = AdaptiveAvgPool2dTestCase {
batch_size: 1,
channels: 1,
height: 4,
width: 8,
output_size_1: 1,
output_size_2: 1,
};
test.assert_output(TestTensor::from_floats(
[[[
[
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
],
[
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
],
[
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
],
[
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
],
]]],
&Default::default(),
));
}
struct AdaptiveAvgPool2dTestCase {
batch_size: usize,
channels: usize,
height: usize,
width: usize,
output_size_1: usize,
output_size_2: usize,
}
impl AdaptiveAvgPool2dTestCase {
fn assert_output(self, x_grad: TestTensor<4>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<4, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = adaptive_avg_pool2d(x.clone(), [self.output_size_1, self.output_size_2]);
let grads = output.backward();
let x_grad_actual = x.grad(&grads).unwrap();
x_grad.to_data().assert_approx_eq::<FloatElem>(
&x_grad_actual.into_data(),
Tolerance::default().set_half_precision_relative(1e-3),
);
}
}

View File

@@ -0,0 +1,74 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_add() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_floats([2.0, 5.0], &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_floats([4.0, 1.0], &device).require_grad();
let tensor_3 = tensor_1.clone() + tensor_2.clone();
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([1.0, 1.0]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([1.0, 1.0]), false);
tensor_3
.to_data()
.assert_eq(&TensorData::from([6.0, 6.0]), false);
}
#[test]
fn should_diff_add_scalar() {
let data = TensorData::from([2.0, 10.0]);
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
let tensor_out = tensor.clone().add_scalar(5.0);
let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();
grad.to_data()
.assert_eq(&TensorData::from([1.0, 1.0]), false);
tensor_out
.into_data()
.assert_eq(&TensorData::from([7.0, 15.0]), false);
}
#[test]
fn test_add_complex_1() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().add(tensor_2.clone());
let tensor_5 = tensor_4
.add(tensor_3)
.add_scalar(5.0)
.add(tensor_1.clone())
.add(tensor_2.clone());
let tensor_6 = tensor_1.clone().add(tensor_5);
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[3.0, 3.0], [3.0, 3.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[2.0, 2.0], [2.0, 2.0]]), false);
}

View File

@@ -0,0 +1,138 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_mean() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.mean().unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[3.5, 9.5], [3.5, 9.5]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[-0.75, -0.75], [3.0, 3.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_sum_1() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.sum().unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[14.0, 38.0], [14.0, 38.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[-3.0, -3.0], [12.0, 12.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_sum_2() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.clone().sum_dim(1);
let tensor_5 = tensor_4.mul(tensor_3);
let grads = tensor_5.sum().backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[494.0, 722.0], [2990.0, 4370.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[690.0, 690.0], [958.0, 958.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_mean_dim() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.mean_dim(1).unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[4.0, 36.0], [3.0, -17.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[9.0, 9.0], [35.5, 35.5]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_sum_dim() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.sum_dim(1).unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[8.0, 72.0], [6.0, -34.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[18.0, 18.0], [71.0, 71.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,102 @@
use super::*;
use burn_tensor::module::avg_pool1d;
use burn_tensor::{Shape, Tolerance};
#[test]
fn test_avg_pool1d_simple() {
let test = AvgPool1dTestCase {
batch_size: 1,
channels: 1,
kernel_size: 3,
padding: 0,
stride: 1,
length: 6,
count_include_pad: true,
};
test.assert_output(TestTensor::from_floats(
[[[0.33333, 0.66667, 1.0000, 1.0000, 0.66667, 0.33333]]],
&Default::default(),
));
}
#[test]
fn test_avg_pool1d_complex() {
let test = AvgPool1dTestCase {
batch_size: 1,
channels: 2,
kernel_size: 3,
padding: 1,
stride: 2,
length: 6,
count_include_pad: true,
};
test.assert_output(TestTensor::from_floats(
[[
[0.33333, 0.66667, 0.33333, 0.66667, 0.33333, 0.33333],
[0.33333, 0.66667, 0.33333, 0.66667, 0.33333, 0.33333],
]],
&Default::default(),
));
}
#[test]
fn test_avg_pool1d_complex_dont_count_pad() {
let test = AvgPool1dTestCase {
batch_size: 1,
channels: 2,
kernel_size: 3,
padding: 1,
stride: 2,
length: 6,
count_include_pad: false,
};
test.assert_output(TestTensor::from_floats(
[[
[0.5000, 0.83333, 0.33333, 0.66667, 0.33333, 0.33333],
[0.5000, 0.83333, 0.33333, 0.66667, 0.33333, 0.33333],
]],
&Default::default(),
));
}
struct AvgPool1dTestCase {
batch_size: usize,
channels: usize,
kernel_size: usize,
padding: usize,
stride: usize,
length: usize,
count_include_pad: bool,
}
impl AvgPool1dTestCase {
fn assert_output(self, x_grad: TestTensor<3>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<3, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = avg_pool1d(
x.clone(),
self.kernel_size,
self.stride,
self.padding,
self.count_include_pad,
false,
);
let grads = output.backward();
let x_grad_actual = x.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
x_grad
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.into_data(), tolerance);
}
}

View File

@@ -0,0 +1,129 @@
use super::*;
use burn_tensor::module::avg_pool2d;
use burn_tensor::{Shape, Tolerance};
#[test]
fn test_avg_pool2d_simple() {
let test = AvgPool2dTestCase {
batch_size: 1,
channels: 1,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
height: 6,
width: 6,
count_include_pad: true,
};
test.assert_output(TestTensor::from_floats(
[[[
[0.11111, 0.22222, 0.33333, 0.33333, 0.22222, 0.11111],
[0.22222, 0.44444, 0.66667, 0.66667, 0.44444, 0.22222],
[0.33333, 0.66667, 1.00000, 1.00000, 0.66667, 0.33333],
[0.33333, 0.66667, 1.00000, 1.00000, 0.66667, 0.33333],
[0.22222, 0.44444, 0.66667, 0.66667, 0.44444, 0.22222],
[0.11111, 0.22222, 0.33333, 0.33333, 0.22222, 0.11111],
]]],
&Default::default(),
));
}
#[test]
fn test_avg_pool2d_complex() {
let test = AvgPool2dTestCase {
batch_size: 1,
channels: 1,
kernel_size_1: 3,
kernel_size_2: 4,
padding_1: 1,
padding_2: 2,
stride_1: 1,
stride_2: 2,
height: 4,
width: 6,
count_include_pad: true,
};
test.assert_output(TestTensor::from_floats(
[[[
[0.33333, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333],
[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
[0.33333, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333],
]]],
&Default::default(),
));
}
#[test]
fn test_avg_pool2d_complex_dont_include_pad() {
let test = AvgPool2dTestCase {
batch_size: 1,
channels: 1,
kernel_size_1: 3,
kernel_size_2: 4,
padding_1: 1,
padding_2: 2,
stride_1: 1,
stride_2: 2,
height: 4,
width: 6,
count_include_pad: false,
};
test.assert_output(TestTensor::from_floats(
[[[
[0.6250, 0.6250, 0.41667, 0.41667, 0.6250, 0.6250],
[0.8750, 0.8750, 0.58333, 0.58333, 0.8750, 0.8750],
[0.8750, 0.8750, 0.58333, 0.58333, 0.8750, 0.8750],
[0.6250, 0.6250, 0.41667, 0.41667, 0.6250, 0.6250],
]]],
&Default::default(),
));
}
struct AvgPool2dTestCase {
batch_size: usize,
channels: usize,
kernel_size_1: usize,
kernel_size_2: usize,
padding_1: usize,
padding_2: usize,
stride_1: usize,
stride_2: usize,
height: usize,
width: usize,
count_include_pad: bool,
}
impl AvgPool2dTestCase {
fn assert_output(self, x_grad: TestTensor<4>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<4, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = avg_pool2d(
x.clone(),
[self.kernel_size_1, self.kernel_size_2],
[self.stride_1, self.stride_2],
[self.padding_1, self.padding_2],
self.count_include_pad,
false,
);
let grads = output.backward();
let x_grad_actual = x.grad(&grads).unwrap();
x_grad.to_data().assert_approx_eq::<FloatElem>(
&x_grad_actual.into_data(),
Tolerance::default().set_half_precision_relative(1e-3),
);
}
}

View File

@@ -0,0 +1,24 @@
use super::*;
use burn_tensor::{Int, Tensor, TensorData, module::embedding};
#[test]
fn test_embedding_backward() {
let weights = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let indices = TensorData::from([[0, 1], [1, 1]]);
let x = TensorData::from([
[[1.0, 2.0], [4.0, 5.0], [3.0, 4.0]],
[[4.0, 5.0], [8.0, 5.0], [1.0, 9.0]],
]);
let device = Default::default();
let weights = Tensor::<TestAutodiffBackend, 2>::from_data(weights, &device).require_grad();
let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(indices, &device);
let x = Tensor::<TestAutodiffBackend, 3>::from_data(x, &device).require_grad();
let output = embedding(weights.clone(), indices);
let output = output.matmul(x);
let grads = output.backward();
let grad = weights.grad(&grads).unwrap();
grad.to_data()
.assert_eq(&TensorData::from([[3., 9., 7.], [21., 35., 27.]]), false);
}

View File

@@ -0,0 +1,27 @@
use super::*;
use burn_tensor::{DType, Distribution, Tensor};
#[test]
fn test_full_precision() {
let device = Default::default();
let x1 = Tensor::<TestAutodiffBackend, 2>::random([32, 32], Distribution::Default, &device)
.require_grad();
let x2 = Tensor::<TestAutodiffBackend, 2>::random([32, 32], Distribution::Default, &device)
.require_grad();
let dtype = x1.dtype();
let x3 = x1.clone().cast(DType::F32);
let x4 = x2.clone().cast(DType::F32);
let x5 = x3.matmul(x4);
let x6 = x5.cast(dtype);
let x7 = x6 * x1.clone() / x2.clone();
let grads = x7.backward();
let x1_grad = x1.grad(&grads);
let x2_grad = x2.grad(&grads);
assert!(x1_grad.is_some());
assert!(x2_grad.is_some());
}

View File

@@ -0,0 +1,56 @@
use super::*;
#[test]
fn mul_broadcast() {
test_ops_broadcast_backward(|x, y| x * y);
}
#[test]
fn div_broadcast() {
test_ops_broadcast_backward(|x, y| x / y);
}
#[test]
fn sub_broadcast() {
test_ops_broadcast_backward(|x, y| x - y);
}
#[test]
fn add_broadcast() {
test_ops_broadcast_backward(|x, y| x + y);
}
#[test]
fn matmul_broadcast() {
test_ops_broadcast_backward(|x, y| x.matmul(y));
}
#[test]
fn mask_where_broadcast() {
test_ops_broadcast_backward(|x, y| {
let cond = y.clone().equal_elem(4);
x.mask_where(cond, y)
});
}
fn test_ops_broadcast_backward<F>(func: F)
where
F: Fn(TestAutodiffTensor<3>, TestAutodiffTensor<3>) -> TestAutodiffTensor<3>,
{
let device = Default::default();
let w = TestAutodiffTensor::zeros([16, 5, 5], &device).require_grad();
let x = TestAutodiffTensor::zeros([4, 5, 5], &device).require_grad();
// Slice isn't a broadcastable operation, so it will fail when the previous backward pass
// of an operation that support broadcast doesn't support it during the backward pass.
let y = func(w.clone().slice([0..1]), x.clone());
// Will panic if broadcast isn't supported!
let grads = y.backward();
let w_grad = w.grad(&grads).unwrap();
let x_grad = x.grad(&grads).unwrap();
assert_eq!(w_grad.shape(), w.shape());
assert_eq!(x_grad.shape(), x.shape());
}

View File

@@ -0,0 +1,28 @@
// Skip on metal - F64 not supported
#![cfg(all(feature = "std", not(feature = "metal")))]
use super::*;
use burn_backend_tests::might_panic;
use burn_tensor::{DType, Tensor, TensorData};
#[might_panic(reason = "Unsupported precision for fusion")]
#[test]
fn cast_keeps_gradient_flow() {
let device = Default::default();
let x = Tensor::<TestAutodiffBackend, 2>::from_data(
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
&device,
)
.require_grad();
let y = x.clone().cast(DType::F64);
let z = y.sum();
let grads = z.backward();
let grad_x = x.grad(&grads).unwrap();
grad_x
.to_data()
.assert_eq(&TensorData::from([[1., 1.], [1., 1.]]), false);
}

View File

@@ -0,0 +1,110 @@
use super::*;
use burn_tensor::Tolerance;
#[test]
fn should_diff_cat() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_data([[2.0, -1.0], [5.0, 2.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::<2>::from_data([[5.0, 4.0], [-1.0, 4.0]], &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let mut tensor_1_list = Vec::new();
let mut tensor_2_list = Vec::new();
for i in 0..2 {
tensor_1_list.push(tensor_1.clone().slice([i..i + 1]));
tensor_2_list.push(tensor_2.clone().slice([i..i + 1]));
}
let tensor_1_cat = TestAutodiffTensor::cat(tensor_1_list.clone(), 0);
let tensor_2_cat = TestAutodiffTensor::cat(tensor_2_list.clone(), 0);
let tensor_3_cat = tensor_1_cat.clone().matmul(tensor_2_cat.clone());
let grads = tensor_3_cat.backward();
let grad_1_slice_1 = tensor_1.grad(&grads).unwrap().slice([0..1]);
let grad_1_slice_2 = tensor_1.grad(&grads).unwrap().slice([1..2]);
let grad_2_slice_1 = tensor_2.grad(&grads).unwrap().slice([0..1]);
let grad_2_slice_2 = tensor_2.grad(&grads).unwrap().slice([1..2]);
grad_1
.clone()
.slice([0..1])
.to_data()
.assert_approx_eq::<FloatElem>(&grad_1_slice_1.to_data(), Tolerance::default());
grad_1
.slice([1..2])
.to_data()
.assert_approx_eq::<FloatElem>(&grad_1_slice_2.to_data(), Tolerance::default());
grad_2
.clone()
.slice([0..1])
.to_data()
.assert_approx_eq::<FloatElem>(&grad_2_slice_1.to_data(), Tolerance::default());
grad_2
.slice([1..2])
.to_data()
.assert_approx_eq::<FloatElem>(&grad_2_slice_2.to_data(), Tolerance::default());
}
#[test]
fn should_diff_cat_more_than_1_dim() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_data([[2.0, -1.0], [5.0, 2.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::<2>::from_data([[5.0, 4.0], [-1.0, 4.0], [4.0, 1.0]], &device)
.require_grad();
// Concat a tensor [2, 2] with another tensor [3, 2] along dim 0.
// The resulting tensor should be [5, 2]
let tensor_3 = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 0);
assert_eq!(tensor_3.dims(), [5, 2]);
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
assert_eq!(tensor_1.dims(), grad_1.dims());
assert_eq!(tensor_2.dims(), grad_2.dims());
}
#[test]
fn should_slice_grads_correctly_when_some_inputs_not_tracked() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data([[1.0]], &device).require_grad(); // tracked
let tensor_2 = TestAutodiffTensor::<2>::from_data([[10.0, 20.0]], &device); // not tracked
let tensor_3 =
TestAutodiffTensor::<2>::from_data([[100.0, 200.0, 300.0]], &device).require_grad(); // tracked
let cat = TestAutodiffTensor::cat(
vec![tensor_1.clone(), tensor_2.clone(), tensor_3.clone()],
1,
);
// Make gradient per column unique so wrong slicing shows up.
let weights = TestAutodiffTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]], &device);
let loss = (cat * weights).sum();
let grads = loss.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_3 = tensor_3.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&burn_tensor::TensorData::from([[1.0]]), false);
grad_3
.to_data()
.assert_eq(&burn_tensor::TensorData::from([[4.0, 5.0, 6.0]]), false);
}

View File

@@ -0,0 +1,21 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_ceil() {
let data = TensorData::from([
[-1.9751, 0.0714, 0.0643, 0.2406],
[-1.3172, 0.1252, -0.1119, -0.0127],
]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
let tensor_2 = tensor_1.clone().ceil();
let grads = tensor_2.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
grad_1.to_data().assert_eq(
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
false,
);
}

View File

@@ -0,0 +1,215 @@
use super::*;
use burn_tensor::{Bool, Tensor, TensorData};
#[test]
fn test_autodiff_checkpoint_complicated_computation() {
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
let device = Default::default();
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
let tensor_5 = compute_bound_eager(tensor_0, tensor_1);
let tensor_6 = compute_bound_lazy(tensor_2, tensor_3.clone());
let tensor_7 = memory_bound_eager(tensor_3, tensor_4);
let tensor_8 = compute_bound_lazy(tensor_6, tensor_7.clone());
let tensor_9 = memory_bound_eager_scalar(tensor_7, 11.);
let tensor_10 = memory_bound_lazy(tensor_5, tensor_8.clone());
let tensor_11 = memory_bound_lazy(tensor_8, tensor_9);
let tensor_12 = compute_bound_lazy(tensor_10, tensor_11);
assert_checkpoint(tensor_12);
}
#[test]
fn test_autodiff_checkpoint_with_missing_requirement() {
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
let device = Default::default();
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device); // does not require_grad
let tensor_2 = memory_bound_eager(tensor_0, tensor_1);
let tensor_3 = memory_bound_eager_scalar(tensor_2.clone(), 11.);
let tensor_4 = memory_bound_eager_scalar(tensor_2.clone(), 11.);
let tensor_5 = compute_bound_lazy(tensor_3, tensor_4);
let tensor_6 = compute_bound_eager_scalar(tensor_5.clone(), 11.);
let tensor_7 = memory_bound_eager(tensor_5, tensor_2);
let tensor_8 = memory_bound_eager(tensor_6, tensor_7);
assert_checkpoint(tensor_8);
}
#[test]
fn test_autodiff_checkpoint_with_many_duplicates() {
let data_0 = TensorData::from([[4.0, 7.0], [7.0, 7.0]]);
let device = Default::default();
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
let tensor_1 = memory_bound_eager(tensor_0.clone(), tensor_0.clone());
let tensor_2 = compute_bound_eager(tensor_0.clone(), tensor_0.clone());
let tensor_3 = memory_bound_lazy(tensor_0.clone(), tensor_0.clone());
let tensor_4 = compute_bound_lazy(tensor_0.clone(), tensor_0.clone());
let tensor_5 = memory_bound_eager(tensor_1.clone(), tensor_0.clone());
let tensor_6 = memory_bound_eager(tensor_0.clone(), tensor_5.clone());
let tensor_7 = compute_bound_lazy(tensor_3.clone(), tensor_5.clone());
let tensor_8 = compute_bound_eager(tensor_4.clone(), tensor_2.clone());
let tensor_9 = memory_bound_lazy(tensor_6, tensor_7);
let tensor_10 = memory_bound_eager(tensor_0, tensor_9);
let tensor_11 = memory_bound_eager_scalar(tensor_10, 9.);
let tensor_12 = compute_bound_lazy(tensor_8, tensor_11);
assert_checkpoint(tensor_12);
}
#[test]
fn test_autodiff_checkpoint_with_long_chain_of_eager_memory_bound() {
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
let device = Default::default();
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device);
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
let tensor_5 = memory_bound_eager(tensor_0, tensor_1.clone());
let tensor_6 = memory_bound_eager(tensor_5, tensor_2);
let tensor_7 = memory_bound_eager(tensor_6, tensor_3);
let tensor_8 = memory_bound_eager(tensor_7, tensor_4);
let tensor_9 = memory_bound_lazy(tensor_8, tensor_1);
assert_checkpoint(tensor_9)
}
#[test]
fn test_autodiff_checkpoint_half_sub_graph_not_tracked() {
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
let data_5 = TensorData::from([[0.5, 7.0], [7.0, 7.0]]);
let device = Default::default();
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device);
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device);
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device);
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
let tensor_5 = TestAutodiffTensor::from_data(data_5, &device).require_grad();
let tensor_6 = memory_bound_lazy(tensor_0, tensor_1);
let tensor_7 = compute_bound_eager(tensor_6, tensor_2);
let tensor_8 = memory_bound_eager(tensor_3, tensor_4);
let tensor_9 = compute_bound_lazy(tensor_8, tensor_5);
let tensor_10 = compute_bound_lazy(tensor_7, tensor_9);
assert_checkpoint(tensor_10);
}
#[test]
fn test_autodiff_checkpoint_very_complex() {
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
let device = Default::default();
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device);
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
let tensor_5 = memory_bound_eager_scalar(tensor_0, 8.);
let tensor_6 = memory_bound_lazy(tensor_5.clone(), tensor_1.clone());
let tensor_7 = compute_bound_lazy(tensor_6.clone(), tensor_6);
let tensor_8 = memory_bound_lazy(tensor_1.clone(), tensor_5.clone());
let tensor_9 = memory_bound_eager_scalar(tensor_7.clone(), 7.);
let tensor_10 = compute_bound_eager(tensor_5, tensor_8);
let tensor_11 = memory_bound_eager(tensor_2.clone(), tensor_9);
let tensor_12 = memory_bound_lazy(tensor_2.clone(), tensor_2);
let tensor_13 = compute_bound_eager(tensor_10.clone(), tensor_11);
let tensor_14 = compute_bound_eager_scalar(tensor_3, 8.);
let tensor_15 = compute_bound_lazy(tensor_4, tensor_12);
let tensor_16 = memory_bound_lazy(tensor_10, tensor_7);
let tensor_17 = compute_bound_lazy(tensor_13, tensor_1);
let tensor_18 = memory_bound_eager(tensor_15, tensor_16);
let tensor_19 = compute_bound_eager(tensor_14, tensor_17);
let tensor_20 = memory_bound_lazy(tensor_18, tensor_19);
let tensor_21 = memory_bound_eager_scalar(tensor_20, 8.);
assert_checkpoint(tensor_21)
}
fn assert_checkpoint<const D: usize>(tensor: TestAutodiffTensor<D>) {
// Assert is not explicit here, but the test can fail
// - when a tensor is actually required more than n_required, it won't be found and will panic
// - when a tensor is actually required less than n_required, the backward states map won't be
// empty and will fail the assertion within the backward code, same for retro_forwards
tensor.backward();
}
// Does not save its state and does not need its parents
fn memory_bound_eager<const D: usize>(
tensor_a: TestAutodiffTensor<D>,
tensor_b: TestAutodiffTensor<D>,
) -> TestAutodiffTensor<D> {
tensor_a.add(tensor_b)
}
fn memory_bound_eager_scalar<const D: usize>(
tensor_a: TestAutodiffTensor<D>,
b: f32,
) -> TestAutodiffTensor<D> {
tensor_a.add_scalar(b)
}
// Saves its own state and does not need its parents
fn compute_bound_eager<const D: usize>(
tensor_a: TestAutodiffTensor<D>,
tensor_b: TestAutodiffTensor<D>,
) -> TestAutodiffTensor<D> {
let mask = Tensor::<TestAutodiffBackend, D, Bool>::empty(tensor_a.shape(), &tensor_a.device());
tensor_a.mask_where(mask, tensor_b)
}
fn compute_bound_eager_scalar<const D: usize>(
tensor_a: TestAutodiffTensor<D>,
b: f32,
) -> TestAutodiffTensor<D> {
let mask = Tensor::<TestAutodiffBackend, D, Bool>::empty(tensor_a.shape(), &tensor_a.device());
tensor_a.mask_fill(mask, b)
}
// Does not save its state and needs its parents
fn memory_bound_lazy<const D: usize>(
tensor_a: TestAutodiffTensor<D>,
tensor_b: TestAutodiffTensor<D>,
) -> TestAutodiffTensor<D> {
tensor_a.mul(tensor_b)
}
// Saves its own state and needs its parents
fn compute_bound_lazy<const D: usize>(
tensor_a: TestAutodiffTensor<D>,
tensor_b: TestAutodiffTensor<D>,
) -> TestAutodiffTensor<D> {
tensor_a.matmul(tensor_b)
}

View File

@@ -0,0 +1,81 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_full_complex_1() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.matmul(tensor_1.clone());
let tensor_5 = tensor_4.mul(tensor_2.clone());
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[593., 463.0], [487.0, 539.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[734.0, 294.0], [1414.0, 242.0]]), false);
}
#[test]
fn should_diff_full_complex_2() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.matmul(tensor_1.clone());
let tensor_5 = tensor_4.add_scalar(17.0).add(tensor_2.clone());
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[166.0, 110.0], [212.0, 156.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[113.0, 141.0], [33.0, 41.0]]), false);
}
#[test]
fn should_diff_full_complex_3() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.matmul(tensor_1.clone());
let tensor_5 = tensor_4.clone().sub(tensor_2.clone());
let tensor_6 = tensor_5.add(tensor_4);
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[332.0, 220.0], [424.0, 312.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[223.0, 279.0], [63.0, 79.0]]), false);
}

View File

@@ -0,0 +1,277 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv1d, ops::ConvOptions};
#[test]
fn test_conv1d_basic() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 1,
groups: 1,
length: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[14., 24., 24., 18.], [26., 42., 42., 30.]],
[[14., 24., 24., 18.], [26., 42., 42., 30.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[30., 44., 36.], [54., 76., 60.]],
[[30., 44., 36.], [54., 76., 60.]],
],
&device,
),
bias: TestTensor::from_floats([8., 8.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv1d_different_channels() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 3,
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 1,
groups: 1,
length: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[39., 63., 63., 45.], [57., 90., 90., 63.]],
[[39., 63., 63., 45.], [57., 90., 90., 63.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[30., 44., 36.], [54., 76., 60.]],
[[30., 44., 36.], [54., 76., 60.]],
[[30., 44., 36.], [54., 76., 60.]],
],
&device,
),
bias: TestTensor::from_floats([8., 8., 8.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv1d_with_padding() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 2,
stride: 1,
dilation: 1,
groups: 1,
length: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[24., 24., 24., 24.], [42., 42., 42., 42.]],
[[24., 24., 24., 24.], [42., 42., 42., 42.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[44., 44., 44.], [76., 76., 76.]],
[[44., 44., 44.], [76., 76., 76.]],
],
&device,
),
bias: TestTensor::from_floats([12., 12.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv1d_with_stride() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
stride: 2,
dilation: 1,
groups: 1,
length: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[8., 16., 8., 10.], [14., 28., 14., 16.]],
[[8., 16., 8., 10.], [14., 28., 14., 16.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[10., 20., 24.], [18., 36., 40.]],
[[10., 20., 24.], [18., 36., 40.]],
],
&device,
),
bias: TestTensor::from_floats([4., 4.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv1d_dilation() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 2,
groups: 1,
length: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[6., 8., 8., 10.], [12., 14., 14., 16.]],
[[6., 8., 8., 10.], [12., 14., 14., 16.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[8., 22., 14.], [16., 38., 22.]],
[[8., 22., 14.], [16., 38., 22.]],
],
&device,
),
bias: TestTensor::from_floats([4., 4.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv1d_groups() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 1,
groups: 2,
length: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[1., 3., 3., 3.], [7., 12., 12., 9.]],
[[1., 3., 3., 3.], [7., 12., 12., 9.]],
],
&device,
),
weight: TestTensor::from_floats([[[30., 44., 36.]], [[54., 76., 60.]]], &device),
bias: TestTensor::from_floats([8., 8.], &device),
};
test.assert_grads(grads);
}
struct Conv1dTestCase {
batch_size: usize,
channels_in: usize,
channels_out: usize,
kernel_size: usize,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
length: usize,
}
struct Grads {
x: TestTensor<3>,
weight: TestTensor<3>,
bias: TestTensor<1>,
}
impl Conv1dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]);
let shape_weight = Shape::new([
self.channels_out,
self.channels_in / self.groups,
self.kernel_size,
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<3, _>(shape_weight)
.into_data(),
&device,
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<3, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = conv1d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
let tolerance = Tolerance::default();
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
}
}

View File

@@ -0,0 +1,962 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv2d, ops::ConvOptions};
#[test]
fn test_conv2d_basic() {
let test = Conv2dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[88., 138., 138., 96.],
[150., 234., 234., 162.],
[150., 234., 234., 162.],
[112., 174., 174., 120.],
],
[
[160., 246., 246., 168.],
[258., 396., 396., 270.],
[258., 396., 396., 270.],
[184., 282., 282., 192.],
],
],
[
[
[88., 138., 138., 96.],
[150., 234., 234., 162.],
[150., 234., 234., 162.],
[112., 174., 174., 120.],
],
[
[160., 246., 246., 168.],
[258., 396., 396., 270.],
[258., 396., 396., 270.],
[184., 282., 282., 192.],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
[
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
],
&device,
),
bias: TestTensor::from_floats([32., 32.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_channels() {
let test = Conv2dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 3,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[240., 369., 369., 252.],
[387., 594., 594., 405.],
[387., 594., 594., 405.],
[276., 423., 423., 288.],
],
[
[348., 531., 531., 360.],
[549., 837., 837., 567.],
[549., 837., 837., 567.],
[384., 585., 585., 396.],
],
],
[
[
[240., 369., 369., 252.],
[387., 594., 594., 405.],
[387., 594., 594., 405.],
[276., 423., 423., 288.],
],
[
[348., 531., 531., 360.],
[549., 837., 837., 567.],
[549., 837., 837., 567.],
[384., 585., 585., 396.],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
[
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
[
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
],
&device,
),
bias: TestTensor::from_floats([32., 32., 32.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_kernel_size() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 4,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[116., 180., 192., 132.],
[198., 306., 324., 222.],
[198., 306., 324., 222.],
[148., 228., 240., 164.],
],
[
[212., 324., 336., 228.],
[342., 522., 540., 366.],
[342., 522., 540., 366.],
[244., 372., 384., 260.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[27., 45., 54., 39.],
[52., 84., 96., 68.],
[51., 81., 90., 63.],
],
[
[123., 189., 198., 135.],
[180., 276., 288., 196.],
[147., 225., 234., 159.],
],
],
[
[
[27., 45., 54., 39.],
[52., 84., 96., 68.],
[51., 81., 90., 63.],
],
[
[123., 189., 198., 135.],
[180., 276., 288., 196.],
[147., 225., 234., 159.],
],
],
],
&device,
),
bias: TestTensor::from_floats([12., 12.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_padding() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 2,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[138., 138., 138., 138.],
[234., 234., 234., 234.],
[234., 234., 234., 234.],
[174., 174., 174., 174.],
],
[
[246., 246., 246., 246.],
[396., 396., 396., 396.],
[396., 396., 396., 396.],
[282., 282., 282., 282.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]],
[[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]],
],
[
[[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]],
[[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]],
],
],
&device,
),
bias: TestTensor::from_floats([24., 24.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_width() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 4,
width: 5,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[88., 138., 138., 138., 96.],
[150., 234., 234., 234., 162.],
[150., 234., 234., 234., 162.],
[112., 174., 174., 174., 120.],
],
[
[160., 246., 246., 246., 168.],
[258., 396., 396., 396., 270.],
[258., 396., 396., 396., 270.],
[184., 282., 282., 282., 192.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]],
[[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]],
],
[
[[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]],
[[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]],
],
],
&device,
),
bias: TestTensor::from_floats([20., 20.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_stride_2() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 2,
stride_2: 2,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 6,
width: 6,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[26., 52., 26., 52., 26., 28.],
[52., 104., 52., 104., 52., 56.],
[26., 52., 26., 52., 26., 28.],
[52., 104., 52., 104., 52., 56.],
[26., 52., 26., 52., 26., 28.],
[32., 64., 32., 64., 32., 34.],
],
[
[44., 88., 44., 88., 44., 46.],
[88., 176., 88., 176., 88., 92.],
[44., 88., 44., 88., 44., 46.],
[88., 176., 88., 176., 88., 92.],
[44., 88., 44., 88., 44., 46.],
[50., 100., 50., 100., 50., 52.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]],
[[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]],
],
[
[[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]],
[[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]],
],
],
&device,
),
bias: TestTensor::from_floats([9., 9.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_stride() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 3,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 8,
width: 8,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[50., 78., 78., 78., 78., 78., 78., 54.],
[62., 96., 96., 96., 96., 96., 96., 66.],
[38., 60., 60., 60., 60., 60., 60., 42.],
[50., 78., 78., 78., 78., 78., 78., 54.],
[62., 96., 96., 96., 96., 96., 96., 66.],
[38., 60., 60., 60., 60., 60., 60., 42.],
[50., 78., 78., 78., 78., 78., 78., 54.],
[62., 96., 96., 96., 96., 96., 96., 66.],
],
[
[86., 132., 132., 132., 132., 132., 132., 90.],
[98., 150., 150., 150., 150., 150., 150., 102.],
[74., 114., 114., 114., 114., 114., 114., 78.],
[86., 132., 132., 132., 132., 132., 132., 90.],
[98., 150., 150., 150., 150., 150., 150., 102.],
[74., 114., 114., 114., 114., 114., 114., 78.],
[86., 132., 132., 132., 132., 132., 132., 90.],
[98., 150., 150., 150., 150., 150., 150., 102.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]],
[
[1330., 1528., 1344.],
[1911., 2196., 1932.],
[2079., 2388., 2100.],
],
],
[
[[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]],
[
[1330., 1528., 1344.],
[1911., 2196., 1932.],
[2079., 2388., 2100.],
],
],
],
&device,
),
bias: TestTensor::from_floats([24., 24.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_dilation_2() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 2,
dilation_2: 2,
groups: 1,
height: 6,
width: 6,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[18., 38., 38., 42., 42., 22.],
[42., 88., 88., 96., 96., 50.],
[42., 88., 88., 96., 96., 50.],
[54., 112., 112., 120., 120., 62.],
[54., 112., 112., 120., 120., 62.],
[30., 62., 62., 66., 66., 34.],
],
[
[36., 74., 74., 78., 78., 40.],
[78., 160., 160., 168., 168., 86.],
[78., 160., 160., 168., 168., 86.],
[90., 184., 184., 192., 192., 98.],
[90., 184., 184., 192., 192., 98.],
[48., 98., 98., 102., 102., 52.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]],
[[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]],
],
[
[[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]],
[[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]],
],
],
&device,
),
bias: TestTensor::from_floats([16., 16.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_dilation() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 2,
dilation_2: 3,
groups: 1,
height: 6,
width: 6,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[18., 0., 20., 20., 0., 22.],
[42., 0., 46., 46., 0., 50.],
[42., 0., 46., 46., 0., 50.],
[54., 0., 58., 58., 0., 62.],
[54., 0., 58., 58., 0., 62.],
[30., 0., 32., 32., 0., 34.],
],
[
[36., 0., 38., 38., 0., 40.],
[78., 0., 82., 82., 0., 86.],
[78., 0., 82., 82., 0., 86.],
[90., 0., 94., 94., 0., 98.],
[90., 0., 94., 94., 0., 98.],
[48., 0., 50., 50., 0., 52.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]],
[[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]],
],
[
[[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]],
[[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]],
],
],
&device,
),
bias: TestTensor::from_floats([8., 8.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_groups() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 2,
height: 5,
width: 5,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[0., 1., 3., 3., 2.],
[3., 8., 15., 12., 7.],
[9., 21., 36., 27., 15.],
[9., 20., 33., 24., 13.],
[6., 13., 21., 15., 8.],
],
[
[9., 19., 30., 21., 11.],
[21., 44., 69., 48., 25.],
[36., 75., 117., 81., 42.],
[27., 56., 87., 60., 31.],
[15., 31., 48., 33., 17.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[[54., 63., 72.], [99., 108., 117.], [144., 153., 162.]]],
[[[279., 288., 297.], [324., 333., 342.], [369., 378., 387.]]],
],
&device,
),
bias: TestTensor::from_floats([9., 9.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_groups_stride_2() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 4,
channels_out: 4,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 2,
stride_2: 2,
dilation_1: 1,
dilation_2: 1,
groups: 4,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[4., 8., 4., 5.],
[8., 16., 8., 10.],
[4., 8., 4., 5.],
[7., 14., 7., 8.],
],
[
[13., 26., 13., 14.],
[26., 52., 26., 28.],
[13., 26., 13., 14.],
[16., 32., 16., 17.],
],
[
[22., 44., 22., 23.],
[44., 88., 44., 46.],
[22., 44., 22., 23.],
[25., 50., 25., 26.],
],
[
[31., 62., 31., 32.],
[62., 124., 62., 64.],
[31., 62., 31., 32.],
[34., 68., 34., 35.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[[5., 10., 12.], [10., 20., 24.], [18., 36., 40.]]],
[[[21., 42., 44.], [42., 84., 88.], [50., 100., 104.]]],
[[[37., 74., 76.], [74., 148., 152.], [82., 164., 168.]]],
[[[53., 106., 108.], [106., 212., 216.], [114., 228., 232.]]],
],
&device,
),
bias: TestTensor::from_floats([4., 4., 4., 4.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_groups_different_channels() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 3,
channels_out: 6,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 3,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[9., 20., 24., 13.],
[24., 52., 60., 32.],
[36., 76., 84., 44.],
[21., 44., 48., 25.],
],
[
[45., 92., 96., 49.],
[96., 196., 204., 104.],
[108., 220., 228., 116.],
[57., 116., 120., 61.],
],
[
[81., 164., 168., 85.],
[168., 340., 348., 176.],
[180., 364., 372., 188.],
[93., 188., 192., 97.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]],
[[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]],
[[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]],
[[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]],
[[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]],
[[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]],
],
&device,
),
bias: TestTensor::from_floats([4., 4., 4., 4., 4., 4.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_complex() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 3,
kernel_size_1: 2,
kernel_size_2: 3,
padding_1: 1,
padding_2: 2,
stride_1: 1,
stride_2: 2,
dilation_1: 2,
dilation_2: 3,
groups: 1,
height: 4,
width: 5,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[36., 39., 0., 39., 42.],
[81., 87., 0., 87., 93.],
[81., 87., 0., 87., 93.],
[45., 48., 0., 48., 51.],
],
[
[54., 57., 0., 57., 60.],
[117., 123., 0., 123., 129.],
[117., 123., 0., 123., 129.],
[63., 66., 0., 66., 69.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[15., 42., 27.], [30., 72., 42.]],
[[75., 162., 87.], [90., 192., 102.]],
],
[
[[15., 42., 27.], [30., 72., 42.]],
[[75., 162., 87.], [90., 192., 102.]],
],
[
[[15., 42., 27.], [30., 72., 42.]],
[[75., 162., 87.], [90., 192., 102.]],
],
],
&device,
),
bias: TestTensor::from_floats([8., 8., 8.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_groups_stride_2_no_pad() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 4,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 2,
stride_2: 2,
dilation_1: 1,
dilation_2: 1,
groups: 2,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[0., 1., 2., 0.],
[3., 4., 5., 0.],
[6., 7., 8., 0.],
[0., 0., 0., 0.],
],
[
[9., 10., 11., 0.],
[12., 13., 14., 0.],
[15., 16., 17., 0.],
[0., 0., 0., 0.],
],
[
[18., 19., 20., 0.],
[21., 22., 23., 0.],
[24., 25., 26., 0.],
[0., 0., 0., 0.],
],
[
[27., 28., 29., 0.],
[30., 31., 32., 0.],
[33., 34., 35., 0.],
[0., 0., 0., 0.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[0., 1., 2.], [4., 5., 6.], [8., 9., 10.]],
[[16., 17., 18.], [20., 21., 22.], [24., 25., 26.]],
],
[
[[32., 33., 34.], [36., 37., 38.], [40., 41., 42.]],
[[48., 49., 50.], [52., 53., 54.], [56., 57., 58.]],
],
],
&device,
),
bias: TestTensor::from_floats([1., 1.], &device),
};
test.assert_grads(grads);
}
struct Conv2dTestCase {
batch_size: usize,
channels_in: usize,
channels_out: usize,
kernel_size_1: usize,
kernel_size_2: usize,
padding_1: usize,
padding_2: usize,
stride_1: usize,
stride_2: usize,
dilation_1: usize,
dilation_2: usize,
groups: usize,
height: usize,
width: usize,
}
struct Grads {
x: TestTensor<4>,
weight: TestTensor<4>,
bias: TestTensor<1>,
}
impl Conv2dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
let shape_weight = Shape::new([
self.channels_out,
self.channels_in / self.groups,
self.kernel_size_1,
self.kernel_size_2,
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<4, _>(shape_weight)
.into_data(),
&device,
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<4, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = conv2d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvOptions::new(
[self.stride_1, self.stride_2],
[self.padding_1, self.padding_2],
[self.dilation_1, self.dilation_2],
self.groups,
),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
let tolerance = Tolerance::rel_abs(0.01, 0.01);
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
}
}

View File

@@ -0,0 +1,690 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv3d, ops::ConvOptions};
#[test]
fn test_conv3d_basic() {
let test = Conv3dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
kernel_size_3: 3,
padding_1: 1,
padding_2: 1,
padding_3: 1,
stride_1: 1,
stride_2: 1,
stride_3: 1,
dilation_1: 1,
dilation_2: 1,
dilation_3: 1,
groups: 1,
depth: 4,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[
[536., 816., 816., 552.],
[840., 1278., 1278., 864.],
[840., 1278., 1278., 864.],
[584., 888., 888., 600.],
],
[
[912., 1386., 1386., 936.],
[1422., 2160., 2160., 1458.],
[1422., 2160., 2160., 1458.],
[984., 1494., 1494., 1008.],
],
[
[912., 1386., 1386., 936.],
[1422., 2160., 2160., 1458.],
[1422., 2160., 2160., 1458.],
[984., 1494., 1494., 1008.],
],
[
[680., 1032., 1032., 696.],
[1056., 1602., 1602., 1080.],
[1056., 1602., 1602., 1080.],
[728., 1104., 1104., 744.],
],
],
[
[
[968., 1464., 1464., 984.],
[1488., 2250., 2250., 1512.],
[1488., 2250., 2250., 1512.],
[1016., 1536., 1536., 1032.],
],
[
[1560., 2358., 2358., 1584.],
[2394., 3618., 3618., 2430.],
[2394., 3618., 3618., 2430.],
[1632., 2466., 2466., 1656.],
],
[
[1560., 2358., 2358., 1584.],
[2394., 3618., 3618., 2430.],
[2394., 3618., 3618., 2430.],
[1632., 2466., 2466., 1656.],
],
[
[1112., 1680., 1680., 1128.],
[1704., 2574., 2574., 1728.],
[1704., 2574., 2574., 1728.],
[1160., 1752., 1752., 1176.],
],
],
],
[
[
[
[536., 816., 816., 552.],
[840., 1278., 1278., 864.],
[840., 1278., 1278., 864.],
[584., 888., 888., 600.],
],
[
[912., 1386., 1386., 936.],
[1422., 2160., 2160., 1458.],
[1422., 2160., 2160., 1458.],
[984., 1494., 1494., 1008.],
],
[
[912., 1386., 1386., 936.],
[1422., 2160., 2160., 1458.],
[1422., 2160., 2160., 1458.],
[984., 1494., 1494., 1008.],
],
[
[680., 1032., 1032., 696.],
[1056., 1602., 1602., 1080.],
[1056., 1602., 1602., 1080.],
[728., 1104., 1104., 744.],
],
],
[
[
[968., 1464., 1464., 984.],
[1488., 2250., 2250., 1512.],
[1488., 2250., 2250., 1512.],
[1016., 1536., 1536., 1032.],
],
[
[1560., 2358., 2358., 1584.],
[2394., 3618., 3618., 2430.],
[2394., 3618., 3618., 2430.],
[1632., 2466., 2466., 1656.],
],
[
[1560., 2358., 2358., 1584.],
[2394., 3618., 3618., 2430.],
[2394., 3618., 3618., 2430.],
[1632., 2466., 2466., 1656.],
],
[
[1112., 1680., 1680., 1128.],
[1704., 2574., 2574., 1728.],
[1704., 2574., 2574., 1728.],
[1160., 1752., 1752., 1176.],
],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[
[4590., 6156., 4644.],
[6264., 8400., 6336.],
[4806., 6444., 4860.],
],
[
[6696., 8976., 6768.],
[9120., 12224., 9216.],
[6984., 9360., 7056.],
],
[
[5454., 7308., 5508.],
[7416., 9936., 7488.],
[5670., 7596., 5724.],
],
],
[
[
[8046., 10764., 8100.],
[10872., 14544., 10944.],
[8262., 11052., 8316.],
],
[
[11304., 15120., 11376.],
[15264., 20416., 15360.],
[11592., 15504., 11664.],
],
[
[8910., 11916., 8964.],
[12024., 16080., 12096.],
[9126., 12204., 9180.],
],
],
],
[
[
[
[4590., 6156., 4644.],
[6264., 8400., 6336.],
[4806., 6444., 4860.],
],
[
[6696., 8976., 6768.],
[9120., 12224., 9216.],
[6984., 9360., 7056.],
],
[
[5454., 7308., 5508.],
[7416., 9936., 7488.],
[5670., 7596., 5724.],
],
],
[
[
[8046., 10764., 8100.],
[10872., 14544., 10944.],
[8262., 11052., 8316.],
],
[
[11304., 15120., 11376.],
[15264., 20416., 15360.],
[11592., 15504., 11664.],
],
[
[8910., 11916., 8964.],
[12024., 16080., 12096.],
[9126., 12204., 9180.],
],
],
],
],
&device,
),
bias: TestTensor::from_floats([128., 128.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv3d_complex() {
let test = Conv3dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 3,
kernel_size_1: 2,
kernel_size_2: 3,
kernel_size_3: 4,
padding_1: 1,
padding_2: 2,
padding_3: 3,
stride_1: 1,
stride_2: 2,
stride_3: 3,
dilation_1: 2,
dilation_2: 3,
dilation_3: 4,
groups: 1,
depth: 5,
height: 6,
width: 7,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[
[0., 147., 0., 0., 0., 150., 0.],
[0., 159., 0., 0., 0., 162., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 159., 0., 0., 0., 162., 0.],
[0., 171., 0., 0., 0., 174., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 330., 0., 0., 0., 336., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 378., 0., 0., 0., 384., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 330., 0., 0., 0., 336., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 378., 0., 0., 0., 384., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 330., 0., 0., 0., 336., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 378., 0., 0., 0., 384., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 183., 0., 0., 0., 186., 0.],
[0., 195., 0., 0., 0., 198., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 195., 0., 0., 0., 198., 0.],
[0., 207., 0., 0., 0., 210., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
],
[
[
[0., 219., 0., 0., 0., 222., 0.],
[0., 231., 0., 0., 0., 234., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 231., 0., 0., 0., 234., 0.],
[0., 243., 0., 0., 0., 246., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 474., 0., 0., 0., 480., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 522., 0., 0., 0., 528., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 474., 0., 0., 0., 480., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 522., 0., 0., 0., 528., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 474., 0., 0., 0., 480., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 522., 0., 0., 0., 528., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 255., 0., 0., 0., 258., 0.],
[0., 267., 0., 0., 0., 270., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 267., 0., 0., 0., 270., 0.],
[0., 279., 0., 0., 0., 282., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[
[0., 256., 272., 0.],
[0., 624., 656., 0.],
[0., 368., 384., 0.],
],
[
[0., 424., 440., 0.],
[0., 960., 992., 0.],
[0., 536., 552., 0.],
],
],
[
[
[0., 1096., 1112., 0.],
[0., 2304., 2336., 0.],
[0., 1208., 1224., 0.],
],
[
[0., 1264., 1280., 0.],
[0., 2640., 2672., 0.],
[0., 1376., 1392., 0.],
],
],
],
[
[
[
[0., 256., 272., 0.],
[0., 624., 656., 0.],
[0., 368., 384., 0.],
],
[
[0., 424., 440., 0.],
[0., 960., 992., 0.],
[0., 536., 552., 0.],
],
],
[
[
[0., 1096., 1112., 0.],
[0., 2304., 2336., 0.],
[0., 1208., 1224., 0.],
],
[
[0., 1264., 1280., 0.],
[0., 2640., 2672., 0.],
[0., 1376., 1392., 0.],
],
],
],
[
[
[
[0., 256., 272., 0.],
[0., 624., 656., 0.],
[0., 368., 384., 0.],
],
[
[0., 424., 440., 0.],
[0., 960., 992., 0.],
[0., 536., 552., 0.],
],
],
[
[
[0., 1096., 1112., 0.],
[0., 2304., 2336., 0.],
[0., 1208., 1224., 0.],
],
[
[0., 1264., 1280., 0.],
[0., 2640., 2672., 0.],
[0., 1376., 1392., 0.],
],
],
],
],
&device,
),
bias: TestTensor::from_floats([10., 10., 10.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv3d_groups_stride_2_no_pad() {
let test = Conv3dTestCase {
batch_size: 1,
channels_in: 4,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
kernel_size_3: 3,
padding_1: 0,
padding_2: 0,
padding_3: 0,
stride_1: 2,
stride_2: 2,
stride_3: 2,
dilation_1: 1,
dilation_2: 1,
dilation_3: 1,
groups: 2,
depth: 4,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[
[0., 1., 2., 0.],
[3., 4., 5., 0.],
[6., 7., 8., 0.],
[0., 0., 0., 0.],
],
[
[9., 10., 11., 0.],
[12., 13., 14., 0.],
[15., 16., 17., 0.],
[0., 0., 0., 0.],
],
[
[18., 19., 20., 0.],
[21., 22., 23., 0.],
[24., 25., 26., 0.],
[0., 0., 0., 0.],
],
[
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
],
],
[
[
[27., 28., 29., 0.],
[30., 31., 32., 0.],
[33., 34., 35., 0.],
[0., 0., 0., 0.],
],
[
[36., 37., 38., 0.],
[39., 40., 41., 0.],
[42., 43., 44., 0.],
[0., 0., 0., 0.],
],
[
[45., 46., 47., 0.],
[48., 49., 50., 0.],
[51., 52., 53., 0.],
[0., 0., 0., 0.],
],
[
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
],
],
[
[
[54., 55., 56., 0.],
[57., 58., 59., 0.],
[60., 61., 62., 0.],
[0., 0., 0., 0.],
],
[
[63., 64., 65., 0.],
[66., 67., 68., 0.],
[69., 70., 71., 0.],
[0., 0., 0., 0.],
],
[
[72., 73., 74., 0.],
[75., 76., 77., 0.],
[78., 79., 80., 0.],
[0., 0., 0., 0.],
],
[
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
],
],
[
[
[81., 82., 83., 0.],
[84., 85., 86., 0.],
[87., 88., 89., 0.],
[0., 0., 0., 0.],
],
[
[90., 91., 92., 0.],
[93., 94., 95., 0.],
[96., 97., 98., 0.],
[0., 0., 0., 0.],
],
[
[99., 100., 101., 0.],
[102., 103., 104., 0.],
[105., 106., 107., 0.],
[0., 0., 0., 0.],
],
[
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[[0., 1., 2.], [4., 5., 6.], [8., 9., 10.]],
[[16., 17., 18.], [20., 21., 22.], [24., 25., 26.]],
[[32., 33., 34.], [36., 37., 38.], [40., 41., 42.]],
],
[
[[64., 65., 66.], [68., 69., 70.], [72., 73., 74.]],
[[80., 81., 82.], [84., 85., 86.], [88., 89., 90.]],
[[96., 97., 98.], [100., 101., 102.], [104., 105., 106.]],
],
],
[
[
[[128., 129., 130.], [132., 133., 134.], [136., 137., 138.]],
[[144., 145., 146.], [148., 149., 150.], [152., 153., 154.]],
[[160., 161., 162.], [164., 165., 166.], [168., 169., 170.]],
],
[
[[192., 193., 194.], [196., 197., 198.], [200., 201., 202.]],
[[208., 209., 210.], [212., 213., 214.], [216., 217., 218.]],
[[224., 225., 226.], [228., 229., 230.], [232., 233., 234.]],
],
],
],
&device,
),
bias: TestTensor::from_floats([1., 1.], &device),
};
test.assert_grads(grads);
}
struct Conv3dTestCase {
batch_size: usize,
channels_in: usize,
channels_out: usize,
kernel_size_1: usize,
kernel_size_2: usize,
kernel_size_3: usize,
padding_1: usize,
padding_2: usize,
padding_3: usize,
stride_1: usize,
stride_2: usize,
stride_3: usize,
dilation_1: usize,
dilation_2: usize,
dilation_3: usize,
groups: usize,
depth: usize,
height: usize,
width: usize,
}
struct Grads {
x: TestTensor<5>,
weight: TestTensor<5>,
bias: TestTensor<1>,
}
impl Conv3dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([
self.batch_size,
self.channels_in,
self.depth,
self.height,
self.width,
]);
let shape_weight = Shape::new([
self.channels_out,
self.channels_in / self.groups,
self.kernel_size_1,
self.kernel_size_2,
self.kernel_size_3,
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<5, _>(shape_weight)
.into_data(),
&device,
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<5, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = conv3d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvOptions::new(
[self.stride_1, self.stride_2, self.stride_3],
[self.padding_1, self.padding_2, self.padding_3],
[self.dilation_1, self.dilation_2, self.dilation_3],
self.groups,
),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
let tolerance = Tolerance::default();
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
}
}

View File

@@ -0,0 +1,292 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv_transpose1d, ops::ConvTransposeOptions};
#[test]
fn test_conv_transpose1d_basic() {
let test = ConvTranspose1dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: 3,
padding: 0,
padding_out: 0,
stride: 1,
dilation: 1,
groups: 1,
size: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]],
[[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[44.0, 44.0, 44.0], [44.0, 44.0, 44.0]],
[[76.0, 76.0, 76.0], [76.0, 76.0, 76.0]],
],
&device,
),
bias: TestTensor::from_floats([12., 12.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose1d_padding() {
let test = ConvTranspose1dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: 3,
padding: 2,
padding_out: 0,
stride: 1,
dilation: 1,
groups: 1,
size: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[7., 12., 8., 3.], [19., 36., 32., 15.]],
[[7., 12., 8., 3.], [19., 36., 32., 15.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[26., 22., 18.], [26., 22., 18.]],
[[42., 38., 34.], [42., 38., 34.]],
],
&device,
),
bias: TestTensor::from_floats([4., 4.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose1d_stride() {
let test = ConvTranspose1dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: 3,
padding: 0,
padding_out: 0,
stride: 2,
dilation: 1,
groups: 1,
size: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[44., 44., 44.], [44., 44., 44.]],
[[76., 76., 76.], [76., 76., 76.]],
],
&device,
),
bias: TestTensor::from_floats([18., 18.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose1d_stride_padding_out() {
let test = ConvTranspose1dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: 3,
padding: 0,
padding_out: 1,
stride: 2,
dilation: 1,
groups: 1,
size: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[44., 44., 44.], [44., 44., 44.]],
[[76., 76., 76.], [76., 76., 76.]],
],
&device,
),
bias: TestTensor::from_floats([20., 20.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose1d_dilation() {
let test = ConvTranspose1dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: 3,
padding: 0,
padding_out: 0,
stride: 1,
dilation: 2,
groups: 1,
size: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
],
&device,
),
weight: TestTensor::from_floats(
[
[[44., 44., 44.], [44., 44., 44.]],
[[76., 76., 76.], [76., 76., 76.]],
],
&device,
),
bias: TestTensor::from_floats([16., 16.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose1d_complex() {
let test = ConvTranspose1dTestCase {
batch_size: 2,
channels: [2, 4],
kernel_size: 3,
padding: 1,
padding_out: 1,
stride: 2,
dilation: 2,
groups: 2,
size: 8,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0],
[36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0],
],
[
[12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0],
[36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[[168.0, 184.0, 184.0], [168.0, 184.0, 184.0]],
[[280.0, 312.0, 312.0], [280.0, 312.0, 312.0]],
],
&device,
),
bias: TestTensor::from_floats([36.0, 36.0, 36.0, 36.0], &device),
};
test.assert_grads(grads);
}
struct ConvTranspose1dTestCase {
batch_size: usize,
channels: [usize; 2],
kernel_size: usize,
padding: usize,
padding_out: usize,
stride: usize,
dilation: usize,
groups: usize,
size: usize,
}
struct Grads {
x: TestTensor<3>,
weight: TestTensor<3>,
bias: TestTensor<1>,
}
impl ConvTranspose1dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([self.batch_size, self.channels[0], self.size]);
let shape_weight = Shape::new([
self.channels[0],
self.channels[1] / self.groups,
self.kernel_size,
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<3, _>(shape_weight)
.into_data(),
&device,
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<3, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = conv_transpose1d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvTransposeOptions::new(
[self.stride],
[self.padding],
[self.padding_out],
[self.dilation],
self.groups,
),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), Tolerance::default());
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), Tolerance::default());
}
}

View File

@@ -0,0 +1,706 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv_transpose2d, ops::ConvTransposeOptions};
#[test]
fn test_conv_transpose2d_basic() {
let test = ConvTranspose2dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: [3, 3],
padding: [0, 0],
padding_out: [0, 0],
stride: [1, 1],
dilation: [1, 1],
groups: 1,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[153., 153., 153., 153.],
[153., 153., 153., 153.],
[153., 153., 153., 153.],
[153., 153., 153., 153.],
],
[
[477., 477., 477., 477.],
[477., 477., 477., 477.],
[477., 477., 477., 477.],
[477., 477., 477., 477.],
],
],
[
[
[153., 153., 153., 153.],
[153., 153., 153., 153.],
[153., 153., 153., 153.],
[153., 153., 153., 153.],
],
[
[477., 477., 477., 477.],
[477., 477., 477., 477.],
[477., 477., 477., 477.],
[477., 477., 477., 477.],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]],
[[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]],
],
[
[
[1264., 1264., 1264.],
[1264., 1264., 1264.],
[1264., 1264., 1264.],
],
[
[1264., 1264., 1264.],
[1264., 1264., 1264.],
[1264., 1264., 1264.],
],
],
],
&device,
),
bias: TestTensor::from_floats([72., 72.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_padding() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [1, 1],
kernel_size: [3, 3],
padding: [1, 2],
padding_out: [0, 0],
stride: [1, 1],
dilation: [1, 1],
groups: 1,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[[
[13., 24., 20., 9.],
[15., 27., 21., 9.],
[15., 27., 21., 9.],
[7., 12., 8., 3.],
]]],
&device,
),
weight: TestTensor::from_floats(
[[[[63., 57., 51.], [68., 60., 52.], [39., 33., 27.]]]],
&device,
),
bias: TestTensor::from_floats([8.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_stride() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [1, 1],
kernel_size: [3, 3],
padding: [0, 0],
padding_out: [0, 0],
stride: [2, 3],
dilation: [1, 1],
groups: 1,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[[
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
]]],
&device,
),
weight: TestTensor::from_floats(
[[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]],
&device,
),
bias: TestTensor::from_floats([108.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_stride_padding_out() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [1, 1],
kernel_size: [3, 3],
padding: [0, 0],
padding_out: [1, 2],
stride: [2, 3],
dilation: [1, 1],
groups: 1,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[[
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
]]],
&device,
),
weight: TestTensor::from_floats(
[[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]],
&device,
),
bias: TestTensor::from_floats([140.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_dilation() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [1, 1],
kernel_size: [3, 3],
padding: [0, 0],
padding_out: [0, 0],
stride: [1, 1],
dilation: [2, 3],
groups: 1,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[[
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
]]],
&device,
),
weight: TestTensor::from_floats(
[[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]],
&device,
),
bias: TestTensor::from_floats([80.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_channels() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [2, 3],
kernel_size: [3, 3],
padding: [0, 0],
padding_out: [0, 0],
stride: [1, 1],
dilation: [1, 1],
groups: 1,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[351., 351., 351., 351.],
[351., 351., 351., 351.],
[351., 351., 351., 351.],
[351., 351., 351., 351.],
],
[
[1080., 1080., 1080., 1080.],
[1080., 1080., 1080., 1080.],
[1080., 1080., 1080., 1080.],
[1080., 1080., 1080., 1080.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],
[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],
[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],
],
[
[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],
[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],
[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],
],
],
&device,
),
bias: TestTensor::from_floats([36., 36., 36.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_kernel_size() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [1, 1],
kernel_size: [3, 5],
padding: [0, 0],
padding_out: [0, 0],
stride: [1, 1],
dilation: [1, 1],
groups: 1,
size: [6, 6],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[[
[105., 105., 105., 105., 105., 105.],
[105., 105., 105., 105., 105., 105.],
[105., 105., 105., 105., 105., 105.],
[105., 105., 105., 105., 105., 105.],
[105., 105., 105., 105., 105., 105.],
[105., 105., 105., 105., 105., 105.],
]]],
&device,
),
weight: TestTensor::from_floats(
[[[
[630., 630., 630., 630., 630.],
[630., 630., 630., 630., 630.],
[630., 630., 630., 630., 630.],
]]],
&device,
),
bias: TestTensor::from_floats([80.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_groups() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [2, 2],
kernel_size: [3, 3],
padding: [0, 0],
padding_out: [0, 0],
stride: [1, 1],
dilation: [1, 1],
groups: 2,
size: [4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
[36., 36., 36., 36.],
],
[
[117., 117., 117., 117.],
[117., 117., 117., 117.],
[117., 117., 117., 117.],
[117., 117., 117., 117.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]],
[[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]]],
],
&device,
),
bias: TestTensor::from_floats([36., 36.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_complex_no_groups() {
let test = ConvTranspose2dTestCase {
batch_size: 2,
channels: [2, 3],
kernel_size: [3, 5],
padding: [1, 2],
padding_out: [1, 2],
stride: [2, 3],
dilation: [2, 3],
groups: 1,
size: [6, 8],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[600., 735., 735., 735., 735., 735., 735., 735.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
],
[
[1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
],
],
[
[
[600., 735., 735., 735., 735., 735., 735., 735.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
[810., 990., 990., 990., 990., 990., 990., 990.],
],
[
[1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[5320., 6040., 6040., 6040., 6040.],
[6048., 6864., 6864., 6864., 6864.],
[6048., 6864., 6864., 6864., 6864.],
],
[
[5320., 6040., 6040., 6040., 6040.],
[6048., 6864., 6864., 6864., 6864.],
[6048., 6864., 6864., 6864., 6864.],
],
[
[5320., 6040., 6040., 6040., 6040.],
[6048., 6864., 6864., 6864., 6864.],
[6048., 6864., 6864., 6864., 6864.],
],
],
[
[
[8680., 9880., 9880., 9880., 9880.],
[10080., 11472., 11472., 11472., 11472.],
[10080., 11472., 11472., 11472., 11472.],
],
[
[8680., 9880., 9880., 9880., 9880.],
[10080., 11472., 11472., 11472., 11472.],
[10080., 11472., 11472., 11472., 11472.],
],
[
[8680., 9880., 9880., 9880., 9880.],
[10080., 11472., 11472., 11472., 11472.],
[10080., 11472., 11472., 11472., 11472.],
],
],
],
&device,
),
bias: TestTensor::from_floats([896., 896., 896.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_complex_no_groups_2() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [4, 2],
kernel_size: [2, 3],
padding: [1, 2],
padding_out: [1, 2],
stride: [2, 3],
dilation: [1, 2],
groups: 1,
size: [10, 10],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[30., 42., 42., 42., 42., 42., 42., 42., 42., 42.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
],
[
[78., 114., 114., 114., 114., 114., 114., 114., 114., 114.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
],
[
[126., 186., 186., 186., 186., 186., 186., 186., 186., 186.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
],
[
[174., 258., 258., 258., 258., 258., 258., 258., 258., 258.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[[4455., 4905., 4905.], [4500., 4950., 4950.]],
[[4455., 4905., 4905.], [4500., 4950., 4950.]],
],
[
[[12555., 13905., 13905.], [13500., 14950., 14950.]],
[[12555., 13905., 13905.], [13500., 14950., 14950.]],
],
[
[[20655., 22905., 22905.], [22500., 24950., 24950.]],
[[20655., 22905., 22905.], [22500., 24950., 24950.]],
],
[
[[28755., 31905., 31905.], [31500., 34950., 34950.]],
[[28755., 31905., 31905.], [31500., 34950., 34950.]],
],
],
&device,
),
bias: TestTensor::from_floats([570., 570.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose2d_complex_groups() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels: [4, 2],
kernel_size: [2, 3],
padding: [1, 2],
padding_out: [1, 2],
stride: [2, 3],
dilation: [1, 2],
groups: 2,
size: [10, 10],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[9., 12., 12., 12., 12., 12., 12., 12., 12., 12.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
],
[
[21., 30., 30., 30., 30., 30., 30., 30., 30., 30.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
],
[
[33., 48., 48., 48., 48., 48., 48., 48., 48., 48.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
],
[
[45., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[[4455., 4905., 4905.], [4500., 4950., 4950.]]],
[[[12555., 13905., 13905.], [13500., 14950., 14950.]]],
[[[20655., 22905., 22905.], [22500., 24950., 24950.]]],
[[[28755., 31905., 31905.], [31500., 34950., 34950.]]],
],
&device,
),
bias: TestTensor::from_floats([570., 570.], &device),
};
test.assert_grads(grads);
}
struct ConvTranspose2dTestCase {
batch_size: usize,
channels: [usize; 2],
kernel_size: [usize; 2],
padding: [usize; 2],
padding_out: [usize; 2],
stride: [usize; 2],
dilation: [usize; 2],
groups: usize,
size: [usize; 2],
}
struct Grads {
x: TestTensor<4>,
weight: TestTensor<4>,
bias: TestTensor<1>,
}
impl ConvTranspose2dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([
self.batch_size,
self.channels[0],
self.size[0],
self.size[1],
]);
let shape_weight = Shape::new([
self.channels[0],
self.channels[1] / self.groups,
self.kernel_size[0],
self.kernel_size[1],
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<4, _>(shape_weight)
.into_data(),
&device,
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<4, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = conv_transpose2d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvTransposeOptions::new(
self.stride,
self.padding,
self.padding_out,
self.dilation,
self.groups,
),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
let tolerance = Tolerance::permissive();
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
}
}

View File

@@ -0,0 +1,711 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv_transpose3d, ops::ConvTransposeOptions};
#[test]
fn test_conv_transpose3d_basic() {
let test = ConvTranspose3dTestCase {
batch_size: 2,
channels: [2, 2],
kernel_size: [3, 3, 3],
padding: [0, 0, 0],
padding_out: [0, 0, 0],
stride: [1, 1, 1],
dilation: [1, 1, 1],
groups: 1,
size: [4, 4, 4],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
],
[
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
],
],
[
[
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
[
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
[13.250001, 13.250001, 13.250001, 13.250001],
],
],
[
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
[
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
[40.249992, 40.249992, 40.249992, 40.249992],
],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
],
[
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
],
[
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
],
],
[
[
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
],
[
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
],
[
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
[47.750000, 47.750000, 47.750000],
],
],
],
[
[
[
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
],
[
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
],
[
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
],
],
[
[
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
],
[
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
],
[
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
[79.750000, 79.750000, 79.750000],
],
],
],
],
&device,
),
bias: TestTensor::from_floats([432., 432.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv_transpose3d_complex_groups() {
let test = ConvTranspose3dTestCase {
batch_size: 1,
channels: [4, 2],
kernel_size: [2, 3, 4],
padding: [1, 2, 3],
padding_out: [1, 2, 3],
stride: [2, 3, 4],
dilation: [1, 2, 3],
groups: 2,
size: [6, 6, 6],
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[
[1.250000, 1.625000, 1.625000, 1.625000, 1.625000, 1.625000],
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
],
[
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
],
[
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
],
[
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
],
[
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
],
[
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
],
],
[
[
[2.750000, 3.625000, 3.625000, 3.625000, 3.625000, 3.625000],
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
],
[
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
],
[
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
],
[
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
],
[
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
],
[
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
],
],
[
[
[4.250000, 5.625000, 5.625000, 5.625000, 5.625000, 5.625000],
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
],
[
[
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
],
[
[
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
],
[
[
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
],
[
[
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
],
[
[
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
[
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
],
],
],
[
[
[5.750000, 7.625000, 7.625000, 7.625000, 7.625000, 7.625000],
[
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
],
[
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
],
[
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
],
[
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
],
[
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
],
],
[
[
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
],
[
[
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
],
[
[
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
],
[
[
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
],
[
[
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
[
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
],
],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[
[
[18.663193, 22.309027, 22.309027, 22.309027],
[21.875000, 26.145834, 26.145834, 26.145834],
[21.875000, 26.145834, 26.145834, 26.145834],
],
[
[19.270832, 23.020834, 23.020834, 23.020834],
[22.500000, 26.875002, 26.875002, 26.875002],
[22.500000, 26.875002, 26.875002, 26.875002],
],
]],
[[
[
[49.913193, 59.809029, 59.809029, 59.809029],
[59.375000, 71.145836, 71.145836, 71.145836],
[59.375000, 71.145836, 71.145836, 71.145836],
],
[
[56.770836, 68.020836, 68.020836, 68.020836],
[67.500000, 80.875000, 80.875000, 80.875000],
[67.500000, 80.875000, 80.875000, 80.875000],
],
]],
[[
[
[81.163193, 97.309029, 97.309029, 97.309029],
[96.875000, 116.145828, 116.145828, 116.145828],
[96.875000, 116.145828, 116.145828, 116.145828],
],
[
[94.270828, 113.020828, 113.020828, 113.020828],
[112.500000, 134.875000, 134.875000, 134.875000],
[112.500000, 134.875000, 134.875000, 134.875000],
],
]],
[[
[
[112.413200, 134.809021, 134.809021, 134.809021],
[134.375000, 161.145828, 161.145828, 161.145828],
[134.375000, 161.145828, 161.145828, 161.145828],
],
[
[131.770844, 158.020828, 158.020828, 158.020828],
[157.500000, 188.875000, 188.875000, 188.875000],
[157.500000, 188.875000, 188.875000, 188.875000],
],
]],
],
&device,
),
bias: TestTensor::from_floats([5346., 5346.], &device),
};
test.assert_grads(grads);
}
struct ConvTranspose3dTestCase {
batch_size: usize,
channels: [usize; 2],
kernel_size: [usize; 3],
padding: [usize; 3],
padding_out: [usize; 3],
stride: [usize; 3],
dilation: [usize; 3],
groups: usize,
size: [usize; 3],
}
struct Grads {
x: TestTensor<5>,
weight: TestTensor<5>,
bias: TestTensor<1>,
}
impl ConvTranspose3dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([
self.batch_size,
self.channels[0],
self.size[0],
self.size[1],
self.size[2],
]);
let shape_weight = Shape::new([
self.channels[0],
self.channels[1] / self.groups,
self.kernel_size[0],
self.kernel_size[1],
self.kernel_size[2],
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<5, _>(shape_weight.clone())
.into_data(),
&device,
)
.div_scalar(shape_weight.num_elements() as f32)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<5, _>(shape_x.clone())
.into_data(),
&device,
)
.div_scalar(shape_x.num_elements() as f32)
.require_grad();
let output = conv_transpose3d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvTransposeOptions::new(
self.stride,
self.padding,
self.padding_out,
self.dilation,
self.groups,
),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
let tolerance = Tolerance::permissive();
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
}
}

View File

@@ -0,0 +1,103 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[cfg(feature = "std")]
use burn_backend_tests::might_panic;
#[test]
fn backward_basic() {
let device = Default::default();
let a = TestAutodiffTensor::<2>::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
)
.require_grad();
let b = TestAutodiffTensor::<2>::from_data(
TensorData::from([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
&device,
)
.require_grad();
// Simple cross product; grad is a vector of ones.
let c = a.clone().cross(b.clone(), 1);
let grads = c.backward();
let a_grad = a.grad(&grads).unwrap().to_data();
let b_grad = b.grad(&grads).unwrap().to_data();
// For a: b×grad_out, where grad_out = [1,1,1]
let expected_a = TensorData::from([[-1.0, 2.0, -1.0], [-1.0, 2.0, -1.0]]);
// For b: grad_out×a
let expected_b = TensorData::from([[1.0, -2.0, 1.0], [1.0, -2.0, 1.0]]);
a_grad.assert_approx_eq::<FloatElem>(&expected_a, Tolerance::default());
b_grad.assert_approx_eq::<FloatElem>(&expected_b, Tolerance::default());
}
#[test]
fn backward_after_sum() {
let device = Default::default();
let a = TestAutodiffTensor::<2>::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
)
.require_grad();
let b = TestAutodiffTensor::<2>::from_data(
TensorData::from([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
&device,
)
.require_grad();
// Sum reduces to scalar, but the gradient should be the same.
let c = a.clone().cross(b.clone(), 1).sum();
let grads = c.backward();
let a_grad = a.grad(&grads).unwrap().to_data();
let b_grad = b.grad(&grads).unwrap().to_data();
let expected_a = TensorData::from([[-1.0, 2.0, -1.0], [-1.0, 2.0, -1.0]]);
let expected_b = TensorData::from([[1.0, -2.0, 1.0], [1.0, -2.0, 1.0]]);
a_grad.assert_approx_eq::<FloatElem>(&expected_a, Tolerance::default());
b_grad.assert_approx_eq::<FloatElem>(&expected_b, Tolerance::default());
}
#[cfg(feature = "std")]
#[might_panic(reason = "not implemented: Cross product on non-last dimension")]
#[test]
fn different_dim() {
// Also check when the cross is along a different dimension (e.g. dim 0).
let device = Default::default();
let a_raw = [[1.0, 4.0, 7.0], [2.0, 5.0, 8.0], [3.0, 6.0, 9.0]];
let b_raw = [[9.0, 6.0, 3.0], [8.0, 5.0, 2.0], [7.0, 4.0, 1.0]];
let a = TestTensor::<2>::from_data(TensorData::from(a_raw), &device);
let b = TestTensor::<2>::from_data(TensorData::from(b_raw), &device);
// Cross along dim 0. Some backends (for example CubeCL) may not support
// cross on non-last dimensions and will intentionally panic with a
// message like "Cross product on non-last dimension not yet implemented".
// In that case we treat the panic as a skipped test for that backend.
let out = a.cross(b.clone(), 0);
// Manually compute cross of each column vector using raw arrays
let expected = [
[
a_raw[1][0] * b_raw[2][0] - a_raw[2][0] * b_raw[1][0],
a_raw[1][1] * b_raw[2][1] - a_raw[2][1] * b_raw[1][1],
a_raw[1][2] * b_raw[2][2] - a_raw[2][2] * b_raw[1][2],
],
[
a_raw[2][0] * b_raw[0][0] - a_raw[0][0] * b_raw[2][0],
a_raw[2][1] * b_raw[0][1] - a_raw[0][1] * b_raw[2][1],
a_raw[2][2] * b_raw[0][2] - a_raw[0][2] * b_raw[2][2],
],
[
a_raw[0][0] * b_raw[1][0] - a_raw[1][0] * b_raw[0][0],
a_raw[0][1] * b_raw[1][1] - a_raw[1][1] * b_raw[0][1],
a_raw[0][2] * b_raw[1][2] - a_raw[1][2] * b_raw[0][2],
],
];
out.to_data()
.assert_approx_eq::<FloatElem>(&TensorData::from(expected), Tolerance::default());
}

View File

@@ -0,0 +1,33 @@
use super::*;
use burn_tensor::{Tensor, TensorData, Tolerance, loss};
#[test]
fn test_cross_entropy_loss_grad() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let data_targets = TensorData::from([[0.8, 0.2], [0.9, 0.1]]);
let device = Default::default();
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();
let tensor_targets =
Tensor::<TestAutodiffBackend, 2>::from_data(data_targets, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = loss::cross_entropy_with_logits(tensor_3, tensor_targets);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::permissive();
let expected = TensorData::from([[0.26553, 0.26553], [0.44954, 0.44954]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[-1.34863, 1.34863], [-2.06371, 2.06371]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,117 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_cummax() {
// Simple test to verify cummax gradients work
let device = Default::default();
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 3.0, 2.0]), &device)
.require_grad();
let output = tensor.clone().cummax(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 2.0, 0.0]
let expected = TensorData::from([1.0, 2.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummax_2d() {
// Test 2D cummax gradients
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(
TensorData::from([[1.0, 3.0, 2.0], [2.0, 5.0, 4.0]]),
&device,
)
.require_grad();
let output = tensor.clone().cummax(1);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]
let expected = TensorData::from([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummax_duplicate_values() {
// Test with duplicate maximum values - critical edge case
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 3.0, 3.0, 2.0]), &device)
.require_grad();
let output = tensor.clone().cummax(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// input: [1.0, 3.0, 3.0, 2.0]
// cummax: [1.0, 3.0, 3.0, 3.0]
// PyTorch reference: [1.0, 1.0, 2.0, 0.0]
// Position 2 gets grad from itself + position 3
let expected = TensorData::from([1.0, 1.0, 2.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummax_all_same() {
// Test with all same values
let device = Default::default();
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 2.0, 2.0]), &device)
.require_grad();
let output = tensor.clone().cummax(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 1.0, 1.0]
// Each position matches cummax, so each gets its own gradient
let expected = TensorData::from([1.0, 1.0, 1.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummax_increasing() {
// Test with increasing sequence
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 2.0, 3.0, 4.0]), &device)
.require_grad();
let output = tensor.clone().cummax(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 1.0, 1.0, 1.0]
// Each position is a new maximum
let expected = TensorData::from([1.0, 1.0, 1.0, 1.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummax_2d_duplicates() {
// Test 2D with duplicate values
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(
TensorData::from([[1.0, 3.0, 3.0, 2.0], [2.0, 5.0, 5.0, 4.0]]),
&device,
)
.require_grad();
let output = tensor.clone().cummax(1);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]
let expected = TensorData::from([[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,117 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_cummin() {
// Simple test to verify cummin gradients work
let device = Default::default();
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([3.0, 2.0, 4.0]), &device)
.require_grad();
let output = tensor.clone().cummin(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 2.0, 0.0]
let expected = TensorData::from([1.0, 2.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummin_2d() {
// Test 2D cummin gradients
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(
TensorData::from([[3.0, 2.0, 4.0], [5.0, 1.0, 3.0]]),
&device,
)
.require_grad();
let output = tensor.clone().cummin(1);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]
let expected = TensorData::from([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummin_duplicate_values() {
// Test with duplicate minimum values - critical edge case
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([3.0, 2.0, 2.0, 4.0]), &device)
.require_grad();
let output = tensor.clone().cummin(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// input: [3.0, 2.0, 2.0, 4.0]
// cummin: [3.0, 2.0, 2.0, 2.0]
// PyTorch reference: [1.0, 1.0, 2.0, 0.0]
// Position 2 gets grad from itself + position 3
let expected = TensorData::from([1.0, 1.0, 2.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummin_all_same() {
// Test with all same values
let device = Default::default();
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 2.0, 2.0]), &device)
.require_grad();
let output = tensor.clone().cummin(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 1.0, 1.0]
// Each position matches cummin, so each gets its own gradient
let expected = TensorData::from([1.0, 1.0, 1.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummin_decreasing() {
// Test with decreasing sequence
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([5.0, 4.0, 3.0, 2.0]), &device)
.require_grad();
let output = tensor.clone().cummin(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 1.0, 1.0, 1.0]
// Each position is a new minimum
let expected = TensorData::from([1.0, 1.0, 1.0, 1.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cummin_2d_duplicates() {
// Test 2D with duplicate values
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(
TensorData::from([[3.0, 2.0, 2.0, 4.0], [5.0, 1.0, 1.0, 3.0]]),
&device,
)
.require_grad();
let output = tensor.clone().cummin(1);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]
let expected = TensorData::from([[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,132 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_cumprod() {
// Simple test to verify cumprod gradients work
let device = Default::default();
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 3.0, 4.0]), &device)
.require_grad();
let output = tensor.clone().cumprod(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [16.0, 10.0, 6.0]
let expected = TensorData::from([16.0, 10.0, 6.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cumprod_2d() {
// Test 2D cumprod gradients
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
)
.require_grad();
let output = tensor.clone().cumprod(1);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [[9.0, 4.0, 2.0], [36.0, 28.0, 20.0]]
let expected = TensorData::from([[9.0, 4.0, 2.0], [36.0, 28.0, 20.0]]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
// TODO: The following tests are currently ignored due to a known limitation
// in the cumprod gradient implementation. The current implementation uses
// division (grad / input), which produces NaN when the input contains zeros.
//
// A proper fix requires implementing a zero-safe algorithm using exclusive
// cumulative products (similar to PyTorch's cumprod_backward or JAX's
// associative_scan approach). This is a non-trivial implementation that
// requires careful handling of cumulative products in both forward and
// reverse directions.
//
// See: https://github.com/tracel-ai/burn/issues/3864
//
// References:
// - PyTorch: https://github.com/pytorch/pytorch (cumprod_backward)
// - JAX PR #2596: Parallel prefix scan implementation
// - TensorFlow Issue #3862: tf.cumprod's gradient produces nans given zeros
#[test]
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
fn should_diff_cumprod_zero_in_middle() {
// Test cumprod with zero in the middle - edge case for division
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 0.0, 3.0, 4.0]), &device)
.require_grad();
let output = tensor.clone().cumprod(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 32.0, 0.0, 0.0]
let expected = TensorData::from([1.0, 32.0, 0.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
fn should_diff_cumprod_zero_at_start() {
// Test cumprod with zero at the beginning
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([0.0, 2.0, 3.0, 4.0]), &device)
.require_grad();
let output = tensor.clone().cumprod(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [33.0, 0.0, 0.0, 0.0]
let expected = TensorData::from([33.0, 0.0, 0.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
fn should_diff_cumprod_zero_at_end() {
// Test cumprod with zero at the end
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 3.0, 4.0, 0.0]), &device)
.require_grad();
let output = tensor.clone().cumprod(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [16.0, 10.0, 6.0, 24.0]
let expected = TensorData::from([16.0, 10.0, 6.0, 24.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
fn should_diff_cumprod_multiple_zeros() {
// Test cumprod with multiple zeros
let device = Default::default();
let tensor =
TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 0.0, 3.0, 0.0, 5.0]), &device)
.require_grad();
let output = tensor.clone().cumprod(0);
let grads = output.sum().backward();
let grad = tensor.grad(&grads).unwrap();
// PyTorch reference: [1.0, 8.0, 0.0, 0.0, 0.0]
let expected = TensorData::from([1.0, 8.0, 0.0, 0.0, 0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,89 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_cumsum_dim0() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.cumsum(0);
let tensor_5 = tensor_1.clone().mul(tensor_4);
let grads = tensor_5.sum().backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
// Expected gradients computed with PyTorch
let expected = TensorData::from([[-14.0, 24.0], [17.0, 6.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[3.0, 10.0], [-1.0, 37.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cumsum_dim1() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.cumsum(1);
let tensor_5 = tensor_1.clone().mul(tensor_4);
let grads = tensor_5.sum().backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
// Expected gradients computed with PyTorch
let expected = TensorData::from([[1.0, 69.0], [-13.0, -28.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[18.0, 13.0], [71.0, 58.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_cumsum_complex() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.clone().cumsum(1);
let tensor_5 = tensor_4.mul(tensor_3);
let grads = tensor_5.sum().backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
// Expected gradients computed with PyTorch
let expected = TensorData::from([[371.0, 542.0], [2246.0, 3281.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[507.0, 528.0], [704.0, 733.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,105 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_div() {
let data_1 = TensorData::from([1.0, 7.0]);
let data_2 = TensorData::from([4.0, 7.0]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().div(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([0.25, 0.14285715]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([-0.0625, -0.14285715]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_div_scalar() {
let data = TensorData::from([1.0, 7.0]);
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
let tensor_out = tensor.clone().div_scalar(4.0);
let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();
grad.to_data()
.assert_eq(&TensorData::from([0.25, 0.25]), false);
}
#[test]
fn test_div_complex_1() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().div(tensor_2.clone());
let tensor_5 = tensor_4.div(tensor_3.clone());
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let grad_3 = tensor_3.grad(&grads).unwrap();
let expected = TensorData::from([[0.1250, 0.07142857], [0.25, 0.16666667]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[-0.03125, -0.07142857], [-1.6250, 0.16666667]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[-0.0625, -0.25], [-1.6250, 0.25]]);
grad_3
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn test_div_complex_2() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.div(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_absolute(2e-3);
let expected = TensorData::from([[2.00, 2.92857146], [1.36666667, 2.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[0.08333334, 0.09591837], [-0.05555558, -0.06714284]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,29 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_erf() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().erf());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[32.0, 32.0], [32.0, 32.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[8.0, 8.0], [8.0, 8.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,29 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_exp() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().exp());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default();
let expected = TensorData::from([[54.5991, 27.4746], [54.5991, 27.4746]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[-5.4598e+01, -9.1188e-04], [2.9556e+01, 8.0342e+01]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,39 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_expand() {
// Python code to generate the test case values
// import torch
// x1 = torch.tensor([4.0, 7.0, 2.0, 3.0], requires_grad=True)
// x2 = torch.tensor([2.0, 4.5, 7.0, 3.0], requires_grad=True)
// y = x1.expand(4, 4)
// z = (x2 * y).sum()
// z.backward()
// print("x1", x1.grad)
// print("x2", x2.grad)
let device = Default::default();
let data_1 = TensorData::from([4.0, 7.0, 2.0, 3.0]);
let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad();
let data_2 = TensorData::from([2.0, 4.5, 7.0, 3.0]);
let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().expand([4, 4]);
// Use unsqueeze to make tensor_2 have the same shape as tensor_3
let tensor_4 = tensor_2.clone().unsqueeze().mul(tensor_3).sum();
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([8., 18., 28., 12.]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([16., 28., 8., 12.]), false);
}

View File

@@ -0,0 +1,29 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_flip() {
let data_1 = TensorData::from([[[1.0, 7.0], [2.0, 3.0]]]); // 1x2x2
let data_2 = TensorData::from([[[3.0, 2.0, 7.0], [3.0, 3.2, 1.0]]]); // 1x2x3
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<3>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_2.clone().flip([1, 2]);
let tensor_4 = tensor_1.clone().matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
grad_1
.into_data()
.assert_approx_eq::<FloatElem>(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), tolerance); // 1x2x2
grad_2.into_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[[10.0, 10.0, 10.0], [3.0, 3.0, 3.0]]]),
tolerance,
); // 1x2x3
}

View File

@@ -0,0 +1,21 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_floor() {
let data = TensorData::from([
[-1.9751, 0.0714, 0.0643, 0.2406],
[-1.3172, 0.1252, -0.1119, -0.0127],
]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
let tensor_2 = tensor_1.clone().floor();
let grads = tensor_2.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
grad_1.to_data().assert_eq(
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
false,
);
}

View File

@@ -0,0 +1,99 @@
use super::*;
use burn_tensor::{IndexingUpdateOp, Int, Tensor, TensorData};
#[test]
fn test_gather_grad() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(
TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),
&device,
)
.require_grad();
let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(
TensorData::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]]),
&device,
);
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1.clone().gather(1, indices);
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
grad_1.to_data().assert_eq(
&TensorData::from([[94., 150., 187.], [242., 305., 304.]]),
false,
);
}
#[test]
fn test_scatter_grad() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(
TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),
&device,
)
.require_grad();
let values = TestAutodiffTensor::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
)
.require_grad();
let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(
TensorData::from([[2, 1, 0], [2, 0, 1]]),
&device,
);
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1
.clone()
.scatter(1, indices, values.clone(), IndexingUpdateOp::Add);
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = values.grad(&grads).unwrap();
grad_1.to_data().assert_eq(
&TensorData::from([[127., 181., 235.], [226., 316., 406.]]),
false,
);
grad_2
.to_data()
.assert_eq(&TensorData::from([[19., 19., 19.], [64., 64., 64.]]), false);
}
#[test]
fn test_scatter_add_grad_partial_indices() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::from_data(TensorData::from([[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]]), &device)
.require_grad();
let tensor_2 =
TestAutodiffTensor::from_data(TensorData::from([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]), &device)
.require_grad();
let values =
TestAutodiffTensor::from_data(TensorData::from([[4.0, 5.0, 6.0]]), &device).require_grad();
let indices =
Tensor::<TestAutodiffBackend, 2, Int>::from_data(TensorData::from([[2, 1, 0]]), &device);
let tensor_3 = tensor_1.clone().mul(tensor_2);
let tensor_4 = tensor_3
.clone()
.scatter(1, indices, values.clone(), IndexingUpdateOp::Add);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = values.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[1., 2., 3., 4., 5., 6.]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[1., 1., 1.]]), false);
}

View File

@@ -0,0 +1,29 @@
use super::*;
use burn_tensor::{TensorData, Tolerance, activation};
#[test]
fn should_diff_gelu() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_floats([[0.0, 1.0], [-3.0, 4.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_floats([[6.0, -0.5], [9.0, 10.0]], &device).require_grad();
let x = tensor_1.clone().matmul(activation::gelu(tensor_2.clone()));
let x = tensor_1.clone().matmul(x);
let grads = x.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::permissive();
let expected = TensorData::from([[1.46281, 1.46281], [48.22866, 153.46280]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[-15.0000, -1.98757], [17.0000, 17.0000]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,24 @@
use super::*;
use burn_tensor::{Distribution, activation};
#[test]
fn should_update_tensor_when_grad_replace() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::random([32, 32], Distribution::Default, &device).require_grad();
let tensor_2 = TestAutodiffTensor::random([32, 32], Distribution::Default, &device);
let x = tensor_1.clone().matmul(activation::gelu(tensor_2));
let mut grads = x.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_1_updated =
TestAutodiffTensor::random([32, 32], Distribution::Default, &device).require_grad();
tensor_1.grad_replace(&mut grads, grad_1_updated.clone().inner());
let grad_1_new = tensor_1.grad(&grads).unwrap();
assert_ne!(grad_1_new.to_data(), grad_1.into_data());
assert_eq!(grad_1_new.into_data(), grad_1_updated.into_data());
}

View File

@@ -0,0 +1,30 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_diff_log() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
let expected = TensorData::from([[60.2652, 72.3130], [60.2652, 72.3130]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[22.8614, 24.5043], [24.5729, 26.8507]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,28 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_log1p() {
let tensor_1 = TestAutodiffTensor::<2>::from([[0.0, 1.0], [3.0, 4.0]]).require_grad();
let tensor_2 = TestAutodiffTensor::from([[6.0, 7.0], [9.0, 10.0]]).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log1p());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
let expected = TensorData::from([[64.80622101, 75.49362183], [64.80622101, 75.49362183]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[22.92208481, 24.47565651], [24.72780228, 26.86416626]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,19 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{TensorData, activation};
#[test]
fn should_diff_log_sigmoid() {
let data = TensorData::from([[0.8762, -0.1423], [-300., 200.]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
let tensor_2 = activation::log_sigmoid(tensor_1.clone());
let grads = tensor_2.backward();
let grad = tensor_1.grad(&grads).unwrap();
let expected = TensorData::from([[0.293966, 0.535515], [1.000000, 0.000000]]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,65 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{Bool, Tensor, TensorData};
#[test]
fn should_diff_mask_fill() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let mask = TensorData::from([[true, false], [false, true]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let mask = Tensor::<TestAutodiffBackend, 2, Bool>::from_bool(mask, &device);
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.mask_fill(mask, 2.0);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[7.0, 3.0], [4.0, 2.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[2.0, 1.0], [3.0, 7.0]]), false);
}
#[test]
fn should_diff_mask_where() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data([[1.0, 7.0], [2.0, 3.0]], &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data([[4.0, 7.0], [2.0, 3.0]], &device).require_grad();
let tensor_3 =
TestAutodiffTensor::from_data([[8.8, 9.8], [10.8, 11.8]], &device).require_grad();
let mask =
Tensor::<TestAutodiffBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_5 = tensor_4.clone().matmul(tensor_3.clone());
let tensor_6 = tensor_5.mask_where(mask, tensor_3.clone());
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let grad_3 = tensor_3.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
let expected = TensorData::from([[121.8, 55.0], [110.8, 50.0]]);
grad_1
.into_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[27.4, 33.4], [95.0, 115.0]]);
grad_2
.into_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[15., 18.], [23., 29.]]);
grad_3
.into_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,83 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_matmul() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[3.0, 3.0], [10.0, 10.0]]), false);
tensor_3
.to_data()
.assert_eq(&TensorData::from([[18.0, 28.0], [14.0, 23.0]]), false);
}
#[test]
fn test_matmul_complex_1() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_5 = tensor_4.matmul(tensor_3);
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[44.0, 20.0], [44.0, 20.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[56.0, 56.0], [16.0, 16.0]]), false);
}
#[test]
fn test_matmul_complex_2() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_5 = tensor_4.matmul(tensor_3.clone());
let tensor_6 = tensor_1.clone().matmul(tensor_5);
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[800.0, 792.0], [360.0, 592.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[264., 264.0], [344.0, 344.0]]), false);
}

View File

@@ -0,0 +1,82 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_max_dim() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.max_dim(1).unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[50.0, 34.0], [40.0, -10.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[8.0, 10.0], [56.0, 15.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_min_dim() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.min_dim(1).unsqueeze());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[-42.0, 38.0], [-34.0, -24.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[10.0, 8.0], [15.0, 56.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_min_dim_3d_dim1() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<3>::from_floats([[[1.0, 7.0], [-2.0, -3.0]]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::<3>::from_floats([[[4., -7.], [2., 3.]]], &device).require_grad();
let tensor_3 = tensor_1.clone().mul(tensor_2.clone());
let tensor_4 = tensor_3.min_dim(1);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[[0., -7.], [2., 0.]]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[[0., 7.], [-2., -0.]]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,134 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::module::max_pool1d;
#[test]
fn test_max_pool1d_simple() {
let kernel_size = 4;
let padding = 0;
let stride = 1;
let dilation = 1;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221]]],
&device,
)
.require_grad();
let x_grad_expected =
TestAutodiffTensor::<3>::from_floats([[[1., 1., 0., 0., 0., 1.]]], &device);
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool1d_with_dilation() {
let kernel_size = 4;
let padding = 0;
let stride = 1;
let dilation = 2;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
0.4610, 0.5365, 0.6880,
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<3>::from_floats(
[[[
0., 0., 1., 0., 0., 3., 0., 1., 2., 1., 0., 0., 2., 0., 0., 0., 4., 4., 0., 0., 0., 0.,
0., 0., 1.,
]]],
&device,
);
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool1d_complex() {
let kernel_size = 4;
let padding = 0;
let stride = 1;
let dilation = 1;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
0.4610, 0.5365, 0.6880,
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<3>::from_floats(
[[[
0., 0., 0., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0.,
1., 1., 1.,
]]],
&device,
);
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool1d_complex_with_padding() {
let kernel_size = 4;
let padding = 2;
let stride = 1;
let dilation = 1;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
0.4610, 0.5365, 0.6880,
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<3>::from_floats(
[[[
1., 0., 1., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0.,
1., 1., 3.,
]]],
&device,
);
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}

View File

@@ -0,0 +1,271 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::module::max_pool2d;
#[test]
fn test_max_pool2d_simple_1() {
let kernel_size_1 = 3;
let kernel_size_2 = 3;
let padding_1 = 0;
let padding_2 = 0;
let stride_1 = 1;
let stride_2 = 1;
let dilation_1 = 1;
let dilation_2 = 1;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
[0.2479, 0.6386, 0.3166, 0.5742],
[0.7065, 0.1940, 0.6305, 0.8959],
[0.5416, 0.8602, 0.8129, 0.1662],
[0.3358, 0.3059, 0.8293, 0.0990],
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
[[[
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 2.0],
[0.0, 2.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
]]],
&device,
);
let output = max_pool2d(
x.clone(),
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
false,
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool2d_simple_2() {
let kernel_size_1 = 2;
let kernel_size_2 = 2;
let padding_1 = 1;
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 1;
let dilation_1 = 1;
let dilation_2 = 1;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
[0.2479, 0.6386, 0.3166, 0.5742],
[0.7065, 0.1940, 0.6305, 0.8959],
[0.5416, 0.8602, 0.8129, 0.1662],
[0.3358, 0.3059, 0.8293, 0.0990],
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
[[[
[1., 3., 0., 2.],
[3., 0., 0., 4.],
[1., 4., 0., 1.],
[2., 0., 3., 1.],
]]],
&device,
);
let output = max_pool2d(
x.clone(),
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
false,
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool2d_with_dilation() {
let kernel_size_1 = 2;
let kernel_size_2 = 2;
let padding_1 = 1;
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 1;
let dilation_1 = 2;
let dilation_2 = 2;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
[0.2479, 0.6386, 0.3166, 0.5742],
[0.7065, 0.1940, 0.6305, 0.8959],
[0.5416, 0.8602, 0.8129, 0.1662],
[0.3358, 0.3059, 0.8293, 0.0990],
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
[[[
[0., 0., 0., 0.],
[1., 1., 1., 2.],
[0., 4., 4., 0.],
[0., 1., 2., 0.],
]]],
&device,
);
let output = max_pool2d(
x.clone(),
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
false,
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool2d_complex() {
let kernel_size_1 = 4;
let kernel_size_2 = 2;
let padding_1 = 2;
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 2;
let dilation_1 = 1;
let dilation_2 = 1;
let device = Default::default();
let x = TestAutodiffTensor::from_floats(
[[[
[0.5388, 0.0676, 0.7122, 0.8316, 0.0653],
[0.9154, 0.1536, 0.9089, 0.8016, 0.7518],
[0.2073, 0.0501, 0.8811, 0.5604, 0.5075],
[0.4384, 0.9963, 0.9698, 0.4988, 0.2609],
[0.3391, 0.2230, 0.4610, 0.5365, 0.6880],
]]],
&device,
)
.require_grad();
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
[[[
[0., 0., 0., 3., 0.],
[4., 0., 2., 1., 0.],
[0., 0., 0., 0., 0.],
[2., 4., 0., 0., 0.],
[0., 0., 0., 0., 2.],
]]],
&device,
);
let output = max_pool2d(
x.clone(),
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
false,
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}
#[test]
fn test_max_pool2d_ceil_mode() {
// Test ceil_mode=true with gradient computation
// Using 1x1x6x6 input with kernel 3x3, stride 2x2, padding 0
// Floor mode: output 2x2
// Ceil mode: output 3x3
let kernel_size_1 = 3;
let kernel_size_2 = 3;
let padding_1 = 0;
let padding_2 = 0;
let stride_1 = 2;
let stride_2 = 2;
let dilation_1 = 1;
let dilation_2 = 1;
let device = Default::default();
// Input (values 1-36):
let x = TestAutodiffTensor::from_floats(
[[[
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
[7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0, 17.0, 18.0],
[19.0, 20.0, 21.0, 22.0, 23.0, 24.0],
[25.0, 26.0, 27.0, 28.0, 29.0, 30.0],
[31.0, 32.0, 33.0, 34.0, 35.0, 36.0],
]]],
&device,
)
.require_grad();
// Expected gradients for ceil_mode output 3x3:
// Output positions and their max value positions:
// (0,0): max at (2,2)=15 -> grad[2,2] += 1
// (0,1): max at (2,4)=17 -> grad[2,4] += 1
// (0,2): max at (2,5)=18 -> grad[2,5] += 1
// (1,0): max at (4,2)=27 -> grad[4,2] += 1
// (1,1): max at (4,4)=29 -> grad[4,4] += 1
// (1,2): max at (4,5)=30 -> grad[4,5] += 1
// (2,0): max at (5,2)=33 -> grad[5,2] += 1
// (2,1): max at (5,4)=35 -> grad[5,4] += 1
// (2,2): max at (5,5)=36 -> grad[5,5] += 1
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
[[[
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 1., 1.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 1., 1.],
[0., 0., 1., 0., 1., 1.],
]]],
&device,
);
let output = max_pool2d(
x.clone(),
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
true,
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
}

View File

@@ -0,0 +1,290 @@
use super::*;
use burn_tensor::{Tensor, TensorData};
#[test]
fn test_mm_independent_trees() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// First tree
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_4 = tensor_0 * tensor_1;
let tensor_5 = tensor_2 * tensor_3;
let tensor_6 = tensor_4 * tensor_5;
// Second tree
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_11 = tensor_7.clone() * tensor_8.clone();
let tensor_12 = tensor_9.clone() * tensor_10.clone();
let tensor_13 = tensor_11 * tensor_12;
let _grads = tensor_6.backward();
let grads = tensor_13.backward();
assert!(tensor_7.grad(&grads).is_some());
assert!(tensor_8.grad(&grads).is_some());
assert!(tensor_9.grad(&grads).is_some());
assert!(tensor_10.grad(&grads).is_some());
}
#[test]
#[should_panic]
fn test_mm_crossover_trees_root_unavailable() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// First tree
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_4 = tensor_0 * tensor_1;
let tensor_5 = tensor_2 * tensor_3;
let tensor_6 = tensor_4.clone() * tensor_5;
// Second tree
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_9 = tensor_7.clone() * tensor_8.clone();
let tensor_10 = tensor_4 * tensor_9;
let _grads = tensor_6.backward();
let _grads = tensor_10.backward();
}
#[test]
fn test_mm_crossover_trees_with_referred_subtree() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// First tree
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_4 = tensor_0 * tensor_1;
let tensor_5 = tensor_2 * tensor_3;
let tensor_6 = tensor_4.clone() * tensor_5;
// Second tree
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_9 = tensor_7.clone() * tensor_8.clone();
let _tensor_10 = tensor_4 * tensor_9.clone();
let _grads = tensor_6.backward();
let _grads = tensor_9.backward();
}
#[test]
fn test_mm_three_crossover_trees_last_still_usable() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// First tree
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_4 = tensor_0 * tensor_1;
let tensor_5 = tensor_2 * tensor_3;
let tensor_6 = tensor_4 * tensor_5.clone();
// Third tree
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_11 = tensor_7 * tensor_8;
let tensor_12 = tensor_9 * tensor_10;
let tensor_13 = tensor_11 * tensor_12.clone();
// Second tree (in between)
let _tensor_14 = tensor_5 * tensor_12;
let _grads = tensor_6.backward();
let _grads = tensor_13.backward();
}
#[test]
#[should_panic]
fn test_mm_three_crossover_trees_middle_one_unavailable() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// First tree
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_4 = tensor_0 * tensor_1;
let tensor_5 = tensor_2 * tensor_3;
let tensor_6 = tensor_4 * tensor_5.clone();
// Third tree
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_11 = tensor_7 * tensor_8;
let tensor_12 = tensor_9 * tensor_10;
let _tensor_13 = tensor_11 * tensor_12.clone();
// Second tree (in between)
let tensor_14 = tensor_5 * tensor_12;
let _grads = tensor_6.backward();
let _grads = tensor_14.backward();
}
#[test]
fn test_mm_self_referencing_tree() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// First tree
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
let tensor_3 = tensor_0 * tensor_1;
let tensor_5 = tensor_2 * tensor_3.clone();
let tensor_6 = tensor_3 * tensor_5;
let _grads = tensor_6.backward();
}
#[test]
fn test_mm_with_non_impacting_detach() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_3 = Tensor::<TestAutodiffBackend, 2>::from_data(data, &device).require_grad();
let tensor_4 = tensor_1.clone() * tensor_2.clone();
let tensor_5 = tensor_4.detach() * tensor_3.clone();
let grads = tensor_5.backward();
assert!(tensor_3.grad(&grads).is_some());
}
#[test]
fn test_mm_with_missing_require_grad_after_cleanup() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device);
let tensor_3 = Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device);
let tensor_4 = tensor_1.clone() * tensor_2.clone();
let tensor_5 = tensor_4 * tensor_3.clone();
// Trivial backward, just to trigger cleanup
Tensor::<TestAutodiffBackend, 2>::from_data(data, &device)
.require_grad()
.backward();
let grads = tensor_5.backward();
assert!(tensor_1.grad(&grads).is_some());
assert!(tensor_2.grad(&grads).is_none());
assert!(tensor_3.grad(&grads).is_none());
}
#[test]
fn test_mm_with_detach_after_cleanup() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_3 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_4 = tensor_1.clone() * tensor_2.clone();
let tensor_5 = tensor_4 * tensor_3.clone().detach();
// Trivial backward, just to trigger cleanup
Tensor::<TestAutodiffBackend, 2>::from_data(data, &device)
.require_grad()
.backward();
let grads = tensor_5.backward();
assert!(tensor_1.grad(&grads).is_some());
assert!(tensor_2.grad(&grads).is_some());
assert!(tensor_3.grad(&grads).is_none());
}
#[test]
#[should_panic]
fn test_mm_deletables_propagate_well() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
let tensor_0 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 = tensor_0 * tensor_1;
let tensor_3 = tensor_2.clone().exp();
let _tensor_4 = tensor_3.clone().log();
let _grads = tensor_2.backward();
// We are testing that after backward on tensor_2, not only the leaf tensor_4 is deleted, but
// the intermediate tensor_3 as well
let _grads = tensor_3.backward();
}
#[test]
fn test_mm_node_explored_once_can_still_be_tagged_as_useful_when_found_again_deeper() {
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// The test has 50% chance of starting with leaf tensor_8 instead of tensor_4, which is not informative
// By repeating it many times it becomes almost impossible that it passes if it shouldn't
for _ in 0..12 {
let tensor_0 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 = tensor_1.clone().exp();
let tensor_3 = tensor_0.exp();
let _tensor_4 = tensor_3.clone() * tensor_2.clone();
let tensor_5 = tensor_2.exp();
let tensor_6 = tensor_5.exp();
let tensor_7 = tensor_6.exp();
let tensor_8 = tensor_7.exp();
// tensor_2 should be tagged unknown through the leaf tensor_4, then useful through the leaf tensor_8
// which should happen after because tensor_2 is deeper from tensor_8 point of view and we're in breadth first search
tensor_3.backward();
let grads = tensor_8.backward();
assert!(tensor_1.grad(&grads).is_some());
}
}

View File

@@ -0,0 +1,74 @@
#[allow(unused_imports)] // required for re-included modules
pub use super::*;
mod abs;
mod adaptive_avgpool1d;
mod adaptive_avgpool2d;
mod add;
mod aggregation;
mod avgpool1d;
mod avgpool2d;
mod backward;
mod bridge;
mod broadcast;
mod cast;
mod cat;
mod ceil;
mod checkpoint;
mod complex;
mod conv1d;
mod conv2d;
mod conv3d;
mod conv_transpose1d;
mod conv_transpose2d;
mod conv_transpose3d;
mod cross;
mod cross_entropy;
mod cummax;
mod cummin;
mod cumprod;
mod cumsum;
mod deform_conv2d;
mod div;
mod erf;
mod exp;
mod expand;
mod flip;
mod floor;
mod gather_scatter;
mod gelu;
mod gradients;
mod log;
mod log1p;
mod log_sigmoid;
mod mask;
mod matmul;
mod maxmin;
mod maxpool1d;
mod maxpool2d;
mod memory_management;
mod mul;
mod multithread;
mod nearest_interpolate;
mod neg;
mod nonzero;
mod permute;
mod pow;
mod recip;
mod relu;
mod remainder;
mod repeat_dim;
mod reshape;
mod round;
mod select;
mod sigmoid;
mod sign;
mod slice;
mod slice_assign;
mod softmax;
mod sort;
mod sqrt;
mod sub;
mod transpose;
mod trig;
mod unfold;

View File

@@ -0,0 +1,68 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_mul() {
let data_1 = TensorData::from([1.0, 7.0]);
let data_2 = TensorData::from([4.0, 7.0]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();
let tensor_3 = tensor_1.clone().mul(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let _grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_eq(&data_2, false);
tensor_3
.into_data()
.assert_eq(&TensorData::from([4.0, 49.0]), false);
}
#[test]
fn should_diff_mul_scalar() {
let data = TensorData::from([2.0, 5.0]);
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
let tensor_out = tensor.clone().mul_scalar(4.0);
let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();
tensor_out
.into_data()
.assert_eq(&TensorData::from([8.0, 20.0]), false);
grad.to_data()
.assert_eq(&TensorData::from([4.0, 4.0]), false);
}
#[test]
fn test_mul_complex_1() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().mul(tensor_2.clone());
let tensor_5 = tensor_4.mul(tensor_3);
let tensor_6 = tensor_1.clone().mul(tensor_5);
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[16.0, 196.0], [104.0, -36.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[2.0, 98.0], [338.0, 18.0]]), false);
}

View File

@@ -0,0 +1,88 @@
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[test]
fn should_behave_the_same_with_multithread() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let with_move = || {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.clone().matmul(tensor_2.clone());
let tensor_5 = tensor_4.matmul(tensor_3);
// Task 1
let tensor_1_cloned = tensor_1.clone();
let tensor_2_cloned = tensor_2.clone();
let tensor_5_cloned = tensor_5.clone();
let first_call = move || {
let tensor_6_1 = tensor_5_cloned.matmul(tensor_2_cloned);
tensor_6_1.matmul(tensor_1_cloned)
};
// Task 2
let tensor_1_cloned = tensor_1.clone();
let tensor_2_cloned = tensor_2.clone();
let tensor_5_cloned = tensor_5;
let second_call = move || {
let tensor_6_2 = tensor_5_cloned.matmul(tensor_1_cloned);
tensor_6_2.matmul(tensor_2_cloned)
};
let tensor_7_1_handle = std::thread::spawn(first_call);
let tensor_7_2_handle = std::thread::spawn(second_call);
let tensor_7_1 = tensor_7_1_handle.join().unwrap();
let tensor_7_2 = tensor_7_2_handle.join().unwrap();
let tensor_8 = tensor_7_1.matmul(tensor_7_2);
let grads = tensor_8.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
(grad_1, grad_2)
};
let without_move = || {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.clone().matmul(tensor_2.clone());
let tensor_5 = tensor_4.matmul(tensor_3);
// Task 1
let tensor_6_1 = tensor_5.clone().matmul(tensor_2.clone());
let tensor_7_1 = tensor_6_1.matmul(tensor_1.clone());
// Task 2
let tensor_6_2 = tensor_5.matmul(tensor_1.clone());
let tensor_7_2 = tensor_6_2.matmul(tensor_2.clone());
let tensor_8 = tensor_7_1.matmul(tensor_7_2);
let grads = tensor_8.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
(grad_1, grad_2)
};
let (grad_1, grad_2) = without_move();
let (grad_1_moved, grad_2_moved) = with_move();
grad_1
.into_data()
.assert_approx_eq::<FloatElem>(&grad_1_moved.into_data(), Tolerance::default());
grad_2
.into_data()
.assert_approx_eq::<FloatElem>(&grad_2_moved.into_data(), Tolerance::default());
}

View File

@@ -0,0 +1,97 @@
use super::*;
use burn_tensor::Shape;
use burn_tensor::Tolerance;
use burn_tensor::module::interpolate;
use burn_tensor::ops::{InterpolateMode, InterpolateOptions};
#[test]
fn test_upsample_interpolation() {
let test = InterpolateTestCase {
batch_size: 2,
channels: 1,
height: 7,
width: 5,
height_out: 8,
width_out: 7,
};
test.assert_output(TestTensor::from([
[[
[4., 2., 4., 2., 2.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
]],
[[
[4., 2., 4., 2., 2.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
]],
]));
}
#[test]
fn test_downsample_interpolation() {
let test = InterpolateTestCase {
batch_size: 1,
channels: 1,
height: 8,
width: 8,
height_out: 4,
width_out: 6,
};
test.assert_output(TestTensor::from([[[
[1., 1., 1., 0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
]]]));
}
struct InterpolateTestCase {
batch_size: usize,
channels: usize,
height: usize,
width: usize,
height_out: usize,
width_out: usize,
}
impl InterpolateTestCase {
fn assert_output(self, x_grad: TestTensor<4>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &x_grad.device())
.reshape::<4, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = interpolate(
x.clone(),
[self.height_out, self.width_out],
InterpolateOptions::new(InterpolateMode::Nearest),
);
let grads = output.backward();
let x_grad_actual = x.grad(&grads).unwrap();
x_grad
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.into_data(), Tolerance::permissive());
}
}

View File

@@ -0,0 +1,26 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_neg() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().neg());
let tensor_4 = tensor_3.neg();
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[3.0, 3.0], [10.0, 10.0]]), false);
}

View File

@@ -0,0 +1,41 @@
use super::*;
use burn_tensor::{Bool, Tensor, TensorData};
#[test]
fn should_diff_nonzero() {
let data_1 = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let data_2 = TensorData::from([-1.0, 1.0]);
let mask = TensorData::from([[false, true], [true, false]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad();
// Multi-dimensional tensor indexing isn't really supported yet so the easiest way to do
// this is to flatten the mask and tensor to get proper indexing. Anyway the returned tensor would
// have dimensions different from the input, so this is somewhat equivalent.
let mask = Tensor::<TestAutodiffBackend, 2, Bool>::from_bool(mask, &device).flatten::<1>(0, 1);
let indices = mask.nonzero();
let tensor_3 = tensor_1
.clone()
.flatten::<1>(0, 1)
.select(0, indices[0].clone());
// Vector dot product not supported (only 2D matmuls) so unsqueeze for test purposes
let tensor_4 = tensor_2
.clone()
.unsqueeze_dim::<2>(0)
.matmul(tensor_3.unsqueeze_dim(1));
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[0.0, -1.0], [1.0, 0.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([2.0, 3.0]), false);
}

View File

@@ -0,0 +1,29 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_permute() {
let data_1 = TensorData::from([[[1.0, 7.0], [2.0, 3.0]]]); // 1x2x2
let data_2 = TensorData::from([[[1.0, 7.0], [3.2, 2.0], [3.0, 3.0]]]); // 1x3x2
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_2.clone().permute([0, 2, 1]);
let tensor_4 = tensor_1.clone().matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
grad_1
.into_data()
.assert_approx_eq::<FloatElem>(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), tolerance); // 1x2x2
grad_2.into_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[[3.0, 10.0], [3.0, 10.0], [3.0, 10.0]]]),
tolerance,
); // 1x3x2
}

View File

@@ -0,0 +1,93 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_powf_scalar() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().powf_scalar(0.4));
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(2e-3);
let expected = TensorData::from([[68.0, 79.0328], [68.0, 79.0328]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[23.5081, 25.2779], [26.0502, 28.6383]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}
#[test]
fn should_diff_powf() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data([2.0, 7.0], &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data([4.0, 2.0], &device).require_grad();
let tensor_3 = tensor_1.clone().powf(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([32.0, 14.0]);
grad_1
.into_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([11.09035, 95.34960]);
grad_2
.into_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([16.0, 49.0]);
tensor_3
.into_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_powf_with_untracked_lhs() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data([2.0, 7.0], &device);
let tensor_2 = TestAutodiffTensor::from_data([4.0, 2.0], &device).require_grad();
let tensor_3 = tensor_1.clone().powf(tensor_2.clone());
let grads = tensor_3.backward();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([11.09035, 95.34960]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_powf_with_untracked_rhs() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data([2.0, 7.0], &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data([4.0, 2.0], &device);
let tensor_3 = tensor_1.clone().powf(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let expected = TensorData::from([32.0, 14.0]);
grad_1
.into_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,22 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_recip() {
let data = TensorData::from([2.0, 5.0, 0.4]);
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
let tensor_out = tensor.clone().recip();
let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();
tensor_out
.into_data()
.assert_eq(&TensorData::from([0.5, 0.2, 2.5]), false);
grad.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([-0.25, -0.04, -6.25]),
Tolerance::default(),
);
}

View File

@@ -0,0 +1,27 @@
use super::*;
use burn_tensor::{TensorData, activation};
#[test]
fn should_diff_relu() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = activation::relu(tensor_3);
let tensor_5 = tensor_4.matmul(tensor_2.clone());
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[-47.0, 9.0], [-35.0, 15.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[15.0, 13.0], [-2.0, 39.0]]), false);
}

View File

@@ -0,0 +1,41 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_remainder() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data(
TensorData::from([
0.9742, 0.3676, 0.0905, 0.8066, 0.7072, 0.7883, 0.6987, 0.1560, 0.7179, 0.7874, 0.9032,
0.1845,
]),
&device,
)
.require_grad();
let tensor_2 = TestAutodiffTensor::<1>::from_data(
TensorData::from([
0.3357, 0.0285, 0.4115, 0.5511, 0.8637, 0.3593, 0.3885, 0.2569, 0.0936, 0.7172, 0.4792,
0.4898,
]),
&device,
)
.require_grad();
let tensor_3 = tensor_1.clone().remainder(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([
-2.0, -12.0, -0.0, -1.0, -0.0, -2.0, -1.0, -0.0, -7.0, -1.0, -1.0, -0.0,
]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,44 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_repeat() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0], [2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_2.clone().repeat_dim(1, 3);
let tensor_3 = tensor_1.matmul(tensor_3);
let grads = tensor_3.backward();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_2
.to_data()
.assert_eq(&TensorData::from([[-3.0], [12.0]]), false);
}
#[test]
fn should_diff_repeat_multi_dim() {
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 2.0], [2.0, 4.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_2.clone().repeat_dim(1, 3);
let tensor_3 = tensor_1.matmul(tensor_3);
let grads = tensor_3.backward();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_2
.to_data()
.assert_eq(&TensorData::from([[-3.0, -3.0], [12.0, 12.0]]), false);
}

View File

@@ -0,0 +1,26 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_reshape() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([4.0, 7.0, 2.0, 3.0]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_2.clone().reshape([2, 2]);
let tensor_4 = tensor_1.clone().matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([3.0, 3.0, 10.0, 10.0]), false);
}

View File

@@ -0,0 +1,20 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_round() {
let data = TensorData::from([
[-1.9751, 0.0714, 0.0643, 0.2406],
[-1.3172, 0.1252, -0.1119, -0.0127],
]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_2 = tensor_1.clone().round();
let grads = tensor_2.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
grad_1.to_data().assert_eq(
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
false,
);
}

View File

@@ -0,0 +1,89 @@
use super::*;
use burn_tensor::{IndexingUpdateOp, Int, Tensor, TensorData};
#[test]
fn test_select_grad() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(
TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),
&device,
)
.require_grad();
let indices =
Tensor::<TestAutodiffBackend, 1, Int>::from_data(TensorData::from([1, 0]), &device);
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1.clone().select(0, indices);
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
grad_1.into_data().assert_eq(
&TensorData::from([[109., 148., 187.], [37., 58., 79.]]),
false,
);
}
#[test]
fn test_select_add_grad() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(
TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),
&device,
)
.require_grad();
let values = TestAutodiffTensor::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
)
.require_grad();
let indices =
Tensor::<TestAutodiffBackend, 1, Int>::from_data(TensorData::from([1, 0]), &device);
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 =
tensor_1
.clone()
.select_assign(0, indices, values.clone(), IndexingUpdateOp::Add);
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = values.grad(&grads).unwrap();
grad_1.into_data().assert_eq(
&TensorData::from([[127., 199., 271.], [172., 244., 316.]]),
false,
);
grad_2
.into_data()
.assert_eq(&TensorData::from([[64., 64., 64.], [19., 19., 19.]]), false);
}
#[test]
fn test_select_add_grad_different_shapes() {
let device = Default::default();
let indices: Tensor<TestAutodiffBackend, 1, Int> = Tensor::from_ints([1], &device);
let x: Tensor<TestAutodiffBackend, 2> = Tensor::ones([1, 1], &device).require_grad();
let y = Tensor::ones([2, 1], &device).require_grad();
let w = y
.clone()
.select_assign(0, indices, x.clone(), IndexingUpdateOp::Add);
let w = w.matmul(y.clone().transpose());
let grads = w.backward();
let x_grad = x.grad(&grads).unwrap();
let y_grad = y.grad(&grads).unwrap();
x_grad
.into_data()
.assert_eq(&TensorData::from([[2.0]]), false);
y_grad
.into_data()
.assert_eq(&TensorData::from([[5.0], [5.0]]), false);
}

View File

@@ -0,0 +1,35 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{TensorData, activation};
#[test]
fn should_diff_sigmoid() {
let data = TensorData::from([0.8762]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data(data, &device).require_grad();
let tensor_2 = activation::sigmoid(tensor_1.clone());
let grads = tensor_2.backward();
let grad = tensor_1.grad(&grads).unwrap();
let expected = TensorData::from([0.207549]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn small_neg_val_should_not_cause_grad_overflow() {
let data = TensorData::from([-90.0]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data(data, &device).require_grad();
let tensor_2 = activation::sigmoid(tensor_1.clone());
let grads = tensor_2.backward();
let grad = tensor_1.grad(&grads).unwrap();
let expected = TensorData::from([0.0]);
grad.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,42 @@
use super::*;
use burn_tensor::TensorData;
/// Example using the sign function with PyTorch:
// >>> import torch
// >>> # Create a tensor with requires_grad=True
// >>> x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
// >>> # Forward pass: Apply the sign function
// >>> y = torch.sign(x)
// >>> print("Forward pass:")
// Forward pass:
// >>> print("x:", x)
// x: tensor([-2., -1., 0., 1., 2.], requires_grad=True)
// >>> print("y:", y)
// y: tensor([-1., -1., 0., 1., 1.], grad_fn=<SignBackward0>)
// >>> # Compute the loss (just an example)
// >>> loss = y.sum()
// >>> # Backward pass: Compute the gradients
// >>> loss.backward()
// >>> print("\nBackward pass:")
// Backward pass:
// >>> print("x.grad:", x.grad)
// x.grad: tensor([0., 0., 0., 0., 0.])
#[test]
fn should_diff_sign() {
let data = TensorData::from([-2.0, -1.0, 0.0, 1.0, 2.0]);
let device = Default::default();
let x = TestAutodiffTensor::<1>::from_data(data, &device).require_grad();
let y = x.clone().sign();
let loss = y.clone().sum();
let grads = loss.backward();
let grad = x.grad(&grads).unwrap();
y.to_data()
.assert_eq(&TensorData::from([-1., -1., 0., 1., 1.]), false);
grad.to_data()
.assert_eq(&TensorData::from([0., 0., 0., 0., 0.]), false);
}

View File

@@ -0,0 +1,67 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_matmul_with_slice() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([[4.0, 7.0, 100.0], [2.0, 3.0, 15.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_2.clone().slice([0..2, 0..2]);
let tensor_4 = tensor_1.clone().matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);
grad_2.to_data().assert_eq(
&TensorData::from([[3.0, 3.0, 0.0], [10.0, 10.0, 0.0]]),
false,
);
}
#[test]
fn should_diff_matmul_with_slice_stepped() {
use burn_tensor::s;
let data_1 = TensorData::from([[1.0, 7.0], [100.0, 100.0], [2.0, 3.0], [100.0, 100.0]]);
let data_2 = TensorData::from([[4.0, 100.0, 7.0, 100.0], [2.0, 100.0, 3.0, 15.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().slice(s![0..;2, 0..2]); // [[1., 7.], [2., 3.]]
let tensor_4 = tensor_2.clone().slice(s![0..2, 0..;2]); // [[4., 7.], [2., 3.]]
let tensor_5 = tensor_3.clone().matmul(tensor_4);
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_eq(
&TensorData::from([[11., 5.], [0., 0.], [11., 5.], [0., 0.]]),
false,
);
grad_2.to_data().assert_eq(
&TensorData::from([[3., 0., 3., 0.], [10., 0., 10., 0.]]),
false,
);
}
#[test]
fn should_panic_on_slice_with_step() {
use burn_tensor::s;
let data = TensorData::from([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]);
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
// This should panic because step is 2
let _sliced = tensor.slice(s![.., 0..4; 2]);
}

View File

@@ -0,0 +1,163 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_matmul_with_slice_assign() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_assigned = TensorData::from([[9.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_assigned = TestAutodiffTensor::from_data(data_assigned, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.slice_assign([0..1, 0..1], tensor_assigned);
let tensor_5 = tensor_4.matmul(tensor_1.clone());
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[58.0, 38.0], [118.0, 82.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[16.0, 15.0], [24.0, 50.0]]), false);
}
#[test]
fn should_diff_matmul_with_slice_assign_complex() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[9.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_5 = tensor_2.clone().slice([0..1, 0..1]);
let tensor_6 = tensor_5.mul(tensor_3.clone());
let tensor_7 = tensor_4.slice_assign([0..1, 0..1], tensor_6);
let tensor_8 = tensor_7.matmul(tensor_1.clone());
let grads = tensor_8.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let grad_3 = tensor_3.grad(&grads).unwrap();
grad_3
.to_data()
.assert_eq(&TensorData::from([[32.0]]), false);
grad_1
.to_data()
.assert_eq(&TensorData::from([[85.0, 65.0], [118.0, 82.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[88.0, 15.0], [24.0, 50.0]]), false);
}
#[test]
fn slice_assign_diff_should_give_same_results_as_cat() {
let data_1 = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[5.0, 6.0], [7.0, 8.0]]);
let data_3 = TensorData::from([[14.0, 97.0, 100.0, 9.0], [2.0, 3.0, 15.0, 7.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device);
let slice_assign_output = TestAutodiffTensor::zeros([2, 4], &Default::default());
let slice_assign_output = slice_assign_output.slice_assign([0..2, 0..2], tensor_1.clone());
let slice_assign_output = slice_assign_output.slice_assign([0..2, 2..4], tensor_2.clone());
let slice_assign_output = slice_assign_output / tensor_3.clone();
let cat_output = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 1);
let cat_output = cat_output / tensor_3;
slice_assign_output
.to_data()
.assert_approx_eq::<FloatElem>(&cat_output.to_data(), Tolerance::default());
let slice_assign_grads = slice_assign_output.backward();
let cat_grads = cat_output.backward();
let slice_assign_grad_1 = tensor_1.grad(&slice_assign_grads).unwrap();
let slice_assign_grad_2 = tensor_2.grad(&slice_assign_grads).unwrap();
let cat_grad_1 = tensor_1.grad(&cat_grads).unwrap();
let cat_grad_2 = tensor_2.grad(&cat_grads).unwrap();
slice_assign_grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&cat_grad_1.to_data(), Tolerance::default());
slice_assign_grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&cat_grad_2.to_data(), Tolerance::default());
}
#[test]
fn should_diff_slice_assign_with_step() {
use burn_tensor::s;
let data = TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
let value_data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
let value = TestAutodiffTensor::<2>::from_data(value_data, &device).require_grad();
// Assign with step=2
let result = tensor.clone().slice_assign(s![.., 0..4; 2], value.clone());
let result = result * 2.0; // Scale to create gradients
let grads = result.backward();
let grad_tensor = tensor.grad(&grads).unwrap();
let grad_value = value.grad(&grads).unwrap();
// The gradient for tensor should be 2.0 everywhere except the assigned positions
grad_tensor.to_data().assert_eq(
&TensorData::from([[0.0, 2.0, 0.0, 2.0], [0.0, 2.0, 0.0, 2.0]]),
false,
);
// The gradient for value should be 2.0 at all positions
grad_value
.to_data()
.assert_eq(&TensorData::from([[2.0, 2.0], [2.0, 2.0]]), false);
}
#[test]
fn should_diff_slice_assign_with_negative_step() {
use burn_tensor::s;
let data = TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
let value_data = TensorData::from([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]);
let device = Default::default();
let tensor = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
let value = TestAutodiffTensor::<2>::from_data(value_data, &device).require_grad();
// Assign with step=-1 (reverse order, all elements)
let result = tensor.clone().slice_assign(s![.., ..;-1], value.clone());
let result = result * 2.0; // Scale to create gradients
let grads = result.backward();
let grad_tensor = tensor.grad(&grads).unwrap();
let grad_value = value.grad(&grads).unwrap();
// The gradient for tensor should be 0 since all values were replaced
grad_tensor.to_data().assert_eq(
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
false,
);
// The gradient for value should be 2.0 at all positions
grad_value.to_data().assert_eq(
&TensorData::from([[2.0, 2.0, 2.0, 2.0], [2.0, 2.0, 2.0, 2.0]]),
false,
);
}

View File

@@ -0,0 +1,90 @@
use super::*;
use burn_tensor::Tolerance;
use burn_tensor::{Tensor, TensorData, activation};
#[test]
fn test_softmax_grad() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[1.179665, 1.179661], [0.005462, 0.005463]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(0.05, 0.5));
let expected = TensorData::from([[0.253469, 0.286237], [0.528630, 2.931664]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(0.05, 0.05));
}
#[test]
fn test_log_softmax_grad() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = activation::log_softmax(tensor_3, 1).matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[-4.3939, -4.3939], [-12.9709, -12.9709]]);
// f16 gradients from log-softmax + matmul amplify error, so we increase the tolerance
// to account for limited precision and large representable step sizes in this range.
let tolerance = Tolerance::permissive().set_half_precision_relative(6e-2);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[30.5984, -47.2267], [55.9631, -56.5914]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}
#[test]
fn test_quiet_softmax_grad() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[1.179665, 1.179661], [0.005462, 0.005463]]);
// Precision is quite bad yet on softmax grad especially with half precision.
let tolerance = Tolerance::rel_abs(0.5, 0.2);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[0.253469, 0.286237], [0.528630, 2.931664]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,82 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_sort() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.sort(1));
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[35.0, 35.0], [-1.0, -8.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[11.0, 7.0], [55.0, 16.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_sort_with_indices() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let (values, _indices) = tensor_3.sort_with_indices(1);
let tensor_4 = tensor_1.clone().mul(values);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[35.0, 35.0], [-1.0, -8.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[11.0, 7.0], [55.0, 16.0]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
#[test]
fn should_diff_sort_3d_dim1() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::<3>::from_floats([[[1.0, 7.0], [-2.0, -3.0]]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_floats([[[4.0, -7.0], [2.0, 3.0]]], &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.sort(1));
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let expected = TensorData::from([[[-1., -8.], [-27., 37.]]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
let expected = TensorData::from([[[-4., -17.], [-17., -42.]]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

View File

@@ -0,0 +1,31 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_sqrt() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sqrt());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
let expected = TensorData::from([[82.112640, 99.083275], [82.112640, 99.083275]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[30.309311, 33.120457], [34.581974, 38.769463]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}

View File

@@ -0,0 +1,73 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_sub() {
let data_1 = TensorData::from([2.0, 5.0]);
let data_2 = TensorData::from([4.0, 1.0]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().sub(tensor_2.clone());
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([1.0, 1.0]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([-1.0, -1.0]), false);
tensor_3
.into_data()
.assert_eq(&TensorData::from([-2.0, 4.0]), false);
}
#[test]
fn should_diff_sub_scalar() {
let data = TensorData::from([2.0, 10.0]);
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
let tensor_out = tensor.clone().sub_scalar(5.0);
let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();
grad.to_data()
.assert_eq(&TensorData::from([1.0, 1.0]), false);
tensor_out
.into_data()
.assert_eq(&TensorData::from([-3.0, 5.0]), false);
}
#[test]
fn test_sub_complex_1() {
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1.clone().sub(tensor_2.clone());
let tensor_5 = tensor_4.sub(tensor_3).sub_scalar(5.0);
let tensor_6 = tensor_1.clone().sub(tensor_5);
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_eq(&TensorData::from([[0.0, 0.0], [0.0, 0.0]]), false);
grad_2
.to_data()
.assert_eq(&TensorData::from([[1.0, 1.0], [1.0, 1.0]]), false);
}

View File

@@ -0,0 +1,60 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_transpose() {
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().transpose());
let tensor_4 = tensor_3.transpose();
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[6.0, 10.0], [6.0, 10.0]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[3.0, 10.0], [3.0, 10.0]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_swap_dims() {
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<3>::from_floats(
[[[0.0, 1.0], [3.0, 4.0]], [[6.0, 7.0], [9.0, 10.0]]],
&device,
)
.require_grad();
let tensor_2 = TestAutodiffTensor::from_floats(
[[[1.0, 4.0], [2.0, 5.0]], [[7.0, 10.0], [8.0, 11.0]]],
&device,
)
.require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().swap_dims(0, 2));
let tensor_4 = tensor_3.matmul(tensor_2.clone().swap_dims(1, 2));
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[[66., 78.], [66., 78.]], [[270., 306.], [270., 306.]]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[[22., 286.], [28., 316.]], [[172., 652.], [190., 694.]]]),
Tolerance::default(),
);
}

View File

@@ -0,0 +1,371 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::Tolerance;
#[test]
fn should_diff_cos() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().cos());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
// Metal has less precise trigonometric functions
let tolerance = Tolerance::default().set_half_precision_relative(1e-2);
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[26.8063, -27.7870], [26.8063, -27.7870]]),
tolerance,
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[9.222064, -39.123375], [-28.721354, 49.748356]]),
tolerance,
);
}
#[test]
fn should_diff_sin() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sin());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
// Metal has less precise trigonometric functions
let tolerance = Tolerance::default().set_half_precision_relative(1e-2);
let expected = TensorData::from([[8.8500, -4.9790], [8.8500, -4.9790]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[38.668987, 44.194775], [-59.97261, -80.46094]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}
#[test]
fn should_diff_tanh() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().tanh());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let tolerance = Tolerance::default().set_half_precision_relative(8e-3);
let expected = TensorData::from([[32.0, 32.0], [32.0, 32.0]]);
grad_1
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
let expected = TensorData::from([[8.00092, 8.000153], [8.000003, 7.999995]]);
grad_2
.to_data()
.assert_approx_eq::<FloatElem>(&expected, tolerance);
}
#[test]
fn should_diff_cosh() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().cosh());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[7.092221, 16.696301], [7.092221, 16.696301]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[17.489855, 27.484539], [39.409813, 86.910278]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_sinh() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sinh());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[4.894847, 15.887931], [4.894847, 15.887931]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[17.284000, 28.412029], [39.302979, 87.498329]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_tan() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[0.5, 1.0], [0.3, 0.8]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().tan());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[2.532602, 1.596607], [2.532602, 1.596607]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[9.028598, 14.489801], [18.038082, 21.151270]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_asin() {
let data_1 = TensorData::from([[0.0, 0.1], [0.3, 0.4]]);
let data_2 = TensorData::from([[0.2, 0.3], [0.5, 0.6]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().asin());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[0.435841, 0.969651], [0.435841, 0.969651]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[0.475300, 0.668141], [0.701834, 1.100658]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_acos() {
let data_1 = TensorData::from([[0.0, 0.1], [0.3, 0.4]]);
let data_2 = TensorData::from([[0.2, 0.3], [0.5, 0.6]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().acos());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[2.077433, 1.543624], [2.077433, 1.543624]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[0.781337, 0.588496], [0.554804, 0.155979]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_atan() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().atan());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[3.444365, 5.349211], [3.444365, 5.349211]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[9.904911, 11.554912], [10.199631, 11.391938]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_asinh() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().asinh());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[3.806625, 6.844869], [3.806625, 6.844869]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[11.442373, 14.842072], [14.022551, 17.688538]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_acosh() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[1.5, 2.0], [2.5, 3.0]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().acosh());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[10.611752, 15.178907], [10.611752, 15.178907]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[20.112753, 20.247547], [20.402235, 22.487328]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_atanh() {
let data_1 = TensorData::from([[0.0, 0.1], [0.3, 0.4]]);
let data_2 = TensorData::from([[0.2, 0.3], [0.5, 0.6]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().atanh());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[0.441838, 1.037115], [0.441838, 1.037115]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[0.491723, 0.698110], [0.772763, 1.298805]]),
Tolerance::default(),
);
}
#[test]
fn should_diff_atan2() {
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);
let data_3 = TensorData::from([[1.0, 0.5], [2.0, 1.5]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
let tensor_4 = tensor_1
.clone()
.matmul(tensor_2.clone().atan2(tensor_3.clone()));
let tensor_5 = tensor_4.matmul(tensor_2.clone());
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
let grad_3 = tensor_3.grad(&grads).unwrap();
grad_1.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[4.570492, 4.210785], [4.570492, 4.210785]]),
Tolerance::default(),
);
grad_2.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[8.208448, 8.808449], [10.357923, 12.157923]]),
Tolerance::default(),
);
grad_3.to_data().assert_approx_eq::<FloatElem>(
&TensorData::from([[-1.8, -8.4], [-1.8, -5.6]]),
Tolerance::default(),
);
}

View File

@@ -0,0 +1,18 @@
use super::*;
use burn_tensor::TensorData;
#[test]
fn unfold_backward_accumulates_overlaps() {
let device = Default::default();
let x = TestAutodiffTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0]], &device).require_grad();
let y = x.clone().unfold::<3, _>(1, 2, 1);
let loss = y.sum();
let grads = loss.backward();
let grad_x = x.grad(&grads).unwrap();
grad_x
.to_data()
.assert_eq(&TensorData::from([[1., 2., 2., 1.]]), false);
}