feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
@@ -0,0 +1,58 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance, cast::ToElement};
|
||||
|
||||
#[test]
|
||||
fn should_diff_abs() {
|
||||
let data_1 = TensorData::from([[0.0, -1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, -10.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[71.0, 107.0], [71.0, 107.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[84.0, 42.0], [90.0, 54.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_abs_no_nans() {
|
||||
let data_1 = TensorData::from([[6.0, 7.0], [9.0, -10.0]]);
|
||||
let data_2 = TensorData::from([[0.0, -1.0], [3.0, 4.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs());
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[1.0, 7.0], [1.0, 7.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[0.0, -15.0], [-3.0, -3.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let contains_nan = grad_2.contains_nan();
|
||||
assert!(!contains_nan.into_scalar().to_bool());
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
use super::*;
|
||||
use burn_tensor::module::adaptive_avg_pool1d;
|
||||
use burn_tensor::{Shape, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool1d_simple() {
|
||||
let test = AdaptiveAvgPool1dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
length: 5,
|
||||
output_size: 3,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[
|
||||
[0.5000, 0.83333, 0.33333, 0.83333, 0.5000],
|
||||
[0.5000, 0.83333, 0.33333, 0.83333, 0.5000],
|
||||
]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
struct AdaptiveAvgPool1dTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
length: usize,
|
||||
output_size: usize,
|
||||
}
|
||||
|
||||
impl AdaptiveAvgPool1dTestCase {
|
||||
fn assert_output(self, x_grad: TestTensor<3>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<3, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let output = adaptive_avg_pool1d(x.clone(), self.output_size);
|
||||
let grads = output.backward();
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
|
||||
x_grad.to_data().assert_approx_eq::<FloatElem>(
|
||||
&x_grad_actual.into_data(),
|
||||
Tolerance::default().set_half_precision_relative(1e-3),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
use super::*;
|
||||
use burn_tensor::module::adaptive_avg_pool2d;
|
||||
use burn_tensor::{Shape, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool2d_simple() {
|
||||
let test = AdaptiveAvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
height: 5,
|
||||
width: 3,
|
||||
output_size_1: 3,
|
||||
output_size_2: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[0.2500, 0.5000, 0.2500],
|
||||
[0.41667, 0.83333, 0.41667],
|
||||
[0.16667, 0.33333, 0.16667],
|
||||
[0.41667, 0.83333, 0.41667],
|
||||
[0.2500, 0.5000, 0.2500],
|
||||
],
|
||||
[
|
||||
[0.2500, 0.5000, 0.2500],
|
||||
[0.41667, 0.83333, 0.41667],
|
||||
[0.16667, 0.33333, 0.16667],
|
||||
[0.41667, 0.83333, 0.41667],
|
||||
[0.2500, 0.5000, 0.2500],
|
||||
],
|
||||
]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool2d_output_1() {
|
||||
let test = AdaptiveAvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
height: 4,
|
||||
width: 8,
|
||||
output_size_1: 1,
|
||||
output_size_2: 1,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[[
|
||||
[
|
||||
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
|
||||
],
|
||||
[
|
||||
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
|
||||
],
|
||||
[
|
||||
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
|
||||
],
|
||||
[
|
||||
0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,
|
||||
],
|
||||
]]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
struct AdaptiveAvgPool2dTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
output_size_1: usize,
|
||||
output_size_2: usize,
|
||||
}
|
||||
|
||||
impl AdaptiveAvgPool2dTestCase {
|
||||
fn assert_output(self, x_grad: TestTensor<4>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let output = adaptive_avg_pool2d(x.clone(), [self.output_size_1, self.output_size_2]);
|
||||
let grads = output.backward();
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
|
||||
x_grad.to_data().assert_approx_eq::<FloatElem>(
|
||||
&x_grad_actual.into_data(),
|
||||
Tolerance::default().set_half_precision_relative(1e-3),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_add() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<1>::from_floats([2.0, 5.0], &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_floats([4.0, 1.0], &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone() + tensor_2.clone();
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([1.0, 1.0]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([1.0, 1.0]), false);
|
||||
tensor_3
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([6.0, 6.0]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_add_scalar() {
|
||||
let data = TensorData::from([2.0, 10.0]);
|
||||
|
||||
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
|
||||
let tensor_out = tensor.clone().add_scalar(5.0);
|
||||
let grads = tensor_out.backward();
|
||||
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
grad.to_data()
|
||||
.assert_eq(&TensorData::from([1.0, 1.0]), false);
|
||||
tensor_out
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([7.0, 15.0]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_complex_1() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_1.clone().add(tensor_2.clone());
|
||||
let tensor_5 = tensor_4
|
||||
.add(tensor_3)
|
||||
.add_scalar(5.0)
|
||||
.add(tensor_1.clone())
|
||||
.add(tensor_2.clone());
|
||||
let tensor_6 = tensor_1.clone().add(tensor_5);
|
||||
|
||||
let grads = tensor_6.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[3.0, 3.0], [3.0, 3.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[2.0, 2.0], [2.0, 2.0]]), false);
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_diff_mean() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_1.clone().mul(tensor_3.mean().unsqueeze());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[3.5, 9.5], [3.5, 9.5]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[-0.75, -0.75], [3.0, 3.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_sum_1() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_1.clone().mul(tensor_3.sum().unsqueeze());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[14.0, 38.0], [14.0, 38.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[-3.0, -3.0], [12.0, 12.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_sum_2() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.clone().sum_dim(1);
|
||||
let tensor_5 = tensor_4.mul(tensor_3);
|
||||
|
||||
let grads = tensor_5.sum().backward();
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[494.0, 722.0], [2990.0, 4370.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[690.0, 690.0], [958.0, 958.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_mean_dim() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_1.clone().mul(tensor_3.mean_dim(1).unsqueeze());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[4.0, 36.0], [3.0, -17.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[9.0, 9.0], [35.5, 35.5]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_sum_dim() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_1.clone().mul(tensor_3.sum_dim(1).unsqueeze());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[8.0, 72.0], [6.0, -34.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[18.0, 18.0], [71.0, 71.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
use super::*;
|
||||
use burn_tensor::module::avg_pool1d;
|
||||
use burn_tensor::{Shape, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool1d_simple() {
|
||||
let test = AvgPool1dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
kernel_size: 3,
|
||||
padding: 0,
|
||||
stride: 1,
|
||||
length: 6,
|
||||
count_include_pad: true,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[[0.33333, 0.66667, 1.0000, 1.0000, 0.66667, 0.33333]]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool1d_complex() {
|
||||
let test = AvgPool1dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
length: 6,
|
||||
count_include_pad: true,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[
|
||||
[0.33333, 0.66667, 0.33333, 0.66667, 0.33333, 0.33333],
|
||||
[0.33333, 0.66667, 0.33333, 0.66667, 0.33333, 0.33333],
|
||||
]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool1d_complex_dont_count_pad() {
|
||||
let test = AvgPool1dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
length: 6,
|
||||
count_include_pad: false,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[
|
||||
[0.5000, 0.83333, 0.33333, 0.66667, 0.33333, 0.33333],
|
||||
[0.5000, 0.83333, 0.33333, 0.66667, 0.33333, 0.33333],
|
||||
]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
struct AvgPool1dTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
kernel_size: usize,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
length: usize,
|
||||
count_include_pad: bool,
|
||||
}
|
||||
|
||||
impl AvgPool1dTestCase {
|
||||
fn assert_output(self, x_grad: TestTensor<3>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<3, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let output = avg_pool1d(
|
||||
x.clone(),
|
||||
self.kernel_size,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.count_include_pad,
|
||||
false,
|
||||
);
|
||||
let grads = output.backward();
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
|
||||
x_grad
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.into_data(), tolerance);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
use super::*;
|
||||
use burn_tensor::module::avg_pool2d;
|
||||
use burn_tensor::{Shape, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool2d_simple() {
|
||||
let test = AvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
height: 6,
|
||||
width: 6,
|
||||
count_include_pad: true,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[[
|
||||
[0.11111, 0.22222, 0.33333, 0.33333, 0.22222, 0.11111],
|
||||
[0.22222, 0.44444, 0.66667, 0.66667, 0.44444, 0.22222],
|
||||
[0.33333, 0.66667, 1.00000, 1.00000, 0.66667, 0.33333],
|
||||
[0.33333, 0.66667, 1.00000, 1.00000, 0.66667, 0.33333],
|
||||
[0.22222, 0.44444, 0.66667, 0.66667, 0.44444, 0.22222],
|
||||
[0.11111, 0.22222, 0.33333, 0.33333, 0.22222, 0.11111],
|
||||
]]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool2d_complex() {
|
||||
let test = AvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 4,
|
||||
padding_1: 1,
|
||||
padding_2: 2,
|
||||
stride_1: 1,
|
||||
stride_2: 2,
|
||||
height: 4,
|
||||
width: 6,
|
||||
count_include_pad: true,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[[
|
||||
[0.33333, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333],
|
||||
[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
|
||||
[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
|
||||
[0.33333, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333],
|
||||
]]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool2d_complex_dont_include_pad() {
|
||||
let test = AvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 4,
|
||||
padding_1: 1,
|
||||
padding_2: 2,
|
||||
stride_1: 1,
|
||||
stride_2: 2,
|
||||
height: 4,
|
||||
width: 6,
|
||||
count_include_pad: false,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[[
|
||||
[0.6250, 0.6250, 0.41667, 0.41667, 0.6250, 0.6250],
|
||||
[0.8750, 0.8750, 0.58333, 0.58333, 0.8750, 0.8750],
|
||||
[0.8750, 0.8750, 0.58333, 0.58333, 0.8750, 0.8750],
|
||||
[0.6250, 0.6250, 0.41667, 0.41667, 0.6250, 0.6250],
|
||||
]]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
struct AvgPool2dTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
kernel_size_1: usize,
|
||||
kernel_size_2: usize,
|
||||
padding_1: usize,
|
||||
padding_2: usize,
|
||||
stride_1: usize,
|
||||
stride_2: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
count_include_pad: bool,
|
||||
}
|
||||
|
||||
impl AvgPool2dTestCase {
|
||||
fn assert_output(self, x_grad: TestTensor<4>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let output = avg_pool2d(
|
||||
x.clone(),
|
||||
[self.kernel_size_1, self.kernel_size_2],
|
||||
[self.stride_1, self.stride_2],
|
||||
[self.padding_1, self.padding_2],
|
||||
self.count_include_pad,
|
||||
false,
|
||||
);
|
||||
let grads = output.backward();
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
|
||||
x_grad.to_data().assert_approx_eq::<FloatElem>(
|
||||
&x_grad_actual.into_data(),
|
||||
Tolerance::default().set_half_precision_relative(1e-3),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Int, Tensor, TensorData, module::embedding};
|
||||
|
||||
#[test]
|
||||
fn test_embedding_backward() {
|
||||
let weights = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let indices = TensorData::from([[0, 1], [1, 1]]);
|
||||
let x = TensorData::from([
|
||||
[[1.0, 2.0], [4.0, 5.0], [3.0, 4.0]],
|
||||
[[4.0, 5.0], [8.0, 5.0], [1.0, 9.0]],
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weights = Tensor::<TestAutodiffBackend, 2>::from_data(weights, &device).require_grad();
|
||||
let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(indices, &device);
|
||||
let x = Tensor::<TestAutodiffBackend, 3>::from_data(x, &device).require_grad();
|
||||
|
||||
let output = embedding(weights.clone(), indices);
|
||||
let output = output.matmul(x);
|
||||
let grads = output.backward();
|
||||
|
||||
let grad = weights.grad(&grads).unwrap();
|
||||
grad.to_data()
|
||||
.assert_eq(&TensorData::from([[3., 9., 7.], [21., 35., 27.]]), false);
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
use super::*;
|
||||
use burn_tensor::{DType, Distribution, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_full_precision() {
|
||||
let device = Default::default();
|
||||
let x1 = Tensor::<TestAutodiffBackend, 2>::random([32, 32], Distribution::Default, &device)
|
||||
.require_grad();
|
||||
let x2 = Tensor::<TestAutodiffBackend, 2>::random([32, 32], Distribution::Default, &device)
|
||||
.require_grad();
|
||||
let dtype = x1.dtype();
|
||||
|
||||
let x3 = x1.clone().cast(DType::F32);
|
||||
let x4 = x2.clone().cast(DType::F32);
|
||||
|
||||
let x5 = x3.matmul(x4);
|
||||
let x6 = x5.cast(dtype);
|
||||
let x7 = x6 * x1.clone() / x2.clone();
|
||||
|
||||
let grads = x7.backward();
|
||||
|
||||
let x1_grad = x1.grad(&grads);
|
||||
let x2_grad = x2.grad(&grads);
|
||||
|
||||
assert!(x1_grad.is_some());
|
||||
assert!(x2_grad.is_some());
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn mul_broadcast() {
|
||||
test_ops_broadcast_backward(|x, y| x * y);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn div_broadcast() {
|
||||
test_ops_broadcast_backward(|x, y| x / y);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sub_broadcast() {
|
||||
test_ops_broadcast_backward(|x, y| x - y);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_broadcast() {
|
||||
test_ops_broadcast_backward(|x, y| x + y);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matmul_broadcast() {
|
||||
test_ops_broadcast_backward(|x, y| x.matmul(y));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_where_broadcast() {
|
||||
test_ops_broadcast_backward(|x, y| {
|
||||
let cond = y.clone().equal_elem(4);
|
||||
x.mask_where(cond, y)
|
||||
});
|
||||
}
|
||||
|
||||
fn test_ops_broadcast_backward<F>(func: F)
|
||||
where
|
||||
F: Fn(TestAutodiffTensor<3>, TestAutodiffTensor<3>) -> TestAutodiffTensor<3>,
|
||||
{
|
||||
let device = Default::default();
|
||||
let w = TestAutodiffTensor::zeros([16, 5, 5], &device).require_grad();
|
||||
let x = TestAutodiffTensor::zeros([4, 5, 5], &device).require_grad();
|
||||
|
||||
// Slice isn't a broadcastable operation, so it will fail when the previous backward pass
|
||||
// of an operation that support broadcast doesn't support it during the backward pass.
|
||||
let y = func(w.clone().slice([0..1]), x.clone());
|
||||
|
||||
// Will panic if broadcast isn't supported!
|
||||
let grads = y.backward();
|
||||
|
||||
let w_grad = w.grad(&grads).unwrap();
|
||||
let x_grad = x.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(w_grad.shape(), w.shape());
|
||||
assert_eq!(x_grad.shape(), x.shape());
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
// Skip on metal - F64 not supported
|
||||
#![cfg(all(feature = "std", not(feature = "metal")))]
|
||||
|
||||
use super::*;
|
||||
use burn_backend_tests::might_panic;
|
||||
use burn_tensor::{DType, Tensor, TensorData};
|
||||
|
||||
#[might_panic(reason = "Unsupported precision for fusion")]
|
||||
#[test]
|
||||
fn cast_keeps_gradient_flow() {
|
||||
let device = Default::default();
|
||||
|
||||
let x = Tensor::<TestAutodiffBackend, 2>::from_data(
|
||||
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let y = x.clone().cast(DType::F64);
|
||||
let z = y.sum();
|
||||
|
||||
let grads = z.backward();
|
||||
let grad_x = x.grad(&grads).unwrap();
|
||||
|
||||
grad_x
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[1., 1.], [1., 1.]]), false);
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
use super::*;
|
||||
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_diff_cat() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::<2>::from_data([[2.0, -1.0], [5.0, 2.0]], &device).require_grad();
|
||||
let tensor_2 =
|
||||
TestAutodiffTensor::<2>::from_data([[5.0, 4.0], [-1.0, 4.0]], &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let mut tensor_1_list = Vec::new();
|
||||
let mut tensor_2_list = Vec::new();
|
||||
|
||||
for i in 0..2 {
|
||||
tensor_1_list.push(tensor_1.clone().slice([i..i + 1]));
|
||||
tensor_2_list.push(tensor_2.clone().slice([i..i + 1]));
|
||||
}
|
||||
|
||||
let tensor_1_cat = TestAutodiffTensor::cat(tensor_1_list.clone(), 0);
|
||||
let tensor_2_cat = TestAutodiffTensor::cat(tensor_2_list.clone(), 0);
|
||||
|
||||
let tensor_3_cat = tensor_1_cat.clone().matmul(tensor_2_cat.clone());
|
||||
let grads = tensor_3_cat.backward();
|
||||
|
||||
let grad_1_slice_1 = tensor_1.grad(&grads).unwrap().slice([0..1]);
|
||||
let grad_1_slice_2 = tensor_1.grad(&grads).unwrap().slice([1..2]);
|
||||
|
||||
let grad_2_slice_1 = tensor_2.grad(&grads).unwrap().slice([0..1]);
|
||||
let grad_2_slice_2 = tensor_2.grad(&grads).unwrap().slice([1..2]);
|
||||
|
||||
grad_1
|
||||
.clone()
|
||||
.slice([0..1])
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&grad_1_slice_1.to_data(), Tolerance::default());
|
||||
grad_1
|
||||
.slice([1..2])
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&grad_1_slice_2.to_data(), Tolerance::default());
|
||||
|
||||
grad_2
|
||||
.clone()
|
||||
.slice([0..1])
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&grad_2_slice_1.to_data(), Tolerance::default());
|
||||
grad_2
|
||||
.slice([1..2])
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&grad_2_slice_2.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cat_more_than_1_dim() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::<2>::from_data([[2.0, -1.0], [5.0, 2.0]], &device).require_grad();
|
||||
let tensor_2 =
|
||||
TestAutodiffTensor::<2>::from_data([[5.0, 4.0], [-1.0, 4.0], [4.0, 1.0]], &device)
|
||||
.require_grad();
|
||||
|
||||
// Concat a tensor [2, 2] with another tensor [3, 2] along dim 0.
|
||||
// The resulting tensor should be [5, 2]
|
||||
let tensor_3 = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 0);
|
||||
assert_eq!(tensor_3.dims(), [5, 2]);
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(tensor_1.dims(), grad_1.dims());
|
||||
assert_eq!(tensor_2.dims(), grad_2.dims());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_slice_grads_correctly_when_some_inputs_not_tracked() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data([[1.0]], &device).require_grad(); // tracked
|
||||
let tensor_2 = TestAutodiffTensor::<2>::from_data([[10.0, 20.0]], &device); // not tracked
|
||||
let tensor_3 =
|
||||
TestAutodiffTensor::<2>::from_data([[100.0, 200.0, 300.0]], &device).require_grad(); // tracked
|
||||
|
||||
let cat = TestAutodiffTensor::cat(
|
||||
vec![tensor_1.clone(), tensor_2.clone(), tensor_3.clone()],
|
||||
1,
|
||||
);
|
||||
|
||||
// Make gradient per column unique so wrong slicing shows up.
|
||||
let weights = TestAutodiffTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]], &device);
|
||||
let loss = (cat * weights).sum();
|
||||
|
||||
let grads = loss.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_3 = tensor_3.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&burn_tensor::TensorData::from([[1.0]]), false);
|
||||
grad_3
|
||||
.to_data()
|
||||
.assert_eq(&burn_tensor::TensorData::from([[4.0, 5.0, 6.0]]), false);
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_ceil() {
|
||||
let data = TensorData::from([
|
||||
[-1.9751, 0.0714, 0.0643, 0.2406],
|
||||
[-1.3172, 0.1252, -0.1119, -0.0127],
|
||||
]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
|
||||
let tensor_2 = tensor_1.clone().ceil();
|
||||
let grads = tensor_2.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_eq(
|
||||
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,215 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Bool, Tensor, TensorData};
|
||||
|
||||
#[test]
|
||||
fn test_autodiff_checkpoint_complicated_computation() {
|
||||
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
|
||||
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
|
||||
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
|
||||
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
|
||||
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
|
||||
|
||||
let tensor_5 = compute_bound_eager(tensor_0, tensor_1);
|
||||
let tensor_6 = compute_bound_lazy(tensor_2, tensor_3.clone());
|
||||
let tensor_7 = memory_bound_eager(tensor_3, tensor_4);
|
||||
let tensor_8 = compute_bound_lazy(tensor_6, tensor_7.clone());
|
||||
let tensor_9 = memory_bound_eager_scalar(tensor_7, 11.);
|
||||
let tensor_10 = memory_bound_lazy(tensor_5, tensor_8.clone());
|
||||
let tensor_11 = memory_bound_lazy(tensor_8, tensor_9);
|
||||
let tensor_12 = compute_bound_lazy(tensor_10, tensor_11);
|
||||
|
||||
assert_checkpoint(tensor_12);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autodiff_checkpoint_with_missing_requirement() {
|
||||
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
|
||||
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device); // does not require_grad
|
||||
|
||||
let tensor_2 = memory_bound_eager(tensor_0, tensor_1);
|
||||
let tensor_3 = memory_bound_eager_scalar(tensor_2.clone(), 11.);
|
||||
let tensor_4 = memory_bound_eager_scalar(tensor_2.clone(), 11.);
|
||||
let tensor_5 = compute_bound_lazy(tensor_3, tensor_4);
|
||||
let tensor_6 = compute_bound_eager_scalar(tensor_5.clone(), 11.);
|
||||
let tensor_7 = memory_bound_eager(tensor_5, tensor_2);
|
||||
let tensor_8 = memory_bound_eager(tensor_6, tensor_7);
|
||||
|
||||
assert_checkpoint(tensor_8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autodiff_checkpoint_with_many_duplicates() {
|
||||
let data_0 = TensorData::from([[4.0, 7.0], [7.0, 7.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
|
||||
|
||||
let tensor_1 = memory_bound_eager(tensor_0.clone(), tensor_0.clone());
|
||||
let tensor_2 = compute_bound_eager(tensor_0.clone(), tensor_0.clone());
|
||||
let tensor_3 = memory_bound_lazy(tensor_0.clone(), tensor_0.clone());
|
||||
let tensor_4 = compute_bound_lazy(tensor_0.clone(), tensor_0.clone());
|
||||
|
||||
let tensor_5 = memory_bound_eager(tensor_1.clone(), tensor_0.clone());
|
||||
let tensor_6 = memory_bound_eager(tensor_0.clone(), tensor_5.clone());
|
||||
let tensor_7 = compute_bound_lazy(tensor_3.clone(), tensor_5.clone());
|
||||
let tensor_8 = compute_bound_eager(tensor_4.clone(), tensor_2.clone());
|
||||
let tensor_9 = memory_bound_lazy(tensor_6, tensor_7);
|
||||
let tensor_10 = memory_bound_eager(tensor_0, tensor_9);
|
||||
let tensor_11 = memory_bound_eager_scalar(tensor_10, 9.);
|
||||
let tensor_12 = compute_bound_lazy(tensor_8, tensor_11);
|
||||
|
||||
assert_checkpoint(tensor_12);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autodiff_checkpoint_with_long_chain_of_eager_memory_bound() {
|
||||
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
|
||||
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
|
||||
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
|
||||
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
|
||||
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device);
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
|
||||
|
||||
let tensor_5 = memory_bound_eager(tensor_0, tensor_1.clone());
|
||||
let tensor_6 = memory_bound_eager(tensor_5, tensor_2);
|
||||
let tensor_7 = memory_bound_eager(tensor_6, tensor_3);
|
||||
let tensor_8 = memory_bound_eager(tensor_7, tensor_4);
|
||||
let tensor_9 = memory_bound_lazy(tensor_8, tensor_1);
|
||||
|
||||
assert_checkpoint(tensor_9)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autodiff_checkpoint_half_sub_graph_not_tracked() {
|
||||
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
|
||||
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
|
||||
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
|
||||
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
|
||||
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
|
||||
let data_5 = TensorData::from([[0.5, 7.0], [7.0, 7.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device);
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device);
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device);
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
|
||||
let tensor_5 = TestAutodiffTensor::from_data(data_5, &device).require_grad();
|
||||
|
||||
let tensor_6 = memory_bound_lazy(tensor_0, tensor_1);
|
||||
let tensor_7 = compute_bound_eager(tensor_6, tensor_2);
|
||||
|
||||
let tensor_8 = memory_bound_eager(tensor_3, tensor_4);
|
||||
let tensor_9 = compute_bound_lazy(tensor_8, tensor_5);
|
||||
|
||||
let tensor_10 = compute_bound_lazy(tensor_7, tensor_9);
|
||||
|
||||
assert_checkpoint(tensor_10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autodiff_checkpoint_very_complex() {
|
||||
let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);
|
||||
let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);
|
||||
let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);
|
||||
let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);
|
||||
let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device);
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();
|
||||
|
||||
let tensor_5 = memory_bound_eager_scalar(tensor_0, 8.);
|
||||
let tensor_6 = memory_bound_lazy(tensor_5.clone(), tensor_1.clone());
|
||||
let tensor_7 = compute_bound_lazy(tensor_6.clone(), tensor_6);
|
||||
let tensor_8 = memory_bound_lazy(tensor_1.clone(), tensor_5.clone());
|
||||
let tensor_9 = memory_bound_eager_scalar(tensor_7.clone(), 7.);
|
||||
let tensor_10 = compute_bound_eager(tensor_5, tensor_8);
|
||||
let tensor_11 = memory_bound_eager(tensor_2.clone(), tensor_9);
|
||||
let tensor_12 = memory_bound_lazy(tensor_2.clone(), tensor_2);
|
||||
let tensor_13 = compute_bound_eager(tensor_10.clone(), tensor_11);
|
||||
let tensor_14 = compute_bound_eager_scalar(tensor_3, 8.);
|
||||
let tensor_15 = compute_bound_lazy(tensor_4, tensor_12);
|
||||
let tensor_16 = memory_bound_lazy(tensor_10, tensor_7);
|
||||
let tensor_17 = compute_bound_lazy(tensor_13, tensor_1);
|
||||
let tensor_18 = memory_bound_eager(tensor_15, tensor_16);
|
||||
let tensor_19 = compute_bound_eager(tensor_14, tensor_17);
|
||||
let tensor_20 = memory_bound_lazy(tensor_18, tensor_19);
|
||||
let tensor_21 = memory_bound_eager_scalar(tensor_20, 8.);
|
||||
|
||||
assert_checkpoint(tensor_21)
|
||||
}
|
||||
|
||||
fn assert_checkpoint<const D: usize>(tensor: TestAutodiffTensor<D>) {
|
||||
// Assert is not explicit here, but the test can fail
|
||||
// - when a tensor is actually required more than n_required, it won't be found and will panic
|
||||
// - when a tensor is actually required less than n_required, the backward states map won't be
|
||||
// empty and will fail the assertion within the backward code, same for retro_forwards
|
||||
tensor.backward();
|
||||
}
|
||||
|
||||
// Does not save its state and does not need its parents
|
||||
fn memory_bound_eager<const D: usize>(
|
||||
tensor_a: TestAutodiffTensor<D>,
|
||||
tensor_b: TestAutodiffTensor<D>,
|
||||
) -> TestAutodiffTensor<D> {
|
||||
tensor_a.add(tensor_b)
|
||||
}
|
||||
fn memory_bound_eager_scalar<const D: usize>(
|
||||
tensor_a: TestAutodiffTensor<D>,
|
||||
b: f32,
|
||||
) -> TestAutodiffTensor<D> {
|
||||
tensor_a.add_scalar(b)
|
||||
}
|
||||
|
||||
// Saves its own state and does not need its parents
|
||||
fn compute_bound_eager<const D: usize>(
|
||||
tensor_a: TestAutodiffTensor<D>,
|
||||
tensor_b: TestAutodiffTensor<D>,
|
||||
) -> TestAutodiffTensor<D> {
|
||||
let mask = Tensor::<TestAutodiffBackend, D, Bool>::empty(tensor_a.shape(), &tensor_a.device());
|
||||
tensor_a.mask_where(mask, tensor_b)
|
||||
}
|
||||
fn compute_bound_eager_scalar<const D: usize>(
|
||||
tensor_a: TestAutodiffTensor<D>,
|
||||
b: f32,
|
||||
) -> TestAutodiffTensor<D> {
|
||||
let mask = Tensor::<TestAutodiffBackend, D, Bool>::empty(tensor_a.shape(), &tensor_a.device());
|
||||
tensor_a.mask_fill(mask, b)
|
||||
}
|
||||
|
||||
// Does not save its state and needs its parents
|
||||
fn memory_bound_lazy<const D: usize>(
|
||||
tensor_a: TestAutodiffTensor<D>,
|
||||
tensor_b: TestAutodiffTensor<D>,
|
||||
) -> TestAutodiffTensor<D> {
|
||||
tensor_a.mul(tensor_b)
|
||||
}
|
||||
|
||||
// Saves its own state and needs its parents
|
||||
fn compute_bound_lazy<const D: usize>(
|
||||
tensor_a: TestAutodiffTensor<D>,
|
||||
tensor_b: TestAutodiffTensor<D>,
|
||||
) -> TestAutodiffTensor<D> {
|
||||
tensor_a.matmul(tensor_b)
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_full_complex_1() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.matmul(tensor_1.clone());
|
||||
let tensor_5 = tensor_4.mul(tensor_2.clone());
|
||||
|
||||
let grads = tensor_5.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[593., 463.0], [487.0, 539.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[734.0, 294.0], [1414.0, 242.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_full_complex_2() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.matmul(tensor_1.clone());
|
||||
let tensor_5 = tensor_4.add_scalar(17.0).add(tensor_2.clone());
|
||||
|
||||
let grads = tensor_5.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[166.0, 110.0], [212.0, 156.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[113.0, 141.0], [33.0, 41.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_full_complex_3() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.matmul(tensor_1.clone());
|
||||
let tensor_5 = tensor_4.clone().sub(tensor_2.clone());
|
||||
let tensor_6 = tensor_5.add(tensor_4);
|
||||
|
||||
let grads = tensor_6.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[332.0, 220.0], [424.0, 312.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[223.0, 279.0], [63.0, 79.0]]), false);
|
||||
}
|
||||
@@ -0,0 +1,277 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Shape, Tolerance, module::conv1d, ops::ConvOptions};
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_basic() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[14., 24., 24., 18.], [26., 42., 42., 30.]],
|
||||
[[14., 24., 24., 18.], [26., 42., 42., 30.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[30., 44., 36.], [54., 76., 60.]],
|
||||
[[30., 44., 36.], [54., 76., 60.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([8., 8.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_different_channels() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[39., 63., 63., 45.], [57., 90., 90., 63.]],
|
||||
[[39., 63., 63., 45.], [57., 90., 90., 63.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[30., 44., 36.], [54., 76., 60.]],
|
||||
[[30., 44., 36.], [54., 76., 60.]],
|
||||
[[30., 44., 36.], [54., 76., 60.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([8., 8., 8.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_with_padding() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 2,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[24., 24., 24., 24.], [42., 42., 42., 42.]],
|
||||
[[24., 24., 24., 24.], [42., 42., 42., 42.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[44., 44., 44.], [76., 76., 76.]],
|
||||
[[44., 44., 44.], [76., 76., 76.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([12., 12.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_with_stride() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[8., 16., 8., 10.], [14., 28., 14., 16.]],
|
||||
[[8., 16., 8., 10.], [14., 28., 14., 16.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[10., 20., 24.], [18., 36., 40.]],
|
||||
[[10., 20., 24.], [18., 36., 40.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([4., 4.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_dilation() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 2,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[6., 8., 8., 10.], [12., 14., 14., 16.]],
|
||||
[[6., 8., 8., 10.], [12., 14., 14., 16.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[8., 22., 14.], [16., 38., 22.]],
|
||||
[[8., 22., 14.], [16., 38., 22.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([4., 4.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_groups() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 2,
|
||||
length: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[1., 3., 3., 3.], [7., 12., 12., 9.]],
|
||||
[[1., 3., 3., 3.], [7., 12., 12., 9.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats([[[30., 44., 36.]], [[54., 76., 60.]]], &device),
|
||||
bias: TestTensor::from_floats([8., 8.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
struct Conv1dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
channels_out: usize,
|
||||
kernel_size: usize,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
length: usize,
|
||||
}
|
||||
|
||||
struct Grads {
|
||||
x: TestTensor<3>,
|
||||
weight: TestTensor<3>,
|
||||
bias: TestTensor<1>,
|
||||
}
|
||||
|
||||
impl Conv1dTestCase {
|
||||
fn assert_grads(self, expected_grads: Grads) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels_out,
|
||||
self.channels_in / self.groups,
|
||||
self.kernel_size,
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape::<3, _>(shape_weight)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let bias = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<3, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let output = conv1d(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
Some(bias.clone()),
|
||||
ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups),
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Assert
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
let weight_grad_actual = weight.grad(&grads).unwrap();
|
||||
let bias_grad_actual = bias.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default();
|
||||
expected_grads
|
||||
.bias
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.x
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,962 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Shape, Tolerance, module::conv2d, ops::ConvOptions};
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_basic() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[88., 138., 138., 96.],
|
||||
[150., 234., 234., 162.],
|
||||
[150., 234., 234., 162.],
|
||||
[112., 174., 174., 120.],
|
||||
],
|
||||
[
|
||||
[160., 246., 246., 168.],
|
||||
[258., 396., 396., 270.],
|
||||
[258., 396., 396., 270.],
|
||||
[184., 282., 282., 192.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[88., 138., 138., 96.],
|
||||
[150., 234., 234., 162.],
|
||||
[150., 234., 234., 162.],
|
||||
[112., 174., 174., 120.],
|
||||
],
|
||||
[
|
||||
[160., 246., 246., 168.],
|
||||
[258., 396., 396., 270.],
|
||||
[258., 396., 396., 270.],
|
||||
[184., 282., 282., 192.],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
|
||||
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
|
||||
],
|
||||
[
|
||||
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
|
||||
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([32., 32.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_different_channels() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[240., 369., 369., 252.],
|
||||
[387., 594., 594., 405.],
|
||||
[387., 594., 594., 405.],
|
||||
[276., 423., 423., 288.],
|
||||
],
|
||||
[
|
||||
[348., 531., 531., 360.],
|
||||
[549., 837., 837., 567.],
|
||||
[549., 837., 837., 567.],
|
||||
[384., 585., 585., 396.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[240., 369., 369., 252.],
|
||||
[387., 594., 594., 405.],
|
||||
[387., 594., 594., 405.],
|
||||
[276., 423., 423., 288.],
|
||||
],
|
||||
[
|
||||
[348., 531., 531., 360.],
|
||||
[549., 837., 837., 567.],
|
||||
[549., 837., 837., 567.],
|
||||
[384., 585., 585., 396.],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
|
||||
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
|
||||
],
|
||||
[
|
||||
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
|
||||
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
|
||||
],
|
||||
[
|
||||
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
|
||||
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([32., 32., 32.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_different_kernel_size() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 4,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[116., 180., 192., 132.],
|
||||
[198., 306., 324., 222.],
|
||||
[198., 306., 324., 222.],
|
||||
[148., 228., 240., 164.],
|
||||
],
|
||||
[
|
||||
[212., 324., 336., 228.],
|
||||
[342., 522., 540., 366.],
|
||||
[342., 522., 540., 366.],
|
||||
[244., 372., 384., 260.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[27., 45., 54., 39.],
|
||||
[52., 84., 96., 68.],
|
||||
[51., 81., 90., 63.],
|
||||
],
|
||||
[
|
||||
[123., 189., 198., 135.],
|
||||
[180., 276., 288., 196.],
|
||||
[147., 225., 234., 159.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[27., 45., 54., 39.],
|
||||
[52., 84., 96., 68.],
|
||||
[51., 81., 90., 63.],
|
||||
],
|
||||
[
|
||||
[123., 189., 198., 135.],
|
||||
[180., 276., 288., 196.],
|
||||
[147., 225., 234., 159.],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([12., 12.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_different_padding() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 2,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[138., 138., 138., 138.],
|
||||
[234., 234., 234., 234.],
|
||||
[234., 234., 234., 234.],
|
||||
[174., 174., 174., 174.],
|
||||
],
|
||||
[
|
||||
[246., 246., 246., 246.],
|
||||
[396., 396., 396., 396.],
|
||||
[396., 396., 396., 396.],
|
||||
[282., 282., 282., 282.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]],
|
||||
[[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]],
|
||||
],
|
||||
[
|
||||
[[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]],
|
||||
[[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([24., 24.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_different_width() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 5,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[88., 138., 138., 138., 96.],
|
||||
[150., 234., 234., 234., 162.],
|
||||
[150., 234., 234., 234., 162.],
|
||||
[112., 174., 174., 174., 120.],
|
||||
],
|
||||
[
|
||||
[160., 246., 246., 246., 168.],
|
||||
[258., 396., 396., 396., 270.],
|
||||
[258., 396., 396., 396., 270.],
|
||||
[184., 282., 282., 282., 192.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]],
|
||||
[[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]],
|
||||
],
|
||||
[
|
||||
[[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]],
|
||||
[[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([20., 20.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_stride_2() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 2,
|
||||
stride_2: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 6,
|
||||
width: 6,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[26., 52., 26., 52., 26., 28.],
|
||||
[52., 104., 52., 104., 52., 56.],
|
||||
[26., 52., 26., 52., 26., 28.],
|
||||
[52., 104., 52., 104., 52., 56.],
|
||||
[26., 52., 26., 52., 26., 28.],
|
||||
[32., 64., 32., 64., 32., 34.],
|
||||
],
|
||||
[
|
||||
[44., 88., 44., 88., 44., 46.],
|
||||
[88., 176., 88., 176., 88., 92.],
|
||||
[44., 88., 44., 88., 44., 46.],
|
||||
[88., 176., 88., 176., 88., 92.],
|
||||
[44., 88., 44., 88., 44., 46.],
|
||||
[50., 100., 50., 100., 50., 52.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]],
|
||||
[[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]],
|
||||
],
|
||||
[
|
||||
[[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]],
|
||||
[[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([9., 9.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_different_stride() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 3,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 8,
|
||||
width: 8,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[50., 78., 78., 78., 78., 78., 78., 54.],
|
||||
[62., 96., 96., 96., 96., 96., 96., 66.],
|
||||
[38., 60., 60., 60., 60., 60., 60., 42.],
|
||||
[50., 78., 78., 78., 78., 78., 78., 54.],
|
||||
[62., 96., 96., 96., 96., 96., 96., 66.],
|
||||
[38., 60., 60., 60., 60., 60., 60., 42.],
|
||||
[50., 78., 78., 78., 78., 78., 78., 54.],
|
||||
[62., 96., 96., 96., 96., 96., 96., 66.],
|
||||
],
|
||||
[
|
||||
[86., 132., 132., 132., 132., 132., 132., 90.],
|
||||
[98., 150., 150., 150., 150., 150., 150., 102.],
|
||||
[74., 114., 114., 114., 114., 114., 114., 78.],
|
||||
[86., 132., 132., 132., 132., 132., 132., 90.],
|
||||
[98., 150., 150., 150., 150., 150., 150., 102.],
|
||||
[74., 114., 114., 114., 114., 114., 114., 78.],
|
||||
[86., 132., 132., 132., 132., 132., 132., 90.],
|
||||
[98., 150., 150., 150., 150., 150., 150., 102.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]],
|
||||
[
|
||||
[1330., 1528., 1344.],
|
||||
[1911., 2196., 1932.],
|
||||
[2079., 2388., 2100.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]],
|
||||
[
|
||||
[1330., 1528., 1344.],
|
||||
[1911., 2196., 1932.],
|
||||
[2079., 2388., 2100.],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([24., 24.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_dilation_2() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 2,
|
||||
dilation_2: 2,
|
||||
groups: 1,
|
||||
height: 6,
|
||||
width: 6,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[18., 38., 38., 42., 42., 22.],
|
||||
[42., 88., 88., 96., 96., 50.],
|
||||
[42., 88., 88., 96., 96., 50.],
|
||||
[54., 112., 112., 120., 120., 62.],
|
||||
[54., 112., 112., 120., 120., 62.],
|
||||
[30., 62., 62., 66., 66., 34.],
|
||||
],
|
||||
[
|
||||
[36., 74., 74., 78., 78., 40.],
|
||||
[78., 160., 160., 168., 168., 86.],
|
||||
[78., 160., 160., 168., 168., 86.],
|
||||
[90., 184., 184., 192., 192., 98.],
|
||||
[90., 184., 184., 192., 192., 98.],
|
||||
[48., 98., 98., 102., 102., 52.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]],
|
||||
[[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]],
|
||||
],
|
||||
[
|
||||
[[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]],
|
||||
[[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([16., 16.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_different_dilation() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 2,
|
||||
dilation_2: 3,
|
||||
groups: 1,
|
||||
height: 6,
|
||||
width: 6,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[18., 0., 20., 20., 0., 22.],
|
||||
[42., 0., 46., 46., 0., 50.],
|
||||
[42., 0., 46., 46., 0., 50.],
|
||||
[54., 0., 58., 58., 0., 62.],
|
||||
[54., 0., 58., 58., 0., 62.],
|
||||
[30., 0., 32., 32., 0., 34.],
|
||||
],
|
||||
[
|
||||
[36., 0., 38., 38., 0., 40.],
|
||||
[78., 0., 82., 82., 0., 86.],
|
||||
[78., 0., 82., 82., 0., 86.],
|
||||
[90., 0., 94., 94., 0., 98.],
|
||||
[90., 0., 94., 94., 0., 98.],
|
||||
[48., 0., 50., 50., 0., 52.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]],
|
||||
[[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]],
|
||||
],
|
||||
[
|
||||
[[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]],
|
||||
[[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([8., 8.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_groups() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 2,
|
||||
height: 5,
|
||||
width: 5,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[0., 1., 3., 3., 2.],
|
||||
[3., 8., 15., 12., 7.],
|
||||
[9., 21., 36., 27., 15.],
|
||||
[9., 20., 33., 24., 13.],
|
||||
[6., 13., 21., 15., 8.],
|
||||
],
|
||||
[
|
||||
[9., 19., 30., 21., 11.],
|
||||
[21., 44., 69., 48., 25.],
|
||||
[36., 75., 117., 81., 42.],
|
||||
[27., 56., 87., 60., 31.],
|
||||
[15., 31., 48., 33., 17.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[[54., 63., 72.], [99., 108., 117.], [144., 153., 162.]]],
|
||||
[[[279., 288., 297.], [324., 333., 342.], [369., 378., 387.]]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([9., 9.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_groups_stride_2() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 4,
|
||||
channels_out: 4,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 2,
|
||||
stride_2: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 4,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[4., 8., 4., 5.],
|
||||
[8., 16., 8., 10.],
|
||||
[4., 8., 4., 5.],
|
||||
[7., 14., 7., 8.],
|
||||
],
|
||||
[
|
||||
[13., 26., 13., 14.],
|
||||
[26., 52., 26., 28.],
|
||||
[13., 26., 13., 14.],
|
||||
[16., 32., 16., 17.],
|
||||
],
|
||||
[
|
||||
[22., 44., 22., 23.],
|
||||
[44., 88., 44., 46.],
|
||||
[22., 44., 22., 23.],
|
||||
[25., 50., 25., 26.],
|
||||
],
|
||||
[
|
||||
[31., 62., 31., 32.],
|
||||
[62., 124., 62., 64.],
|
||||
[31., 62., 31., 32.],
|
||||
[34., 68., 34., 35.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[[5., 10., 12.], [10., 20., 24.], [18., 36., 40.]]],
|
||||
[[[21., 42., 44.], [42., 84., 88.], [50., 100., 104.]]],
|
||||
[[[37., 74., 76.], [74., 148., 152.], [82., 164., 168.]]],
|
||||
[[[53., 106., 108.], [106., 212., 216.], [114., 228., 232.]]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([4., 4., 4., 4.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_groups_different_channels() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 3,
|
||||
channels_out: 6,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 3,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[9., 20., 24., 13.],
|
||||
[24., 52., 60., 32.],
|
||||
[36., 76., 84., 44.],
|
||||
[21., 44., 48., 25.],
|
||||
],
|
||||
[
|
||||
[45., 92., 96., 49.],
|
||||
[96., 196., 204., 104.],
|
||||
[108., 220., 228., 116.],
|
||||
[57., 116., 120., 61.],
|
||||
],
|
||||
[
|
||||
[81., 164., 168., 85.],
|
||||
[168., 340., 348., 176.],
|
||||
[180., 364., 372., 188.],
|
||||
[93., 188., 192., 97.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]],
|
||||
[[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]],
|
||||
[[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]],
|
||||
[[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]],
|
||||
[[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]],
|
||||
[[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([4., 4., 4., 4., 4., 4.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_complex() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 2,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 2,
|
||||
stride_1: 1,
|
||||
stride_2: 2,
|
||||
dilation_1: 2,
|
||||
dilation_2: 3,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 5,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[36., 39., 0., 39., 42.],
|
||||
[81., 87., 0., 87., 93.],
|
||||
[81., 87., 0., 87., 93.],
|
||||
[45., 48., 0., 48., 51.],
|
||||
],
|
||||
[
|
||||
[54., 57., 0., 57., 60.],
|
||||
[117., 123., 0., 123., 129.],
|
||||
[117., 123., 0., 123., 129.],
|
||||
[63., 66., 0., 66., 69.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[15., 42., 27.], [30., 72., 42.]],
|
||||
[[75., 162., 87.], [90., 192., 102.]],
|
||||
],
|
||||
[
|
||||
[[15., 42., 27.], [30., 72., 42.]],
|
||||
[[75., 162., 87.], [90., 192., 102.]],
|
||||
],
|
||||
[
|
||||
[[15., 42., 27.], [30., 72., 42.]],
|
||||
[[75., 162., 87.], [90., 192., 102.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([8., 8., 8.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_groups_stride_2_no_pad() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 4,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 2,
|
||||
stride_2: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 2,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[0., 1., 2., 0.],
|
||||
[3., 4., 5., 0.],
|
||||
[6., 7., 8., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[9., 10., 11., 0.],
|
||||
[12., 13., 14., 0.],
|
||||
[15., 16., 17., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[18., 19., 20., 0.],
|
||||
[21., 22., 23., 0.],
|
||||
[24., 25., 26., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[27., 28., 29., 0.],
|
||||
[30., 31., 32., 0.],
|
||||
[33., 34., 35., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[0., 1., 2.], [4., 5., 6.], [8., 9., 10.]],
|
||||
[[16., 17., 18.], [20., 21., 22.], [24., 25., 26.]],
|
||||
],
|
||||
[
|
||||
[[32., 33., 34.], [36., 37., 38.], [40., 41., 42.]],
|
||||
[[48., 49., 50.], [52., 53., 54.], [56., 57., 58.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([1., 1.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
struct Conv2dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
channels_out: usize,
|
||||
kernel_size_1: usize,
|
||||
kernel_size_2: usize,
|
||||
padding_1: usize,
|
||||
padding_2: usize,
|
||||
stride_1: usize,
|
||||
stride_2: usize,
|
||||
dilation_1: usize,
|
||||
dilation_2: usize,
|
||||
groups: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
}
|
||||
|
||||
struct Grads {
|
||||
x: TestTensor<4>,
|
||||
weight: TestTensor<4>,
|
||||
bias: TestTensor<1>,
|
||||
}
|
||||
|
||||
impl Conv2dTestCase {
|
||||
fn assert_grads(self, expected_grads: Grads) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels_out,
|
||||
self.channels_in / self.groups,
|
||||
self.kernel_size_1,
|
||||
self.kernel_size_2,
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_weight)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let bias = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let output = conv2d(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
Some(bias.clone()),
|
||||
ConvOptions::new(
|
||||
[self.stride_1, self.stride_2],
|
||||
[self.padding_1, self.padding_2],
|
||||
[self.dilation_1, self.dilation_2],
|
||||
self.groups,
|
||||
),
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Assert
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
let weight_grad_actual = weight.grad(&grads).unwrap();
|
||||
let bias_grad_actual = bias.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::rel_abs(0.01, 0.01);
|
||||
expected_grads
|
||||
.bias
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.x
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,690 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Shape, Tolerance, module::conv3d, ops::ConvOptions};
|
||||
|
||||
#[test]
|
||||
fn test_conv3d_basic() {
|
||||
let test = Conv3dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
kernel_size_3: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
padding_3: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
stride_3: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
dilation_3: 1,
|
||||
groups: 1,
|
||||
depth: 4,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[
|
||||
[536., 816., 816., 552.],
|
||||
[840., 1278., 1278., 864.],
|
||||
[840., 1278., 1278., 864.],
|
||||
[584., 888., 888., 600.],
|
||||
],
|
||||
[
|
||||
[912., 1386., 1386., 936.],
|
||||
[1422., 2160., 2160., 1458.],
|
||||
[1422., 2160., 2160., 1458.],
|
||||
[984., 1494., 1494., 1008.],
|
||||
],
|
||||
[
|
||||
[912., 1386., 1386., 936.],
|
||||
[1422., 2160., 2160., 1458.],
|
||||
[1422., 2160., 2160., 1458.],
|
||||
[984., 1494., 1494., 1008.],
|
||||
],
|
||||
[
|
||||
[680., 1032., 1032., 696.],
|
||||
[1056., 1602., 1602., 1080.],
|
||||
[1056., 1602., 1602., 1080.],
|
||||
[728., 1104., 1104., 744.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[968., 1464., 1464., 984.],
|
||||
[1488., 2250., 2250., 1512.],
|
||||
[1488., 2250., 2250., 1512.],
|
||||
[1016., 1536., 1536., 1032.],
|
||||
],
|
||||
[
|
||||
[1560., 2358., 2358., 1584.],
|
||||
[2394., 3618., 3618., 2430.],
|
||||
[2394., 3618., 3618., 2430.],
|
||||
[1632., 2466., 2466., 1656.],
|
||||
],
|
||||
[
|
||||
[1560., 2358., 2358., 1584.],
|
||||
[2394., 3618., 3618., 2430.],
|
||||
[2394., 3618., 3618., 2430.],
|
||||
[1632., 2466., 2466., 1656.],
|
||||
],
|
||||
[
|
||||
[1112., 1680., 1680., 1128.],
|
||||
[1704., 2574., 2574., 1728.],
|
||||
[1704., 2574., 2574., 1728.],
|
||||
[1160., 1752., 1752., 1176.],
|
||||
],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[
|
||||
[536., 816., 816., 552.],
|
||||
[840., 1278., 1278., 864.],
|
||||
[840., 1278., 1278., 864.],
|
||||
[584., 888., 888., 600.],
|
||||
],
|
||||
[
|
||||
[912., 1386., 1386., 936.],
|
||||
[1422., 2160., 2160., 1458.],
|
||||
[1422., 2160., 2160., 1458.],
|
||||
[984., 1494., 1494., 1008.],
|
||||
],
|
||||
[
|
||||
[912., 1386., 1386., 936.],
|
||||
[1422., 2160., 2160., 1458.],
|
||||
[1422., 2160., 2160., 1458.],
|
||||
[984., 1494., 1494., 1008.],
|
||||
],
|
||||
[
|
||||
[680., 1032., 1032., 696.],
|
||||
[1056., 1602., 1602., 1080.],
|
||||
[1056., 1602., 1602., 1080.],
|
||||
[728., 1104., 1104., 744.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[968., 1464., 1464., 984.],
|
||||
[1488., 2250., 2250., 1512.],
|
||||
[1488., 2250., 2250., 1512.],
|
||||
[1016., 1536., 1536., 1032.],
|
||||
],
|
||||
[
|
||||
[1560., 2358., 2358., 1584.],
|
||||
[2394., 3618., 3618., 2430.],
|
||||
[2394., 3618., 3618., 2430.],
|
||||
[1632., 2466., 2466., 1656.],
|
||||
],
|
||||
[
|
||||
[1560., 2358., 2358., 1584.],
|
||||
[2394., 3618., 3618., 2430.],
|
||||
[2394., 3618., 3618., 2430.],
|
||||
[1632., 2466., 2466., 1656.],
|
||||
],
|
||||
[
|
||||
[1112., 1680., 1680., 1128.],
|
||||
[1704., 2574., 2574., 1728.],
|
||||
[1704., 2574., 2574., 1728.],
|
||||
[1160., 1752., 1752., 1176.],
|
||||
],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[
|
||||
[4590., 6156., 4644.],
|
||||
[6264., 8400., 6336.],
|
||||
[4806., 6444., 4860.],
|
||||
],
|
||||
[
|
||||
[6696., 8976., 6768.],
|
||||
[9120., 12224., 9216.],
|
||||
[6984., 9360., 7056.],
|
||||
],
|
||||
[
|
||||
[5454., 7308., 5508.],
|
||||
[7416., 9936., 7488.],
|
||||
[5670., 7596., 5724.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[8046., 10764., 8100.],
|
||||
[10872., 14544., 10944.],
|
||||
[8262., 11052., 8316.],
|
||||
],
|
||||
[
|
||||
[11304., 15120., 11376.],
|
||||
[15264., 20416., 15360.],
|
||||
[11592., 15504., 11664.],
|
||||
],
|
||||
[
|
||||
[8910., 11916., 8964.],
|
||||
[12024., 16080., 12096.],
|
||||
[9126., 12204., 9180.],
|
||||
],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[
|
||||
[4590., 6156., 4644.],
|
||||
[6264., 8400., 6336.],
|
||||
[4806., 6444., 4860.],
|
||||
],
|
||||
[
|
||||
[6696., 8976., 6768.],
|
||||
[9120., 12224., 9216.],
|
||||
[6984., 9360., 7056.],
|
||||
],
|
||||
[
|
||||
[5454., 7308., 5508.],
|
||||
[7416., 9936., 7488.],
|
||||
[5670., 7596., 5724.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[8046., 10764., 8100.],
|
||||
[10872., 14544., 10944.],
|
||||
[8262., 11052., 8316.],
|
||||
],
|
||||
[
|
||||
[11304., 15120., 11376.],
|
||||
[15264., 20416., 15360.],
|
||||
[11592., 15504., 11664.],
|
||||
],
|
||||
[
|
||||
[8910., 11916., 8964.],
|
||||
[12024., 16080., 12096.],
|
||||
[9126., 12204., 9180.],
|
||||
],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([128., 128.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv3d_complex() {
|
||||
let test = Conv3dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 2,
|
||||
kernel_size_2: 3,
|
||||
kernel_size_3: 4,
|
||||
padding_1: 1,
|
||||
padding_2: 2,
|
||||
padding_3: 3,
|
||||
stride_1: 1,
|
||||
stride_2: 2,
|
||||
stride_3: 3,
|
||||
dilation_1: 2,
|
||||
dilation_2: 3,
|
||||
dilation_3: 4,
|
||||
groups: 1,
|
||||
depth: 5,
|
||||
height: 6,
|
||||
width: 7,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[
|
||||
[0., 147., 0., 0., 0., 150., 0.],
|
||||
[0., 159., 0., 0., 0., 162., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 159., 0., 0., 0., 162., 0.],
|
||||
[0., 171., 0., 0., 0., 174., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 330., 0., 0., 0., 336., 0.],
|
||||
[0., 354., 0., 0., 0., 360., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 354., 0., 0., 0., 360., 0.],
|
||||
[0., 378., 0., 0., 0., 384., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 330., 0., 0., 0., 336., 0.],
|
||||
[0., 354., 0., 0., 0., 360., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 354., 0., 0., 0., 360., 0.],
|
||||
[0., 378., 0., 0., 0., 384., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 330., 0., 0., 0., 336., 0.],
|
||||
[0., 354., 0., 0., 0., 360., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 354., 0., 0., 0., 360., 0.],
|
||||
[0., 378., 0., 0., 0., 384., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 183., 0., 0., 0., 186., 0.],
|
||||
[0., 195., 0., 0., 0., 198., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 195., 0., 0., 0., 198., 0.],
|
||||
[0., 207., 0., 0., 0., 210., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[0., 219., 0., 0., 0., 222., 0.],
|
||||
[0., 231., 0., 0., 0., 234., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 231., 0., 0., 0., 234., 0.],
|
||||
[0., 243., 0., 0., 0., 246., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 474., 0., 0., 0., 480., 0.],
|
||||
[0., 498., 0., 0., 0., 504., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 498., 0., 0., 0., 504., 0.],
|
||||
[0., 522., 0., 0., 0., 528., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 474., 0., 0., 0., 480., 0.],
|
||||
[0., 498., 0., 0., 0., 504., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 498., 0., 0., 0., 504., 0.],
|
||||
[0., 522., 0., 0., 0., 528., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 474., 0., 0., 0., 480., 0.],
|
||||
[0., 498., 0., 0., 0., 504., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 498., 0., 0., 0., 504., 0.],
|
||||
[0., 522., 0., 0., 0., 528., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 255., 0., 0., 0., 258., 0.],
|
||||
[0., 267., 0., 0., 0., 270., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 267., 0., 0., 0., 270., 0.],
|
||||
[0., 279., 0., 0., 0., 282., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[
|
||||
[0., 256., 272., 0.],
|
||||
[0., 624., 656., 0.],
|
||||
[0., 368., 384., 0.],
|
||||
],
|
||||
[
|
||||
[0., 424., 440., 0.],
|
||||
[0., 960., 992., 0.],
|
||||
[0., 536., 552., 0.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[0., 1096., 1112., 0.],
|
||||
[0., 2304., 2336., 0.],
|
||||
[0., 1208., 1224., 0.],
|
||||
],
|
||||
[
|
||||
[0., 1264., 1280., 0.],
|
||||
[0., 2640., 2672., 0.],
|
||||
[0., 1376., 1392., 0.],
|
||||
],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[
|
||||
[0., 256., 272., 0.],
|
||||
[0., 624., 656., 0.],
|
||||
[0., 368., 384., 0.],
|
||||
],
|
||||
[
|
||||
[0., 424., 440., 0.],
|
||||
[0., 960., 992., 0.],
|
||||
[0., 536., 552., 0.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[0., 1096., 1112., 0.],
|
||||
[0., 2304., 2336., 0.],
|
||||
[0., 1208., 1224., 0.],
|
||||
],
|
||||
[
|
||||
[0., 1264., 1280., 0.],
|
||||
[0., 2640., 2672., 0.],
|
||||
[0., 1376., 1392., 0.],
|
||||
],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[
|
||||
[0., 256., 272., 0.],
|
||||
[0., 624., 656., 0.],
|
||||
[0., 368., 384., 0.],
|
||||
],
|
||||
[
|
||||
[0., 424., 440., 0.],
|
||||
[0., 960., 992., 0.],
|
||||
[0., 536., 552., 0.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[0., 1096., 1112., 0.],
|
||||
[0., 2304., 2336., 0.],
|
||||
[0., 1208., 1224., 0.],
|
||||
],
|
||||
[
|
||||
[0., 1264., 1280., 0.],
|
||||
[0., 2640., 2672., 0.],
|
||||
[0., 1376., 1392., 0.],
|
||||
],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([10., 10., 10.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv3d_groups_stride_2_no_pad() {
|
||||
let test = Conv3dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 4,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
kernel_size_3: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
padding_3: 0,
|
||||
stride_1: 2,
|
||||
stride_2: 2,
|
||||
stride_3: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
dilation_3: 1,
|
||||
groups: 2,
|
||||
depth: 4,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[
|
||||
[0., 1., 2., 0.],
|
||||
[3., 4., 5., 0.],
|
||||
[6., 7., 8., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[9., 10., 11., 0.],
|
||||
[12., 13., 14., 0.],
|
||||
[15., 16., 17., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[18., 19., 20., 0.],
|
||||
[21., 22., 23., 0.],
|
||||
[24., 25., 26., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[27., 28., 29., 0.],
|
||||
[30., 31., 32., 0.],
|
||||
[33., 34., 35., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[36., 37., 38., 0.],
|
||||
[39., 40., 41., 0.],
|
||||
[42., 43., 44., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[45., 46., 47., 0.],
|
||||
[48., 49., 50., 0.],
|
||||
[51., 52., 53., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[54., 55., 56., 0.],
|
||||
[57., 58., 59., 0.],
|
||||
[60., 61., 62., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[63., 64., 65., 0.],
|
||||
[66., 67., 68., 0.],
|
||||
[69., 70., 71., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[72., 73., 74., 0.],
|
||||
[75., 76., 77., 0.],
|
||||
[78., 79., 80., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[81., 82., 83., 0.],
|
||||
[84., 85., 86., 0.],
|
||||
[87., 88., 89., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[90., 91., 92., 0.],
|
||||
[93., 94., 95., 0.],
|
||||
[96., 97., 98., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[99., 100., 101., 0.],
|
||||
[102., 103., 104., 0.],
|
||||
[105., 106., 107., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
[
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[[0., 1., 2.], [4., 5., 6.], [8., 9., 10.]],
|
||||
[[16., 17., 18.], [20., 21., 22.], [24., 25., 26.]],
|
||||
[[32., 33., 34.], [36., 37., 38.], [40., 41., 42.]],
|
||||
],
|
||||
[
|
||||
[[64., 65., 66.], [68., 69., 70.], [72., 73., 74.]],
|
||||
[[80., 81., 82.], [84., 85., 86.], [88., 89., 90.]],
|
||||
[[96., 97., 98.], [100., 101., 102.], [104., 105., 106.]],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[[128., 129., 130.], [132., 133., 134.], [136., 137., 138.]],
|
||||
[[144., 145., 146.], [148., 149., 150.], [152., 153., 154.]],
|
||||
[[160., 161., 162.], [164., 165., 166.], [168., 169., 170.]],
|
||||
],
|
||||
[
|
||||
[[192., 193., 194.], [196., 197., 198.], [200., 201., 202.]],
|
||||
[[208., 209., 210.], [212., 213., 214.], [216., 217., 218.]],
|
||||
[[224., 225., 226.], [228., 229., 230.], [232., 233., 234.]],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([1., 1.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
struct Conv3dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
channels_out: usize,
|
||||
kernel_size_1: usize,
|
||||
kernel_size_2: usize,
|
||||
kernel_size_3: usize,
|
||||
padding_1: usize,
|
||||
padding_2: usize,
|
||||
padding_3: usize,
|
||||
stride_1: usize,
|
||||
stride_2: usize,
|
||||
stride_3: usize,
|
||||
dilation_1: usize,
|
||||
dilation_2: usize,
|
||||
dilation_3: usize,
|
||||
groups: usize,
|
||||
depth: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
}
|
||||
|
||||
struct Grads {
|
||||
x: TestTensor<5>,
|
||||
weight: TestTensor<5>,
|
||||
bias: TestTensor<1>,
|
||||
}
|
||||
|
||||
impl Conv3dTestCase {
|
||||
fn assert_grads(self, expected_grads: Grads) {
|
||||
let shape_x = Shape::new([
|
||||
self.batch_size,
|
||||
self.channels_in,
|
||||
self.depth,
|
||||
self.height,
|
||||
self.width,
|
||||
]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels_out,
|
||||
self.channels_in / self.groups,
|
||||
self.kernel_size_1,
|
||||
self.kernel_size_2,
|
||||
self.kernel_size_3,
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape::<5, _>(shape_weight)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let bias = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<5, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let output = conv3d(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
Some(bias.clone()),
|
||||
ConvOptions::new(
|
||||
[self.stride_1, self.stride_2, self.stride_3],
|
||||
[self.padding_1, self.padding_2, self.padding_3],
|
||||
[self.dilation_1, self.dilation_2, self.dilation_3],
|
||||
self.groups,
|
||||
),
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Assert
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
let weight_grad_actual = weight.grad(&grads).unwrap();
|
||||
let bias_grad_actual = bias.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default();
|
||||
expected_grads
|
||||
.bias
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.x
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,292 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Shape, Tolerance, module::conv_transpose1d, ops::ConvTransposeOptions};
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose1d_basic() {
|
||||
let test = ConvTranspose1dTestCase {
|
||||
batch_size: 2,
|
||||
channels: [2, 2],
|
||||
kernel_size: 3,
|
||||
padding: 0,
|
||||
padding_out: 0,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
size: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]],
|
||||
[[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[44.0, 44.0, 44.0], [44.0, 44.0, 44.0]],
|
||||
[[76.0, 76.0, 76.0], [76.0, 76.0, 76.0]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([12., 12.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose1d_padding() {
|
||||
let test = ConvTranspose1dTestCase {
|
||||
batch_size: 2,
|
||||
channels: [2, 2],
|
||||
kernel_size: 3,
|
||||
padding: 2,
|
||||
padding_out: 0,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
size: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[7., 12., 8., 3.], [19., 36., 32., 15.]],
|
||||
[[7., 12., 8., 3.], [19., 36., 32., 15.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[26., 22., 18.], [26., 22., 18.]],
|
||||
[[42., 38., 34.], [42., 38., 34.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([4., 4.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose1d_stride() {
|
||||
let test = ConvTranspose1dTestCase {
|
||||
batch_size: 2,
|
||||
channels: [2, 2],
|
||||
kernel_size: 3,
|
||||
padding: 0,
|
||||
padding_out: 0,
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
size: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
|
||||
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[44., 44., 44.], [44., 44., 44.]],
|
||||
[[76., 76., 76.], [76., 76., 76.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([18., 18.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose1d_stride_padding_out() {
|
||||
let test = ConvTranspose1dTestCase {
|
||||
batch_size: 2,
|
||||
channels: [2, 2],
|
||||
kernel_size: 3,
|
||||
padding: 0,
|
||||
padding_out: 1,
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
size: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
|
||||
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[44., 44., 44.], [44., 44., 44.]],
|
||||
[[76., 76., 76.], [76., 76., 76.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([20., 20.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose1d_dilation() {
|
||||
let test = ConvTranspose1dTestCase {
|
||||
batch_size: 2,
|
||||
channels: [2, 2],
|
||||
kernel_size: 3,
|
||||
padding: 0,
|
||||
padding_out: 0,
|
||||
stride: 1,
|
||||
dilation: 2,
|
||||
groups: 1,
|
||||
size: 4,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
|
||||
[[15., 15., 15., 15.], [51., 51., 51., 51.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[44., 44., 44.], [44., 44., 44.]],
|
||||
[[76., 76., 76.], [76., 76., 76.]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([16., 16.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose1d_complex() {
|
||||
let test = ConvTranspose1dTestCase {
|
||||
batch_size: 2,
|
||||
channels: [2, 4],
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
padding_out: 1,
|
||||
stride: 2,
|
||||
dilation: 2,
|
||||
groups: 2,
|
||||
size: 8,
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0],
|
||||
[36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0],
|
||||
],
|
||||
[
|
||||
[12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0],
|
||||
[36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[168.0, 184.0, 184.0], [168.0, 184.0, 184.0]],
|
||||
[[280.0, 312.0, 312.0], [280.0, 312.0, 312.0]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([36.0, 36.0, 36.0, 36.0], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
struct ConvTranspose1dTestCase {
|
||||
batch_size: usize,
|
||||
channels: [usize; 2],
|
||||
kernel_size: usize,
|
||||
padding: usize,
|
||||
padding_out: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
struct Grads {
|
||||
x: TestTensor<3>,
|
||||
weight: TestTensor<3>,
|
||||
bias: TestTensor<1>,
|
||||
}
|
||||
|
||||
impl ConvTranspose1dTestCase {
|
||||
fn assert_grads(self, expected_grads: Grads) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels[0], self.size]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels[0],
|
||||
self.channels[1] / self.groups,
|
||||
self.kernel_size,
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape::<3, _>(shape_weight)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let bias = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<3, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let output = conv_transpose1d(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
Some(bias.clone()),
|
||||
ConvTransposeOptions::new(
|
||||
[self.stride],
|
||||
[self.padding],
|
||||
[self.padding_out],
|
||||
[self.dilation],
|
||||
self.groups,
|
||||
),
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Assert
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
let weight_grad_actual = weight.grad(&grads).unwrap();
|
||||
let bias_grad_actual = bias.grad(&grads).unwrap();
|
||||
|
||||
expected_grads
|
||||
.bias
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), Tolerance::default());
|
||||
expected_grads
|
||||
.x
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,706 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Shape, Tolerance, module::conv_transpose2d, ops::ConvTransposeOptions};
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_basic() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 2,
|
||||
channels: [2, 2],
|
||||
kernel_size: [3, 3],
|
||||
padding: [0, 0],
|
||||
padding_out: [0, 0],
|
||||
stride: [1, 1],
|
||||
dilation: [1, 1],
|
||||
groups: 1,
|
||||
size: [4, 4],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[153., 153., 153., 153.],
|
||||
[153., 153., 153., 153.],
|
||||
[153., 153., 153., 153.],
|
||||
[153., 153., 153., 153.],
|
||||
],
|
||||
[
|
||||
[477., 477., 477., 477.],
|
||||
[477., 477., 477., 477.],
|
||||
[477., 477., 477., 477.],
|
||||
[477., 477., 477., 477.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[153., 153., 153., 153.],
|
||||
[153., 153., 153., 153.],
|
||||
[153., 153., 153., 153.],
|
||||
[153., 153., 153., 153.],
|
||||
],
|
||||
[
|
||||
[477., 477., 477., 477.],
|
||||
[477., 477., 477., 477.],
|
||||
[477., 477., 477., 477.],
|
||||
[477., 477., 477., 477.],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]],
|
||||
[[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]],
|
||||
],
|
||||
[
|
||||
[
|
||||
[1264., 1264., 1264.],
|
||||
[1264., 1264., 1264.],
|
||||
[1264., 1264., 1264.],
|
||||
],
|
||||
[
|
||||
[1264., 1264., 1264.],
|
||||
[1264., 1264., 1264.],
|
||||
[1264., 1264., 1264.],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([72., 72.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_padding() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [1, 1],
|
||||
kernel_size: [3, 3],
|
||||
padding: [1, 2],
|
||||
padding_out: [0, 0],
|
||||
stride: [1, 1],
|
||||
dilation: [1, 1],
|
||||
groups: 1,
|
||||
size: [4, 4],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[[
|
||||
[13., 24., 20., 9.],
|
||||
[15., 27., 21., 9.],
|
||||
[15., 27., 21., 9.],
|
||||
[7., 12., 8., 3.],
|
||||
]]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[[[[63., 57., 51.], [68., 60., 52.], [39., 33., 27.]]]],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([8.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_stride() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [1, 1],
|
||||
kernel_size: [3, 3],
|
||||
padding: [0, 0],
|
||||
padding_out: [0, 0],
|
||||
stride: [2, 3],
|
||||
dilation: [1, 1],
|
||||
groups: 1,
|
||||
size: [4, 4],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[[
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
]]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([108.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_stride_padding_out() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [1, 1],
|
||||
kernel_size: [3, 3],
|
||||
padding: [0, 0],
|
||||
padding_out: [1, 2],
|
||||
stride: [2, 3],
|
||||
dilation: [1, 1],
|
||||
groups: 1,
|
||||
size: [4, 4],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[[
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
]]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([140.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_dilation() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [1, 1],
|
||||
kernel_size: [3, 3],
|
||||
padding: [0, 0],
|
||||
padding_out: [0, 0],
|
||||
stride: [1, 1],
|
||||
dilation: [2, 3],
|
||||
groups: 1,
|
||||
size: [4, 4],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[[
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
]]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([80.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_channels() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [2, 3],
|
||||
kernel_size: [3, 3],
|
||||
padding: [0, 0],
|
||||
padding_out: [0, 0],
|
||||
stride: [1, 1],
|
||||
dilation: [1, 1],
|
||||
groups: 1,
|
||||
size: [4, 4],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[351., 351., 351., 351.],
|
||||
[351., 351., 351., 351.],
|
||||
[351., 351., 351., 351.],
|
||||
[351., 351., 351., 351.],
|
||||
],
|
||||
[
|
||||
[1080., 1080., 1080., 1080.],
|
||||
[1080., 1080., 1080., 1080.],
|
||||
[1080., 1080., 1080., 1080.],
|
||||
[1080., 1080., 1080., 1080.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],
|
||||
[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],
|
||||
[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],
|
||||
],
|
||||
[
|
||||
[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],
|
||||
[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],
|
||||
[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([36., 36., 36.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_kernel_size() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [1, 1],
|
||||
kernel_size: [3, 5],
|
||||
padding: [0, 0],
|
||||
padding_out: [0, 0],
|
||||
stride: [1, 1],
|
||||
dilation: [1, 1],
|
||||
groups: 1,
|
||||
size: [6, 6],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[[
|
||||
[105., 105., 105., 105., 105., 105.],
|
||||
[105., 105., 105., 105., 105., 105.],
|
||||
[105., 105., 105., 105., 105., 105.],
|
||||
[105., 105., 105., 105., 105., 105.],
|
||||
[105., 105., 105., 105., 105., 105.],
|
||||
[105., 105., 105., 105., 105., 105.],
|
||||
]]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[[[
|
||||
[630., 630., 630., 630., 630.],
|
||||
[630., 630., 630., 630., 630.],
|
||||
[630., 630., 630., 630., 630.],
|
||||
]]],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([80.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_groups() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [2, 2],
|
||||
kernel_size: [3, 3],
|
||||
padding: [0, 0],
|
||||
padding_out: [0, 0],
|
||||
stride: [1, 1],
|
||||
dilation: [1, 1],
|
||||
groups: 2,
|
||||
size: [4, 4],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
[36., 36., 36., 36.],
|
||||
],
|
||||
[
|
||||
[117., 117., 117., 117.],
|
||||
[117., 117., 117., 117.],
|
||||
[117., 117., 117., 117.],
|
||||
[117., 117., 117., 117.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]],
|
||||
[[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([36., 36.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_complex_no_groups() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 2,
|
||||
channels: [2, 3],
|
||||
kernel_size: [3, 5],
|
||||
padding: [1, 2],
|
||||
padding_out: [1, 2],
|
||||
stride: [2, 3],
|
||||
dilation: [2, 3],
|
||||
groups: 1,
|
||||
size: [6, 8],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[600., 735., 735., 735., 735., 735., 735., 735.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
],
|
||||
[
|
||||
[1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[600., 735., 735., 735., 735., 735., 735., 735.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
[810., 990., 990., 990., 990., 990., 990., 990.],
|
||||
],
|
||||
[
|
||||
[1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
[2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[5320., 6040., 6040., 6040., 6040.],
|
||||
[6048., 6864., 6864., 6864., 6864.],
|
||||
[6048., 6864., 6864., 6864., 6864.],
|
||||
],
|
||||
[
|
||||
[5320., 6040., 6040., 6040., 6040.],
|
||||
[6048., 6864., 6864., 6864., 6864.],
|
||||
[6048., 6864., 6864., 6864., 6864.],
|
||||
],
|
||||
[
|
||||
[5320., 6040., 6040., 6040., 6040.],
|
||||
[6048., 6864., 6864., 6864., 6864.],
|
||||
[6048., 6864., 6864., 6864., 6864.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[8680., 9880., 9880., 9880., 9880.],
|
||||
[10080., 11472., 11472., 11472., 11472.],
|
||||
[10080., 11472., 11472., 11472., 11472.],
|
||||
],
|
||||
[
|
||||
[8680., 9880., 9880., 9880., 9880.],
|
||||
[10080., 11472., 11472., 11472., 11472.],
|
||||
[10080., 11472., 11472., 11472., 11472.],
|
||||
],
|
||||
[
|
||||
[8680., 9880., 9880., 9880., 9880.],
|
||||
[10080., 11472., 11472., 11472., 11472.],
|
||||
[10080., 11472., 11472., 11472., 11472.],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([896., 896., 896.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_complex_no_groups_2() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [4, 2],
|
||||
kernel_size: [2, 3],
|
||||
padding: [1, 2],
|
||||
padding_out: [1, 2],
|
||||
stride: [2, 3],
|
||||
dilation: [1, 2],
|
||||
groups: 1,
|
||||
size: [10, 10],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[30., 42., 42., 42., 42., 42., 42., 42., 42., 42.],
|
||||
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
[48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
],
|
||||
[
|
||||
[78., 114., 114., 114., 114., 114., 114., 114., 114., 114.],
|
||||
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
|
||||
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
|
||||
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
|
||||
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
|
||||
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
|
||||
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
|
||||
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
|
||||
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
|
||||
[144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],
|
||||
],
|
||||
[
|
||||
[126., 186., 186., 186., 186., 186., 186., 186., 186., 186.],
|
||||
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
|
||||
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
|
||||
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
|
||||
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
|
||||
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
|
||||
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
|
||||
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
|
||||
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
|
||||
[240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],
|
||||
],
|
||||
[
|
||||
[174., 258., 258., 258., 258., 258., 258., 258., 258., 258.],
|
||||
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
|
||||
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
|
||||
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
|
||||
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
|
||||
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
|
||||
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
|
||||
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
|
||||
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
|
||||
[336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[[4455., 4905., 4905.], [4500., 4950., 4950.]],
|
||||
[[4455., 4905., 4905.], [4500., 4950., 4950.]],
|
||||
],
|
||||
[
|
||||
[[12555., 13905., 13905.], [13500., 14950., 14950.]],
|
||||
[[12555., 13905., 13905.], [13500., 14950., 14950.]],
|
||||
],
|
||||
[
|
||||
[[20655., 22905., 22905.], [22500., 24950., 24950.]],
|
||||
[[20655., 22905., 22905.], [22500., 24950., 24950.]],
|
||||
],
|
||||
[
|
||||
[[28755., 31905., 31905.], [31500., 34950., 34950.]],
|
||||
[[28755., 31905., 31905.], [31500., 34950., 34950.]],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([570., 570.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_complex_groups() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [4, 2],
|
||||
kernel_size: [2, 3],
|
||||
padding: [1, 2],
|
||||
padding_out: [1, 2],
|
||||
stride: [2, 3],
|
||||
dilation: [1, 2],
|
||||
groups: 2,
|
||||
size: [10, 10],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[9., 12., 12., 12., 12., 12., 12., 12., 12., 12.],
|
||||
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
|
||||
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
|
||||
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
|
||||
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
|
||||
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
|
||||
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
|
||||
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
|
||||
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
|
||||
[12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
|
||||
],
|
||||
[
|
||||
[21., 30., 30., 30., 30., 30., 30., 30., 30., 30.],
|
||||
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
|
||||
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
|
||||
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
|
||||
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
|
||||
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
|
||||
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
|
||||
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
|
||||
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
|
||||
[36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
|
||||
],
|
||||
[
|
||||
[33., 48., 48., 48., 48., 48., 48., 48., 48., 48.],
|
||||
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
|
||||
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
|
||||
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
|
||||
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
|
||||
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
|
||||
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
|
||||
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
|
||||
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
|
||||
[60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],
|
||||
],
|
||||
[
|
||||
[45., 66., 66., 66., 66., 66., 66., 66., 66., 66.],
|
||||
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
|
||||
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
|
||||
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
|
||||
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
|
||||
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
|
||||
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
|
||||
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
|
||||
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
|
||||
[84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[[4455., 4905., 4905.], [4500., 4950., 4950.]]],
|
||||
[[[12555., 13905., 13905.], [13500., 14950., 14950.]]],
|
||||
[[[20655., 22905., 22905.], [22500., 24950., 24950.]]],
|
||||
[[[28755., 31905., 31905.], [31500., 34950., 34950.]]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([570., 570.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
struct ConvTranspose2dTestCase {
|
||||
batch_size: usize,
|
||||
channels: [usize; 2],
|
||||
kernel_size: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
padding_out: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
groups: usize,
|
||||
size: [usize; 2],
|
||||
}
|
||||
|
||||
struct Grads {
|
||||
x: TestTensor<4>,
|
||||
weight: TestTensor<4>,
|
||||
bias: TestTensor<1>,
|
||||
}
|
||||
|
||||
impl ConvTranspose2dTestCase {
|
||||
fn assert_grads(self, expected_grads: Grads) {
|
||||
let shape_x = Shape::new([
|
||||
self.batch_size,
|
||||
self.channels[0],
|
||||
self.size[0],
|
||||
self.size[1],
|
||||
]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels[0],
|
||||
self.channels[1] / self.groups,
|
||||
self.kernel_size[0],
|
||||
self.kernel_size[1],
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_weight)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let bias = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let output = conv_transpose2d(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
Some(bias.clone()),
|
||||
ConvTransposeOptions::new(
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.padding_out,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
),
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Assert
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
let weight_grad_actual = weight.grad(&grads).unwrap();
|
||||
let bias_grad_actual = bias.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::permissive();
|
||||
expected_grads
|
||||
.bias
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.x
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,711 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Shape, Tolerance, module::conv_transpose3d, ops::ConvTransposeOptions};
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose3d_basic() {
|
||||
let test = ConvTranspose3dTestCase {
|
||||
batch_size: 2,
|
||||
channels: [2, 2],
|
||||
kernel_size: [3, 3, 3],
|
||||
padding: [0, 0, 0],
|
||||
padding_out: [0, 0, 0],
|
||||
stride: [1, 1, 1],
|
||||
dilation: [1, 1, 1],
|
||||
groups: 1,
|
||||
size: [4, 4, 4],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
],
|
||||
[
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
],
|
||||
[
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
],
|
||||
[
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
],
|
||||
[
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
],
|
||||
[
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
],
|
||||
[
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
],
|
||||
[
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
],
|
||||
[
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
],
|
||||
[
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
[13.250001, 13.250001, 13.250001, 13.250001],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
],
|
||||
[
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
],
|
||||
[
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
],
|
||||
[
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
[40.249992, 40.249992, 40.249992, 40.249992],
|
||||
],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[
|
||||
[
|
||||
[
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
],
|
||||
[
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
],
|
||||
[
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
],
|
||||
[
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
],
|
||||
[
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
[47.750000, 47.750000, 47.750000],
|
||||
],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
],
|
||||
[
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
],
|
||||
[
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
],
|
||||
[
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
],
|
||||
[
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
[79.750000, 79.750000, 79.750000],
|
||||
],
|
||||
],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([432., 432.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose3d_complex_groups() {
|
||||
let test = ConvTranspose3dTestCase {
|
||||
batch_size: 1,
|
||||
channels: [4, 2],
|
||||
kernel_size: [2, 3, 4],
|
||||
padding: [1, 2, 3],
|
||||
padding_out: [1, 2, 3],
|
||||
stride: [2, 3, 4],
|
||||
dilation: [1, 2, 3],
|
||||
groups: 2,
|
||||
size: [6, 6, 6],
|
||||
};
|
||||
let device = Default::default();
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats(
|
||||
[[
|
||||
[
|
||||
[
|
||||
[1.250000, 1.625000, 1.625000, 1.625000, 1.625000, 1.625000],
|
||||
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
|
||||
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
|
||||
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
|
||||
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
|
||||
[1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
|
||||
],
|
||||
[
|
||||
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
],
|
||||
[
|
||||
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
],
|
||||
[
|
||||
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
],
|
||||
[
|
||||
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
],
|
||||
[
|
||||
[1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
[2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[2.750000, 3.625000, 3.625000, 3.625000, 3.625000, 3.625000],
|
||||
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
|
||||
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
|
||||
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
|
||||
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
|
||||
[3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
|
||||
],
|
||||
[
|
||||
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
],
|
||||
[
|
||||
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
],
|
||||
[
|
||||
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
],
|
||||
[
|
||||
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
],
|
||||
[
|
||||
[4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
[6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[4.250000, 5.625000, 5.625000, 5.625000, 5.625000, 5.625000],
|
||||
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
|
||||
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
|
||||
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
|
||||
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
|
||||
[6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
|
||||
],
|
||||
[
|
||||
[
|
||||
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
[
|
||||
11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
|
||||
],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[5.750000, 7.625000, 7.625000, 7.625000, 7.625000, 7.625000],
|
||||
[
|
||||
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
|
||||
],
|
||||
[
|
||||
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
|
||||
],
|
||||
[
|
||||
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
|
||||
],
|
||||
[
|
||||
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
|
||||
],
|
||||
[
|
||||
8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
[
|
||||
15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
|
||||
],
|
||||
],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
),
|
||||
weight: TestTensor::from_floats(
|
||||
[
|
||||
[[
|
||||
[
|
||||
[18.663193, 22.309027, 22.309027, 22.309027],
|
||||
[21.875000, 26.145834, 26.145834, 26.145834],
|
||||
[21.875000, 26.145834, 26.145834, 26.145834],
|
||||
],
|
||||
[
|
||||
[19.270832, 23.020834, 23.020834, 23.020834],
|
||||
[22.500000, 26.875002, 26.875002, 26.875002],
|
||||
[22.500000, 26.875002, 26.875002, 26.875002],
|
||||
],
|
||||
]],
|
||||
[[
|
||||
[
|
||||
[49.913193, 59.809029, 59.809029, 59.809029],
|
||||
[59.375000, 71.145836, 71.145836, 71.145836],
|
||||
[59.375000, 71.145836, 71.145836, 71.145836],
|
||||
],
|
||||
[
|
||||
[56.770836, 68.020836, 68.020836, 68.020836],
|
||||
[67.500000, 80.875000, 80.875000, 80.875000],
|
||||
[67.500000, 80.875000, 80.875000, 80.875000],
|
||||
],
|
||||
]],
|
||||
[[
|
||||
[
|
||||
[81.163193, 97.309029, 97.309029, 97.309029],
|
||||
[96.875000, 116.145828, 116.145828, 116.145828],
|
||||
[96.875000, 116.145828, 116.145828, 116.145828],
|
||||
],
|
||||
[
|
||||
[94.270828, 113.020828, 113.020828, 113.020828],
|
||||
[112.500000, 134.875000, 134.875000, 134.875000],
|
||||
[112.500000, 134.875000, 134.875000, 134.875000],
|
||||
],
|
||||
]],
|
||||
[[
|
||||
[
|
||||
[112.413200, 134.809021, 134.809021, 134.809021],
|
||||
[134.375000, 161.145828, 161.145828, 161.145828],
|
||||
[134.375000, 161.145828, 161.145828, 161.145828],
|
||||
],
|
||||
[
|
||||
[131.770844, 158.020828, 158.020828, 158.020828],
|
||||
[157.500000, 188.875000, 188.875000, 188.875000],
|
||||
[157.500000, 188.875000, 188.875000, 188.875000],
|
||||
],
|
||||
]],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
bias: TestTensor::from_floats([5346., 5346.], &device),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
struct ConvTranspose3dTestCase {
|
||||
batch_size: usize,
|
||||
channels: [usize; 2],
|
||||
kernel_size: [usize; 3],
|
||||
padding: [usize; 3],
|
||||
padding_out: [usize; 3],
|
||||
stride: [usize; 3],
|
||||
dilation: [usize; 3],
|
||||
groups: usize,
|
||||
size: [usize; 3],
|
||||
}
|
||||
|
||||
struct Grads {
|
||||
x: TestTensor<5>,
|
||||
weight: TestTensor<5>,
|
||||
bias: TestTensor<1>,
|
||||
}
|
||||
|
||||
impl ConvTranspose3dTestCase {
|
||||
fn assert_grads(self, expected_grads: Grads) {
|
||||
let shape_x = Shape::new([
|
||||
self.batch_size,
|
||||
self.channels[0],
|
||||
self.size[0],
|
||||
self.size[1],
|
||||
self.size[2],
|
||||
]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels[0],
|
||||
self.channels[1] / self.groups,
|
||||
self.kernel_size[0],
|
||||
self.kernel_size[1],
|
||||
self.kernel_size[2],
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape::<5, _>(shape_weight.clone())
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.div_scalar(shape_weight.num_elements() as f32)
|
||||
.require_grad();
|
||||
let bias = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<5, _>(shape_x.clone())
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.div_scalar(shape_x.num_elements() as f32)
|
||||
.require_grad();
|
||||
let output = conv_transpose3d(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
Some(bias.clone()),
|
||||
ConvTransposeOptions::new(
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.padding_out,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
),
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Assert
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
let weight_grad_actual = weight.grad(&grads).unwrap();
|
||||
let bias_grad_actual = bias.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::permissive();
|
||||
expected_grads
|
||||
.bias
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.x
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
|
||||
expected_grads
|
||||
.weight
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use burn_backend_tests::might_panic;
|
||||
|
||||
#[test]
|
||||
fn backward_basic() {
|
||||
let device = Default::default();
|
||||
let a = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let b = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
// Simple cross product; grad is a vector of ones.
|
||||
let c = a.clone().cross(b.clone(), 1);
|
||||
let grads = c.backward();
|
||||
|
||||
let a_grad = a.grad(&grads).unwrap().to_data();
|
||||
let b_grad = b.grad(&grads).unwrap().to_data();
|
||||
|
||||
// For a: b×grad_out, where grad_out = [1,1,1]
|
||||
let expected_a = TensorData::from([[-1.0, 2.0, -1.0], [-1.0, 2.0, -1.0]]);
|
||||
// For b: grad_out×a
|
||||
let expected_b = TensorData::from([[1.0, -2.0, 1.0], [1.0, -2.0, 1.0]]);
|
||||
|
||||
a_grad.assert_approx_eq::<FloatElem>(&expected_a, Tolerance::default());
|
||||
b_grad.assert_approx_eq::<FloatElem>(&expected_b, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn backward_after_sum() {
|
||||
let device = Default::default();
|
||||
let a = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let b = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
// Sum reduces to scalar, but the gradient should be the same.
|
||||
let c = a.clone().cross(b.clone(), 1).sum();
|
||||
let grads = c.backward();
|
||||
|
||||
let a_grad = a.grad(&grads).unwrap().to_data();
|
||||
let b_grad = b.grad(&grads).unwrap().to_data();
|
||||
|
||||
let expected_a = TensorData::from([[-1.0, 2.0, -1.0], [-1.0, 2.0, -1.0]]);
|
||||
let expected_b = TensorData::from([[1.0, -2.0, 1.0], [1.0, -2.0, 1.0]]);
|
||||
|
||||
a_grad.assert_approx_eq::<FloatElem>(&expected_a, Tolerance::default());
|
||||
b_grad.assert_approx_eq::<FloatElem>(&expected_b, Tolerance::default());
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[might_panic(reason = "not implemented: Cross product on non-last dimension")]
|
||||
#[test]
|
||||
fn different_dim() {
|
||||
// Also check when the cross is along a different dimension (e.g. dim 0).
|
||||
let device = Default::default();
|
||||
let a_raw = [[1.0, 4.0, 7.0], [2.0, 5.0, 8.0], [3.0, 6.0, 9.0]];
|
||||
let b_raw = [[9.0, 6.0, 3.0], [8.0, 5.0, 2.0], [7.0, 4.0, 1.0]];
|
||||
|
||||
let a = TestTensor::<2>::from_data(TensorData::from(a_raw), &device);
|
||||
let b = TestTensor::<2>::from_data(TensorData::from(b_raw), &device);
|
||||
// Cross along dim 0. Some backends (for example CubeCL) may not support
|
||||
// cross on non-last dimensions and will intentionally panic with a
|
||||
// message like "Cross product on non-last dimension not yet implemented".
|
||||
// In that case we treat the panic as a skipped test for that backend.
|
||||
let out = a.cross(b.clone(), 0);
|
||||
|
||||
// Manually compute cross of each column vector using raw arrays
|
||||
let expected = [
|
||||
[
|
||||
a_raw[1][0] * b_raw[2][0] - a_raw[2][0] * b_raw[1][0],
|
||||
a_raw[1][1] * b_raw[2][1] - a_raw[2][1] * b_raw[1][1],
|
||||
a_raw[1][2] * b_raw[2][2] - a_raw[2][2] * b_raw[1][2],
|
||||
],
|
||||
[
|
||||
a_raw[2][0] * b_raw[0][0] - a_raw[0][0] * b_raw[2][0],
|
||||
a_raw[2][1] * b_raw[0][1] - a_raw[0][1] * b_raw[2][1],
|
||||
a_raw[2][2] * b_raw[0][2] - a_raw[0][2] * b_raw[2][2],
|
||||
],
|
||||
[
|
||||
a_raw[0][0] * b_raw[1][0] - a_raw[1][0] * b_raw[0][0],
|
||||
a_raw[0][1] * b_raw[1][1] - a_raw[1][1] * b_raw[0][1],
|
||||
a_raw[0][2] * b_raw[1][2] - a_raw[1][2] * b_raw[0][2],
|
||||
],
|
||||
];
|
||||
|
||||
out.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&TensorData::from(expected), Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Tensor, TensorData, Tolerance, loss};
|
||||
|
||||
#[test]
|
||||
fn test_cross_entropy_loss_grad() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
let data_targets = TensorData::from([[0.8, 0.2], [0.9, 0.1]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();
|
||||
let tensor_targets =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data_targets, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = loss::cross_entropy_with_logits(tensor_3, tensor_targets);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::permissive();
|
||||
let expected = TensorData::from([[0.26553, 0.26553], [0.44954, 0.44954]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[-1.34863, 1.34863], [-2.06371, 2.06371]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummax() {
|
||||
// Simple test to verify cummax gradients work
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 3.0, 2.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummax(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [1.0, 2.0, 0.0]
|
||||
let expected = TensorData::from([1.0, 2.0, 0.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummax_2d() {
|
||||
// Test 2D cummax gradients
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[1.0, 3.0, 2.0], [2.0, 5.0, 4.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummax(1);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]
|
||||
let expected = TensorData::from([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummax_duplicate_values() {
|
||||
// Test with duplicate maximum values - critical edge case
|
||||
let device = Default::default();
|
||||
let tensor =
|
||||
TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 3.0, 3.0, 2.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummax(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// input: [1.0, 3.0, 3.0, 2.0]
|
||||
// cummax: [1.0, 3.0, 3.0, 3.0]
|
||||
// PyTorch reference: [1.0, 1.0, 2.0, 0.0]
|
||||
// Position 2 gets grad from itself + position 3
|
||||
let expected = TensorData::from([1.0, 1.0, 2.0, 0.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummax_all_same() {
|
||||
// Test with all same values
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 2.0, 2.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummax(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [1.0, 1.0, 1.0]
|
||||
// Each position matches cummax, so each gets its own gradient
|
||||
let expected = TensorData::from([1.0, 1.0, 1.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummax_increasing() {
|
||||
// Test with increasing sequence
|
||||
let device = Default::default();
|
||||
let tensor =
|
||||
TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 2.0, 3.0, 4.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummax(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [1.0, 1.0, 1.0, 1.0]
|
||||
// Each position is a new maximum
|
||||
let expected = TensorData::from([1.0, 1.0, 1.0, 1.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummax_2d_duplicates() {
|
||||
// Test 2D with duplicate values
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[1.0, 3.0, 3.0, 2.0], [2.0, 5.0, 5.0, 4.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummax(1);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]
|
||||
let expected = TensorData::from([[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummin() {
|
||||
// Simple test to verify cummin gradients work
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([3.0, 2.0, 4.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummin(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [1.0, 2.0, 0.0]
|
||||
let expected = TensorData::from([1.0, 2.0, 0.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummin_2d() {
|
||||
// Test 2D cummin gradients
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[3.0, 2.0, 4.0], [5.0, 1.0, 3.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummin(1);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]
|
||||
let expected = TensorData::from([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummin_duplicate_values() {
|
||||
// Test with duplicate minimum values - critical edge case
|
||||
let device = Default::default();
|
||||
let tensor =
|
||||
TestAutodiffTensor::<1>::from_data(TensorData::from([3.0, 2.0, 2.0, 4.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummin(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// input: [3.0, 2.0, 2.0, 4.0]
|
||||
// cummin: [3.0, 2.0, 2.0, 2.0]
|
||||
// PyTorch reference: [1.0, 1.0, 2.0, 0.0]
|
||||
// Position 2 gets grad from itself + position 3
|
||||
let expected = TensorData::from([1.0, 1.0, 2.0, 0.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummin_all_same() {
|
||||
// Test with all same values
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 2.0, 2.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummin(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [1.0, 1.0, 1.0]
|
||||
// Each position matches cummin, so each gets its own gradient
|
||||
let expected = TensorData::from([1.0, 1.0, 1.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummin_decreasing() {
|
||||
// Test with decreasing sequence
|
||||
let device = Default::default();
|
||||
let tensor =
|
||||
TestAutodiffTensor::<1>::from_data(TensorData::from([5.0, 4.0, 3.0, 2.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummin(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [1.0, 1.0, 1.0, 1.0]
|
||||
// Each position is a new minimum
|
||||
let expected = TensorData::from([1.0, 1.0, 1.0, 1.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cummin_2d_duplicates() {
|
||||
// Test 2D with duplicate values
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[3.0, 2.0, 2.0, 4.0], [5.0, 1.0, 1.0, 3.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cummin(1);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]
|
||||
let expected = TensorData::from([[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_diff_cumprod() {
|
||||
// Simple test to verify cumprod gradients work
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 3.0, 4.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cumprod(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [16.0, 10.0, 6.0]
|
||||
let expected = TensorData::from([16.0, 10.0, 6.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cumprod_2d() {
|
||||
// Test 2D cumprod gradients
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cumprod(1);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [[9.0, 4.0, 2.0], [36.0, 28.0, 20.0]]
|
||||
let expected = TensorData::from([[9.0, 4.0, 2.0], [36.0, 28.0, 20.0]]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
// TODO: The following tests are currently ignored due to a known limitation
|
||||
// in the cumprod gradient implementation. The current implementation uses
|
||||
// division (grad / input), which produces NaN when the input contains zeros.
|
||||
//
|
||||
// A proper fix requires implementing a zero-safe algorithm using exclusive
|
||||
// cumulative products (similar to PyTorch's cumprod_backward or JAX's
|
||||
// associative_scan approach). This is a non-trivial implementation that
|
||||
// requires careful handling of cumulative products in both forward and
|
||||
// reverse directions.
|
||||
//
|
||||
// See: https://github.com/tracel-ai/burn/issues/3864
|
||||
//
|
||||
// References:
|
||||
// - PyTorch: https://github.com/pytorch/pytorch (cumprod_backward)
|
||||
// - JAX PR #2596: Parallel prefix scan implementation
|
||||
// - TensorFlow Issue #3862: tf.cumprod's gradient produces nans given zeros
|
||||
|
||||
#[test]
|
||||
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
|
||||
fn should_diff_cumprod_zero_in_middle() {
|
||||
// Test cumprod with zero in the middle - edge case for division
|
||||
let device = Default::default();
|
||||
let tensor =
|
||||
TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 0.0, 3.0, 4.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cumprod(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [1.0, 32.0, 0.0, 0.0]
|
||||
let expected = TensorData::from([1.0, 32.0, 0.0, 0.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
|
||||
fn should_diff_cumprod_zero_at_start() {
|
||||
// Test cumprod with zero at the beginning
|
||||
let device = Default::default();
|
||||
let tensor =
|
||||
TestAutodiffTensor::<1>::from_data(TensorData::from([0.0, 2.0, 3.0, 4.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cumprod(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [33.0, 0.0, 0.0, 0.0]
|
||||
let expected = TensorData::from([33.0, 0.0, 0.0, 0.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
|
||||
fn should_diff_cumprod_zero_at_end() {
|
||||
// Test cumprod with zero at the end
|
||||
let device = Default::default();
|
||||
let tensor =
|
||||
TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 3.0, 4.0, 0.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cumprod(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [16.0, 10.0, 6.0, 24.0]
|
||||
let expected = TensorData::from([16.0, 10.0, 6.0, 24.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"]
|
||||
fn should_diff_cumprod_multiple_zeros() {
|
||||
// Test cumprod with multiple zeros
|
||||
let device = Default::default();
|
||||
let tensor =
|
||||
TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 0.0, 3.0, 0.0, 5.0]), &device)
|
||||
.require_grad();
|
||||
|
||||
let output = tensor.clone().cumprod(0);
|
||||
let grads = output.sum().backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
// PyTorch reference: [1.0, 8.0, 0.0, 0.0, 0.0]
|
||||
let expected = TensorData::from([1.0, 8.0, 0.0, 0.0, 0.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_diff_cumsum_dim0() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.cumsum(0);
|
||||
let tensor_5 = tensor_1.clone().mul(tensor_4);
|
||||
let grads = tensor_5.sum().backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
// Expected gradients computed with PyTorch
|
||||
let expected = TensorData::from([[-14.0, 24.0], [17.0, 6.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[3.0, 10.0], [-1.0, 37.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cumsum_dim1() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.cumsum(1);
|
||||
let tensor_5 = tensor_1.clone().mul(tensor_4);
|
||||
let grads = tensor_5.sum().backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
// Expected gradients computed with PyTorch
|
||||
let expected = TensorData::from([[1.0, 69.0], [-13.0, -28.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[18.0, 13.0], [71.0, 58.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cumsum_complex() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.clone().cumsum(1);
|
||||
let tensor_5 = tensor_4.mul(tensor_3);
|
||||
|
||||
let grads = tensor_5.sum().backward();
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
// Expected gradients computed with PyTorch
|
||||
let expected = TensorData::from([[371.0, 542.0], [2246.0, 3281.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[507.0, 528.0], [704.0, 733.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,105 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_diff_div() {
|
||||
let data_1 = TensorData::from([1.0, 7.0]);
|
||||
let data_2 = TensorData::from([4.0, 7.0]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().div(tensor_2.clone());
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([0.25, 0.14285715]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([-0.0625, -0.14285715]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_div_scalar() {
|
||||
let data = TensorData::from([1.0, 7.0]);
|
||||
|
||||
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
|
||||
let tensor_out = tensor.clone().div_scalar(4.0);
|
||||
|
||||
let grads = tensor_out.backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
grad.to_data()
|
||||
.assert_eq(&TensorData::from([0.25, 0.25]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_div_complex_1() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_1.clone().div(tensor_2.clone());
|
||||
let tensor_5 = tensor_4.div(tensor_3.clone());
|
||||
|
||||
let grads = tensor_5.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
let grad_3 = tensor_3.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[0.1250, 0.07142857], [0.25, 0.16666667]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[-0.03125, -0.07142857], [-1.6250, 0.16666667]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
let expected = TensorData::from([[-0.0625, -0.25], [-1.6250, 0.25]]);
|
||||
grad_3
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_div_complex_2() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.div(tensor_2.clone());
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default().set_half_precision_absolute(2e-3);
|
||||
let expected = TensorData::from([[2.00, 2.92857146], [1.36666667, 2.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[0.08333334, 0.09591837], [-0.05555558, -0.06714284]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_diff_erf() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().erf());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[32.0, 32.0], [32.0, 32.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[8.0, 8.0], [8.0, 8.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_diff_exp() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().exp());
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default();
|
||||
let expected = TensorData::from([[54.5991, 27.4746], [54.5991, 27.4746]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[-5.4598e+01, -9.1188e-04], [2.9556e+01, 8.0342e+01]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_expand() {
|
||||
// Python code to generate the test case values
|
||||
// import torch
|
||||
// x1 = torch.tensor([4.0, 7.0, 2.0, 3.0], requires_grad=True)
|
||||
// x2 = torch.tensor([2.0, 4.5, 7.0, 3.0], requires_grad=True)
|
||||
// y = x1.expand(4, 4)
|
||||
// z = (x2 * y).sum()
|
||||
// z.backward()
|
||||
// print("x1", x1.grad)
|
||||
// print("x2", x2.grad)
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let data_1 = TensorData::from([4.0, 7.0, 2.0, 3.0]);
|
||||
let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad();
|
||||
|
||||
let data_2 = TensorData::from([2.0, 4.5, 7.0, 3.0]);
|
||||
let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().expand([4, 4]);
|
||||
|
||||
// Use unsqueeze to make tensor_2 have the same shape as tensor_3
|
||||
let tensor_4 = tensor_2.clone().unsqueeze().mul(tensor_3).sum();
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([8., 18., 28., 12.]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([16., 28., 8., 12.]), false);
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_diff_flip() {
|
||||
let data_1 = TensorData::from([[[1.0, 7.0], [2.0, 3.0]]]); // 1x2x2
|
||||
let data_2 = TensorData::from([[[3.0, 2.0, 7.0], [3.0, 3.2, 1.0]]]); // 1x2x3
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<3>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_2.clone().flip([1, 2]);
|
||||
let tensor_4 = tensor_1.clone().matmul(tensor_3);
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
|
||||
grad_1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), tolerance); // 1x2x2
|
||||
grad_2.into_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[[10.0, 10.0, 10.0], [3.0, 3.0, 3.0]]]),
|
||||
tolerance,
|
||||
); // 1x2x3
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_floor() {
|
||||
let data = TensorData::from([
|
||||
[-1.9751, 0.0714, 0.0643, 0.2406],
|
||||
[-1.3172, 0.1252, -0.1119, -0.0127],
|
||||
]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
|
||||
let tensor_2 = tensor_1.clone().floor();
|
||||
let grads = tensor_2.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_eq(
|
||||
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
use super::*;
|
||||
use burn_tensor::{IndexingUpdateOp, Int, Tensor, TensorData};
|
||||
|
||||
#[test]
|
||||
fn test_gather_grad() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(
|
||||
TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(
|
||||
TensorData::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
|
||||
let tensor_3 = tensor_1.clone().gather(1, indices);
|
||||
let tensor_4 = tensor_2.matmul(tensor_3);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_eq(
|
||||
&TensorData::from([[94., 150., 187.], [242., 305., 304.]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_grad() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(
|
||||
TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let values = TestAutodiffTensor::from_data(
|
||||
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(
|
||||
TensorData::from([[2, 1, 0], [2, 0, 1]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
|
||||
let tensor_3 = tensor_1
|
||||
.clone()
|
||||
.scatter(1, indices, values.clone(), IndexingUpdateOp::Add);
|
||||
let tensor_4 = tensor_2.matmul(tensor_3);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = values.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_eq(
|
||||
&TensorData::from([[127., 181., 235.], [226., 316., 406.]]),
|
||||
false,
|
||||
);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[19., 19., 19.], [64., 64., 64.]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_add_grad_partial_indices() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::from_data(TensorData::from([[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]]), &device)
|
||||
.require_grad();
|
||||
let tensor_2 =
|
||||
TestAutodiffTensor::from_data(TensorData::from([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]), &device)
|
||||
.require_grad();
|
||||
let values =
|
||||
TestAutodiffTensor::from_data(TensorData::from([[4.0, 5.0, 6.0]]), &device).require_grad();
|
||||
let indices =
|
||||
Tensor::<TestAutodiffBackend, 2, Int>::from_data(TensorData::from([[2, 1, 0]]), &device);
|
||||
|
||||
let tensor_3 = tensor_1.clone().mul(tensor_2);
|
||||
let tensor_4 = tensor_3
|
||||
.clone()
|
||||
.scatter(1, indices, values.clone(), IndexingUpdateOp::Add);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = values.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[1., 2., 3., 4., 5., 6.]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[1., 1., 1.]]), false);
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance, activation};
|
||||
|
||||
#[test]
|
||||
fn should_diff_gelu() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::<2>::from_floats([[0.0, 1.0], [-3.0, 4.0]], &device).require_grad();
|
||||
let tensor_2 =
|
||||
TestAutodiffTensor::from_floats([[6.0, -0.5], [9.0, 10.0]], &device).require_grad();
|
||||
|
||||
let x = tensor_1.clone().matmul(activation::gelu(tensor_2.clone()));
|
||||
let x = tensor_1.clone().matmul(x);
|
||||
let grads = x.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::permissive();
|
||||
let expected = TensorData::from([[1.46281, 1.46281], [48.22866, 153.46280]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[-15.0000, -1.98757], [17.0000, 17.0000]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Distribution, activation};
|
||||
|
||||
#[test]
|
||||
fn should_update_tensor_when_grad_replace() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::<2>::random([32, 32], Distribution::Default, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::random([32, 32], Distribution::Default, &device);
|
||||
|
||||
let x = tensor_1.clone().matmul(activation::gelu(tensor_2));
|
||||
let mut grads = x.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
let grad_1_updated =
|
||||
TestAutodiffTensor::random([32, 32], Distribution::Default, &device).require_grad();
|
||||
tensor_1.grad_replace(&mut grads, grad_1_updated.clone().inner());
|
||||
|
||||
let grad_1_new = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
assert_ne!(grad_1_new.to_data(), grad_1.into_data());
|
||||
assert_eq!(grad_1_new.into_data(), grad_1_updated.into_data());
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_diff_log() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
|
||||
let expected = TensorData::from([[60.2652, 72.3130], [60.2652, 72.3130]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[22.8614, 24.5043], [24.5729, 26.8507]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_diff_log1p() {
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from([[0.0, 1.0], [3.0, 4.0]]).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from([[6.0, 7.0], [9.0, 10.0]]).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log1p());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
|
||||
let expected = TensorData::from([[64.80622101, 75.49362183], [64.80622101, 75.49362183]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[22.92208481, 24.47565651], [24.72780228, 26.86416626]]);
|
||||
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn should_diff_log_sigmoid() {
|
||||
let data = TensorData::from([[0.8762, -0.1423], [-300., 200.]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
|
||||
let tensor_2 = activation::log_sigmoid(tensor_1.clone());
|
||||
let grads = tensor_2.backward();
|
||||
|
||||
let grad = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[0.293966, 0.535515], [1.000000, 0.000000]]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{Bool, Tensor, TensorData};
|
||||
|
||||
#[test]
|
||||
fn should_diff_mask_fill() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let mask = TensorData::from([[true, false], [false, true]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let mask = Tensor::<TestAutodiffBackend, 2, Bool>::from_bool(mask, &device);
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.mask_fill(mask, 2.0);
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[7.0, 3.0], [4.0, 2.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[2.0, 1.0], [3.0, 7.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_mask_where() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::from_data([[1.0, 7.0], [2.0, 3.0]], &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data([[4.0, 7.0], [2.0, 3.0]], &device).require_grad();
|
||||
let tensor_3 =
|
||||
TestAutodiffTensor::from_data([[8.8, 9.8], [10.8, 11.8]], &device).require_grad();
|
||||
let mask =
|
||||
Tensor::<TestAutodiffBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
|
||||
|
||||
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_5 = tensor_4.clone().matmul(tensor_3.clone());
|
||||
let tensor_6 = tensor_5.mask_where(mask, tensor_3.clone());
|
||||
let grads = tensor_6.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
let grad_3 = tensor_3.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
|
||||
let expected = TensorData::from([[121.8, 55.0], [110.8, 50.0]]);
|
||||
grad_1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[27.4, 33.4], [95.0, 115.0]]);
|
||||
grad_2
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[15., 18.], [23., 29.]]);
|
||||
grad_3
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_matmul() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[3.0, 3.0], [10.0, 10.0]]), false);
|
||||
tensor_3
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[18.0, 28.0], [14.0, 23.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_complex_1() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_5 = tensor_4.matmul(tensor_3);
|
||||
|
||||
let grads = tensor_5.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[44.0, 20.0], [44.0, 20.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[56.0, 56.0], [16.0, 16.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_complex_2() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_5 = tensor_4.matmul(tensor_3.clone());
|
||||
let tensor_6 = tensor_1.clone().matmul(tensor_5);
|
||||
|
||||
let grads = tensor_6.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[800.0, 792.0], [360.0, 592.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[264., 264.0], [344.0, 344.0]]), false);
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_diff_max_dim() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();
|
||||
let tensor_2 =
|
||||
TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_1.clone().mul(tensor_3.max_dim(1).unsqueeze());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[50.0, 34.0], [40.0, -10.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[8.0, 10.0], [56.0, 15.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_min_dim() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();
|
||||
let tensor_2 =
|
||||
TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_1.clone().mul(tensor_3.min_dim(1).unsqueeze());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[-42.0, 38.0], [-34.0, -24.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[10.0, 8.0], [15.0, 56.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_min_dim_3d_dim1() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::<3>::from_floats([[[1.0, 7.0], [-2.0, -3.0]]], &device).require_grad();
|
||||
let tensor_2 =
|
||||
TestAutodiffTensor::<3>::from_floats([[[4., -7.], [2., 3.]]], &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().mul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.min_dim(1);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[[0., -7.], [2., 0.]]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[[0., 7.], [-2., -0.]]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::max_pool1d;
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_simple() {
|
||||
let kernel_size = 4;
|
||||
let padding = 0;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_floats(
|
||||
[[[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_grad_expected =
|
||||
TestAutodiffTensor::<3>::from_floats([[[1., 1., 0., 0., 0., 1.]]], &device);
|
||||
|
||||
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_with_dilation() {
|
||||
let kernel_size = 4;
|
||||
let padding = 0;
|
||||
let stride = 1;
|
||||
let dilation = 2;
|
||||
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_floats(
|
||||
[[[
|
||||
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
|
||||
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
|
||||
0.4610, 0.5365, 0.6880,
|
||||
]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_grad_expected = TestAutodiffTensor::<3>::from_floats(
|
||||
[[[
|
||||
0., 0., 1., 0., 0., 3., 0., 1., 2., 1., 0., 0., 2., 0., 0., 0., 4., 4., 0., 0., 0., 0.,
|
||||
0., 0., 1.,
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_complex() {
|
||||
let kernel_size = 4;
|
||||
let padding = 0;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_floats(
|
||||
[[[
|
||||
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
|
||||
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
|
||||
0.4610, 0.5365, 0.6880,
|
||||
]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_grad_expected = TestAutodiffTensor::<3>::from_floats(
|
||||
[[[
|
||||
0., 0., 0., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0.,
|
||||
1., 1., 1.,
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_complex_with_padding() {
|
||||
let kernel_size = 4;
|
||||
let padding = 2;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_floats(
|
||||
[[[
|
||||
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
|
||||
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
|
||||
0.4610, 0.5365, 0.6880,
|
||||
]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_grad_expected = TestAutodiffTensor::<3>::from_floats(
|
||||
[[[
|
||||
1., 0., 1., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0.,
|
||||
1., 1., 3.,
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::max_pool2d;
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_simple_1() {
|
||||
let kernel_size_1 = 3;
|
||||
let kernel_size_2 = 3;
|
||||
let padding_1 = 0;
|
||||
let padding_2 = 0;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_floats(
|
||||
[[[
|
||||
[0.2479, 0.6386, 0.3166, 0.5742],
|
||||
[0.7065, 0.1940, 0.6305, 0.8959],
|
||||
[0.5416, 0.8602, 0.8129, 0.1662],
|
||||
[0.3358, 0.3059, 0.8293, 0.0990],
|
||||
]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
|
||||
[[[
|
||||
[0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 2.0],
|
||||
[0.0, 2.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0],
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = max_pool2d(
|
||||
x.clone(),
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
false,
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_simple_2() {
|
||||
let kernel_size_1 = 2;
|
||||
let kernel_size_2 = 2;
|
||||
let padding_1 = 1;
|
||||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_floats(
|
||||
[[[
|
||||
[0.2479, 0.6386, 0.3166, 0.5742],
|
||||
[0.7065, 0.1940, 0.6305, 0.8959],
|
||||
[0.5416, 0.8602, 0.8129, 0.1662],
|
||||
[0.3358, 0.3059, 0.8293, 0.0990],
|
||||
]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
|
||||
[[[
|
||||
[1., 3., 0., 2.],
|
||||
[3., 0., 0., 4.],
|
||||
[1., 4., 0., 1.],
|
||||
[2., 0., 3., 1.],
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = max_pool2d(
|
||||
x.clone(),
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
false,
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_with_dilation() {
|
||||
let kernel_size_1 = 2;
|
||||
let kernel_size_2 = 2;
|
||||
let padding_1 = 1;
|
||||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
let dilation_1 = 2;
|
||||
let dilation_2 = 2;
|
||||
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_floats(
|
||||
[[[
|
||||
[0.2479, 0.6386, 0.3166, 0.5742],
|
||||
[0.7065, 0.1940, 0.6305, 0.8959],
|
||||
[0.5416, 0.8602, 0.8129, 0.1662],
|
||||
[0.3358, 0.3059, 0.8293, 0.0990],
|
||||
]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
|
||||
[[[
|
||||
[0., 0., 0., 0.],
|
||||
[1., 1., 1., 2.],
|
||||
[0., 4., 4., 0.],
|
||||
[0., 1., 2., 0.],
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = max_pool2d(
|
||||
x.clone(),
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
false,
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_complex() {
|
||||
let kernel_size_1 = 4;
|
||||
let kernel_size_2 = 2;
|
||||
let padding_1 = 2;
|
||||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 2;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_floats(
|
||||
[[[
|
||||
[0.5388, 0.0676, 0.7122, 0.8316, 0.0653],
|
||||
[0.9154, 0.1536, 0.9089, 0.8016, 0.7518],
|
||||
[0.2073, 0.0501, 0.8811, 0.5604, 0.5075],
|
||||
[0.4384, 0.9963, 0.9698, 0.4988, 0.2609],
|
||||
[0.3391, 0.2230, 0.4610, 0.5365, 0.6880],
|
||||
]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
|
||||
[[[
|
||||
[0., 0., 0., 3., 0.],
|
||||
[4., 0., 2., 1., 0.],
|
||||
[0., 0., 0., 0., 0.],
|
||||
[2., 4., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 2.],
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = max_pool2d(
|
||||
x.clone(),
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
false,
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_ceil_mode() {
|
||||
// Test ceil_mode=true with gradient computation
|
||||
// Using 1x1x6x6 input with kernel 3x3, stride 2x2, padding 0
|
||||
// Floor mode: output 2x2
|
||||
// Ceil mode: output 3x3
|
||||
let kernel_size_1 = 3;
|
||||
let kernel_size_2 = 3;
|
||||
let padding_1 = 0;
|
||||
let padding_2 = 0;
|
||||
let stride_1 = 2;
|
||||
let stride_2 = 2;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let device = Default::default();
|
||||
// Input (values 1-36):
|
||||
let x = TestAutodiffTensor::from_floats(
|
||||
[[[
|
||||
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
|
||||
[7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
|
||||
[13.0, 14.0, 15.0, 16.0, 17.0, 18.0],
|
||||
[19.0, 20.0, 21.0, 22.0, 23.0, 24.0],
|
||||
[25.0, 26.0, 27.0, 28.0, 29.0, 30.0],
|
||||
[31.0, 32.0, 33.0, 34.0, 35.0, 36.0],
|
||||
]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
// Expected gradients for ceil_mode output 3x3:
|
||||
// Output positions and their max value positions:
|
||||
// (0,0): max at (2,2)=15 -> grad[2,2] += 1
|
||||
// (0,1): max at (2,4)=17 -> grad[2,4] += 1
|
||||
// (0,2): max at (2,5)=18 -> grad[2,5] += 1
|
||||
// (1,0): max at (4,2)=27 -> grad[4,2] += 1
|
||||
// (1,1): max at (4,4)=29 -> grad[4,4] += 1
|
||||
// (1,2): max at (4,5)=30 -> grad[4,5] += 1
|
||||
// (2,0): max at (5,2)=33 -> grad[5,2] += 1
|
||||
// (2,1): max at (5,4)=35 -> grad[5,4] += 1
|
||||
// (2,2): max at (5,5)=36 -> grad[5,5] += 1
|
||||
let x_grad_expected = TestAutodiffTensor::<4>::from_floats(
|
||||
[[[
|
||||
[0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 1., 0., 1., 1.],
|
||||
[0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 1., 0., 1., 1.],
|
||||
[0., 0., 1., 0., 1., 1.],
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = max_pool2d(
|
||||
x.clone(),
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
true,
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,290 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Tensor, TensorData};
|
||||
|
||||
#[test]
|
||||
fn test_mm_independent_trees() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
// First tree
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_0 * tensor_1;
|
||||
let tensor_5 = tensor_2 * tensor_3;
|
||||
let tensor_6 = tensor_4 * tensor_5;
|
||||
|
||||
// Second tree
|
||||
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_11 = tensor_7.clone() * tensor_8.clone();
|
||||
let tensor_12 = tensor_9.clone() * tensor_10.clone();
|
||||
let tensor_13 = tensor_11 * tensor_12;
|
||||
|
||||
let _grads = tensor_6.backward();
|
||||
let grads = tensor_13.backward();
|
||||
|
||||
assert!(tensor_7.grad(&grads).is_some());
|
||||
assert!(tensor_8.grad(&grads).is_some());
|
||||
assert!(tensor_9.grad(&grads).is_some());
|
||||
assert!(tensor_10.grad(&grads).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mm_crossover_trees_root_unavailable() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
// First tree
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_0 * tensor_1;
|
||||
let tensor_5 = tensor_2 * tensor_3;
|
||||
let tensor_6 = tensor_4.clone() * tensor_5;
|
||||
|
||||
// Second tree
|
||||
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_9 = tensor_7.clone() * tensor_8.clone();
|
||||
let tensor_10 = tensor_4 * tensor_9;
|
||||
|
||||
let _grads = tensor_6.backward();
|
||||
let _grads = tensor_10.backward();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mm_crossover_trees_with_referred_subtree() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
// First tree
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_0 * tensor_1;
|
||||
let tensor_5 = tensor_2 * tensor_3;
|
||||
let tensor_6 = tensor_4.clone() * tensor_5;
|
||||
|
||||
// Second tree
|
||||
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_9 = tensor_7.clone() * tensor_8.clone();
|
||||
let _tensor_10 = tensor_4 * tensor_9.clone();
|
||||
|
||||
let _grads = tensor_6.backward();
|
||||
let _grads = tensor_9.backward();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mm_three_crossover_trees_last_still_usable() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
// First tree
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_0 * tensor_1;
|
||||
let tensor_5 = tensor_2 * tensor_3;
|
||||
let tensor_6 = tensor_4 * tensor_5.clone();
|
||||
|
||||
// Third tree
|
||||
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_11 = tensor_7 * tensor_8;
|
||||
let tensor_12 = tensor_9 * tensor_10;
|
||||
let tensor_13 = tensor_11 * tensor_12.clone();
|
||||
|
||||
// Second tree (in between)
|
||||
let _tensor_14 = tensor_5 * tensor_12;
|
||||
|
||||
let _grads = tensor_6.backward();
|
||||
let _grads = tensor_13.backward();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mm_three_crossover_trees_middle_one_unavailable() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
// First tree
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_0 * tensor_1;
|
||||
let tensor_5 = tensor_2 * tensor_3;
|
||||
let tensor_6 = tensor_4 * tensor_5.clone();
|
||||
|
||||
// Third tree
|
||||
let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_11 = tensor_7 * tensor_8;
|
||||
let tensor_12 = tensor_9 * tensor_10;
|
||||
let _tensor_13 = tensor_11 * tensor_12.clone();
|
||||
|
||||
// Second tree (in between)
|
||||
let tensor_14 = tensor_5 * tensor_12;
|
||||
|
||||
let _grads = tensor_6.backward();
|
||||
let _grads = tensor_14.backward();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mm_self_referencing_tree() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
// First tree
|
||||
let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_0 * tensor_1;
|
||||
let tensor_5 = tensor_2 * tensor_3.clone();
|
||||
let tensor_6 = tensor_3 * tensor_5;
|
||||
|
||||
let _grads = tensor_6.backward();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mm_with_non_impacting_detach() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_3 = Tensor::<TestAutodiffBackend, 2>::from_data(data, &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_1.clone() * tensor_2.clone();
|
||||
let tensor_5 = tensor_4.detach() * tensor_3.clone();
|
||||
|
||||
let grads = tensor_5.backward();
|
||||
assert!(tensor_3.grad(&grads).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mm_with_missing_require_grad_after_cleanup() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
let tensor_1 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device);
|
||||
let tensor_3 = Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device);
|
||||
|
||||
let tensor_4 = tensor_1.clone() * tensor_2.clone();
|
||||
let tensor_5 = tensor_4 * tensor_3.clone();
|
||||
|
||||
// Trivial backward, just to trigger cleanup
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data, &device)
|
||||
.require_grad()
|
||||
.backward();
|
||||
|
||||
let grads = tensor_5.backward();
|
||||
assert!(tensor_1.grad(&grads).is_some());
|
||||
assert!(tensor_2.grad(&grads).is_none());
|
||||
assert!(tensor_3.grad(&grads).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mm_with_detach_after_cleanup() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
let tensor_1 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_3 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_1.clone() * tensor_2.clone();
|
||||
let tensor_5 = tensor_4 * tensor_3.clone().detach();
|
||||
|
||||
// Trivial backward, just to trigger cleanup
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data, &device)
|
||||
.require_grad()
|
||||
.backward();
|
||||
|
||||
let grads = tensor_5.backward();
|
||||
assert!(tensor_1.grad(&grads).is_some());
|
||||
assert!(tensor_2.grad(&grads).is_some());
|
||||
assert!(tensor_3.grad(&grads).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mm_deletables_propagate_well() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
let tensor_0 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_1 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_2 = tensor_0 * tensor_1;
|
||||
let tensor_3 = tensor_2.clone().exp();
|
||||
let _tensor_4 = tensor_3.clone().log();
|
||||
|
||||
let _grads = tensor_2.backward();
|
||||
|
||||
// We are testing that after backward on tensor_2, not only the leaf tensor_4 is deleted, but
|
||||
// the intermediate tensor_3 as well
|
||||
let _grads = tensor_3.backward();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mm_node_explored_once_can_still_be_tagged_as_useful_when_found_again_deeper() {
|
||||
let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
// The test has 50% chance of starting with leaf tensor_8 instead of tensor_4, which is not informative
|
||||
// By repeating it many times it becomes almost impossible that it passes if it shouldn't
|
||||
for _ in 0..12 {
|
||||
let tensor_0 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_1 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_2 = tensor_1.clone().exp();
|
||||
let tensor_3 = tensor_0.exp();
|
||||
let _tensor_4 = tensor_3.clone() * tensor_2.clone();
|
||||
let tensor_5 = tensor_2.exp();
|
||||
let tensor_6 = tensor_5.exp();
|
||||
let tensor_7 = tensor_6.exp();
|
||||
let tensor_8 = tensor_7.exp();
|
||||
|
||||
// tensor_2 should be tagged unknown through the leaf tensor_4, then useful through the leaf tensor_8
|
||||
// which should happen after because tensor_2 is deeper from tensor_8 point of view and we're in breadth first search
|
||||
tensor_3.backward();
|
||||
let grads = tensor_8.backward();
|
||||
|
||||
assert!(tensor_1.grad(&grads).is_some());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
#[allow(unused_imports)] // required for re-included modules
|
||||
pub use super::*;
|
||||
|
||||
mod abs;
|
||||
mod adaptive_avgpool1d;
|
||||
mod adaptive_avgpool2d;
|
||||
mod add;
|
||||
mod aggregation;
|
||||
mod avgpool1d;
|
||||
mod avgpool2d;
|
||||
mod backward;
|
||||
mod bridge;
|
||||
mod broadcast;
|
||||
mod cast;
|
||||
mod cat;
|
||||
mod ceil;
|
||||
mod checkpoint;
|
||||
mod complex;
|
||||
mod conv1d;
|
||||
mod conv2d;
|
||||
mod conv3d;
|
||||
mod conv_transpose1d;
|
||||
mod conv_transpose2d;
|
||||
mod conv_transpose3d;
|
||||
mod cross;
|
||||
mod cross_entropy;
|
||||
mod cummax;
|
||||
mod cummin;
|
||||
mod cumprod;
|
||||
mod cumsum;
|
||||
mod deform_conv2d;
|
||||
mod div;
|
||||
mod erf;
|
||||
mod exp;
|
||||
mod expand;
|
||||
mod flip;
|
||||
mod floor;
|
||||
mod gather_scatter;
|
||||
mod gelu;
|
||||
mod gradients;
|
||||
mod log;
|
||||
mod log1p;
|
||||
mod log_sigmoid;
|
||||
mod mask;
|
||||
mod matmul;
|
||||
mod maxmin;
|
||||
mod maxpool1d;
|
||||
mod maxpool2d;
|
||||
mod memory_management;
|
||||
mod mul;
|
||||
mod multithread;
|
||||
mod nearest_interpolate;
|
||||
mod neg;
|
||||
mod nonzero;
|
||||
mod permute;
|
||||
mod pow;
|
||||
mod recip;
|
||||
mod relu;
|
||||
mod remainder;
|
||||
mod repeat_dim;
|
||||
mod reshape;
|
||||
mod round;
|
||||
mod select;
|
||||
mod sigmoid;
|
||||
mod sign;
|
||||
mod slice;
|
||||
mod slice_assign;
|
||||
mod softmax;
|
||||
mod sort;
|
||||
mod sqrt;
|
||||
mod sub;
|
||||
mod transpose;
|
||||
mod trig;
|
||||
mod unfold;
|
||||
@@ -0,0 +1,68 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_mul() {
|
||||
let data_1 = TensorData::from([1.0, 7.0]);
|
||||
let data_2 = TensorData::from([4.0, 7.0]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1.clone(), &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().mul(tensor_2.clone());
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let _grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_eq(&data_2, false);
|
||||
tensor_3
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([4.0, 49.0]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_mul_scalar() {
|
||||
let data = TensorData::from([2.0, 5.0]);
|
||||
|
||||
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
|
||||
let tensor_out = tensor.clone().mul_scalar(4.0);
|
||||
|
||||
let grads = tensor_out.backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
tensor_out
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([8.0, 20.0]), false);
|
||||
grad.to_data()
|
||||
.assert_eq(&TensorData::from([4.0, 4.0]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mul_complex_1() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_1.clone().mul(tensor_2.clone());
|
||||
let tensor_5 = tensor_4.mul(tensor_3);
|
||||
let tensor_6 = tensor_1.clone().mul(tensor_5);
|
||||
|
||||
let grads = tensor_6.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[16.0, 196.0], [104.0, -36.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[2.0, 98.0], [338.0, 18.0]]), false);
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn should_behave_the_same_with_multithread() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
|
||||
let with_move = || {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1.clone(), &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.clone().matmul(tensor_2.clone());
|
||||
let tensor_5 = tensor_4.matmul(tensor_3);
|
||||
|
||||
// Task 1
|
||||
let tensor_1_cloned = tensor_1.clone();
|
||||
let tensor_2_cloned = tensor_2.clone();
|
||||
let tensor_5_cloned = tensor_5.clone();
|
||||
|
||||
let first_call = move || {
|
||||
let tensor_6_1 = tensor_5_cloned.matmul(tensor_2_cloned);
|
||||
tensor_6_1.matmul(tensor_1_cloned)
|
||||
};
|
||||
|
||||
// Task 2
|
||||
let tensor_1_cloned = tensor_1.clone();
|
||||
let tensor_2_cloned = tensor_2.clone();
|
||||
let tensor_5_cloned = tensor_5;
|
||||
|
||||
let second_call = move || {
|
||||
let tensor_6_2 = tensor_5_cloned.matmul(tensor_1_cloned);
|
||||
tensor_6_2.matmul(tensor_2_cloned)
|
||||
};
|
||||
|
||||
let tensor_7_1_handle = std::thread::spawn(first_call);
|
||||
let tensor_7_2_handle = std::thread::spawn(second_call);
|
||||
|
||||
let tensor_7_1 = tensor_7_1_handle.join().unwrap();
|
||||
let tensor_7_2 = tensor_7_2_handle.join().unwrap();
|
||||
let tensor_8 = tensor_7_1.matmul(tensor_7_2);
|
||||
|
||||
let grads = tensor_8.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
(grad_1, grad_2)
|
||||
};
|
||||
let without_move = || {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1.clone(), &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.clone().matmul(tensor_2.clone());
|
||||
let tensor_5 = tensor_4.matmul(tensor_3);
|
||||
|
||||
// Task 1
|
||||
let tensor_6_1 = tensor_5.clone().matmul(tensor_2.clone());
|
||||
let tensor_7_1 = tensor_6_1.matmul(tensor_1.clone());
|
||||
|
||||
// Task 2
|
||||
let tensor_6_2 = tensor_5.matmul(tensor_1.clone());
|
||||
let tensor_7_2 = tensor_6_2.matmul(tensor_2.clone());
|
||||
|
||||
let tensor_8 = tensor_7_1.matmul(tensor_7_2);
|
||||
|
||||
let grads = tensor_8.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
(grad_1, grad_2)
|
||||
};
|
||||
|
||||
let (grad_1, grad_2) = without_move();
|
||||
let (grad_1_moved, grad_2_moved) = with_move();
|
||||
|
||||
grad_1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&grad_1_moved.into_data(), Tolerance::default());
|
||||
grad_2
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&grad_2_moved.into_data(), Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
use super::*;
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::interpolate;
|
||||
use burn_tensor::ops::{InterpolateMode, InterpolateOptions};
|
||||
|
||||
#[test]
|
||||
fn test_upsample_interpolation() {
|
||||
let test = InterpolateTestCase {
|
||||
batch_size: 2,
|
||||
channels: 1,
|
||||
height: 7,
|
||||
width: 5,
|
||||
height_out: 8,
|
||||
width_out: 7,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([
|
||||
[[
|
||||
[4., 2., 4., 2., 2.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
]],
|
||||
[[
|
||||
[4., 2., 4., 2., 2.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
[2., 1., 2., 1., 1.],
|
||||
]],
|
||||
]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_downsample_interpolation() {
|
||||
let test = InterpolateTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
height: 8,
|
||||
width: 8,
|
||||
height_out: 4,
|
||||
width_out: 6,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[[
|
||||
[1., 1., 1., 0., 1., 1., 1., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 1., 1., 1., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 1., 1., 1., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 1., 1., 1., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0., 0.],
|
||||
]]]));
|
||||
}
|
||||
|
||||
struct InterpolateTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
height_out: usize,
|
||||
width_out: usize,
|
||||
}
|
||||
|
||||
impl InterpolateTestCase {
|
||||
fn assert_output(self, x_grad: TestTensor<4>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &x_grad.device())
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let output = interpolate(
|
||||
x.clone(),
|
||||
[self.height_out, self.width_out],
|
||||
InterpolateOptions::new(InterpolateMode::Nearest),
|
||||
);
|
||||
|
||||
let grads = output.backward();
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
|
||||
x_grad
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&x_grad_actual.into_data(), Tolerance::permissive());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_neg() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().neg());
|
||||
let tensor_4 = tensor_3.neg();
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[3.0, 3.0], [10.0, 10.0]]), false);
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
use super::*;
|
||||
use burn_tensor::{Bool, Tensor, TensorData};
|
||||
|
||||
#[test]
|
||||
fn should_diff_nonzero() {
|
||||
let data_1 = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([-1.0, 1.0]);
|
||||
let mask = TensorData::from([[false, true], [true, false]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad();
|
||||
|
||||
// Multi-dimensional tensor indexing isn't really supported yet so the easiest way to do
|
||||
// this is to flatten the mask and tensor to get proper indexing. Anyway the returned tensor would
|
||||
// have dimensions different from the input, so this is somewhat equivalent.
|
||||
let mask = Tensor::<TestAutodiffBackend, 2, Bool>::from_bool(mask, &device).flatten::<1>(0, 1);
|
||||
let indices = mask.nonzero();
|
||||
let tensor_3 = tensor_1
|
||||
.clone()
|
||||
.flatten::<1>(0, 1)
|
||||
.select(0, indices[0].clone());
|
||||
|
||||
// Vector dot product not supported (only 2D matmuls) so unsqueeze for test purposes
|
||||
let tensor_4 = tensor_2
|
||||
.clone()
|
||||
.unsqueeze_dim::<2>(0)
|
||||
.matmul(tensor_3.unsqueeze_dim(1));
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[0.0, -1.0], [1.0, 0.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([2.0, 3.0]), false);
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_diff_permute() {
|
||||
let data_1 = TensorData::from([[[1.0, 7.0], [2.0, 3.0]]]); // 1x2x2
|
||||
let data_2 = TensorData::from([[[1.0, 7.0], [3.2, 2.0], [3.0, 3.0]]]); // 1x3x2
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_2.clone().permute([0, 2, 1]);
|
||||
let tensor_4 = tensor_1.clone().matmul(tensor_3);
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
|
||||
grad_1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), tolerance); // 1x2x2
|
||||
grad_2.into_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[[3.0, 10.0], [3.0, 10.0], [3.0, 10.0]]]),
|
||||
tolerance,
|
||||
); // 1x3x2
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_diff_powf_scalar() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().powf_scalar(0.4));
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default().set_half_precision_relative(2e-3);
|
||||
let expected = TensorData::from([[68.0, 79.0328], [68.0, 79.0328]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[23.5081, 25.2779], [26.0502, 28.6383]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_powf() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<1>::from_data([2.0, 7.0], &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data([4.0, 2.0], &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().powf(tensor_2.clone());
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([32.0, 14.0]);
|
||||
grad_1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([11.09035, 95.34960]);
|
||||
grad_2
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([16.0, 49.0]);
|
||||
tensor_3
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_powf_with_untracked_lhs() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<1>::from_data([2.0, 7.0], &device);
|
||||
let tensor_2 = TestAutodiffTensor::from_data([4.0, 2.0], &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().powf(tensor_2.clone());
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([11.09035, 95.34960]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_powf_with_untracked_rhs() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<1>::from_data([2.0, 7.0], &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data([4.0, 2.0], &device);
|
||||
|
||||
let tensor_3 = tensor_1.clone().powf(tensor_2.clone());
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([32.0, 14.0]);
|
||||
grad_1
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_diff_recip() {
|
||||
let data = TensorData::from([2.0, 5.0, 0.4]);
|
||||
|
||||
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
|
||||
let tensor_out = tensor.clone().recip();
|
||||
|
||||
let grads = tensor_out.backward();
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
tensor_out
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([0.5, 0.2, 2.5]), false);
|
||||
grad.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([-0.25, -0.04, -6.25]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn should_diff_relu() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = activation::relu(tensor_3);
|
||||
let tensor_5 = tensor_4.matmul(tensor_2.clone());
|
||||
let grads = tensor_5.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[-47.0, 9.0], [-35.0, 15.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[15.0, 13.0], [-2.0, 39.0]]), false);
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_diff_remainder() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<1>::from_data(
|
||||
TensorData::from([
|
||||
0.9742, 0.3676, 0.0905, 0.8066, 0.7072, 0.7883, 0.6987, 0.1560, 0.7179, 0.7874, 0.9032,
|
||||
0.1845,
|
||||
]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::<1>::from_data(
|
||||
TensorData::from([
|
||||
0.3357, 0.0285, 0.4115, 0.5511, 0.8637, 0.3593, 0.3885, 0.2569, 0.0936, 0.7172, 0.4792,
|
||||
0.4898,
|
||||
]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let tensor_3 = tensor_1.clone().remainder(tensor_2.clone());
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([
|
||||
-2.0, -12.0, -0.0, -1.0, -0.0, -2.0, -1.0, -0.0, -7.0, -1.0, -1.0, -0.0,
|
||||
]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_repeat() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0], [2.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_2.clone().repeat_dim(1, 3);
|
||||
|
||||
let tensor_3 = tensor_1.matmul(tensor_3);
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[-3.0], [12.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_repeat_multi_dim() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 2.0], [2.0, 4.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_2.clone().repeat_dim(1, 3);
|
||||
|
||||
let tensor_3 = tensor_1.matmul(tensor_3);
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[-3.0, -3.0], [12.0, 12.0]]), false);
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_reshape() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2 = TensorData::from([4.0, 7.0, 2.0, 3.0]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_2.clone().reshape([2, 2]);
|
||||
let tensor_4 = tensor_1.clone().matmul(tensor_3);
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([3.0, 3.0, 10.0, 10.0]), false);
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_round() {
|
||||
let data = TensorData::from([
|
||||
[-1.9751, 0.0714, 0.0643, 0.2406],
|
||||
[-1.3172, 0.1252, -0.1119, -0.0127],
|
||||
]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_2 = tensor_1.clone().round();
|
||||
let grads = tensor_2.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
grad_1.to_data().assert_eq(
|
||||
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
use super::*;
|
||||
use burn_tensor::{IndexingUpdateOp, Int, Tensor, TensorData};
|
||||
|
||||
#[test]
|
||||
fn test_select_grad() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let indices =
|
||||
Tensor::<TestAutodiffBackend, 1, Int>::from_data(TensorData::from([1, 0]), &device);
|
||||
|
||||
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
|
||||
let tensor_3 = tensor_1.clone().select(0, indices);
|
||||
let tensor_4 = tensor_2.matmul(tensor_3);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
grad_1.into_data().assert_eq(
|
||||
&TensorData::from([[109., 148., 187.], [37., 58., 79.]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_add_grad() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(
|
||||
TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let values = TestAutodiffTensor::from_data(
|
||||
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let indices =
|
||||
Tensor::<TestAutodiffBackend, 1, Int>::from_data(TensorData::from([1, 0]), &device);
|
||||
|
||||
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
|
||||
let tensor_3 =
|
||||
tensor_1
|
||||
.clone()
|
||||
.select_assign(0, indices, values.clone(), IndexingUpdateOp::Add);
|
||||
let tensor_4 = tensor_2.matmul(tensor_3);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = values.grad(&grads).unwrap();
|
||||
|
||||
grad_1.into_data().assert_eq(
|
||||
&TensorData::from([[127., 199., 271.], [172., 244., 316.]]),
|
||||
false,
|
||||
);
|
||||
grad_2
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[64., 64., 64.], [19., 19., 19.]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_add_grad_different_shapes() {
|
||||
let device = Default::default();
|
||||
|
||||
let indices: Tensor<TestAutodiffBackend, 1, Int> = Tensor::from_ints([1], &device);
|
||||
let x: Tensor<TestAutodiffBackend, 2> = Tensor::ones([1, 1], &device).require_grad();
|
||||
let y = Tensor::ones([2, 1], &device).require_grad();
|
||||
|
||||
let w = y
|
||||
.clone()
|
||||
.select_assign(0, indices, x.clone(), IndexingUpdateOp::Add);
|
||||
let w = w.matmul(y.clone().transpose());
|
||||
|
||||
let grads = w.backward();
|
||||
let x_grad = x.grad(&grads).unwrap();
|
||||
let y_grad = y.grad(&grads).unwrap();
|
||||
|
||||
x_grad
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[2.0]]), false);
|
||||
y_grad
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[5.0], [5.0]]), false);
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn should_diff_sigmoid() {
|
||||
let data = TensorData::from([0.8762]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<1>::from_data(data, &device).require_grad();
|
||||
let tensor_2 = activation::sigmoid(tensor_1.clone());
|
||||
let grads = tensor_2.backward();
|
||||
|
||||
let grad = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([0.207549]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn small_neg_val_should_not_cause_grad_overflow() {
|
||||
let data = TensorData::from([-90.0]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<1>::from_data(data, &device).require_grad();
|
||||
let tensor_2 = activation::sigmoid(tensor_1.clone());
|
||||
let grads = tensor_2.backward();
|
||||
|
||||
let grad = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([0.0]);
|
||||
grad.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
/// Example using the sign function with PyTorch:
|
||||
// >>> import torch
|
||||
// >>> # Create a tensor with requires_grad=True
|
||||
// >>> x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
|
||||
// >>> # Forward pass: Apply the sign function
|
||||
// >>> y = torch.sign(x)
|
||||
// >>> print("Forward pass:")
|
||||
// Forward pass:
|
||||
// >>> print("x:", x)
|
||||
// x: tensor([-2., -1., 0., 1., 2.], requires_grad=True)
|
||||
// >>> print("y:", y)
|
||||
// y: tensor([-1., -1., 0., 1., 1.], grad_fn=<SignBackward0>)
|
||||
// >>> # Compute the loss (just an example)
|
||||
// >>> loss = y.sum()
|
||||
// >>> # Backward pass: Compute the gradients
|
||||
// >>> loss.backward()
|
||||
// >>> print("\nBackward pass:")
|
||||
// Backward pass:
|
||||
// >>> print("x.grad:", x.grad)
|
||||
// x.grad: tensor([0., 0., 0., 0., 0.])
|
||||
|
||||
#[test]
|
||||
fn should_diff_sign() {
|
||||
let data = TensorData::from([-2.0, -1.0, 0.0, 1.0, 2.0]);
|
||||
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::<1>::from_data(data, &device).require_grad();
|
||||
|
||||
let y = x.clone().sign();
|
||||
|
||||
let loss = y.clone().sum();
|
||||
let grads = loss.backward();
|
||||
let grad = x.grad(&grads).unwrap();
|
||||
|
||||
y.to_data()
|
||||
.assert_eq(&TensorData::from([-1., -1., 0., 1., 1.]), false);
|
||||
grad.to_data()
|
||||
.assert_eq(&TensorData::from([0., 0., 0., 0., 0.]), false);
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_matmul_with_slice() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0, 100.0], [2.0, 3.0, 15.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_2.clone().slice([0..2, 0..2]);
|
||||
let tensor_4 = tensor_1.clone().matmul(tensor_3);
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);
|
||||
grad_2.to_data().assert_eq(
|
||||
&TensorData::from([[3.0, 3.0, 0.0], [10.0, 10.0, 0.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_matmul_with_slice_stepped() {
|
||||
use burn_tensor::s;
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [100.0, 100.0], [2.0, 3.0], [100.0, 100.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 100.0, 7.0, 100.0], [2.0, 100.0, 3.0, 15.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().slice(s![0..;2, 0..2]); // [[1., 7.], [2., 3.]]
|
||||
let tensor_4 = tensor_2.clone().slice(s![0..2, 0..;2]); // [[4., 7.], [2., 3.]]
|
||||
let tensor_5 = tensor_3.clone().matmul(tensor_4);
|
||||
let grads = tensor_5.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_eq(
|
||||
&TensorData::from([[11., 5.], [0., 0.], [11., 5.], [0., 0.]]),
|
||||
false,
|
||||
);
|
||||
grad_2.to_data().assert_eq(
|
||||
&TensorData::from([[3., 0., 3., 0.], [10., 0., 10., 0.]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_panic_on_slice_with_step() {
|
||||
use burn_tensor::s;
|
||||
|
||||
let data = TensorData::from([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]);
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
|
||||
|
||||
// This should panic because step is 2
|
||||
let _sliced = tensor.slice(s![.., 0..4; 2]);
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_diff_matmul_with_slice_assign() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_assigned = TensorData::from([[9.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_assigned = TestAutodiffTensor::from_data(data_assigned, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.slice_assign([0..1, 0..1], tensor_assigned);
|
||||
let tensor_5 = tensor_4.matmul(tensor_1.clone());
|
||||
|
||||
let grads = tensor_5.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[58.0, 38.0], [118.0, 82.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[16.0, 15.0], [24.0, 50.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_matmul_with_slice_assign_complex() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_3 = TensorData::from([[9.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_5 = tensor_2.clone().slice([0..1, 0..1]);
|
||||
let tensor_6 = tensor_5.mul(tensor_3.clone());
|
||||
let tensor_7 = tensor_4.slice_assign([0..1, 0..1], tensor_6);
|
||||
let tensor_8 = tensor_7.matmul(tensor_1.clone());
|
||||
|
||||
let grads = tensor_8.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
let grad_3 = tensor_3.grad(&grads).unwrap();
|
||||
|
||||
grad_3
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[32.0]]), false);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[85.0, 65.0], [118.0, 82.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[88.0, 15.0], [24.0, 50.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slice_assign_diff_should_give_same_results_as_cat() {
|
||||
let data_1 = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[5.0, 6.0], [7.0, 8.0]]);
|
||||
let data_3 = TensorData::from([[14.0, 97.0, 100.0, 9.0], [2.0, 3.0, 15.0, 7.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device);
|
||||
|
||||
let slice_assign_output = TestAutodiffTensor::zeros([2, 4], &Default::default());
|
||||
let slice_assign_output = slice_assign_output.slice_assign([0..2, 0..2], tensor_1.clone());
|
||||
let slice_assign_output = slice_assign_output.slice_assign([0..2, 2..4], tensor_2.clone());
|
||||
let slice_assign_output = slice_assign_output / tensor_3.clone();
|
||||
|
||||
let cat_output = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 1);
|
||||
let cat_output = cat_output / tensor_3;
|
||||
|
||||
slice_assign_output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&cat_output.to_data(), Tolerance::default());
|
||||
|
||||
let slice_assign_grads = slice_assign_output.backward();
|
||||
let cat_grads = cat_output.backward();
|
||||
|
||||
let slice_assign_grad_1 = tensor_1.grad(&slice_assign_grads).unwrap();
|
||||
let slice_assign_grad_2 = tensor_2.grad(&slice_assign_grads).unwrap();
|
||||
let cat_grad_1 = tensor_1.grad(&cat_grads).unwrap();
|
||||
let cat_grad_2 = tensor_2.grad(&cat_grads).unwrap();
|
||||
|
||||
slice_assign_grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&cat_grad_1.to_data(), Tolerance::default());
|
||||
slice_assign_grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&cat_grad_2.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_slice_assign_with_step() {
|
||||
use burn_tensor::s;
|
||||
let data = TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
|
||||
let value_data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
|
||||
let value = TestAutodiffTensor::<2>::from_data(value_data, &device).require_grad();
|
||||
|
||||
// Assign with step=2
|
||||
let result = tensor.clone().slice_assign(s![.., 0..4; 2], value.clone());
|
||||
let result = result * 2.0; // Scale to create gradients
|
||||
let grads = result.backward();
|
||||
|
||||
let grad_tensor = tensor.grad(&grads).unwrap();
|
||||
let grad_value = value.grad(&grads).unwrap();
|
||||
|
||||
// The gradient for tensor should be 2.0 everywhere except the assigned positions
|
||||
grad_tensor.to_data().assert_eq(
|
||||
&TensorData::from([[0.0, 2.0, 0.0, 2.0], [0.0, 2.0, 0.0, 2.0]]),
|
||||
false,
|
||||
);
|
||||
// The gradient for value should be 2.0 at all positions
|
||||
grad_value
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[2.0, 2.0], [2.0, 2.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_slice_assign_with_negative_step() {
|
||||
use burn_tensor::s;
|
||||
|
||||
let data = TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
|
||||
let value_data = TensorData::from([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]);
|
||||
let device = Default::default();
|
||||
let tensor = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
|
||||
let value = TestAutodiffTensor::<2>::from_data(value_data, &device).require_grad();
|
||||
|
||||
// Assign with step=-1 (reverse order, all elements)
|
||||
let result = tensor.clone().slice_assign(s![.., ..;-1], value.clone());
|
||||
let result = result * 2.0; // Scale to create gradients
|
||||
let grads = result.backward();
|
||||
|
||||
let grad_tensor = tensor.grad(&grads).unwrap();
|
||||
let grad_value = value.grad(&grads).unwrap();
|
||||
|
||||
// The gradient for tensor should be 0 since all values were replaced
|
||||
grad_tensor.to_data().assert_eq(
|
||||
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
|
||||
false,
|
||||
);
|
||||
// The gradient for value should be 2.0 at all positions
|
||||
grad_value.to_data().assert_eq(
|
||||
&TensorData::from([[2.0, 2.0, 2.0, 2.0], [2.0, 2.0, 2.0, 2.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{Tensor, TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_softmax_grad() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone());
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[1.179665, 1.179661], [0.005462, 0.005463]]);
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(0.05, 0.5));
|
||||
|
||||
let expected = TensorData::from([[0.253469, 0.286237], [0.528630, 2.931664]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(0.05, 0.05));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_log_softmax_grad() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = activation::log_softmax(tensor_3, 1).matmul(tensor_2.clone());
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[-4.3939, -4.3939], [-12.9709, -12.9709]]);
|
||||
// f16 gradients from log-softmax + matmul amplify error, so we increase the tolerance
|
||||
// to account for limited precision and large representable step sizes in this range.
|
||||
let tolerance = Tolerance::permissive().set_half_precision_relative(6e-2);
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[30.5984, -47.2267], [55.9631, -56.5914]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quiet_softmax_grad() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone());
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[1.179665, 1.179661], [0.005462, 0.005463]]);
|
||||
|
||||
// Precision is quite bad yet on softmax grad especially with half precision.
|
||||
let tolerance = Tolerance::rel_abs(0.5, 0.2);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[0.253469, 0.286237], [0.528630, 2.931664]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_diff_sort() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();
|
||||
let tensor_2 =
|
||||
TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_1.clone().mul(tensor_3.sort(1));
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[35.0, 35.0], [-1.0, -8.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[11.0, 7.0], [55.0, 16.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_sort_with_indices() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();
|
||||
let tensor_2 =
|
||||
TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let (values, _indices) = tensor_3.sort_with_indices(1);
|
||||
let tensor_4 = tensor_1.clone().mul(values);
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[35.0, 35.0], [-1.0, -8.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[11.0, 7.0], [55.0, 16.0]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_sort_3d_dim1() {
|
||||
let device = Default::default();
|
||||
let tensor_1 =
|
||||
TestAutodiffTensor::<3>::from_floats([[[1.0, 7.0], [-2.0, -3.0]]], &device).require_grad();
|
||||
let tensor_2 =
|
||||
TestAutodiffTensor::from_floats([[[4.0, -7.0], [2.0, 3.0]]], &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_1.clone().mul(tensor_3.sort(1));
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let expected = TensorData::from([[[-1., -8.], [-27., 37.]]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[[-4., -17.], [-17., -42.]]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_diff_sqrt() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sqrt());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default().set_half_precision_relative(1e-3);
|
||||
let expected = TensorData::from([[82.112640, 99.083275], [82.112640, 99.083275]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[30.309311, 33.120457], [34.581974, 38.769463]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_diff_sub() {
|
||||
let data_1 = TensorData::from([2.0, 5.0]);
|
||||
let data_2 = TensorData::from([4.0, 1.0]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().sub(tensor_2.clone());
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([1.0, 1.0]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([-1.0, -1.0]), false);
|
||||
|
||||
tensor_3
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([-2.0, 4.0]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_sub_scalar() {
|
||||
let data = TensorData::from([2.0, 10.0]);
|
||||
let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();
|
||||
let tensor_out = tensor.clone().sub_scalar(5.0);
|
||||
let grads = tensor_out.backward();
|
||||
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
|
||||
grad.to_data()
|
||||
.assert_eq(&TensorData::from([1.0, 1.0]), false);
|
||||
tensor_out
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([-3.0, 5.0]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_complex_1() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_1.clone().sub(tensor_2.clone());
|
||||
let tensor_5 = tensor_4.sub(tensor_3).sub_scalar(5.0);
|
||||
let tensor_6 = tensor_1.clone().sub(tensor_5);
|
||||
|
||||
let grads = tensor_6.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[0.0, 0.0], [0.0, 0.0]]), false);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[1.0, 1.0], [1.0, 1.0]]), false);
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_diff_transpose() {
|
||||
let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().transpose());
|
||||
let tensor_4 = tensor_3.transpose();
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[6.0, 10.0], [6.0, 10.0]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
grad_2.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[3.0, 10.0], [3.0, 10.0]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_swap_dims() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<3>::from_floats(
|
||||
[[[0.0, 1.0], [3.0, 4.0]], [[6.0, 7.0], [9.0, 10.0]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_floats(
|
||||
[[[1.0, 4.0], [2.0, 5.0]], [[7.0, 10.0], [8.0, 11.0]]],
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().swap_dims(0, 2));
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone().swap_dims(1, 2));
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[[66., 78.], [66., 78.]], [[270., 306.], [270., 306.]]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
grad_2.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[[22., 286.], [28., 316.]], [[172., 652.], [190., 694.]]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,371 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_diff_cos() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().cos());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
// Metal has less precise trigonometric functions
|
||||
let tolerance = Tolerance::default().set_half_precision_relative(1e-2);
|
||||
|
||||
grad_1.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[26.8063, -27.7870], [26.8063, -27.7870]]),
|
||||
tolerance,
|
||||
);
|
||||
|
||||
grad_2.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[9.222064, -39.123375], [-28.721354, 49.748356]]),
|
||||
tolerance,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_sin() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sin());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
// Metal has less precise trigonometric functions
|
||||
let tolerance = Tolerance::default().set_half_precision_relative(1e-2);
|
||||
|
||||
let expected = TensorData::from([[8.8500, -4.9790], [8.8500, -4.9790]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[38.668987, 44.194775], [-59.97261, -80.46094]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_tanh() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().tanh());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
let tolerance = Tolerance::default().set_half_precision_relative(8e-3);
|
||||
let expected = TensorData::from([[32.0, 32.0], [32.0, 32.0]]);
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TensorData::from([[8.00092, 8.000153], [8.000003, 7.999995]]);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cosh() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().cosh());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[7.092221, 16.696301], [7.092221, 16.696301]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
|
||||
grad_2.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[17.489855, 27.484539], [39.409813, 86.910278]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_sinh() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sinh());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[4.894847, 15.887931], [4.894847, 15.887931]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
|
||||
grad_2.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[17.284000, 28.412029], [39.302979, 87.498329]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_tan() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[0.5, 1.0], [0.3, 0.8]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().tan());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[2.532602, 1.596607], [2.532602, 1.596607]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
|
||||
grad_2.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[9.028598, 14.489801], [18.038082, 21.151270]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_asin() {
|
||||
let data_1 = TensorData::from([[0.0, 0.1], [0.3, 0.4]]);
|
||||
let data_2 = TensorData::from([[0.2, 0.3], [0.5, 0.6]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().asin());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[0.435841, 0.969651], [0.435841, 0.969651]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
|
||||
grad_2.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[0.475300, 0.668141], [0.701834, 1.100658]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_acos() {
|
||||
let data_1 = TensorData::from([[0.0, 0.1], [0.3, 0.4]]);
|
||||
let data_2 = TensorData::from([[0.2, 0.3], [0.5, 0.6]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().acos());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[2.077433, 1.543624], [2.077433, 1.543624]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
|
||||
grad_2.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[0.781337, 0.588496], [0.554804, 0.155979]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_atan() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().atan());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[3.444365, 5.349211], [3.444365, 5.349211]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
|
||||
grad_2.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[9.904911, 11.554912], [10.199631, 11.391938]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_asinh() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().asinh());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[3.806625, 6.844869], [3.806625, 6.844869]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
|
||||
grad_2.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[11.442373, 14.842072], [14.022551, 17.688538]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_acosh() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[1.5, 2.0], [2.5, 3.0]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().acosh());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[10.611752, 15.178907], [10.611752, 15.178907]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
|
||||
grad_2.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[20.112753, 20.247547], [20.402235, 22.487328]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_atanh() {
|
||||
let data_1 = TensorData::from([[0.0, 0.1], [0.3, 0.4]]);
|
||||
let data_2 = TensorData::from([[0.2, 0.3], [0.5, 0.6]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().atanh());
|
||||
let tensor_4 = tensor_3.matmul(tensor_2.clone());
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[0.441838, 1.037115], [0.441838, 1.037115]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
|
||||
grad_2.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[0.491723, 0.698110], [0.772763, 1.298805]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_atan2() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);
|
||||
let data_3 = TensorData::from([[1.0, 0.5], [2.0, 1.5]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();
|
||||
|
||||
let tensor_4 = tensor_1
|
||||
.clone()
|
||||
.matmul(tensor_2.clone().atan2(tensor_3.clone()));
|
||||
let tensor_5 = tensor_4.matmul(tensor_2.clone());
|
||||
let grads = tensor_5.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
let grad_3 = tensor_3.grad(&grads).unwrap();
|
||||
|
||||
grad_1.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[4.570492, 4.210785], [4.570492, 4.210785]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
|
||||
grad_2.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[8.208448, 8.808449], [10.357923, 12.157923]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
|
||||
grad_3.to_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[-1.8, -8.4], [-1.8, -5.6]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn unfold_backward_accumulates_overlaps() {
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0]], &device).require_grad();
|
||||
|
||||
let y = x.clone().unfold::<3, _>(1, 2, 1);
|
||||
let loss = y.sum();
|
||||
|
||||
let grads = loss.backward();
|
||||
let grad_x = x.grad(&grads).unwrap();
|
||||
|
||||
grad_x
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[1., 2., 2., 1.]]), false);
|
||||
}
|
||||
Reference in New Issue
Block a user