feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
pub use super::*; // re-export test types
|
||||
|
||||
mod ops;
|
||||
@@ -0,0 +1,34 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn test_all() {
|
||||
let tensor = TestTensorBool::<2>::from([[false, true, false], [true, true, true]]);
|
||||
let data_actual = tensor.all().into_data();
|
||||
let data_expected = TensorData::from([false]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
|
||||
let tensor = TestTensorBool::<2>::from([[true, true, true], [true, true, true]]);
|
||||
let data_actual = tensor.all().into_data();
|
||||
let data_expected = TensorData::from([true]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_dim() {
|
||||
let tensor = TestTensorBool::<2>::from([[false, true, false], [true, true, true]]);
|
||||
let data_actual = tensor.all_dim(1).into_data();
|
||||
let data_expected = TensorData::from([[false], [true]]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_with_bool_from_lower_equal() {
|
||||
let tensor1 = TestTensor::<2>::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]) + 1e-6;
|
||||
let tensor2 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]) + 1e-6;
|
||||
|
||||
let ge = tensor1.lower_equal(tensor2);
|
||||
let all = ge.clone().all();
|
||||
|
||||
TensorData::from([true]).assert_eq(&all.clone().into_data(), false);
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn test_any() {
|
||||
let tensor = TestTensorBool::<2>::from([[false, false, false], [true, true, false]]);
|
||||
let data_actual = tensor.any().into_data();
|
||||
let data_expected = TensorData::from([true]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
|
||||
let tensor = TestTensorBool::<2>::from([[false, false, false], [false, false, false]]);
|
||||
let data_actual = tensor.any().into_data();
|
||||
let data_expected = TensorData::from([false]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_any_dim() {
|
||||
let tensor = TestTensorBool::<2>::from([[false, false, false], [true, true, false]]);
|
||||
let data_actual = tensor.any_dim(1).into_data();
|
||||
let data_expected = TensorData::from([[false], [true]]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
use super::*;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::{Shape, TensorData};
|
||||
|
||||
#[test]
|
||||
fn test_argwhere_1d() {
|
||||
let tensor = TestTensorBool::<1>::from([false, true, false, true, true]);
|
||||
let output = tensor.argwhere();
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[1], [3], [4]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_argwhere_2d() {
|
||||
let tensor = TestTensorBool::<2>::from([[false, false], [false, true], [true, true]]);
|
||||
let output = tensor.argwhere();
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[1, 1], [2, 0], [2, 1]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_argwhere_3d() {
|
||||
let tensor = TestTensorBool::<3>::from([
|
||||
[[false, false, false], [false, true, false]],
|
||||
[[true, false, true], [true, true, false]],
|
||||
]);
|
||||
let output = tensor.argwhere();
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[0, 1, 1], [1, 0, 0], [1, 0, 2], [1, 1, 0], [1, 1, 1]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonzero_1d() {
|
||||
let tensor = TestTensorBool::<1>::from([false, true, false, true, true]);
|
||||
let data_actual = tensor
|
||||
.nonzero()
|
||||
.into_iter()
|
||||
.map(|t| t.into_data())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(data_actual.len(), 1);
|
||||
data_actual[0].assert_eq(&TensorData::from([1, 3, 4]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonzero_2d() {
|
||||
// 2-D tensor
|
||||
let tensor = TestTensorBool::<2>::from([[false, false], [false, true], [true, true]]);
|
||||
let data_actual = tensor
|
||||
.nonzero()
|
||||
.into_iter()
|
||||
.map(|t| t.into_data())
|
||||
.collect::<Vec<_>>();
|
||||
let data_expected = [TensorData::from([1, 2, 2]), TensorData::from([1, 0, 1])];
|
||||
|
||||
assert_eq!(data_actual.len(), 2);
|
||||
for (idx, actual) in data_actual.iter().enumerate() {
|
||||
actual.assert_eq(&data_expected[idx], false)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonzero_3d() {
|
||||
// 3-D tensor
|
||||
let tensor = TestTensorBool::<3>::from([
|
||||
[[false, false, false], [false, true, false]],
|
||||
[[true, false, true], [true, true, false]],
|
||||
]);
|
||||
let data_actual = tensor
|
||||
.nonzero()
|
||||
.into_iter()
|
||||
.map(|t| t.into_data())
|
||||
.collect::<Vec<_>>();
|
||||
let data_expected = [
|
||||
TensorData::from([0, 1, 1, 1, 1]),
|
||||
TensorData::from([1, 0, 0, 1, 1]),
|
||||
TensorData::from([1, 0, 2, 0, 1]),
|
||||
];
|
||||
|
||||
assert_eq!(data_actual.len(), 3);
|
||||
for (idx, actual) in data_actual.iter().enumerate() {
|
||||
actual.assert_eq(&data_expected[idx], false)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonzero_empty() {
|
||||
let tensor = TestTensorBool::<1>::from([false, false, false, false, false]);
|
||||
let output = tensor.nonzero();
|
||||
|
||||
assert_eq!(output.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_argwhere_empty() {
|
||||
let tensor = TestTensorBool::<1>::from([false, false, false, false, false]);
|
||||
let output = tensor.argwhere();
|
||||
|
||||
assert_eq!(output.shape(), Shape::new([0, 1]));
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_support_cat_ops_bool() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensorBool::<2>::from_data([[false, true, true]], &device);
|
||||
let tensor_2 = TestTensorBool::<2>::from_data([[true, true, false]], &device);
|
||||
|
||||
let output = Tensor::cat(vec![tensor_1, tensor_2], 0);
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[false, true, true], [true, true, false]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_cat_with_empty_tensor_bool() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensorBool::<2>::from_data([[true, false, true]], &device);
|
||||
let tensor_2: TestTensorBool<2> = TestTensorBool::empty([1, 0], &device);
|
||||
|
||||
let output = Tensor::cat(vec![tensor_1, tensor_2], 1);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[true, false, true]]), false);
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_equal() {
|
||||
let data_1 = TensorData::from([[false, true, true], [true, false, true]]);
|
||||
let data_2 = TensorData::from([[false, false, true], [false, true, true]]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensorBool::<2>::from_data(data_1, &device);
|
||||
let tensor_2 = TestTensorBool::<2>::from_data(data_2, &device);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone());
|
||||
let data_actual_inplace = tensor_1.equal(tensor_2);
|
||||
|
||||
let data_expected = TensorData::from([[true, false, true], [false, false, true]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_not_equal() {
|
||||
let data_1 = TensorData::from([[false, true, true], [true, false, true]]);
|
||||
let data_2 = TensorData::from([[false, false, true], [false, true, true]]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensorBool::<2>::from_data(data_1, &device);
|
||||
let tensor_2 = TestTensorBool::<2>::from_data(data_2, &device);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().not_equal(tensor_2.clone());
|
||||
let data_actual_inplace = tensor_1.not_equal(tensor_2);
|
||||
|
||||
let data_expected = TensorData::from([[false, true, false], [true, true, false]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_not() {
|
||||
let data_1 = TensorData::from([[false, true, true], [true, true, false]]);
|
||||
let tensor_1 = TestTensorBool::<2>::from_data(data_1, &Default::default());
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().bool_not();
|
||||
let data_actual_inplace = tensor_1.bool_not();
|
||||
|
||||
let data_expected = TensorData::from([[true, false, false], [false, false, true]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bool_equal_elem() {
|
||||
let tensor_1 = TestTensorBool::<2>::from([[true, false, true], [false, true, false]]);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().equal_elem(false);
|
||||
let data_actual_inplace = tensor_1.equal_elem(false);
|
||||
|
||||
let data_expected = TensorData::from([[false, true, false], [true, false, true]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bool_not_equal_elem() {
|
||||
let tensor_1 = TestTensorBool::<2>::from([[true, false, true], [false, true, false]]);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().not_equal_elem(true);
|
||||
let data_actual_inplace = tensor_1.not_equal_elem(true);
|
||||
|
||||
let data_expected = TensorData::from([[false, true, false], [true, false, true]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_support_zeros_like() {
|
||||
let tensor = TestTensorBool::<3>::from([
|
||||
[[false, true, false], [true, true, true]],
|
||||
[[false, false, false], [true, true, false]],
|
||||
]);
|
||||
|
||||
let tensor = tensor.zeros_like();
|
||||
let expected = TensorData::from([
|
||||
[[false, false, false], [false, false, false]],
|
||||
[[false, false, false], [false, false, false]],
|
||||
]);
|
||||
|
||||
tensor.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_ones_like() {
|
||||
let tensor = TestTensorBool::<3>::from([
|
||||
[[false, true, false], [true, true, true]],
|
||||
[[false, false, false], [true, true, false]],
|
||||
]);
|
||||
|
||||
let tensor = tensor.ones_like();
|
||||
let expected = TensorData::from([
|
||||
[[true, true, true], [true, true, true]],
|
||||
[[true, true, true], [true, true, true]],
|
||||
]);
|
||||
|
||||
tensor.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn expand_2d_bool() {
|
||||
let tensor = TestTensorBool::<1>::from([false, true, false]);
|
||||
let expanded_tensor = tensor.expand([3, 3]);
|
||||
|
||||
let expected_data = TensorData::from([
|
||||
[false, true, false],
|
||||
[false, true, false],
|
||||
[false, true, false],
|
||||
]);
|
||||
|
||||
expanded_tensor.into_data().assert_eq(&expected_data, false);
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn flip_bool() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorInt::<1>::arange(0..24, &device)
|
||||
.reshape([2, 3, 4])
|
||||
.greater_elem(10);
|
||||
|
||||
let flipped = tensor.clone().flip([0, 2]);
|
||||
|
||||
// from pytorch:
|
||||
// import torch; torch.arange(0, 24).reshape(2, 3, 4).flip((0, 2)).gt(10)
|
||||
let data_expected = TensorData::from([
|
||||
[
|
||||
[true, true, true, true],
|
||||
[true, true, true, true],
|
||||
[true, true, true, true],
|
||||
],
|
||||
[
|
||||
[false, false, false, false],
|
||||
[false, false, false, false],
|
||||
[true, false, false, false],
|
||||
],
|
||||
]);
|
||||
|
||||
flipped.into_data().assert_eq(&data_expected, false);
|
||||
|
||||
// Test with no flip
|
||||
let flipped = tensor.clone().flip([]);
|
||||
tensor.into_data().assert_eq(&flipped.into_data(), false);
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn test_tensor_full() {
|
||||
let device = Default::default();
|
||||
let bool_tensor = TestTensorBool::<2>::full([2, 2], true, &device);
|
||||
bool_tensor
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[true, true], [true, true]]), false);
|
||||
|
||||
let bool_tensor = TestTensorBool::<2>::full([2, 2], false, &device);
|
||||
bool_tensor
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[false, false], [false, false]]), false);
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
use super::*;
|
||||
use burn_tensor::{IndexingUpdateOp, TensorData};
|
||||
|
||||
#[test]
|
||||
fn should_scatter_1d_bool() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorBool::<1>::from_data([true, false, false], &device);
|
||||
let values = TestTensorBool::from_data([false, true, true], &device);
|
||||
let indices = TestTensorInt::from_ints([1, 0, 2], &device);
|
||||
|
||||
let output = tensor.scatter(0, indices, values, IndexingUpdateOp::Add);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([true, false, true]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_gather_1d_dim0_bool() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorBool::<1>::from_data([true, false, false], &device);
|
||||
let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2], &device);
|
||||
|
||||
let output = tensor.gather(0, indices);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([false, false, true, false, false]), false);
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_empty() {
|
||||
let shape = [2, 2];
|
||||
let tensor = TestTensorBool::<2>::empty(shape, &Default::default());
|
||||
assert_eq!(tensor.shape(), shape.into())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_zeros() {
|
||||
let shape = [2, 2];
|
||||
let tensor = TestTensorBool::<2>::zeros(shape, &Default::default());
|
||||
assert_eq!(tensor.shape(), shape.into());
|
||||
|
||||
tensor
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[false, false], [false, false]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_ones() {
|
||||
let shape = [2, 2];
|
||||
let tensor = TestTensorBool::<2>::ones(shape, &Default::default());
|
||||
assert_eq!(tensor.shape(), shape.into());
|
||||
|
||||
tensor
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[true, true], [true, true]]), false);
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn test_bool_and() {
|
||||
let tensor1 = TestTensorBool::<2>::from([[false, true, false], [true, false, true]]);
|
||||
let tensor2 = TestTensorBool::<2>::from([[true, true, false], [false, false, true]]);
|
||||
let data_actual = tensor1.bool_and(tensor2).into_data();
|
||||
let data_expected = TensorData::from([[false, true, false], [false, false, true]]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bool_or() {
|
||||
let tensor1 = TestTensorBool::<2>::from([[false, true, false], [true, false, true]]);
|
||||
let tensor2 = TestTensorBool::<2>::from([[true, true, false], [false, false, true]]);
|
||||
let data_actual = tensor1.bool_or(tensor2).into_data();
|
||||
let data_expected = TensorData::from([[true, true, false], [true, false, true]]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bool_xor() {
|
||||
let tensor1 = TestTensorBool::<2>::from([[false, true, false], [true, false, true]]);
|
||||
let tensor2 = TestTensorBool::<2>::from([[true, true, false], [false, false, true]]);
|
||||
let data_actual = tensor1.bool_xor(tensor2).into_data();
|
||||
let data_expected = TensorData::from([[true, false, false], [true, false, false]]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bool_or_vec() {
|
||||
let device = Default::default();
|
||||
let tensor1 = TestTensorBool::<1>::full([256], 0, &device);
|
||||
let tensor2 = TestTensorBool::<1>::full([256], 1, &device);
|
||||
let data_actual = tensor1.bool_or(tensor2).into_data();
|
||||
let data_expected = TensorData::from([true; 256]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bool_and_vec() {
|
||||
let device = Default::default();
|
||||
let tensor1 = TestTensorBool::<1>::full([256], 0, &device);
|
||||
let tensor2 = TestTensorBool::<1>::full([256], 1, &device);
|
||||
let data_actual = tensor1.bool_and(tensor2).into_data();
|
||||
let data_expected = TensorData::from([false; 256]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_mask_where_ops() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorBool::<2>::from_data([[true, false], [false, false]], &device);
|
||||
let mask =
|
||||
TestTensorBool::<2>::from_bool(TensorData::from([[true, false], [false, true]]), &device);
|
||||
let value =
|
||||
TestTensorBool::<2>::from_data(TensorData::from([[false, true], [true, false]]), &device);
|
||||
|
||||
let output = tensor.mask_where(mask, value);
|
||||
let expected = TensorData::from([[false, false], [false, false]]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_mask_fill_ops() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorBool::<2>::from_data([[false, true], [false, false]], &device);
|
||||
let mask =
|
||||
TestTensorBool::<2>::from_bool(TensorData::from([[true, false], [false, true]]), &device);
|
||||
|
||||
let output = tensor.mask_fill(mask, true);
|
||||
let expected = TensorData::from([[true, true], [false, true]]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
pub use super::*; // re-export test types
|
||||
|
||||
mod all;
|
||||
mod any;
|
||||
mod argwhere_nonzero;
|
||||
mod cat;
|
||||
mod comparison;
|
||||
mod create_like;
|
||||
mod expand;
|
||||
mod flip;
|
||||
mod full;
|
||||
mod gather_scatter;
|
||||
mod init;
|
||||
mod logical;
|
||||
mod mask;
|
||||
mod movedim;
|
||||
mod permute;
|
||||
mod repeat;
|
||||
mod repeat_dim;
|
||||
mod reshape;
|
||||
mod select;
|
||||
mod stack;
|
||||
mod take;
|
||||
mod transpose;
|
||||
mod tri_mask;
|
||||
mod unfold;
|
||||
@@ -0,0 +1,56 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn movedim_bool() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorInt::<1>::arange(0..24, &device)
|
||||
.reshape([2, 3, 4])
|
||||
.greater_elem(10);
|
||||
|
||||
let permuted = tensor.clone().movedim(0, 2);
|
||||
// from pytorch:
|
||||
// import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim(0, 2).gt(10)
|
||||
let expected = TensorData::from([
|
||||
[[false, true], [false, true], [false, true], [false, true]],
|
||||
[[false, true], [false, true], [false, true], [false, true]],
|
||||
[[false, true], [false, true], [false, true], [true, true]],
|
||||
]);
|
||||
|
||||
permuted.into_data().assert_eq(&expected, false);
|
||||
|
||||
// Test with negative axis
|
||||
let permuted = tensor.clone().movedim(0, -1);
|
||||
permuted.into_data().assert_eq(&expected, false);
|
||||
|
||||
// Test with the same axis
|
||||
let permuted = tensor.clone().movedim(0, 0);
|
||||
permuted.into_data().assert_eq(&tensor.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_input_bool() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorInt::<1>::arange(0..24, &device)
|
||||
.reshape([2, 3, 4])
|
||||
.greater_elem(10);
|
||||
|
||||
let permuted = tensor.clone().movedim(vec![0, 1], vec![1, 0]);
|
||||
// from pytorch
|
||||
// import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim([0, 1], [1, 0]).gt(10)
|
||||
let expected = TensorData::from([
|
||||
[[false, false, false, false], [true, true, true, true]],
|
||||
[[false, false, false, false], [true, true, true, true]],
|
||||
[[false, false, false, true], [true, true, true, true]],
|
||||
]);
|
||||
|
||||
permuted.into_data().assert_eq(&expected, false);
|
||||
|
||||
// Test with negative axes
|
||||
let permuted = tensor.clone().movedim(vec![-3, -2], vec![-2, -3]);
|
||||
permuted.into_data().assert_eq(&expected, false);
|
||||
|
||||
// Test with the same axes
|
||||
let permuted = tensor.clone().movedim(vec![0, 1], vec![0, 1]);
|
||||
permuted.into_data().assert_eq(&tensor.into_data(), false);
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn permute_bool() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorInt::<1>::arange(0..24, &device)
|
||||
.reshape([2, 3, 4])
|
||||
.greater_elem(10);
|
||||
|
||||
let permuted = tensor.clone().permute([2, 1, 0]);
|
||||
|
||||
// from pytorch:
|
||||
// import torch; torch.arange(0, 24).reshape(2, 3, 4).permute(2, 1, 0).gt(10)
|
||||
let expected = TensorData::from([
|
||||
[[false, true], [false, true], [false, true]],
|
||||
[[false, true], [false, true], [false, true]],
|
||||
[[false, true], [false, true], [false, true]],
|
||||
[[false, true], [false, true], [true, true]],
|
||||
]);
|
||||
|
||||
permuted.into_data().assert_eq(&expected, false);
|
||||
|
||||
// Test with negative axis
|
||||
let permuted = tensor.clone().permute([-1, 1, 0]);
|
||||
permuted.into_data().assert_eq(&expected, false);
|
||||
|
||||
// Test with the same axis
|
||||
let permuted = tensor.clone().permute([0, 1, 2]);
|
||||
permuted.into_data().assert_eq(&tensor.into_data(), false);
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_repeat_ops_one_dimension() {
|
||||
let data = TensorData::from([[true, false, false]]);
|
||||
let tensor = TestTensorBool::<2>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.repeat(&[4, 1, 1]);
|
||||
let expected = TensorData::from([
|
||||
[true, false, false],
|
||||
[true, false, false],
|
||||
[true, false, false],
|
||||
[true, false, false],
|
||||
]);
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_repeat_on_many_dimension() {
|
||||
let data = TensorData::from([
|
||||
[[false, true], [true, false]],
|
||||
[[true, true], [false, false]],
|
||||
]);
|
||||
let tensor = TestTensorBool::<3>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.repeat(&[2, 3, 2]);
|
||||
let expected = TensorData::from([
|
||||
[
|
||||
[false, true, false, true],
|
||||
[true, false, true, false],
|
||||
[false, true, false, true],
|
||||
[true, false, true, false],
|
||||
[false, true, false, true],
|
||||
[true, false, true, false],
|
||||
],
|
||||
[
|
||||
[true, true, true, true],
|
||||
[false, false, false, false],
|
||||
[true, true, true, true],
|
||||
[false, false, false, false],
|
||||
[true, true, true, true],
|
||||
[false, false, false, false],
|
||||
],
|
||||
[
|
||||
[false, true, false, true],
|
||||
[true, false, true, false],
|
||||
[false, true, false, true],
|
||||
[true, false, true, false],
|
||||
[false, true, false, true],
|
||||
[true, false, true, false],
|
||||
],
|
||||
[
|
||||
[true, true, true, true],
|
||||
[false, false, false, false],
|
||||
[true, true, true, true],
|
||||
[false, false, false, false],
|
||||
[true, true, true, true],
|
||||
[false, false, false, false],
|
||||
],
|
||||
]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_repeat_ops() {
|
||||
let data = TensorData::from([[true, false, false]]);
|
||||
let tensor = TestTensorBool::<2>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.repeat_dim(0, 4);
|
||||
let expected = TensorData::from([
|
||||
[true, false, false],
|
||||
[true, false, false],
|
||||
[true, false, false],
|
||||
[true, false, false],
|
||||
]);
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_repeat_on_dims_larger_than_1() {
|
||||
let data = TensorData::from([
|
||||
[[false, true], [true, false]],
|
||||
[[true, true], [false, false]],
|
||||
]);
|
||||
let tensor = TestTensorBool::<3>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.repeat_dim(1, 2);
|
||||
let expected = TensorData::from([
|
||||
[[false, true], [true, false], [false, true], [true, false]],
|
||||
[[true, true], [false, false], [true, true], [false, false]],
|
||||
]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_support_reshape_bool() {
|
||||
let data = TensorData::from([false, true, false]);
|
||||
let tensor = TestTensorBool::<1>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.clone().reshape([1, 3]);
|
||||
let expected = TensorData::from([[false, true, false]]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
@@ -0,0 +1,241 @@
|
||||
use super::*;
|
||||
use burn_tensor::{IndexingUpdateOp, TensorData};
|
||||
|
||||
#[test]
|
||||
fn should_select_bool_tensor_1d() {
|
||||
// Test that select works for boolean tensors
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorBool::<1>::from_data([true, false, true], &device);
|
||||
let indices = TestTensorInt::from_data([0, 2, 1, 0], &device);
|
||||
|
||||
let output = tensor.select(0, indices);
|
||||
let expected = TensorData::from([true, true, false, true]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_bool_tensor_2d() {
|
||||
// Test that select works for boolean 2D tensors
|
||||
let device = Default::default();
|
||||
let tensor =
|
||||
TestTensorBool::<2>::from_data([[true, false, true], [false, true, false]], &device);
|
||||
let indices = TestTensorInt::from_data([1, 0], &device);
|
||||
|
||||
let output = tensor.select(0, indices);
|
||||
let expected = TensorData::from([[false, true, false], [true, false, true]]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_add_bool_tensor() {
|
||||
// Test that select_add works for boolean tensors
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorBool::<1>::from_data([true, false, true], &device);
|
||||
let values = TestTensorBool::<1>::from_data([false, true], &device);
|
||||
let indices = TestTensorInt::from_data([0, 2], &device);
|
||||
|
||||
let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add);
|
||||
// Note: select_add uses sum reduction, so:
|
||||
// index 0: true OR false = true
|
||||
// index 2: true OR true = true
|
||||
// index 1: false (unchanged)
|
||||
let expected = TensorData::from([true, false, true]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_add_bool_overlapping_indices() {
|
||||
// Test accumulation behavior with overlapping indices
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorBool::<1>::from_data([false, true], &device);
|
||||
let indices = TestTensorInt::from_data([0, 0], &device);
|
||||
let values = TestTensorBool::<1>::from_data([true, false], &device);
|
||||
|
||||
let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add);
|
||||
// Index 0: false OR true OR false = true
|
||||
let expected = TensorData::from([true, true]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_add_bool_false_to_true_case() {
|
||||
// Test false OR true = true
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorBool::<1>::from_data([false], &device);
|
||||
let indices = TestTensorInt::from_data([0], &device);
|
||||
let values = TestTensorBool::<1>::from_data([true], &device);
|
||||
|
||||
let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add);
|
||||
let expected = TensorData::from([true]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_add_bool_true_or_true_accumulation() {
|
||||
// Test multiple true accumulations
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorBool::<1>::from_data([true, false], &device);
|
||||
let indices = TestTensorInt::from_data([0, 0, 0], &device);
|
||||
let values = TestTensorBool::<1>::from_data([true, true, true], &device);
|
||||
|
||||
let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add);
|
||||
let expected = TensorData::from([true, false]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_match_default_implementation_behavior() {
|
||||
// Verify optimized implementation matches original default logic
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorBool::<1>::from_data([true, false, true], &device);
|
||||
let indices = TestTensorInt::from_data([0, 1, 0], &device);
|
||||
let values = TestTensorBool::<1>::from_data([false, true, true], &device);
|
||||
|
||||
let optimized_result =
|
||||
tensor
|
||||
.clone()
|
||||
.select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add);
|
||||
|
||||
// Manual default implementation logic
|
||||
let int_tensor = tensor.int();
|
||||
let int_values = values.int();
|
||||
let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add);
|
||||
let default_result = assigned.greater_elem(0);
|
||||
|
||||
optimized_result
|
||||
.into_data()
|
||||
.assert_eq(&default_result.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_add_bool_overlapping_indices_vs_default() {
|
||||
// Test overlapping indices against default implementation
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorBool::<1>::from_data([false, true], &device);
|
||||
let indices = TestTensorInt::from_data([0, 0], &device);
|
||||
let values = TestTensorBool::<1>::from_data([true, false], &device);
|
||||
|
||||
let optimized_result =
|
||||
tensor
|
||||
.clone()
|
||||
.select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add);
|
||||
|
||||
let int_tensor = tensor.int();
|
||||
let int_values = values.int();
|
||||
let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add);
|
||||
let default_result = assigned.greater_elem(0);
|
||||
|
||||
optimized_result
|
||||
.into_data()
|
||||
.assert_eq(&default_result.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_add_bool_true_or_true_accumulation_vs_default() {
|
||||
// Test multiple true accumulations against default implementation
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorBool::<1>::from_data([true, false], &device);
|
||||
let indices = TestTensorInt::from_data([0, 0, 0], &device);
|
||||
let values = TestTensorBool::<1>::from_data([true, true, true], &device);
|
||||
|
||||
let optimized_result =
|
||||
tensor
|
||||
.clone()
|
||||
.select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add);
|
||||
|
||||
let int_tensor = tensor.int();
|
||||
let int_values = values.int();
|
||||
let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add);
|
||||
let default_result = assigned.greater_elem(0);
|
||||
|
||||
optimized_result
|
||||
.into_data()
|
||||
.assert_eq(&default_result.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_add_bool_false_to_true_case_vs_default() {
|
||||
// Test false OR true case against default implementation
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorBool::<1>::from_data([false], &device);
|
||||
let indices = TestTensorInt::from_data([0], &device);
|
||||
let values = TestTensorBool::<1>::from_data([true], &device);
|
||||
|
||||
let optimized_result =
|
||||
tensor
|
||||
.clone()
|
||||
.select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add);
|
||||
|
||||
let int_tensor = tensor.int();
|
||||
let int_values = values.int();
|
||||
let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add);
|
||||
let default_result = assigned.greater_elem(0);
|
||||
|
||||
optimized_result
|
||||
.into_data()
|
||||
.assert_eq(&default_result.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_add_bool_tensor_vs_default() {
|
||||
// Test existing basic case against default implementation
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorBool::<1>::from_data([true, false, true], &device);
|
||||
let indices = TestTensorInt::from_data([0, 2], &device);
|
||||
let values = TestTensorBool::<1>::from_data([false, false], &device);
|
||||
|
||||
let optimized_result =
|
||||
tensor
|
||||
.clone()
|
||||
.select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add);
|
||||
|
||||
let int_tensor = tensor.int();
|
||||
let int_values = values.int();
|
||||
let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add);
|
||||
let default_result = assigned.greater_elem(0);
|
||||
|
||||
optimized_result
|
||||
.into_data()
|
||||
.assert_eq(&default_result.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Tensors are not eq")]
|
||||
fn should_fail_if_replacement_semantics_were_used() {
|
||||
// Test that framework uses accumulation, not replacement
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorBool::<1>::from_data([true], &device);
|
||||
let indices = TestTensorInt::from_data([0], &device);
|
||||
let values = TestTensorBool::<1>::from_data([false], &device);
|
||||
|
||||
let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add);
|
||||
let replacement_expected = TensorData::from([false]);
|
||||
|
||||
output.into_data().assert_eq(&replacement_expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Tensors are not eq")]
|
||||
fn should_fail_if_replacement_semantics_were_used_vs_default() {
|
||||
// Test that default implementation also uses accumulation, not replacement
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorBool::<1>::from_data([true], &device);
|
||||
let indices = TestTensorInt::from_data([0], &device);
|
||||
let values = TestTensorBool::<1>::from_data([false], &device);
|
||||
|
||||
let int_tensor = tensor.int();
|
||||
let int_values = values.int();
|
||||
let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add);
|
||||
let default_result = assigned.greater_elem(0);
|
||||
let replacement_expected = TensorData::from([false]);
|
||||
|
||||
default_result
|
||||
.into_data()
|
||||
.assert_eq(&replacement_expected, false);
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
use super::*;
|
||||
use alloc::vec;
|
||||
use burn_tensor::{Tensor, TensorData};
|
||||
|
||||
#[test]
|
||||
fn should_support_stack_ops_bool() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensorBool::<2>::from_data([[false, true, true]], &device);
|
||||
let tensor_2 = TestTensorBool::<2>::from_data([[true, true, false]], &device);
|
||||
|
||||
let output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 0);
|
||||
let expected = TensorData::from([[[false, true, true]], [[true, true, false]]]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_take_bool_tensor() {
|
||||
// Test take with boolean tensors
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorBool::<2>::from_data([[true, false], [false, true]], &device);
|
||||
let indices = TestTensorInt::<1>::from_data([1, 0], &device);
|
||||
|
||||
let output = tensor.take::<1, 2>(0, indices);
|
||||
let expected = TensorData::from([[false, true], [true, false]]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_take_bool_tensor_with_2d_indices() {
|
||||
// Test take with boolean tensors - output will be 3D
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorBool::<2>::from_data(
|
||||
[
|
||||
[true, false, true],
|
||||
[false, true, false],
|
||||
[true, true, false],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
|
||||
// 2D indices - shape [2, 2]
|
||||
let indices = TestTensorInt::<2>::from_data([[0, 2], [1, 0]], &device);
|
||||
let output = tensor.take::<2, 3>(0, indices);
|
||||
|
||||
// Expected: shape [2, 2, 3]
|
||||
let expected = TensorData::from([
|
||||
[[true, false, true], [true, true, false]],
|
||||
[[false, true, false], [true, false, true]],
|
||||
]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_support_transpose_bool() {
|
||||
let tensor = TestTensorBool::<3>::from_data(
|
||||
[
|
||||
[[false, true, false], [false, false, false]],
|
||||
[[false, false, true], [false, false, true]],
|
||||
],
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = tensor.transpose();
|
||||
let expected = TensorData::from([
|
||||
[[false, false], [true, false], [false, false]],
|
||||
[[false, false], [false, false], [true, true]],
|
||||
]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_swap_dims_bool() {
|
||||
let tensor = TestTensorBool::<3>::from_data(
|
||||
[
|
||||
[[false, true, false], [false, false, false]],
|
||||
[[false, false, true], [false, false, true]],
|
||||
],
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = tensor.swap_dims(0, 2);
|
||||
let expected = TensorData::from([
|
||||
[[false, false], [false, false]],
|
||||
[[true, false], [false, false]],
|
||||
[[false, true], [false, true]],
|
||||
]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn square_diag() {
|
||||
let device = Default::default();
|
||||
let data_expected = TensorData::from([
|
||||
[false, true, true],
|
||||
[true, false, true],
|
||||
[true, true, false],
|
||||
]);
|
||||
let tensor = TestTensorBool::<2>::diag_mask([3, 3], 0, &device);
|
||||
tensor.into_data().assert_eq(&data_expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn square_diag_offset() {
|
||||
let device = Default::default();
|
||||
let data_expected =
|
||||
TensorData::from([[true, false, true], [true, true, false], [true, true, true]]);
|
||||
let tensor = TestTensorBool::<2>::diag_mask([3, 3], 1, &device);
|
||||
tensor.into_data().assert_eq(&data_expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn square_tri_upper() {
|
||||
let device = Default::default();
|
||||
let data_expected = TensorData::from([
|
||||
[false, false, false],
|
||||
[true, false, false],
|
||||
[true, true, false],
|
||||
]);
|
||||
let tensor = TestTensorBool::<2>::triu_mask([3, 3], 0, &device);
|
||||
tensor.into_data().assert_eq(&data_expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn square_tri_upper_offset() {
|
||||
let device = Default::default();
|
||||
let data_expected = TensorData::from([
|
||||
[true, false, false],
|
||||
[true, true, false],
|
||||
[true, true, true],
|
||||
]);
|
||||
let tensor = TestTensorBool::<2>::triu_mask([3, 3], 1, &device);
|
||||
tensor.into_data().assert_eq(&data_expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn square_tri_lower() {
|
||||
let device = Default::default();
|
||||
|
||||
let data_expected = TensorData::from([
|
||||
[false, true, true],
|
||||
[false, false, true],
|
||||
[false, false, false],
|
||||
]);
|
||||
let tensor = TestTensorBool::<2>::tril_mask([3, 3], 0, &device);
|
||||
tensor.into_data().assert_eq(&data_expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn square_tri_lower_offset() {
|
||||
let device = Default::default();
|
||||
|
||||
let data_expected = TensorData::from([
|
||||
[true, true, true],
|
||||
[false, true, true],
|
||||
[false, false, true],
|
||||
]);
|
||||
let tensor = TestTensorBool::<2>::tril_mask([3, 3], -1, &device);
|
||||
tensor.into_data().assert_eq(&data_expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rect_diag() {
|
||||
let device = Default::default();
|
||||
let data_expected = TensorData::from([
|
||||
[false, true, true, true],
|
||||
[true, false, true, true],
|
||||
[true, true, false, true],
|
||||
]);
|
||||
let tensor = TestTensorBool::<2>::diag_mask([3, 4], 0, &device);
|
||||
tensor.into_data().assert_eq(&data_expected, false);
|
||||
|
||||
let data_expected = TensorData::from([
|
||||
[false, true, true],
|
||||
[true, false, true],
|
||||
[true, true, false],
|
||||
[true, true, true],
|
||||
]);
|
||||
let tensor = TestTensorBool::<2>::diag_mask([4, 3], 0, &device);
|
||||
tensor.into_data().assert_eq(&data_expected, false);
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
use super::*;
|
||||
use burn_tensor::Distribution;
|
||||
use burn_tensor::s;
|
||||
|
||||
#[test]
|
||||
fn test_unfold_bool() {
|
||||
let device = Default::default();
|
||||
|
||||
let input =
|
||||
TestTensor::<3>::random([2, 6, 6], Distribution::Default, &device).greater_elem(0.5);
|
||||
|
||||
let dim = 1;
|
||||
let size = 3;
|
||||
let step = 2;
|
||||
let actual: TestTensorBool<4> = input.clone().unfold(dim, size, step);
|
||||
|
||||
let expected = TestTensorBool::<4>::empty([2, 2, 6, 3], &device)
|
||||
.slice_assign(
|
||||
s![.., 0, .., ..],
|
||||
input
|
||||
.clone()
|
||||
.slice(s![.., 0..3, ..])
|
||||
.swap_dims(1, 2)
|
||||
.unsqueeze_dim::<4>(1),
|
||||
)
|
||||
.slice_assign(
|
||||
s![.., 1, .., ..],
|
||||
input
|
||||
.clone()
|
||||
.slice(s![.., 2..5, ..])
|
||||
.swap_dims(1, 2)
|
||||
.unsqueeze_dim::<4>(1),
|
||||
);
|
||||
|
||||
actual.to_data().assert_eq(&expected.to_data(), true);
|
||||
}
|
||||
@@ -0,0 +1,761 @@
|
||||
/// This module tests whether basic tensor operations remain invariant when performed on clones,
|
||||
/// meaning that cloning input tensors won't affect the results.
|
||||
///
|
||||
/// Those are relevant tests because backends may employ unsafe optimizations to reuse tensor data
|
||||
/// and use different kernels in such cases. We ensure that the results are consistent regardless
|
||||
/// of the approach and that the input tensors are not modified when cloned.
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::activation::{
|
||||
gelu, log_sigmoid, log_softmax, mish, relu, sigmoid, silu, softmax, softplus, tanh,
|
||||
};
|
||||
use burn_tensor::{Distribution, IndexingUpdateOp, TensorData};
|
||||
|
||||
pub trait CloneInvarianceTest<const D: usize> {
|
||||
type Args;
|
||||
|
||||
fn args(&self) -> Self::Args;
|
||||
|
||||
fn run(&self, args: &Self::Args, inplace: bool) -> TensorData;
|
||||
|
||||
fn check(&self) {
|
||||
let args = self.args();
|
||||
let out = self.run(&args, false);
|
||||
let out_inplace = self.run(&args, true);
|
||||
|
||||
out.assert_approx_eq::<FloatElem>(&out_inplace, Tolerance::default());
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! clone_invariance_test {
|
||||
(unary: $name:ident, ops_float: $ops:expr) => {
|
||||
#[test]
|
||||
#[allow(non_snake_case)]
|
||||
fn $name() {
|
||||
struct $name;
|
||||
|
||||
impl CloneInvarianceTest<2> for $name {
|
||||
type Args = TensorData;
|
||||
|
||||
fn args(&self) -> Self::Args {
|
||||
TestTensor::<2>::random([32, 32], Distribution::Default, &Default::default())
|
||||
.into_data()
|
||||
.convert::<f32>()
|
||||
}
|
||||
|
||||
fn run(&self, args: &Self::Args, inplace: bool) -> TensorData {
|
||||
let lhs = TestTensor::from_data(args.clone(), &Default::default());
|
||||
|
||||
if inplace {
|
||||
$ops(lhs).into_data().convert::<f32>()
|
||||
} else {
|
||||
let out = $ops(lhs.clone()).into_data().convert::<f32>();
|
||||
lhs.into_data()
|
||||
.assert_approx_eq::<FloatElem>(args, Tolerance::default());
|
||||
out
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CloneInvarianceTest::<2>::check(&$name);
|
||||
}
|
||||
};
|
||||
|
||||
(binary: $name:ident, ops_float: $ops:expr) => {
|
||||
#[test]
|
||||
#[allow(non_snake_case)]
|
||||
fn $name() {
|
||||
struct $name;
|
||||
|
||||
impl CloneInvarianceTest<2> for $name {
|
||||
type Args = (TensorData, TensorData);
|
||||
|
||||
fn args(&self) -> Self::Args {
|
||||
let device = Default::default();
|
||||
(
|
||||
TestTensor::<2>::ones([32, 32], &device)
|
||||
.into_data()
|
||||
.convert::<f32>(),
|
||||
// Avoid div by zero.
|
||||
TestTensor::<2>::ones([32, 32], &device)
|
||||
.into_data()
|
||||
.convert::<f32>(),
|
||||
)
|
||||
}
|
||||
|
||||
fn run(&self, (lhs_arg, rhs_arg): &Self::Args, inplace: bool) -> TensorData {
|
||||
let device = Default::default();
|
||||
let lhs = TestTensor::from_data(lhs_arg.clone(), &device);
|
||||
let rhs = TestTensor::from_data(rhs_arg.clone(), &device);
|
||||
|
||||
if inplace {
|
||||
$ops(lhs, rhs).into_data().convert::<f32>()
|
||||
} else {
|
||||
let out = $ops(lhs.clone(), rhs.clone()).into_data().convert::<f32>();
|
||||
|
||||
lhs.into_data()
|
||||
.assert_approx_eq::<FloatElem>(lhs_arg, Tolerance::default());
|
||||
rhs.into_data()
|
||||
.assert_approx_eq::<FloatElem>(rhs_arg, Tolerance::default());
|
||||
|
||||
out
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CloneInvarianceTest::<2>::check(&$name);
|
||||
}
|
||||
};
|
||||
|
||||
(unary: $name:ident, ops_int: $ops:expr) => {
|
||||
#[test]
|
||||
#[allow(non_snake_case)]
|
||||
fn $name() {
|
||||
struct $name;
|
||||
|
||||
impl CloneInvarianceTest<2> for $name {
|
||||
type Args = TensorData;
|
||||
|
||||
fn args(&self) -> Self::Args {
|
||||
TestTensor::<2>::random(
|
||||
[32, 32],
|
||||
Distribution::Uniform(0.0, 50.0),
|
||||
&Default::default(),
|
||||
)
|
||||
.into_data()
|
||||
.convert::<i32>()
|
||||
}
|
||||
|
||||
fn run(&self, args: &Self::Args, inplace: bool) -> TensorData {
|
||||
let lhs = TestTensorInt::from_data(args.clone(), &Default::default());
|
||||
|
||||
if inplace {
|
||||
$ops(lhs).into_data().convert::<f32>()
|
||||
} else {
|
||||
let out = $ops(lhs.clone()).into_data().convert::<f32>();
|
||||
lhs.into_data()
|
||||
.convert::<i32>()
|
||||
.assert_approx_eq::<FloatElem>(args, Tolerance::default());
|
||||
out
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CloneInvarianceTest::<2>::check(&$name);
|
||||
}
|
||||
};
|
||||
|
||||
(binary: $name:ident, ops_int: $ops:expr) => {
|
||||
#[test]
|
||||
#[allow(non_snake_case)]
|
||||
fn $name() {
|
||||
struct $name;
|
||||
|
||||
impl CloneInvarianceTest<2> for $name {
|
||||
type Args = (TensorData, TensorData);
|
||||
|
||||
fn args(&self) -> Self::Args {
|
||||
let device = Default::default();
|
||||
(
|
||||
TestTensor::<2>::random([32, 32], Distribution::Uniform(0., 50.), &device)
|
||||
.into_data()
|
||||
.convert::<i32>(),
|
||||
// Avoid div by zero.
|
||||
TestTensor::<2>::random([32, 32], Distribution::Uniform(1., 51.), &device)
|
||||
.into_data()
|
||||
.convert::<i32>(),
|
||||
)
|
||||
}
|
||||
|
||||
fn run(&self, (lhs_arg, rhs_arg): &Self::Args, inplace: bool) -> TensorData {
|
||||
let device = Default::default();
|
||||
let lhs = TestTensorInt::from_data(lhs_arg.clone(), &device);
|
||||
let rhs = TestTensorInt::from_data(rhs_arg.clone(), &device);
|
||||
|
||||
if inplace {
|
||||
$ops(lhs, rhs).into_data().convert::<f32>()
|
||||
} else {
|
||||
let out = $ops(lhs.clone(), rhs.clone()).into_data().convert::<f32>();
|
||||
|
||||
lhs.into_data()
|
||||
.convert::<i32>()
|
||||
.assert_approx_eq::<FloatElem>(lhs_arg, Tolerance::default());
|
||||
rhs.into_data()
|
||||
.convert::<i32>()
|
||||
.assert_approx_eq::<FloatElem>(rhs_arg, Tolerance::default());
|
||||
|
||||
out
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CloneInvarianceTest::<2>::check(&$name);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
mod float {
|
||||
use super::*;
|
||||
|
||||
// Unary ops
|
||||
clone_invariance_test!(
|
||||
unary: AddScalar,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.add_scalar(2.0)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: SubScalar,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.sub_scalar(2.0)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: DivScalar,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.div_scalar(2.0)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: MulScalar,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.mul_scalar(2.0)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: PowScalar,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.powf_scalar(2.0)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Square,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.square()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Sqrt,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.sqrt()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Exp,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.exp()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Neg,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.neg()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: MeanDim,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.mean_dim(1)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: SumDim,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.sum_dim(1)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Sum,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.sum().unsqueeze::<2>()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Mean,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.mean().unsqueeze::<2>()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Clamp,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.clamp(-2., 2.)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: ClampMin,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.clamp_min(-2.)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: ClampMax,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.clamp_max(2.)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Abs,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.abs()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Cos,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.cos()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Sin,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.sin()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Tan,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.tan()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Log,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.log()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Log1P,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.log1p()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: SwapDims,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.swap_dims(0, 1)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Transpose,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.transpose()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Slice,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.slice([0..12, 12..24])
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Erf,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.erf()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: EqualElem,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.equal_elem(0.5)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: NotEqualElem,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.not_equal_elem(0.5)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: GreaterElem,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.greater_elem(0.5)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: GreaterEqualElem,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.greater_equal_elem(0.5)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: LowerElem,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.lower_elem(0.5)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: LowerEqualElem,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.lower_equal_elem(0.5)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Argmax,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.argmax(0)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Argmin,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.argmin(0)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Max,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.max().unsqueeze::<2>()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Min,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.min().unsqueeze::<2>()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: MaxDim,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.max_dim(1)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: MaxDimWithIndices,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.max_dim_with_indices(1).0
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: MinDimWithIndices,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.min_dim_with_indices(1).0
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: MinDim,
|
||||
ops_float: |tensor: TestTensor<2>| tensor.min_dim(1)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Repeat,
|
||||
ops_float: |tensor: TestTensor<2>| {
|
||||
tensor.reshape([1, 32, 32]).repeat_dim(0, 4).reshape([4 * 32, 32])
|
||||
}
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Reshape,
|
||||
ops_float: |tensor: TestTensor<2>| {
|
||||
let shape = tensor.shape();
|
||||
let new_shape = [shape.num_elements(), 1];
|
||||
tensor.reshape(new_shape)
|
||||
}
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Gatter,
|
||||
ops_float: |tensor: TestTensor<2>| {
|
||||
let shape = tensor.shape();
|
||||
let indices = TestTensorInt::ones(shape, &Default::default());
|
||||
tensor.gather(0, indices)
|
||||
}
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Select,
|
||||
ops_float: |tensor: TestTensor<2>| {
|
||||
let indices = TestTensorInt::from_ints([1, 2, 0, 5], &Default::default());
|
||||
tensor.select(0, indices)
|
||||
}
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: MaskFill,
|
||||
ops_float: |tensor: TestTensor<2>| {
|
||||
let mask = tensor.clone().greater_elem(0.5);
|
||||
tensor.mask_fill(mask, 77.0)
|
||||
}
|
||||
);
|
||||
|
||||
// Activation
|
||||
clone_invariance_test!(
|
||||
unary: Softmax,
|
||||
ops_float: |tensor: TestTensor<2>| softmax(tensor, 1)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: LogSoftmax,
|
||||
ops_float: |tensor: TestTensor<2>| log_softmax(tensor, 1)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Sigmoid,
|
||||
ops_float: |tensor: TestTensor<2>| sigmoid(tensor)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: LogSigmoid,
|
||||
ops_float: |tensor: TestTensor<2>| log_sigmoid(tensor)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Relu,
|
||||
ops_float: |tensor: TestTensor<2>| relu(tensor)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Gelu,
|
||||
ops_float: |tensor: TestTensor<2>| gelu(tensor)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Mish,
|
||||
ops_float: |tensor: TestTensor<2>| mish(tensor)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Silu,
|
||||
ops_float: |tensor: TestTensor<2>| silu(tensor)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Softplus,
|
||||
ops_float: |tensor: TestTensor<2>| softplus(tensor, 1.0)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Tanh,
|
||||
ops_float: |tensor: TestTensor<2>| tanh(tensor)
|
||||
);
|
||||
|
||||
// Binary ops
|
||||
clone_invariance_test!(
|
||||
binary: Add,
|
||||
ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.add(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: Sub,
|
||||
ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.sub(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: Div,
|
||||
ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.div(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: Mul,
|
||||
ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.mul(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: Matmul,
|
||||
ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.matmul(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: Equal,
|
||||
ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.equal(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: Greater,
|
||||
ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.greater(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: GreaterEqual,
|
||||
ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.greater_equal(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: Lower,
|
||||
ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.lower(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: LowerEqual,
|
||||
ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.lower_equal(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: Cat,
|
||||
ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| {
|
||||
let lhs = lhs.reshape([1usize, 32, 32]);
|
||||
let rhs = rhs.reshape([1usize, 32, 32]);
|
||||
|
||||
TestTensor::cat(vec![lhs, rhs], 0).reshape([64, 32])
|
||||
}
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: Scatter,
|
||||
ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| {
|
||||
let shape = tensor.shape();
|
||||
let indices = TestTensorInt::ones(shape, &Default::default());
|
||||
tensor.scatter(0, indices, values, IndexingUpdateOp::Add)
|
||||
}
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: SliceAssign,
|
||||
ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| {
|
||||
tensor.slice_assign([0..12, 12..24], values.slice([12..24, 0..12]))
|
||||
}
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: MaskWhere,
|
||||
ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| {
|
||||
let mask = tensor.clone().greater_elem(0.5);
|
||||
tensor.mask_where(mask, values)
|
||||
}
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: SelectAssign,
|
||||
ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| {
|
||||
let indices = TestTensorInt::from_ints([1, 2, 0, 5], &Default::default());
|
||||
let values = values.select(0, indices.clone());
|
||||
tensor.select_assign(0, indices, values, IndexingUpdateOp::Add)
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
mod int {
|
||||
use super::*;
|
||||
|
||||
// Unary ops
|
||||
clone_invariance_test!(
|
||||
unary: AddScalar,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.add_scalar(2.0)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: SubScalar,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.sub_scalar(2.0)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: DivScalar,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.div_scalar(2.0)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: MulScalar,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.mul_scalar(2.0)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Neg,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.neg()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: MeanDim,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.mean_dim(1)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: SumDim,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.sum_dim(1)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Sum,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.sum().unsqueeze::<2>()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Mean,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.mean().unsqueeze::<2>()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Clamp,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.clamp(-2., 2.)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: ClampMin,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.clamp_min(-2.)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: ClampMax,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.clamp_max(2.)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Abs,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.abs()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: SwapDims,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.swap_dims(0, 1)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Transpose,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.transpose()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Slice,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.slice([0..12, 12..24])
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: EqualElem,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.equal_elem(25)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: NotEqualElem,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.not_equal_elem(25)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: GreaterElem,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.greater_elem(25)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: GreaterEqualElem,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.greater_equal_elem(25)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: LowerElem,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.lower_elem(25)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: LowerEqualElem,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.lower_equal_elem(25)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Argmax,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.argmax(0)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Argmin,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.argmin(0)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Max,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.max().unsqueeze::<2>()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Min,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.min().unsqueeze::<2>()
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: MaxDim,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.max_dim(1)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: MaxDimWithIndices,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.max_dim_with_indices(1).0
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: MinDimWithIndices,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.min_dim_with_indices(1).0
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: MinDim,
|
||||
ops_int: |tensor: TestTensorInt<2>| tensor.min_dim(1)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Repeat,
|
||||
ops_int: |tensor: TestTensorInt<2>| {
|
||||
tensor.reshape([1, 32, 32]).repeat_dim(0, 4).reshape([4 * 32, 32])
|
||||
}
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Reshape,
|
||||
ops_int: |tensor: TestTensorInt<2>| {
|
||||
let shape = tensor.shape();
|
||||
let new_shape = [shape.num_elements(), 1];
|
||||
tensor.reshape(new_shape)
|
||||
}
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Gatter,
|
||||
ops_int: |tensor: TestTensorInt<2>| {
|
||||
let shape = tensor.shape();
|
||||
let indices = TestTensorInt::ones(shape, &Default::default());
|
||||
tensor.gather(0, indices)
|
||||
}
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: Select,
|
||||
ops_int: |tensor: TestTensorInt<2>| {
|
||||
let indices = TestTensorInt::from_ints([1, 2, 0, 5], &Default::default());
|
||||
tensor.select(0, indices)
|
||||
}
|
||||
);
|
||||
clone_invariance_test!(
|
||||
unary: MaskFill,
|
||||
ops_int: |tensor: TestTensorInt<2>| {
|
||||
let mask = tensor.clone().greater_elem(0.5);
|
||||
tensor.mask_fill(mask, 77.0)
|
||||
}
|
||||
);
|
||||
|
||||
// Binary ops
|
||||
clone_invariance_test!(
|
||||
binary: Add,
|
||||
ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.add(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: Sub,
|
||||
ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.sub(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: Div,
|
||||
ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.div(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: Mul,
|
||||
ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.mul(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: Equal,
|
||||
ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.equal(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: NotEqual,
|
||||
ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.not_equal(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: Greater,
|
||||
ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.greater(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: GreaterEqual,
|
||||
ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.greater_equal(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: Lower,
|
||||
ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.lower(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: LowerEqual,
|
||||
ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.lower_equal(rhs)
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: Cat,
|
||||
ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| {
|
||||
let lhs = lhs.reshape([1usize, 32, 32]);
|
||||
let rhs = rhs.reshape([1usize, 32, 32]);
|
||||
|
||||
TestTensorInt::cat(vec![lhs, rhs], 0).reshape([64, 32])
|
||||
}
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: Scatter,
|
||||
ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| {
|
||||
let shape = tensor.shape();
|
||||
let indices = TestTensorInt::ones(shape, &Default::default());
|
||||
tensor.scatter(0, indices, values, IndexingUpdateOp::Add)
|
||||
}
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: SliceAssign,
|
||||
ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| {
|
||||
tensor.slice_assign([0..12, 12..24], values.slice([12..24, 0..12]))
|
||||
}
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: MaskWhere,
|
||||
ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| {
|
||||
let mask = tensor.clone().greater_elem(0.5);
|
||||
tensor.mask_where(mask, values)
|
||||
}
|
||||
);
|
||||
clone_invariance_test!(
|
||||
binary: SelectAssign,
|
||||
ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| {
|
||||
let indices = TestTensorInt::from_ints([1, 2, 0, 5], &Default::default());
|
||||
let values = values.select(0, indices.clone());
|
||||
tensor.select_assign(0, indices, values, IndexingUpdateOp::Add)
|
||||
}
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_celu_d2() {
|
||||
let tensor = TestTensor::<2>::from([[1.0, 7.0], [-3.0, 0.5]]);
|
||||
|
||||
let output = activation::celu(tensor, 1.0);
|
||||
// celu(1, 1) = 1
|
||||
// celu(7, 1) = 7
|
||||
// celu(-3, 1) = 1 * (exp(-3) - 1) = -0.950213
|
||||
// celu(0.5, 1) = 0.5
|
||||
let expected = TensorData::from([[1.0, 7.0], [-0.950213, 0.5]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_celu_with_alpha() {
|
||||
let tensor = TestTensor::<1>::from([0.0, -1.0, -2.0]);
|
||||
|
||||
let output = activation::celu(tensor, 2.0);
|
||||
// celu(0, 2) = 0
|
||||
// celu(-1, 2) = 2 * (exp(-0.5) - 1) = -0.786939
|
||||
// celu(-2, 2) = 2 * (exp(-1) - 1) = -1.264241
|
||||
let expected = TensorData::from([0.0, -0.786939, -1.264241]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_elu() {
|
||||
let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
|
||||
let output = activation::elu(tensor, 1.0);
|
||||
// elu(1, 1) = 1, elu(7, 1) = 7, elu(13, 1) = 13
|
||||
// elu(-3, 1) = 1 * (exp(-3) - 1) = -0.950213
|
||||
let expected = TensorData::from([[1.0, 7.0], [13.0, -0.950213]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_elu_alpha() {
|
||||
let tensor = TestTensor::<1>::from([0.0, -1.0, -2.0]);
|
||||
|
||||
let output = activation::elu(tensor, 2.0);
|
||||
// elu(0, 2) = 2*(exp(0)-1) = 0
|
||||
// elu(-1, 2) = 2*(exp(-1)-1) = 2*(-0.632121) = -1.264241
|
||||
// elu(-2, 2) = 2*(exp(-2)-1) = 2*(-0.864665) = -1.729329
|
||||
let expected = TensorData::from([0.0, -1.264241, -1.729329]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_gelu() {
|
||||
let tensor = TestTensor::<2>::from([[
|
||||
0.5447, 0.9809, 0.4114, 0.1398, 0.8045, 0.4103, 0.2388, 0.5262, 0.6677, 0.6737,
|
||||
]]);
|
||||
let output = activation::gelu(tensor);
|
||||
let expected = TensorData::from([[
|
||||
0.3851, 0.8207, 0.2714, 0.0777, 0.6351, 0.2704, 0.1419, 0.3687, 0.4993, 0.5051,
|
||||
]]);
|
||||
|
||||
// Low precision to allow approximation implementation using tanh
|
||||
output.into_data().assert_approx_eq::<FloatElem>(
|
||||
&expected,
|
||||
Tolerance::default().set_half_precision_absolute(2e-3),
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_glu_d3() {
|
||||
let tensor = TestTensor::<3>::from([[
|
||||
[
|
||||
-0.5710, -1.3416, 1.9128, -0.8257, -0.1331, -1.4804, -0.6281, -0.6115,
|
||||
],
|
||||
[
|
||||
0.0267, -1.3834, 0.2752, 0.7844, -0.3549, -0.4274, 0.3290, -0.5459,
|
||||
],
|
||||
[
|
||||
-1.6347, -2.0908, 1.8801, 0.3541, 0.2237, 1.0377, 2.4850, 0.3490,
|
||||
],
|
||||
]]);
|
||||
|
||||
let output = activation::glu(tensor, 2);
|
||||
|
||||
output.into_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[
|
||||
[-0.2665, -0.2487, 0.6656, -0.2904],
|
||||
[0.0110, -0.5461, 0.1601, 0.2877],
|
||||
[-0.9084, -1.5439, 1.7355, 0.2077],
|
||||
]]),
|
||||
Default::default(),
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_hard_sigmoid() {
|
||||
let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
|
||||
let output = activation::hard_sigmoid(tensor, 0.2, 0.5);
|
||||
let expected = TensorData::from([[0.7, 1.0], [1.0, 0.0]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hard_sigmoid_overflow() {
|
||||
let tensor = TestTensor::<1>::from([FloatElem::MAX, FloatElem::MIN]);
|
||||
|
||||
let output = activation::hard_sigmoid(tensor, 0.2, 0.5);
|
||||
let expected = TensorData::from([1.0, 0.0]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_leaky_relu_d2() {
|
||||
let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, -4.0, 5.0]]);
|
||||
|
||||
let output = activation::leaky_relu(tensor, 0.01);
|
||||
|
||||
// Account for conversion errors if `FloatType != f32`
|
||||
output.into_data().assert_approx_eq::<FloatElem>(
|
||||
&TensorData::from([[0.0, -0.01, 2.0], [3.0, -0.04, 5.0]]),
|
||||
Tolerance::default(),
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{ElementConversion, TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_log_sigmoid() {
|
||||
let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
|
||||
let output = activation::log_sigmoid(tensor);
|
||||
let expected = TensorData::from([[-3.132617e-1, -9.114665e-4], [-2.260327e-6, -3.0485873]]);
|
||||
|
||||
let tolerance = Tolerance::rel_abs(0.01, 0.0001);
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_log_sigmoid_numerical_stability() {
|
||||
let tensor = TestTensor::<1>::from([300.0, -300.0]);
|
||||
|
||||
let output = activation::log_sigmoid(tensor);
|
||||
|
||||
// For large negative values, the previous implementation −log(1 + exp(−x)) would give -inf
|
||||
let expected = TensorData::from([0.0, -300.0]);
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let tensor = TestTensor::<1>::from([FloatElem::MAX, FloatElem::MIN]);
|
||||
let output = activation::log_sigmoid(tensor);
|
||||
let expected = TensorData::from([0.elem(), FloatElem::MIN]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_mish() {
|
||||
let tensor = TestTensor::<2>::from([[-0.4240, -0.9574, -0.2215], [-0.5767, 0.7218, -0.1620]]);
|
||||
|
||||
let output = activation::mish(tensor);
|
||||
let expected = TensorData::from([
|
||||
[-0.19709, -0.30056, -0.11714],
|
||||
[-0.24132, 0.58235, -0.08877],
|
||||
]);
|
||||
|
||||
// Metal has less precise trigonometric functions (tanh inside mish)
|
||||
let tolerance = Tolerance::default().set_half_precision_relative(1e-2);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
use super::*;
|
||||
|
||||
mod celu;
|
||||
mod elu;
|
||||
mod gelu;
|
||||
mod glu;
|
||||
mod hard_sigmoid;
|
||||
mod leaky_relu;
|
||||
mod log_sigmoid;
|
||||
mod mish;
|
||||
mod prelu;
|
||||
mod quiet_softmax;
|
||||
mod relu;
|
||||
mod selu;
|
||||
mod sigmoid;
|
||||
mod silu;
|
||||
mod softmax;
|
||||
mod softmin;
|
||||
mod softplus;
|
||||
mod softsign;
|
||||
mod tanh_activation;
|
||||
mod thresholded_relu;
|
||||
@@ -0,0 +1,101 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_prelu_2_dimension() {
|
||||
let data = [
|
||||
[-1.1, 0.0, 1.2, 0.25, -5.4],
|
||||
[-4.567, 0.56, -1.55, 99.9, 0.0],
|
||||
];
|
||||
let tensor = TestTensor::<2>::from(data);
|
||||
let output = activation::prelu(tensor, TestTensor::from([0.5, 0.25, 0.0, -0.8, -0.4]));
|
||||
let expected = TensorData::from([
|
||||
[-0.5500, 0.0000, 1.2000, 0.2500, 2.1600],
|
||||
[-2.2835, 0.5600, -0.0000, 99.9000, -0.0000],
|
||||
]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
#[test]
|
||||
fn test_prelu_2_dimension_scalar_weight() {
|
||||
let data = [
|
||||
[-1.1, 0.0, 1.2, 0.25, -5.4],
|
||||
[-4.567, 0.56, -1.55, 99.9, 0.0],
|
||||
];
|
||||
let tensor = TestTensor::<2>::from(data);
|
||||
let output = activation::prelu(tensor, TestTensor::from([-0.8]));
|
||||
let expected = TensorData::from([
|
||||
[0.8800, -0.0000, 1.2000, 0.2500, 4.3200],
|
||||
[3.6536, 0.5600, 1.2400, 99.9000, -0.0000],
|
||||
]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prelu_positives() {
|
||||
// Check that positives are untouched
|
||||
let data = [[
|
||||
0.5447, 0.9809, 0.4114, 0.1398, 0.8045, 0.4103, 0.2388, 0.5262, 0.6677, 0.6737,
|
||||
]];
|
||||
let tensor = TestTensor::<2>::from(data);
|
||||
let output = activation::prelu(tensor, TestTensor::from([0.25]));
|
||||
let expected = TensorData::from(data);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prelu_zero_weight() {
|
||||
// test that with weight 0 it behaves as relu
|
||||
let data = [-1.1, 0.0, 1.2, 0.25, -5.4];
|
||||
let tensor = TestTensor::<1>::from(data);
|
||||
let output = activation::prelu(tensor, TestTensor::from([0.0]));
|
||||
let expected = TensorData::from([0.0, 0.0, 1.2, 0.25, 0.0]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prelu_some_weight() {
|
||||
// test that with some non zero weight it works like leaky relu
|
||||
let data = [-1.1, 0.0, 1.2, 0.25, -5.4];
|
||||
let tensor = TestTensor::<1>::from(data);
|
||||
let output = activation::prelu(tensor, TestTensor::from([0.5]));
|
||||
let expected = TensorData::from([-0.550, 0.0, 1.20, 0.250, -2.70]);
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_prelu_single_dim_multi_weight() {
|
||||
// should panic because the data has only 1 channel
|
||||
let data = [-1.1, 2.0, 1.2, 0.25, -5.4];
|
||||
let tensor = TestTensor::<1>::from(data);
|
||||
let data_actual =
|
||||
activation::prelu(tensor, TestTensor::from([0.5, -0.25, 0.0, 0.5, -1.0])).into_data();
|
||||
let data_expected = TensorData::from([-0.550, 0.0, 1.20, 0.250, -2.70]);
|
||||
data_expected.assert_approx_eq::<FloatElem>(&data_actual, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_prelu_multi_dim_wrong_weights() {
|
||||
let data = [
|
||||
[-1.1, 0.0, 1.2, 0.25, -5.4],
|
||||
[-4.567, 0.56, -1.55, 99.9, 0.0],
|
||||
];
|
||||
let tensor = TestTensor::<2>::from(data);
|
||||
let _ = activation::prelu(tensor, TestTensor::from([-0.8, 0.1]));
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_quiet_softmax_d2() {
|
||||
let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
|
||||
let output = activation::quiet_softmax(tensor, 1);
|
||||
let expected = TensorData::from([[2.47e-03, 9.975e-01], [1.0, 1.1254e-07]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_relu_d2() {
|
||||
let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, -4.0, 5.0]]);
|
||||
|
||||
let output = activation::relu(tensor);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[0.0, 0.0, 2.0], [3.0, 0.0, 5.0]]), false);
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_selu() {
|
||||
// selu(x) = gamma * x if x > 0, gamma * alpha * (exp(x) - 1) if x <= 0
|
||||
// alpha = 1.6733, gamma = 1.0507
|
||||
let tensor = TestTensor::<2>::from([[0.0, 1.0, -1.0], [2.0, -2.0, 0.5]]);
|
||||
|
||||
let output = activation::selu(tensor);
|
||||
|
||||
// Expected values computed from the formula:
|
||||
// selu(0.0) = 1.0507 * 1.6733 * (exp(0) - 1) = 0.0
|
||||
// selu(1.0) = 1.0507 * 1.0 = 1.0507
|
||||
// selu(-1.0) = 1.0507 * 1.6733 * (exp(-1) - 1) = 1.7581 * (0.3679 - 1) = -1.1113
|
||||
// selu(2.0) = 1.0507 * 2.0 = 2.1014
|
||||
// selu(-2.0) = 1.0507 * 1.6733 * (exp(-2) - 1) = 1.7581 * (0.1353 - 1) = -1.5202
|
||||
// selu(0.5) = 1.0507 * 0.5 = 0.5254
|
||||
let expected = TensorData::from([[0.0, 1.0507, -1.1113], [2.1014, -1.5202, 0.5254]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_selu_zero() {
|
||||
let tensor = TestTensor::<1>::from([0.0]);
|
||||
|
||||
let output = activation::selu(tensor);
|
||||
let expected = TensorData::from([0.0]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_sigmoid() {
|
||||
let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
|
||||
let output = activation::sigmoid(tensor);
|
||||
let expected = TensorData::from([[0.731059, 0.999089], [0.999998, 0.047426]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sigmoid_overflow() {
|
||||
let tensor = TestTensor::<1>::from([FloatElem::MAX, FloatElem::MIN]);
|
||||
|
||||
let output = activation::sigmoid(tensor);
|
||||
let expected = TensorData::from([1.0, 0.0]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_silu() {
|
||||
let tensor = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
|
||||
let output = activation::silu(tensor);
|
||||
let expected = TensorData::from([[0.73106, 1.76159], [2.85772, 3.92806]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_softmax_d2() {
|
||||
let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
|
||||
let output = activation::softmax(tensor, 1);
|
||||
let expected = TensorData::from([[2.472623e-03, 9.975274e-01], [1.0, 1.125352e-07]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_softmin_d2() {
|
||||
let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
|
||||
let output = activation::softmin(tensor, 1);
|
||||
let expected = TensorData::from([[9.975274e-01, 2.472623e-03], [1.125352e-07, 1.0000]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_softplus_d2() {
|
||||
let tensor = TestTensor::<2>::from([[-0.4240, -0.9574, -0.2215], [-0.5767, 0.7218, -0.1620]]);
|
||||
|
||||
let output = activation::softplus(tensor.clone(), 1.0);
|
||||
let expected = TensorData::from([
|
||||
[0.503453, 0.324898, 0.588517],
|
||||
[0.445806, 1.117805, 0.615424],
|
||||
]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let output = activation::softplus(tensor, 2.0);
|
||||
let expected = TensorData::from([
|
||||
[0.178232, 0.068737, 0.247990],
|
||||
[0.137132, 0.827771, 0.272106],
|
||||
]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_softsign() {
|
||||
let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
|
||||
let output = activation::softsign(tensor);
|
||||
let expected = TensorData::from([[0.5, 0.875], [0.928571, -0.75]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softsign_zero() {
|
||||
let tensor = TestTensor::<1>::from([0.0]);
|
||||
|
||||
let output = activation::softsign(tensor);
|
||||
let expected = TensorData::from([0.0]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_tanh() {
|
||||
let tensor = TestTensor::<2>::from([[1., 2.], [3., 4.]]);
|
||||
|
||||
let output = activation::tanh(tensor);
|
||||
let expected = TensorData::from([[0.761594, 0.964028], [0.995055, 0.999329]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, activation};
|
||||
|
||||
#[test]
|
||||
fn test_thresholded_relu_d2() {
|
||||
// alpha = 1.0 (ONNX default): x if x > 1.0, else 0
|
||||
let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, 1.0, 0.5]]);
|
||||
|
||||
let output = activation::thresholded_relu(tensor, 1.0);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[0.0, 0.0, 2.0], [3.0, 0.0, 0.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thresholded_relu_d2_alpha() {
|
||||
// alpha = 0.5: x if x > 0.5, else 0
|
||||
let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, 0.5, 0.6]]);
|
||||
|
||||
let output = activation::thresholded_relu(tensor, 0.5);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[0.0, 0.0, 2.0], [3.0, 0.0, 0.6]]), false);
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
use super::*;
|
||||
use burn_tensor::grid::affine_grid_2d;
|
||||
|
||||
fn create_identity_transform(batch_size: usize) -> TestTensor<3> {
|
||||
// Identity affine transform (batch_size, 2, 3)
|
||||
TestTensor::<3>::from([[[1f32, 0., 0.], [0., 1., 0.]]]).expand([batch_size, 2, 3])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_affine_grid_identity() {
|
||||
let batch_size = 1;
|
||||
let channels = 1;
|
||||
let height = 2;
|
||||
let width = 2;
|
||||
|
||||
let transform = create_identity_transform(batch_size);
|
||||
|
||||
let output = affine_grid_2d(transform, [batch_size, channels, height, width]);
|
||||
|
||||
// Expected normalized coords:
|
||||
// [-1, -1], [ 1,-1]
|
||||
// [-1, 1], [ 1, 1]
|
||||
let expected = TestTensor::<4>::from([[
|
||||
[[-1f32, -1f32], [1f32, -1f32]],
|
||||
[[-1f32, 1f32], [1f32, 1f32]],
|
||||
]]);
|
||||
|
||||
output.into_data().assert_eq(&expected.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_affine_grid_scaling() {
|
||||
let batch_size = 1;
|
||||
let channels = 1;
|
||||
let height = 2;
|
||||
let width = 2;
|
||||
|
||||
let scale = 2.0f32;
|
||||
let transform = TestTensor::<3>::from([[[scale, 0., 0.], [0., scale, 0.]]]);
|
||||
|
||||
let output = affine_grid_2d(transform, [batch_size, channels, height, width]);
|
||||
|
||||
// Expect scaled coordinates from normalized grid, so coords * 2
|
||||
let expected = TestTensor::<4>::from([[
|
||||
[[-2f32, -2f32], [2f32, -2f32]],
|
||||
[[-2f32, 2f32], [2f32, 2f32]],
|
||||
]]);
|
||||
|
||||
output.into_data().assert_eq(&expected.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_affine_grid_translation() {
|
||||
let batch_size = 1;
|
||||
let channels = 1;
|
||||
let height = 2;
|
||||
let width = 2;
|
||||
|
||||
// Translate by 0.5 in x and -0.5 in y (normalized coords)
|
||||
let tx = 0.5f32;
|
||||
let ty = -0.5f32;
|
||||
|
||||
let transform = TestTensor::<3>::from([[[1.0, 0.0, tx], [0.0, 1.0, ty]]]);
|
||||
|
||||
let output = affine_grid_2d(transform, [batch_size, channels, height, width]);
|
||||
|
||||
// Expected coordinates:
|
||||
// Original normalized coords are [-1,1] in x and y
|
||||
// After translation, each coordinate shifts by tx and ty
|
||||
// So points become:
|
||||
// [-1 + 0.5, -1 - 0.5] = [-0.5, -1.5]
|
||||
// [ 1 + 0.5, -1 - 0.5] = [1.5, -1.5]
|
||||
// [-1 + 0.5, 1 - 0.5] = [-0.5, 0.5]
|
||||
// [ 1 + 0.5, 1 - 0.5] = [1.5, 0.5]
|
||||
|
||||
let expected = TestTensor::<4>::from([[
|
||||
[[-0.5f32, -1.5f32], [1.5f32, -1.5f32]],
|
||||
[[-0.5f32, 0.5f32], [1.5f32, 0.5f32]],
|
||||
]]);
|
||||
|
||||
output.into_data().assert_eq(&expected.into_data(), false);
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
use super::*;
|
||||
use burn_tensor::BasicOps;
|
||||
use burn_tensor::Tensor;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::grid::{
|
||||
GridIndexing, GridOptions, GridSparsity, IndexPos, meshgrid, meshgrid_stack,
|
||||
};
|
||||
|
||||
fn assert_tensors_equal<const N: usize, B: Backend, K>(
|
||||
actual: &[Tensor<B, N, K>; N],
|
||||
expected: &[Tensor<B, N, K>; N],
|
||||
) where
|
||||
K: BasicOps<B>,
|
||||
{
|
||||
for (a, e) in actual.iter().zip(expected.iter()) {
|
||||
a.clone()
|
||||
.into_data()
|
||||
.assert_eq(&e.clone().into_data(), true);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_meshgrid() {
|
||||
let x = TestTensor::<1>::from([1, 2, 3, 4]);
|
||||
let y = TestTensor::<1>::from([5, 6]);
|
||||
let z = TestTensor::<1>::from([7, 8]);
|
||||
|
||||
let grid_shape = [x.dims()[0], y.dims()[0], z.dims()[0]];
|
||||
|
||||
// 3D, Dense, Matrix
|
||||
assert_tensors_equal(
|
||||
&meshgrid(&[x.clone(), y.clone(), z.clone()], GridOptions::default()),
|
||||
&[
|
||||
x.clone().reshape([4, 1, 1]).expand(grid_shape),
|
||||
y.clone().reshape([1, 2, 1]).expand(grid_shape),
|
||||
z.clone().reshape([1, 1, 2]).expand(grid_shape),
|
||||
],
|
||||
);
|
||||
assert_tensors_equal(
|
||||
&meshgrid(&[x.clone(), y.clone(), z.clone()], GridSparsity::Dense),
|
||||
&[
|
||||
x.clone().reshape([4, 1, 1]).expand(grid_shape),
|
||||
y.clone().reshape([1, 2, 1]).expand(grid_shape),
|
||||
z.clone().reshape([1, 1, 2]).expand(grid_shape),
|
||||
],
|
||||
);
|
||||
assert_tensors_equal(
|
||||
&meshgrid(&[x.clone(), y.clone(), z.clone()], GridIndexing::Matrix),
|
||||
&[
|
||||
x.clone().reshape([4, 1, 1]).expand(grid_shape),
|
||||
y.clone().reshape([1, 2, 1]).expand(grid_shape),
|
||||
z.clone().reshape([1, 1, 2]).expand(grid_shape),
|
||||
],
|
||||
);
|
||||
|
||||
// 3D, Sparse, Matrix
|
||||
assert_tensors_equal(
|
||||
&meshgrid(
|
||||
&[x.clone(), y.clone(), z.clone()],
|
||||
GridOptions {
|
||||
indexing: GridIndexing::Matrix,
|
||||
sparsity: GridSparsity::Sparse,
|
||||
},
|
||||
),
|
||||
&[
|
||||
x.clone().reshape([4, 1, 1]),
|
||||
y.clone().reshape([1, 2, 1]),
|
||||
z.clone().reshape([1, 1, 2]),
|
||||
],
|
||||
);
|
||||
assert_tensors_equal(
|
||||
&meshgrid(&[x.clone(), y.clone(), z.clone()], GridSparsity::Sparse),
|
||||
&[
|
||||
x.clone().reshape([4, 1, 1]),
|
||||
y.clone().reshape([1, 2, 1]),
|
||||
z.clone().reshape([1, 1, 2]),
|
||||
],
|
||||
);
|
||||
|
||||
// 3D, Dense, Cartesian
|
||||
assert_tensors_equal(
|
||||
&meshgrid(&[x.clone(), y.clone(), z.clone()], GridIndexing::Cartesian),
|
||||
&[
|
||||
x.clone()
|
||||
.reshape([4, 1, 1])
|
||||
.expand(grid_shape)
|
||||
.swap_dims(0, 1),
|
||||
y.clone()
|
||||
.reshape([1, 2, 1])
|
||||
.expand(grid_shape)
|
||||
.swap_dims(0, 1),
|
||||
z.clone()
|
||||
.reshape([1, 1, 2])
|
||||
.expand(grid_shape)
|
||||
.swap_dims(0, 1),
|
||||
],
|
||||
);
|
||||
|
||||
// 3D, Sparse, Cartesian
|
||||
assert_tensors_equal(
|
||||
&meshgrid(
|
||||
&[x.clone(), y.clone(), z.clone()],
|
||||
GridOptions::new(GridIndexing::Cartesian, GridSparsity::Sparse),
|
||||
),
|
||||
&[
|
||||
x.clone().reshape([4, 1, 1]).swap_dims(0, 1),
|
||||
y.clone().reshape([1, 2, 1]).swap_dims(0, 1),
|
||||
z.clone().reshape([1, 1, 2]).swap_dims(0, 1),
|
||||
],
|
||||
);
|
||||
assert_tensors_equal(
|
||||
&meshgrid(
|
||||
&[x.clone(), y.clone(), z.clone()],
|
||||
GridOptions {
|
||||
indexing: GridIndexing::Cartesian,
|
||||
sparsity: GridSparsity::Sparse,
|
||||
},
|
||||
),
|
||||
&[
|
||||
x.clone().reshape([4, 1, 1]).swap_dims(0, 1),
|
||||
y.clone().reshape([1, 2, 1]).swap_dims(0, 1),
|
||||
z.clone().reshape([1, 1, 2]).swap_dims(0, 1),
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_meshgrid_stack() {
|
||||
let tensors = [
|
||||
TestTensor::from([0.5, 1.0, 2.5]),
|
||||
TestTensor::from([0.5, 1.0]),
|
||||
];
|
||||
|
||||
let result: Tensor<_, 3> = meshgrid_stack(&tensors, IndexPos::First);
|
||||
result.to_data().assert_eq(
|
||||
&TensorData::from([
|
||||
[[0.5, 0.5], [1.0, 1.0], [2.5, 2.5]],
|
||||
[[0.5, 1.0], [0.5, 1.0], [0.5, 1.0]],
|
||||
]),
|
||||
false,
|
||||
);
|
||||
|
||||
let result: Tensor<_, 3> = meshgrid_stack(&tensors, IndexPos::Last);
|
||||
result.to_data().assert_eq(
|
||||
&TensorData::from([
|
||||
[[0.5, 0.5], [0.5, 1.0]],
|
||||
[[1.0, 0.5], [1.0, 1.0]],
|
||||
[[2.5, 0.5], [2.5, 1.0]],
|
||||
]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
use super::*;
|
||||
|
||||
pub(crate) mod affine_grid;
|
||||
pub(crate) mod meshgrid;
|
||||
@@ -0,0 +1,100 @@
|
||||
use super::*;
|
||||
use burn_tensor::{ElementConversion, Tolerance};
|
||||
use burn_tensor::{TensorData, linalg};
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity_basic() {
|
||||
// Create test tensors
|
||||
let x1 = TestTensor::<2>::from([[1.0, 2.0, 3.0], [0.5, 1.5, 2.5]]);
|
||||
let x2 = TestTensor::<2>::from([[1.5, 2.5, 3.5], [0.7, 1.7, 2.7]]);
|
||||
|
||||
// Test cosine similarity along dimension 1
|
||||
let expected = TensorData::from([[0.99983203], [0.99987257]]);
|
||||
linalg::cosine_similarity(x1.clone(), x2.clone(), 1, None)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
// Test with explicit epsilon
|
||||
linalg::cosine_similarity(x1.clone(), x2.clone(), 1, Some(1e-8.elem::<FloatElem>()))
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity_orthogonal() {
|
||||
// Create orthogonal vectors
|
||||
let x1 = TestTensor::<2>::from([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]);
|
||||
let x2 = TestTensor::<2>::from([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
|
||||
|
||||
// Orthogonal vectors should have cosine similarity of 0
|
||||
let expected = TensorData::from([[0.0], [0.0]]);
|
||||
linalg::cosine_similarity(x1, x2, 1, None)
|
||||
.into_data()
|
||||
.assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity_parallel() {
|
||||
// Create parallel vectors
|
||||
let x1 = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
|
||||
let x2 = TestTensor::<2>::from([[2.0, 4.0, 6.0], [8.0, 10.0, 12.0]]);
|
||||
|
||||
// Parallel vectors should have cosine similarity of 1
|
||||
let expected = TensorData::from([[1.0], [1.0]]);
|
||||
linalg::cosine_similarity(x1, x2, 1, None)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity_opposite() {
|
||||
// Create opposite direction vectors
|
||||
let x1 = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
|
||||
let x2 = TestTensor::<2>::from([[-1.0, -2.0, -3.0], [-4.0, -5.0, -6.0]]);
|
||||
|
||||
// Opposite vectors should have cosine similarity of -1
|
||||
let expected = TensorData::from([[-1.0], [-1.0]]);
|
||||
linalg::cosine_similarity(x1, x2, 1, None)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity_different_dimension() {
|
||||
// Test with a 3D tensor
|
||||
let x1 = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]);
|
||||
let x2 = TestTensor::<3>::from([[[2.0, 3.0], [4.0, 5.0]], [[6.0, 7.0], [8.0, 9.0]]]);
|
||||
|
||||
// Test along dimension 2
|
||||
let expected = TensorData::from([[[0.9959688], [0.9958376]], [[0.9955946], [0.9955169]]]);
|
||||
|
||||
// sensitive to rounding in dot/norm; loosen f16 tolerance
|
||||
let tolerance = Tolerance::default().set_half_precision_relative(7e-3);
|
||||
|
||||
linalg::cosine_similarity(x1.clone(), x2.clone(), 2, None)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
// Test with negative dimension (-1 is the last dimension, which is 2 in this case)
|
||||
linalg::cosine_similarity(x1.clone(), x2.clone(), -1, None)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity_near_zero() {
|
||||
// Test with near-zero vectors
|
||||
let x1 = TestTensor::<2>::from([[1e-10, 2e-10, 3e-10], [4e-10, 5e-10, 6e-10]]);
|
||||
let x2 = TestTensor::<2>::from([[2e-10, 4e-10, 6e-10], [8e-10, 10e-10, 12e-10]]);
|
||||
|
||||
// Update the expected values based on the actual implementation behavior
|
||||
let expected = TensorData::from([[0.0028], [0.0154]]);
|
||||
|
||||
// Smaller values result in NaN on metal f16
|
||||
let epsilon = Some(FloatElem::from_elem(1e-2));
|
||||
let tolerance = Tolerance::absolute(0.2);
|
||||
|
||||
linalg::cosine_similarity(x1, x2, 1, epsilon)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
@@ -0,0 +1,259 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, linalg::diag};
|
||||
|
||||
#[test]
|
||||
fn test_diag_2d_square() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
|
||||
let result = diag::<_, 2, 1, _>(tensor);
|
||||
let expected = TensorData::from([1.0, 4.0]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diag_2d_tall() {
|
||||
let device = Default::default();
|
||||
// 4x2 matrix (tall) - min(4,2) = 2 diagonal elements
|
||||
let tensor =
|
||||
TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], &device);
|
||||
let result = diag::<_, 2, 1, _>(tensor);
|
||||
// Result should have shape [2] with values [1.0, 4.0]
|
||||
let expected = TensorData::from([1.0, 4.0]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diag_2d_wide() {
|
||||
let device = Default::default();
|
||||
// 2x4 matrix (wide) - min(2,4) = 2 diagonal elements
|
||||
let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], &device);
|
||||
let result = diag::<_, 2, 1, _>(tensor);
|
||||
// Result should have shape [2] with values [1.0, 6.0]
|
||||
let expected = TensorData::from([1.0, 6.0]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diag_3d_batch_square() {
|
||||
let device = Default::default();
|
||||
// Batch of 2 matrices, each 2x2
|
||||
let tensor = TestTensor::<3>::from_data(
|
||||
[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],
|
||||
&device,
|
||||
);
|
||||
let result = diag::<_, 3, 2, _>(tensor);
|
||||
// Result should have shape [2, 2]
|
||||
let expected = TensorData::from([[1.0, 4.0], [5.0, 8.0]]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diag_3d_batch_tall() {
|
||||
let device = Default::default();
|
||||
// Batch of 2 matrices, each 3x2 (tall)
|
||||
let tensor = TestTensor::<3>::from_data(
|
||||
[
|
||||
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
|
||||
[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
let result = diag::<_, 3, 2, _>(tensor);
|
||||
// Result should have shape [2, 2] - min(3,2) = 2 diagonal elements each
|
||||
let expected = TensorData::from([[1.0, 4.0], [7.0, 10.0]]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diag_3d_batch_wide() {
|
||||
let device = Default::default();
|
||||
// Batch of 2 matrices, each 2x3 (wide)
|
||||
let tensor = TestTensor::<3>::from_data(
|
||||
[
|
||||
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
|
||||
[[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
let result = diag::<_, 3, 2, _>(tensor);
|
||||
// Result should have shape [2, 2] - min(2,3) = 2 diagonal elements each
|
||||
let expected = TensorData::from([[1.0, 5.0], [7.0, 11.0]]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diag_4d_batch_channel_square() {
|
||||
let device = Default::default();
|
||||
// [batch=2, channel=2, rows=2, cols=2]
|
||||
let tensor = TestTensor::<4>::from_data(
|
||||
[
|
||||
[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],
|
||||
[[[9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0]]],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
let result = diag::<_, 4, 3, _>(tensor);
|
||||
// Result should have shape [2, 2, 2]
|
||||
let expected = TensorData::from([[[1.0, 4.0], [5.0, 8.0]], [[9.0, 12.0], [13.0, 16.0]]]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diag_4d_batch_channel_tall() {
|
||||
let device = Default::default();
|
||||
// [batch=2, channel=1, rows=3, cols=2]
|
||||
let tensor = TestTensor::<4>::from_data(
|
||||
[
|
||||
[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]],
|
||||
[[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
let result = diag::<_, 4, 3, _>(tensor);
|
||||
// Result should have shape [2, 1, 2] - min(3,2) = 2 diagonal elements each
|
||||
let expected = TensorData::from([[[1.0, 4.0]], [[7.0, 10.0]]]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diag_4d_batch_channel_wide() {
|
||||
let device = Default::default();
|
||||
// [batch=1, channel=2, rows=2, cols=4]
|
||||
let tensor = TestTensor::<4>::from_data(
|
||||
[[
|
||||
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
|
||||
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
let result = diag::<_, 4, 3, _>(tensor);
|
||||
// Result should have shape [1, 2, 2] - min(2,4) = 2 diagonal elements each
|
||||
let expected = TensorData::from([[[1.0, 6.0], [9.0, 14.0]]]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diag_1x1() {
|
||||
let device = Default::default();
|
||||
// Single element matrix
|
||||
let tensor = TestTensor::<2>::from_data([[5.0]], &device);
|
||||
let result = diag::<_, 2, 1, _>(tensor);
|
||||
// Should return [5.0] with shape [1]
|
||||
let expected = TensorData::from([5.0]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diag_single_row() {
|
||||
let device = Default::default();
|
||||
// Single row matrix
|
||||
let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device);
|
||||
let result = diag::<_, 2, 1, _>(tensor);
|
||||
// min(1,3) = 1, should return [1.0] with shape [1]
|
||||
let expected = TensorData::from([1.0]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diag_single_column() {
|
||||
let device = Default::default();
|
||||
// Single column matrix
|
||||
let tensor = TestTensor::<2>::from_data([[1.0], [2.0], [3.0]], &device);
|
||||
let result = diag::<_, 2, 1, _>(tensor);
|
||||
// min(3,1) = 1, should return [1.0] with shape [1]
|
||||
let expected = TensorData::from([1.0]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diag_zeros() {
|
||||
let device = Default::default();
|
||||
// Matrix with zeros on diagonal
|
||||
let tensor = TestTensor::<2>::from_data([[0.0, 1.0], [2.0, 0.0]], &device);
|
||||
let result = diag::<_, 2, 1, _>(tensor);
|
||||
// Should extract diagonal: [0.0, 0.0]
|
||||
let expected = TensorData::from([0.0, 0.0]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diag_batch_single_element() {
|
||||
let device = Default::default();
|
||||
// Batch with single element matrices
|
||||
let tensor = TestTensor::<3>::from_data([[[5.0]], [[7.0]]], &device);
|
||||
let result = diag::<_, 3, 2, _>(tensor);
|
||||
// Should return [[5.0], [7.0]] with shape [2, 1]
|
||||
let expected = TensorData::from([[5.0], [7.0]]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diag_batch_mixed_zeros() {
|
||||
let device = Default::default();
|
||||
// Batch with mixed zero and non-zero diagonal elements
|
||||
let tensor = TestTensor::<3>::from_data(
|
||||
[[[1.0, 2.0], [3.0, 0.0]], [[0.0, 5.0], [6.0, 7.0]]],
|
||||
&device,
|
||||
);
|
||||
let result = diag::<_, 3, 2, _>(tensor);
|
||||
// Should return [[1.0, 0.0], [0.0, 7.0]] with shape [2, 2]
|
||||
let expected = TensorData::from([[1.0, 0.0], [0.0, 7.0]]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diag_int_tensor() {
|
||||
let device = Default::default();
|
||||
// Test with integer tensor
|
||||
let tensor = TestTensorInt::<2>::from_data([[1, 2], [3, 4]], &device);
|
||||
let result = diag::<_, 2, 1, _>(tensor);
|
||||
// Result should have shape [2] with values [1, 4]
|
||||
let expected = TensorData::from([1, 4]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diag_int_3x3() {
|
||||
let device = Default::default();
|
||||
// Test with 3x3 integer matrix
|
||||
let tensor = TestTensorInt::<2>::from_data([[1, 2, 3], [4, 5, 6], [7, 8, 9]], &device);
|
||||
let result = diag::<_, 2, 1, _>(tensor);
|
||||
// Result should have shape [3] with values [1, 5, 9]
|
||||
let expected = TensorData::from([1, 5, 9]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_diag_1d_should_panic() {
|
||||
let device = Default::default();
|
||||
// 1D tensor should panic - diagonal requires at least 2 dimensions
|
||||
let tensor = TestTensor::<1>::from_data([1.0, 2.0, 3.0], &device);
|
||||
let _result = diag::<_, 1, 0, _>(tensor);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_diag_wrong_output_rank_should_panic() {
|
||||
let device = Default::default();
|
||||
// Providing wrong output rank should panic
|
||||
let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
|
||||
let _result = diag::<_, 2, 2, _>(tensor); // Should be 2,1 not 2,2
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
use super::*;
|
||||
use burn_tensor::{
|
||||
Distribution, Shape, TensorData, Tolerance, cast::ToElement, linalg::lu_decomposition, s,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_lu_2x2_decomposition() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensor::<2>::from_data([[4.0, 3.0], [6.0, 3.0]], &device);
|
||||
let (result, _permutations) = lu_decomposition(tensor);
|
||||
let expected = TensorData::from([[6.0, 3.0], [2.0 / 3.0, 1.0]]);
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lu_3x3_decomposition() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensor::<2>::from_data(
|
||||
[[0.0, 5.0, 22.0 / 3.0], [4.0, 2.0, 1.0], [2.0, 7.0, 9.0]],
|
||||
&device,
|
||||
);
|
||||
let (result, permutations) = lu_decomposition(tensor);
|
||||
let expected = TestTensor::<2>::from_data(
|
||||
[
|
||||
[4.0, 2.0, 1.0],
|
||||
[0.5, 6.0, 8.5],
|
||||
[0.0, 0.8333333, 0.25000048],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
let expected_permutations = TensorData::from([1, 2, 0]);
|
||||
permutations
|
||||
.into_data()
|
||||
.assert_eq(&expected_permutations, false);
|
||||
|
||||
let tolerance = Tolerance::default().set_half_precision_absolute(5e-3);
|
||||
result
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected.into_data(), tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_lu_singular_matrix() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [2.0, 4.0]], &device);
|
||||
let _result = lu_decomposition(tensor);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_lu_non_square_matrix() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
|
||||
let _result = lu_decomposition(tensor);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lu_1x1_element_matrix() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensor::<2>::from_data([[5.0]], &device);
|
||||
let (result, _permutations) = lu_decomposition(tensor);
|
||||
let expected = TensorData::from([[5.0]]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lu_identity_matrix() {
|
||||
let device = Default::default();
|
||||
|
||||
let tensor = TestTensor::<2>::eye(4, &device);
|
||||
let (result, _permutations) = lu_decomposition(tensor);
|
||||
let expected = TestTensor::<2>::eye(4, &device);
|
||||
result.into_data().assert_eq(&expected.into_data(), true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lu_50x50_random_matrix() {
|
||||
let device = Default::default();
|
||||
let size = 50;
|
||||
let distribution = Distribution::Uniform(0.0, 1.0);
|
||||
let tensor = TestTensor::<2>::random(Shape::new([size, size]), distribution, &device);
|
||||
let (result, permutations) = lu_decomposition(tensor.clone());
|
||||
// Reconstruct the original matrix from L and U
|
||||
let mut l = TestTensor::<2>::eye(size, &device);
|
||||
let mut u = TestTensor::<2>::zeros(Shape::new([size, size]), &device);
|
||||
|
||||
for i in 0..size {
|
||||
for j in 0..size {
|
||||
if i > j {
|
||||
l = l.slice_assign(s![i, j], result.clone().slice(s![i, j]));
|
||||
} else {
|
||||
u = u.slice_assign(s![i, j], result.clone().slice(s![i, j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
// Construct the permutation matrix P from the permutation vector
|
||||
let mut p = TestTensor::<2>::zeros(Shape::new([size, size]), &device);
|
||||
for i in 0..size {
|
||||
let perm_index = permutations.clone().slice(s![i]).into_scalar().to_usize();
|
||||
p = p.slice_assign(
|
||||
s![perm_index, i],
|
||||
TestTensor::<2>::from_data([[1.0]], &device),
|
||||
);
|
||||
}
|
||||
|
||||
// Verify that P * L * U reconstructs the original matrix
|
||||
let reconstructed = p.matmul(l).matmul(u);
|
||||
reconstructed
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&tensor.into_data(), Tolerance::permissive());
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, Tolerance, linalg};
|
||||
|
||||
#[test]
|
||||
fn test_matvec_basic_float() {
|
||||
let device = Default::default();
|
||||
let matrix = TestTensor::<2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
|
||||
let vector = TestTensor::<1>::from_floats([5.0, 6.0], &device);
|
||||
|
||||
let result = linalg::matvec::<TestBackend, 2, 1, _>(matrix, vector);
|
||||
let expected = TensorData::from([17.0, 39.0]);
|
||||
|
||||
result
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matvec_basic_int() {
|
||||
let device = Default::default();
|
||||
let matrix = TestTensorInt::<2>::from_ints([[2, 0, -1], [1, 3, 2]], &device);
|
||||
let vector = TestTensorInt::<1>::from_ints([3, -2, 4], &device);
|
||||
|
||||
let result = linalg::matvec::<TestBackend, 2, 1, _>(matrix, vector);
|
||||
let expected = TensorData::from([2, 5]);
|
||||
|
||||
result.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matvec_batched() {
|
||||
let device = Default::default();
|
||||
let matrix = TestTensor::<3>::from_floats(
|
||||
[
|
||||
[[1.0, 0.0, 2.0], [3.0, 1.0, -1.0]],
|
||||
[[-2.0, 1.0, 0.0], [0.5, -1.5, 2.0]],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
let vector = TestTensor::<2>::from_floats([[1.0, -1.0, 0.5], [2.0, 0.0, -1.0]], &device);
|
||||
|
||||
let result = linalg::matvec::<TestBackend, 3, 2, _>(matrix, vector);
|
||||
let expected = TensorData::from([[2.0, 1.5], [-4.0, -1.0]]);
|
||||
|
||||
result
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matvec_vector_broadcasts_over_batches() {
|
||||
let device = Default::default();
|
||||
let matrix = TestTensor::<3>::from_floats(
|
||||
[
|
||||
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
|
||||
[[-1.0, 0.0, 2.0], [3.0, 1.0, -2.0]],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
let vector = TestTensor::<2>::from_floats([[1.0, 0.0, -1.0]], &device);
|
||||
|
||||
let result = linalg::matvec::<TestBackend, 3, 2, _>(matrix, vector);
|
||||
let expected = TensorData::from([[-2.0, -2.0], [-3.0, 5.0]]);
|
||||
|
||||
result
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matvec_matrix_broadcasts_over_vector_batches() {
|
||||
let device = Default::default();
|
||||
let matrix = TestTensor::<3>::from_floats([[[1.0, 0.0, 2.0], [3.0, -1.0, 1.0]]], &device);
|
||||
let vector = TestTensor::<2>::from_floats([[2.0, 1.0, 0.0], [1.0, -1.0, 3.0]], &device);
|
||||
|
||||
let result = linalg::matvec::<TestBackend, 3, 2, _>(matrix, vector);
|
||||
let expected = TensorData::from([[2.0, 5.0], [7.0, 7.0]]);
|
||||
|
||||
result
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_matvec_invalid_inner_dim_panics() {
|
||||
let device = Default::default();
|
||||
let matrix = TestTensor::<2>::zeros([2, 3], &device);
|
||||
let vector = TestTensor::<1>::zeros([4], &device);
|
||||
|
||||
let _ = linalg::matvec::<TestBackend, 2, 1, _>(matrix, vector);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_matvec_mismatched_batches_panics() {
|
||||
let device = Default::default();
|
||||
let matrix = TestTensor::<3>::zeros([2, 3, 4], &device);
|
||||
let vector = TestTensor::<2>::zeros([3, 4], &device);
|
||||
|
||||
let _ = linalg::matvec::<TestBackend, 3, 2, _>(matrix, vector);
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
use super::*;
|
||||
|
||||
pub(crate) mod cosine_similarity;
|
||||
pub(crate) mod diag;
|
||||
pub(crate) mod lu_decomposition;
|
||||
pub(crate) mod matvec;
|
||||
pub(crate) mod outer;
|
||||
pub(crate) mod trace;
|
||||
pub(crate) mod vector_norm;
|
||||
@@ -0,0 +1,262 @@
|
||||
use super::*;
|
||||
use burn_tensor::{ElementConversion, Tolerance};
|
||||
use burn_tensor::{TensorData, linalg};
|
||||
|
||||
// ---------- Vector (D=1, R=2) tests ----------
|
||||
|
||||
#[test]
|
||||
fn test_outer_basic() {
|
||||
let u = TestTensor::<1>::from([1.0, 2.0, 3.0]);
|
||||
let v = TestTensor::<1>::from([4.0, 5.0]);
|
||||
|
||||
let out = linalg::outer::<TestBackend, 1, 2, _>(u, v).into_data();
|
||||
let expected = TensorData::from([[4.0, 5.0], [8.0, 10.0], [12.0, 15.0]]);
|
||||
|
||||
out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outer_shapes_only() {
|
||||
let device = Default::default();
|
||||
let u = TestTensor::<1>::zeros([3], &device);
|
||||
let v = TestTensor::<1>::zeros([5], &device);
|
||||
let out = linalg::outer::<TestBackend, 1, 2, _>(u, v);
|
||||
assert_eq!(out.shape().dims(), [3, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outer_asymmetry_and_shapes() {
|
||||
let u = TestTensor::<1>::from([1.0, 2.0]);
|
||||
let v = TestTensor::<1>::from([3.0, 4.0, 5.0]);
|
||||
|
||||
let uv = linalg::outer::<TestBackend, 1, 2, _>(u.clone(), v.clone());
|
||||
let vu = linalg::outer::<TestBackend, 1, 2, _>(v, u);
|
||||
|
||||
assert_eq!(uv.shape().dims(), [2, 3]);
|
||||
assert_eq!(vu.shape().dims(), [3, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outer_zero_left() {
|
||||
let device = Default::default();
|
||||
let u = TestTensor::<1>::zeros([3], &device);
|
||||
let v = TestTensor::<1>::from([7.0, 8.0]);
|
||||
|
||||
let out = linalg::outer::<TestBackend, 1, 2, _>(u, v).into_data();
|
||||
let expected = TensorData::zeros::<FloatElem, _>([3, 2]);
|
||||
|
||||
out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outer_zero_right() {
|
||||
let device = Default::default();
|
||||
let u = TestTensor::<1>::from([1.0, -2.0, 3.0]);
|
||||
let v = TestTensor::<1>::zeros([4], &device);
|
||||
|
||||
let out = linalg::outer::<TestBackend, 1, 2, _>(u, v).into_data();
|
||||
let expected = TensorData::zeros::<FloatElem, _>([3, 4]);
|
||||
|
||||
out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outer_signs() {
|
||||
let u = TestTensor::<1>::from([-1.0, 2.0]);
|
||||
let v = TestTensor::<1>::from([3.0, -4.0]);
|
||||
|
||||
let out = linalg::outer::<TestBackend, 1, 2, _>(u, v).into_data();
|
||||
let expected = TensorData::from([[-3.0, 4.0], [6.0, -8.0]]);
|
||||
|
||||
out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outer_integer_inputs() {
|
||||
let u = TestTensorInt::<1>::from([1, 2, 3]);
|
||||
let v = TestTensorInt::<1>::from([4, 5]);
|
||||
|
||||
let out = linalg::outer::<TestBackend, 1, 2, _>(u, v).into_data();
|
||||
let expected = TensorData::from([[4, 5], [8, 10], [12, 15]]);
|
||||
|
||||
out.assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outer_equivalence_to_matmul() {
|
||||
let u = TestTensor::<1>::from([1.0, 2.0, 3.0]);
|
||||
let v = TestTensor::<1>::from([4.0, 5.0]);
|
||||
|
||||
let out = linalg::outer::<TestBackend, 1, 2, _>(u.clone(), v.clone()).into_data();
|
||||
|
||||
let u2 = u.reshape([3, 1]);
|
||||
let v2 = v.reshape([1, 2]);
|
||||
let out_matmul = u2.matmul(v2).into_data();
|
||||
|
||||
out.assert_approx_eq::<FloatElem>(&out_matmul, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outer_vector_identity_right_mult() {
|
||||
let u = TestTensor::<1>::from([2.0, -1.0]);
|
||||
let v = TestTensor::<1>::from([3.0, 4.0]);
|
||||
let w = TestTensor::<1>::from([5.0, 6.0]);
|
||||
|
||||
let uv = linalg::outer::<TestBackend, 1, 2, _>(u.clone(), v.clone());
|
||||
let left = uv.matmul(w.clone().reshape([2, 1])).reshape([2]);
|
||||
|
||||
let v_dot_w = v.dot(w);
|
||||
let right = u * v_dot_w;
|
||||
|
||||
left.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&right.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outer_length_one_vectors() {
|
||||
let u = TestTensor::<1>::from([3.0]);
|
||||
let v = TestTensor::<1>::from([4.0, 5.0, 6.0]);
|
||||
|
||||
let out = linalg::outer::<TestBackend, 1, 2, _>(u, v).into_data();
|
||||
let expected = TensorData::from([[12.0, 15.0, 18.0]]);
|
||||
|
||||
out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outer_large_values() {
|
||||
let big = 1.0e10;
|
||||
let u = TestTensor::<1>::from([big, -big]);
|
||||
let v = TestTensor::<1>::from([big, big]);
|
||||
|
||||
let out = linalg::outer::<TestBackend, 1, 2, _>(u, v).into_data();
|
||||
let expected = TensorData::from([[big * big, big * big], [-big * big, -big * big]]);
|
||||
|
||||
let tol = Tolerance::relative(1e-6).set_half_precision_relative(1e-3);
|
||||
out.assert_approx_eq::<FloatElem>(&expected, tol);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outer_nan_propagation() {
|
||||
let u = TestTensor::<1>::from([f32::NAN, 2.0]);
|
||||
let v = TestTensor::<1>::from([3.0, 4.0]);
|
||||
|
||||
let out = linalg::outer::<TestBackend, 1, 2, _>(u, v).into_data();
|
||||
|
||||
let s: &[FloatElem] = out
|
||||
.as_slice::<FloatElem>()
|
||||
.expect("outer nan_propagation: as_slice failed");
|
||||
|
||||
assert!(s[0].is_nan());
|
||||
assert!(s[1].is_nan());
|
||||
assert_eq!(s[2], 6.0f32.elem::<FloatElem>());
|
||||
assert_eq!(s[3], 8.0f32.elem::<FloatElem>());
|
||||
}
|
||||
|
||||
// ---------- Batched (D=2, R=3) tests ----------
|
||||
|
||||
#[test]
|
||||
fn test_outer_batched_basic() {
|
||||
let x = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let y = TestTensor::<2>::from([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]]);
|
||||
let out = linalg::outer::<TestBackend, 2, 3, _>(x, y).into_data();
|
||||
|
||||
let expected = TensorData::from([
|
||||
[[5.0, 6.0, 7.0], [10.0, 12.0, 14.0]],
|
||||
[[24.0, 27.0, 30.0], [32.0, 36.0, 40.0]],
|
||||
]);
|
||||
|
||||
out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outer_batched_shapes() {
|
||||
let device = Default::default();
|
||||
let x = TestTensor::<2>::zeros([3, 4], &device);
|
||||
let y = TestTensor::<2>::zeros([3, 5], &device);
|
||||
let out = linalg::outer::<TestBackend, 2, 3, _>(x, y);
|
||||
assert_eq!(out.shape().dims(), [3, 4, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outer_batched_zero_left() {
|
||||
let device = Default::default();
|
||||
let x = TestTensor::<2>::zeros([2, 3], &device);
|
||||
let y = TestTensor::<2>::from([[7.0, 8.0], [9.0, 10.0]]);
|
||||
let out = linalg::outer::<TestBackend, 2, 3, _>(x, y).into_data();
|
||||
|
||||
let expected = TensorData::zeros::<FloatElem, _>([2, 3, 2]);
|
||||
out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outer_batched_zero_right() {
|
||||
let device = Default::default();
|
||||
let x = TestTensor::<2>::from([[1.0, -2.0, 3.0], [4.0, 5.0, -6.0]]);
|
||||
let y = TestTensor::<2>::zeros([2, 4], &device);
|
||||
let out = linalg::outer::<TestBackend, 2, 3, _>(x, y).into_data();
|
||||
|
||||
let expected = TensorData::zeros::<FloatElem, _>([2, 3, 4]);
|
||||
out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outer_batched_signs() {
|
||||
let x = TestTensor::<2>::from([[-1.0, 2.0], [3.0, -4.0]]);
|
||||
let y = TestTensor::<2>::from([[3.0, -4.0], [-5.0, 6.0]]);
|
||||
let out = linalg::outer::<TestBackend, 2, 3, _>(x, y).into_data();
|
||||
|
||||
let expected = TensorData::from([[[-3.0, 4.0], [6.0, -8.0]], [[-15.0, 18.0], [20.0, -24.0]]]);
|
||||
|
||||
out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outer_batched_equivalence_to_per_sample_outer() {
|
||||
let x = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let y = TestTensor::<2>::from([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]]);
|
||||
let batched = linalg::outer::<TestBackend, 2, 3, _>(x.clone(), y.clone());
|
||||
|
||||
for b in 0..2 {
|
||||
let idx = TestTensorInt::<1>::from([b]);
|
||||
|
||||
let xb2d = x.clone().select(0, idx.clone()); // (1, m)
|
||||
let yb2d = y.clone().select(0, idx); // (1, n)
|
||||
|
||||
let dims_x: [usize; 2] = xb2d.shape().dims();
|
||||
let dims_y: [usize; 2] = yb2d.shape().dims();
|
||||
let (m, n) = (dims_x[1], dims_y[1]);
|
||||
|
||||
let per = linalg::outer::<TestBackend, 1, 2, _>(xb2d.reshape([m]), yb2d.reshape([n]));
|
||||
|
||||
let bat3d = batched.clone().select(0, TestTensorInt::<1>::from([b])); // (m, n)
|
||||
|
||||
let per_len = per.shape().num_elements();
|
||||
let per_flat = per.reshape([per_len]).into_data();
|
||||
|
||||
let bat_len = bat3d.shape().num_elements();
|
||||
let bat_flat = bat3d.reshape([bat_len]).into_data();
|
||||
|
||||
bat_flat.assert_approx_eq::<FloatElem>(&per_flat, Tolerance::default());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_outer_batched_mismatched_batches_panics() {
|
||||
let device = Default::default();
|
||||
let x = TestTensor::<2>::zeros([2, 3], &device);
|
||||
let y = TestTensor::<2>::zeros([3, 4], &device);
|
||||
let _ = linalg::outer::<TestBackend, 2, 3, _>(x, y);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outer_dim() {
|
||||
let u = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let v = TestTensor::<2>::from([[4.0, 5.0], [5.0, 6.0]]);
|
||||
|
||||
let out = linalg::outer_dim::<TestBackend, 2, 3, _, _>(u, v, 0).into_data();
|
||||
let expected = TensorData::from([[[4.0, 10.0], [5.0, 12.0]], [[12.0, 20.0], [15.0, 24.0]]]);
|
||||
|
||||
out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
use super::*;
|
||||
use burn_tensor::linalg::trace;
|
||||
|
||||
#[test]
|
||||
fn test_trace_2d_square() {
|
||||
let device = Default::default();
|
||||
let tensor =
|
||||
TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], &device);
|
||||
let result = trace::<_, 2, 1>(tensor);
|
||||
let expected = TestTensor::<1>::from_data([15.0], &device); // 1 + 5 + 9 = 15
|
||||
|
||||
assert_eq!(result.to_data(), expected.to_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trace_2d_rectangular_wide() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], &device);
|
||||
let result = trace::<_, 2, 1>(tensor);
|
||||
let expected = TestTensor::<1>::from_data([7.0], &device); // 1 + 6 = 7
|
||||
|
||||
assert_eq!(result.to_data(), expected.to_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trace_2d_rectangular_tall() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], &device);
|
||||
let result = trace::<_, 2, 1>(tensor);
|
||||
let expected = TestTensor::<1>::from_data([5.0], &device); // 1 + 4 = 5
|
||||
|
||||
assert_eq!(result.to_data(), expected.to_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trace_3d_batch() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensor::<3>::from_data(
|
||||
[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let result = trace::<_, 3, 2>(tensor);
|
||||
let expected = TestTensor::<2>::from_data([[5.0], [13.0]], &device); // [1+4=5, 5+8=13]
|
||||
|
||||
assert_eq!(result.to_data(), expected.to_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trace_4d_batch() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensor::<4>::from_data(
|
||||
[[
|
||||
// Batch 0, Channel 0
|
||||
[[1.0, 2.0], [3.0, 4.0]],
|
||||
// Batch 0, Channel 1
|
||||
[[5.0, 6.0], [7.0, 8.0]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let result = trace::<_, 4, 3>(tensor);
|
||||
let expected = TestTensor::<3>::from_data([[[5.0], [13.0]]], &device);
|
||||
|
||||
assert_eq!(result.to_data(), expected.to_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trace_single_element() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensor::<2>::from_data([[42.0]], &device);
|
||||
let result = trace::<_, 2, 1>(tensor);
|
||||
let expected = TestTensor::<1>::from_data([42.0], &device);
|
||||
|
||||
assert_eq!(result.to_data(), expected.to_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trace_zeros() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensor::<2>::zeros([3, 3], &device);
|
||||
let result = trace::<_, 2, 1>(tensor);
|
||||
let expected = TestTensor::<1>::from_data([0.0], &device);
|
||||
|
||||
assert_eq!(result.to_data(), expected.to_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trace_negative_values() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensor::<2>::from_data([[-1.0, 2.0], [3.0, -4.0]], &device);
|
||||
let result = trace::<_, 2, 1>(tensor);
|
||||
let expected = TestTensor::<1>::from_data([-5.0], &device); // -1 + (-4) = -5
|
||||
|
||||
assert_eq!(result.to_data(), expected.to_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_trace_1d_should_panic() {
|
||||
let device = Default::default();
|
||||
// 1D tensor should panic - trace requires at least 2 dimensions
|
||||
let tensor = TestTensor::<1>::from_data([1.0, 2.0, 3.0], &device);
|
||||
let _result = trace::<_, 1, 0>(tensor);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_trace_wrong_output_rank_should_panic() {
|
||||
let device = Default::default();
|
||||
// Providing wrong output rank should panic
|
||||
let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
|
||||
let _result = trace::<_, 2, 2>(tensor); // Should be 2,1 not 2,2
|
||||
}
|
||||
@@ -0,0 +1,241 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::linalg;
|
||||
|
||||
#[test]
|
||||
fn test_max_min_abs() {
|
||||
let x = TestTensor::<2>::from([[1., 2.], [3., 4.]]);
|
||||
|
||||
let expected = TestTensor::<2>::from([[3., 4.]]).into_data();
|
||||
linalg::vector_norm(x.clone(), linalg::Norm::LInf, 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
linalg::max_abs_norm(x.clone(), 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
|
||||
let expected = TestTensor::<2>::from([[1., 2.]]).into_data();
|
||||
linalg::vector_norm(x.clone(), -f64::INFINITY, 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
linalg::vector_norm(x.clone(), f64::NEG_INFINITY, 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
linalg::min_abs_norm(x.clone(), 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
|
||||
let expected = TestTensor::<2>::from([[2.], [4.]]).into_data();
|
||||
linalg::vector_norm(x.clone(), f64::INFINITY, 1)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
linalg::max_abs_norm(x.clone(), 1)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
|
||||
let expected = TestTensor::<2>::from([[1.], [3.]]).into_data();
|
||||
linalg::vector_norm(x.clone(), -f64::INFINITY, 1)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
linalg::vector_norm(x.clone(), f64::NEG_INFINITY, 1)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
linalg::min_abs_norm(x, 1)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
|
||||
// Test with integer tensor
|
||||
let z = TestTensorInt::<2>::from([[1, 2], [3, 4]]);
|
||||
|
||||
linalg::max_abs_norm(z.clone(), 0)
|
||||
.into_data()
|
||||
.assert_eq(&TestTensorInt::<2>::from([[3, 4]]).into_data(), true);
|
||||
linalg::max_abs_norm(z.clone(), 1)
|
||||
.into_data()
|
||||
.assert_eq(&TestTensorInt::<2>::from([[2], [4]]).into_data(), true);
|
||||
|
||||
linalg::min_abs_norm(z.clone(), 0)
|
||||
.into_data()
|
||||
.assert_eq(&TestTensorInt::<2>::from([[1, 2]]).into_data(), true);
|
||||
linalg::min_abs_norm(z, 1)
|
||||
.into_data()
|
||||
.assert_eq(&TestTensorInt::<2>::from([[1], [3]]).into_data(), true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_l0_norm() {
|
||||
let x = TestTensor::<2>::from([[1.0, -2.0, 0.], [0.0, 0., 4.]]);
|
||||
|
||||
let expected = TestTensor::<2>::from([[1., 1., 1.]]).into_data();
|
||||
linalg::vector_norm(x.clone(), linalg::Norm::L0, 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
linalg::l0_norm(x.clone(), 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
|
||||
let expected = TestTensor::<2>::from([[2.], [1.]]).into_data();
|
||||
linalg::vector_norm(x.clone(), 0.0, 1)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
linalg::l0_norm(x.clone(), 1)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
|
||||
// Test with integer tensor
|
||||
let z = TestTensorInt::<2>::from([[1, -2, 0], [0, 0, 4]]);
|
||||
|
||||
linalg::l0_norm(z.clone(), 0)
|
||||
.into_data()
|
||||
.assert_eq(&TestTensor::<2>::from([[1, 1, 1]]).int().into_data(), true);
|
||||
linalg::l0_norm(z.clone(), 1)
|
||||
.into_data()
|
||||
.assert_eq(&TestTensor::<2>::from([[2], [1]]).int().into_data(), true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_l1_norm() {
|
||||
let x = TestTensor::<2>::from([[1., 2.], [3., 4.]]);
|
||||
|
||||
let expected = TestTensor::<2>::from([[4.0, 6.0]]).into_data();
|
||||
linalg::vector_norm(x.clone(), linalg::Norm::L1, 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
linalg::l1_norm(x.clone(), 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
|
||||
let expected = TestTensor::<2>::from([[3.0], [7.0]]).into_data();
|
||||
linalg::vector_norm(x.clone(), 1.0, 1)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
linalg::l1_norm(x.clone(), 1)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lp_norm() {
|
||||
let x = TestTensor::<2>::from([[1., -2., 0.], [0., 3., 4.]]);
|
||||
let tolerance = Tolerance::relative(1e-5).set_half_precision_relative(2e-3);
|
||||
|
||||
fn lp_norm_naive<B: Backend, const D: usize>(
|
||||
x: Tensor<B, D>,
|
||||
p: f64,
|
||||
dim: usize,
|
||||
) -> Tensor<B, D> {
|
||||
x.abs().powf_scalar(p).sum_dim(dim).powf_scalar(1. / p)
|
||||
}
|
||||
|
||||
// Arbitrary P
|
||||
let expected = TestTensor::<2>::from([[1.0, 3.2710664, 4.0]]).into_data();
|
||||
linalg::vector_norm(x.clone(), 3, 0)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
linalg::lp_norm(x.clone(), 3., 0)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
// L0
|
||||
let expected = TestTensor::<2>::from([[1., 2., 1.]]).into_data();
|
||||
linalg::vector_norm(x.clone(), linalg::Norm::L0, 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
linalg::l0_norm(x.clone(), 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
linalg::lp_norm(x.clone(), 0.0, 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
|
||||
// L1
|
||||
let expected = TestTensor::<2>::from([[1.0, 5.0, 4.0]]).into_data();
|
||||
linalg::vector_norm(x.clone(), linalg::Norm::L1, 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
linalg::l1_norm(x.clone(), 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
lp_norm_naive(x.clone(), 1.0, 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
linalg::lp_norm(x.clone(), 1.0, 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
|
||||
// L2
|
||||
let expected = TestTensor::<2>::from([[1.0, 3.6055512, 4.0]]).into_data();
|
||||
linalg::vector_norm(x.clone(), linalg::Norm::L2, 0)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
linalg::l2_norm(x.clone(), 0)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
lp_norm_naive(x.clone(), 2.0, 0)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
linalg::lp_norm(x.clone(), 2.0, 0)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
// LInf
|
||||
let expected = TestTensor::<2>::from([[1.0, 3.0, 4.0]]).into_data();
|
||||
linalg::vector_norm(x.clone(), linalg::Norm::LInf, 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
linalg::max_abs_norm(x.clone(), 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
linalg::lp_norm(x.clone(), f64::INFINITY, 0)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
// LNegInf
|
||||
let expected = TestTensor::<2>::from([[0.0, 2.0, 0.0]]).into_data();
|
||||
linalg::vector_norm(x.clone(), linalg::Norm::LNegInf, 0)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
linalg::min_abs_norm(x.clone(), 0)
|
||||
.into_data()
|
||||
.assert_eq(&expected, true);
|
||||
linalg::lp_norm(x.clone(), f64::NEG_INFINITY, 0)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_l2_norm() {
|
||||
let x = TestTensor::<2>::from([[1., 2.], [3., 4.]]);
|
||||
let tolerance = Tolerance::relative(1e-5).set_half_precision_relative(1e-3);
|
||||
|
||||
let expected = TestTensor::<2>::from([[3.16227766, 4.47213595]]).into_data();
|
||||
linalg::vector_norm(x.clone(), linalg::Norm::L2, 0)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
linalg::l2_norm(x.clone(), 0)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
|
||||
let expected = TestTensor::<2>::from([[2.23606798], [5.0]]).into_data();
|
||||
linalg::vector_norm(x.clone(), 2.0, 1)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
linalg::l2_norm(x.clone(), 1)
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, tolerance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize() {
|
||||
let x = TestTensor::<2>::from([[1., 2.], [3., 4.]]);
|
||||
|
||||
let expected = TensorData::from([[1. / 4., 2. / 6.], [3. / 4., 4. / 6.]]);
|
||||
let output = linalg::vector_normalize(x.clone(), 1.0, 0, 0.25).into_data();
|
||||
output.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let expected = TensorData::from([[1. / 5., 2. / 6.], [3. / 5., 4. / 6.]]);
|
||||
let output = linalg::vector_normalize(x.clone(), 1.0, 0, 5.0).into_data();
|
||||
output.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
#[allow(unused_imports)]
|
||||
pub use super::*; // re-export test types
|
||||
|
||||
mod activation;
|
||||
mod grid;
|
||||
mod linalg;
|
||||
mod module;
|
||||
mod ops;
|
||||
mod primitive;
|
||||
mod stats;
|
||||
|
||||
#[cfg(feature = "quantization")]
|
||||
mod quantization;
|
||||
@@ -0,0 +1,70 @@
|
||||
use super::*;
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::adaptive_avg_pool1d;
|
||||
|
||||
#[test]
|
||||
fn test_adaptive_avg_pool1d_simple() {
|
||||
let test = AdaptiveAvgPool1dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
length: 8,
|
||||
length_out: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[0.5, 2.5, 4.5, 6.5],
|
||||
[8.5, 10.5, 12.5, 14.5],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adaptive_avg_pool1d_dyn_filter_size() {
|
||||
let test = AdaptiveAvgPool1dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
length: 7,
|
||||
length_out: 3,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[[1.0, 3.0, 5.0], [8.0, 10.0, 12.0]]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adaptive_avg_pool1d_bigger_output() {
|
||||
let test = AdaptiveAvgPool1dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
length: 4,
|
||||
length_out: 8,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0],
|
||||
[4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0],
|
||||
]]));
|
||||
}
|
||||
|
||||
struct AdaptiveAvgPool1dTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
length: usize,
|
||||
length_out: usize,
|
||||
}
|
||||
|
||||
impl AdaptiveAvgPool1dTestCase {
|
||||
fn assert_output(self, y: TestTensor<3>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
|
||||
let device = Default::default();
|
||||
let x = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<3, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
);
|
||||
let output = adaptive_avg_pool1d(x, self.length_out);
|
||||
|
||||
y.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
use super::*;
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::adaptive_avg_pool2d;
|
||||
|
||||
#[test]
|
||||
fn test_adaptive_avg_pool2d_simple() {
|
||||
let test = AdaptiveAvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
height: 8,
|
||||
width: 6,
|
||||
height_out: 4,
|
||||
width_out: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[
|
||||
[3.5000, 4.5000, 6.5000, 7.5000],
|
||||
[15.5000, 16.5000, 18.5000, 19.5000],
|
||||
[27.5000, 28.5000, 30.5000, 31.5000],
|
||||
[39.5000, 40.5000, 42.5000, 43.5000],
|
||||
],
|
||||
[
|
||||
[51.5000, 52.5000, 54.5000, 55.5000],
|
||||
[63.5000, 64.5000, 66.5000, 67.5000],
|
||||
[75.5000, 76.5000, 78.5000, 79.5000],
|
||||
[87.5000, 88.5000, 90.5000, 91.5000],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adaptive_avg_pool2d_dyn_filter_size() {
|
||||
let test = AdaptiveAvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
height: 5,
|
||||
width: 7,
|
||||
height_out: 3,
|
||||
width_out: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[[5.0000, 8.0000], [15.5000, 18.5000], [26.0000, 29.0000]],
|
||||
[[40.0000, 43.0000], [50.5000, 53.5000], [61.0000, 64.0000]],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adaptive_avg_pool2d_bigger_output() {
|
||||
let test = AdaptiveAvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
height: 4,
|
||||
width: 3,
|
||||
height_out: 5,
|
||||
width_out: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[
|
||||
[0.0000, 0.5000, 1.5000, 2.0000],
|
||||
[1.5000, 2.0000, 3.0000, 3.5000],
|
||||
[4.5000, 5.0000, 6.0000, 6.5000],
|
||||
[7.5000, 8.0000, 9.0000, 9.5000],
|
||||
[9.0000, 9.5000, 10.5000, 11.0000],
|
||||
],
|
||||
[
|
||||
[12.0000, 12.5000, 13.5000, 14.0000],
|
||||
[13.5000, 14.0000, 15.0000, 15.5000],
|
||||
[16.5000, 17.0000, 18.0000, 18.5000],
|
||||
[19.5000, 20.0000, 21.0000, 21.5000],
|
||||
[21.0000, 21.5000, 22.5000, 23.0000],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
struct AdaptiveAvgPool2dTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
height_out: usize,
|
||||
width_out: usize,
|
||||
}
|
||||
|
||||
impl AdaptiveAvgPool2dTestCase {
|
||||
fn assert_output(self, y: TestTensor<4>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data(),
|
||||
);
|
||||
let output = adaptive_avg_pool2d(x, [self.height_out, self.width_out]);
|
||||
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,372 @@
|
||||
use super::*;
|
||||
use burn_tensor::Distribution;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::attention;
|
||||
use burn_tensor::module::attention_fallback;
|
||||
use burn_tensor::ops::AttentionModuleOptions;
|
||||
|
||||
#[test]
|
||||
fn test_attention_no_mask() {
|
||||
// Skip on metal with f16 - flash attention returns zeros
|
||||
// Enable once this issue is fixed: https://github.com/tracel-ai/burn/issues/4325
|
||||
#[cfg(feature = "metal")]
|
||||
if core::any::TypeId::of::<FloatElemType>() == core::any::TypeId::of::<burn_tensor::f16>() {
|
||||
return;
|
||||
}
|
||||
|
||||
let num_batches = 1;
|
||||
let num_heads = 1;
|
||||
let seq_q = 128;
|
||||
let seq_kv = 128;
|
||||
let head_dim = 64;
|
||||
let val_dim = 64;
|
||||
|
||||
let query = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_q, head_dim],
|
||||
Distribution::Uniform(0., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let key = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_kv, head_dim],
|
||||
Distribution::Uniform(0., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let value = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_kv, val_dim],
|
||||
Distribution::Uniform(0., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = attention(
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
value.clone(),
|
||||
None,
|
||||
None,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let expected =
|
||||
attention_fallback::<TestBackend>(query, key, value, None, None, Default::default());
|
||||
|
||||
output.into_data().assert_approx_eq::<FloatElem>(
|
||||
&expected.into_data(),
|
||||
Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_custom_scale() {
|
||||
let [num_batches, num_heads, seq_len, head_dim] = [1, 2, 16, 32];
|
||||
|
||||
let query = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let key = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let value = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let options = AttentionModuleOptions {
|
||||
scale: Some(0.1),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let output = attention(
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
value.clone(),
|
||||
None,
|
||||
None,
|
||||
options,
|
||||
);
|
||||
|
||||
let expected = attention_fallback::<TestBackend>(query, key, value, None, None, options);
|
||||
|
||||
output.into_data().assert_approx_eq::<FloatElem>(
|
||||
&expected.into_data(),
|
||||
Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_attn_bias() {
|
||||
let [num_batches, num_heads, seq_len, head_dim] = [1, 2, 16, 32];
|
||||
|
||||
let query = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let key = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let value = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let bias = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, seq_len],
|
||||
Distribution::Uniform(-0.5, 0.5),
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = attention(
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
value.clone(),
|
||||
None,
|
||||
Some(bias.clone()),
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let expected =
|
||||
attention_fallback::<TestBackend>(query, key, value, None, Some(bias), Default::default());
|
||||
|
||||
output.into_data().assert_approx_eq::<FloatElem>(
|
||||
&expected.into_data(),
|
||||
Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_softcap() {
|
||||
let [num_batches, num_heads, seq_len, head_dim] = [1, 2, 16, 32];
|
||||
|
||||
let query = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let key = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let value = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let options = AttentionModuleOptions {
|
||||
softcap: Some(50.0),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let output = attention(
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
value.clone(),
|
||||
None,
|
||||
None,
|
||||
options,
|
||||
);
|
||||
|
||||
let expected = attention_fallback::<TestBackend>(query, key, value, None, None, options);
|
||||
|
||||
output.into_data().assert_approx_eq::<FloatElem>(
|
||||
&expected.into_data(),
|
||||
Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_is_causal() {
|
||||
let [num_batches, num_heads, seq_len, head_dim] = [2, 4, 16, 32];
|
||||
|
||||
let query = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let key = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let value = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let options = AttentionModuleOptions {
|
||||
is_causal: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let output = attention(
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
value.clone(),
|
||||
None,
|
||||
None,
|
||||
options,
|
||||
);
|
||||
|
||||
let expected = attention_fallback::<TestBackend>(query, key, value, None, None, options);
|
||||
|
||||
output.into_data().assert_approx_eq::<FloatElem>(
|
||||
&expected.into_data(),
|
||||
Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1),
|
||||
);
|
||||
}
|
||||
|
||||
/// Cross-attention: seq_q != seq_k, with causal masking and additive bias.
|
||||
#[test]
|
||||
fn test_attention_cross_attention_with_bias() {
|
||||
let [num_batches, num_heads, seq_q, seq_k, head_dim] = [2, 2, 8, 24, 32];
|
||||
|
||||
let query = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_q, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let key = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_k, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let value = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_k, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let bias = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_q, seq_k],
|
||||
Distribution::Uniform(-0.5, 0.5),
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let options = AttentionModuleOptions {
|
||||
is_causal: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let output = attention(
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
value.clone(),
|
||||
None,
|
||||
Some(bias.clone()),
|
||||
options,
|
||||
);
|
||||
|
||||
let expected = attention_fallback::<TestBackend>(query, key, value, None, Some(bias), options);
|
||||
|
||||
output.into_data().assert_approx_eq::<FloatElem>(
|
||||
&expected.into_data(),
|
||||
Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1),
|
||||
);
|
||||
}
|
||||
|
||||
/// Regression: softcap must be applied before -inf masking.
|
||||
/// With causal masking, position 0 can only attend to itself, so output[0] == value[0].
|
||||
/// If softcap were applied after masking, tanh(-inf/softcap) = -softcap (finite),
|
||||
/// and the masked position would leak into the output.
|
||||
#[test]
|
||||
fn test_attention_softcap_preserves_causal_mask() {
|
||||
let [num_batches, num_heads, seq_len, head_dim] = [1, 1, 4, 8];
|
||||
|
||||
let query = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let key = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let value = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let options = AttentionModuleOptions {
|
||||
softcap: Some(20.0),
|
||||
is_causal: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let output = attention_fallback::<TestBackend>(query, key, value.clone(), None, None, options);
|
||||
|
||||
// With causal masking, position 0 can only attend to itself (softmax = [1, 0, 0, 0]).
|
||||
// So output[..., 0, :] must equal value[..., 0, :].
|
||||
let output_row0 = output.slice([0..1, 0..1, 0..1, 0..head_dim]);
|
||||
let value_row0 = value.slice([0..1, 0..1, 0..1, 0..head_dim]);
|
||||
|
||||
output_row0
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&value_row0.into_data(), Tolerance::relative(1e-5));
|
||||
}
|
||||
|
||||
/// Combined: mask + bias + custom scale + softcap together.
|
||||
#[test]
|
||||
fn test_attention_all_options() {
|
||||
let [num_batches, num_heads, seq_len, head_dim] = [2, 2, 16, 32];
|
||||
|
||||
let query = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let key = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let value = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, head_dim],
|
||||
Distribution::Uniform(-1., 1.),
|
||||
&Default::default(),
|
||||
);
|
||||
let bias = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, seq_len],
|
||||
Distribution::Uniform(-0.5, 0.5),
|
||||
&Default::default(),
|
||||
);
|
||||
// Create a random bool mask by thresholding a uniform float tensor
|
||||
let mask = TestTensor::<4>::random(
|
||||
[num_batches, num_heads, seq_len, seq_len],
|
||||
Distribution::Uniform(0., 1.),
|
||||
&Default::default(),
|
||||
)
|
||||
.greater_elem(0.7);
|
||||
|
||||
let options = AttentionModuleOptions {
|
||||
scale: Some(0.05),
|
||||
softcap: Some(30.0),
|
||||
is_causal: true,
|
||||
};
|
||||
|
||||
let output = attention(
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
value.clone(),
|
||||
Some(mask.clone()),
|
||||
Some(bias.clone()),
|
||||
options,
|
||||
);
|
||||
|
||||
let expected =
|
||||
attention_fallback::<TestBackend>(query, key, value, Some(mask), Some(bias), options);
|
||||
|
||||
output.into_data().assert_approx_eq::<FloatElem>(
|
||||
&expected.into_data(),
|
||||
Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1),
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
use super::*;
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::avg_pool1d;
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool1d_simple() {
|
||||
let test = AvgPool1dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
kernel_size: 3,
|
||||
padding: 0,
|
||||
stride: 1,
|
||||
length: 6,
|
||||
count_include_pad: true,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[[1., 2., 3., 4.]]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool1d_complex() {
|
||||
let test = AvgPool1dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
length: 6,
|
||||
count_include_pad: true,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[0.33333, 2.0000, 4.0000],
|
||||
[4.33333, 8.0000, 10.0000],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool1d_complex_dont_count_pad() {
|
||||
let test = AvgPool1dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
length: 6,
|
||||
count_include_pad: false,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[0.5000, 2.0000, 4.0000],
|
||||
[6.5000, 8.0000, 10.0000],
|
||||
]]));
|
||||
}
|
||||
|
||||
struct AvgPool1dTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
kernel_size: usize,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
length: usize,
|
||||
count_include_pad: bool,
|
||||
}
|
||||
|
||||
impl AvgPool1dTestCase {
|
||||
fn assert_output(self, y: TestTensor<3>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())
|
||||
.reshape::<3, _>(shape_x)
|
||||
.into_data(),
|
||||
);
|
||||
let output = avg_pool1d(
|
||||
x,
|
||||
self.kernel_size,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.count_include_pad,
|
||||
false,
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq::<FloatElem>(
|
||||
&output.into_data(),
|
||||
Tolerance::default().set_half_precision_relative(1e-3),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool1d_ceil_mode() {
|
||||
// Test ceil_mode=true produces larger output when input doesn't divide evenly by stride
|
||||
// Input: 1x1x6 (values 0-5), kernel: 3, stride: 2, padding: 0
|
||||
// Floor mode: output = (6-3)/2+1 = 2 elements
|
||||
// Ceil mode: output = ceil((6-3)/2)+1 = ceil(1.5)+1 = 3 elements
|
||||
let x = TestTensor::from([[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]]]);
|
||||
|
||||
// With ceil_mode=false (floor): output is 2 elements
|
||||
// Window 0: avg(0,1,2) = 1
|
||||
// Window 1: avg(2,3,4) = 3
|
||||
let y_floor = TestTensor::<3>::from([[[1.0, 3.0]]]);
|
||||
|
||||
let output_floor = avg_pool1d(
|
||||
x.clone(),
|
||||
3, // kernel_size
|
||||
2, // stride
|
||||
0, // padding
|
||||
true, // count_include_pad
|
||||
false,
|
||||
);
|
||||
|
||||
y_floor.to_data().assert_approx_eq::<FloatElem>(
|
||||
&output_floor.into_data(),
|
||||
Tolerance::default().set_half_precision_relative(1e-3),
|
||||
);
|
||||
|
||||
// With ceil_mode=true: output is 3 elements
|
||||
// Window 0: avg(0,1,2) = 1
|
||||
// Window 1: avg(2,3,4) = 3
|
||||
// Window 2: avg(4,5) = 4.5 (partial window, count_include_pad=false divides by 2)
|
||||
let y_ceil = TestTensor::<3>::from([[[1.0, 3.0, 4.5]]]);
|
||||
|
||||
let output_ceil = avg_pool1d(
|
||||
x, 3, // kernel_size
|
||||
2, // stride
|
||||
0, // padding
|
||||
false, // count_include_pad=false to get correct average for partial window
|
||||
true,
|
||||
);
|
||||
|
||||
y_ceil.to_data().assert_approx_eq::<FloatElem>(
|
||||
&output_ceil.into_data(),
|
||||
Tolerance::default().set_half_precision_relative(1e-3),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool1d_ceil_mode_count_include_pad() {
|
||||
// Test count_include_pad=true + ceil_mode=true interaction for 1D
|
||||
// When ceil_mode creates windows that extend beyond the padded input:
|
||||
// - count_include_pad=true should count positions within padded bounds (not ceil_mode extensions)
|
||||
//
|
||||
// Input: 1x1x6, kernel 3, stride 2, padding 1, ceil_mode=true
|
||||
// Output is 4 elements
|
||||
let x = TestTensor::from([[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]]]);
|
||||
|
||||
// Expected PyTorch output with padding=1, ceil_mode=true, count_include_pad=true:
|
||||
// Window 0: positions -1,0,1 -> values 0,0,1 (0 is padding) / 3 = 0.333
|
||||
// Window 1: positions 1,2,3 -> values 1,2,3 / 3 = 2.0
|
||||
// Window 2: positions 3,4,5 -> values 3,4,5 / 3 = 4.0
|
||||
// Window 3: positions 5,6,7 -> only 5 is valid, 6 is padding, 7 is ceil_mode extension
|
||||
// value 5 / 2 (only 2 positions within padded bounds) = 2.5
|
||||
let expected = TestTensor::<3>::from([[[0.3333, 2.0, 4.0, 2.5]]]);
|
||||
|
||||
let output = avg_pool1d(
|
||||
x, 3, // kernel_size
|
||||
2, // stride
|
||||
1, // padding
|
||||
true, // count_include_pad=true
|
||||
true, // ceil_mode=true
|
||||
);
|
||||
|
||||
expected.to_data().assert_approx_eq::<FloatElem>(
|
||||
&output.into_data(),
|
||||
Tolerance::default().set_half_precision_relative(1e-2),
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,220 @@
|
||||
use super::*;
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::avg_pool2d;
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool2d_simple() {
|
||||
let test = AvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
height: 6,
|
||||
width: 6,
|
||||
count_include_pad: true,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[[
|
||||
[7., 8., 9., 10.],
|
||||
[13., 14., 15., 16.],
|
||||
[19., 20., 21., 22.],
|
||||
[25., 26., 27., 28.],
|
||||
]]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool2d_complex() {
|
||||
let test = AvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 4,
|
||||
padding_1: 1,
|
||||
padding_2: 2,
|
||||
stride_1: 1,
|
||||
stride_2: 2,
|
||||
height: 4,
|
||||
width: 6,
|
||||
count_include_pad: true,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[[
|
||||
[1.1667, 3.0000, 4.3333, 2.5000],
|
||||
[3.2500, 7.5000, 9.5000, 5.2500],
|
||||
[6.2500, 13.5000, 15.5000, 8.2500],
|
||||
[5.1667, 11.0000, 12.3333, 6.5000],
|
||||
]]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool2d_complex_dont_include_pad() {
|
||||
let test = AvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 4,
|
||||
padding_1: 1,
|
||||
padding_2: 2,
|
||||
stride_1: 1,
|
||||
stride_2: 2,
|
||||
height: 4,
|
||||
width: 6,
|
||||
count_include_pad: false,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[[
|
||||
[3.5000, 4.5000, 6.5000, 7.5000],
|
||||
[6.5000, 7.5000, 9.5000, 10.5000],
|
||||
[12.5000, 13.5000, 15.5000, 16.5000],
|
||||
[15.5000, 16.5000, 18.5000, 19.5000],
|
||||
]]]));
|
||||
}
|
||||
|
||||
struct AvgPool2dTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
kernel_size_1: usize,
|
||||
kernel_size_2: usize,
|
||||
padding_1: usize,
|
||||
padding_2: usize,
|
||||
stride_1: usize,
|
||||
stride_2: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
count_include_pad: bool,
|
||||
}
|
||||
|
||||
impl AvgPool2dTestCase {
|
||||
fn assert_output(self, y: TestTensor<4>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data(),
|
||||
);
|
||||
let output = avg_pool2d(
|
||||
x,
|
||||
[self.kernel_size_1, self.kernel_size_2],
|
||||
[self.stride_1, self.stride_2],
|
||||
[self.padding_1, self.padding_2],
|
||||
self.count_include_pad,
|
||||
false,
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq::<FloatElem>(
|
||||
&output.into_data(),
|
||||
Tolerance::default().set_half_precision_relative(1e-3),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool2d_ceil_mode() {
|
||||
// Test ceil_mode=true produces larger output when input doesn't divide evenly by stride
|
||||
// Input: 1x1x6x6 (values 0-35), kernel: 3x3, stride: 2x2, padding: 0x0
|
||||
// Floor mode: output = (6-3)/2+1 = 2 x 2
|
||||
// Ceil mode: output = ceil((6-3)/2)+1 = ceil(1.5)+1 = 3 x 3
|
||||
let x = TestTensor::from([[[
|
||||
[0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0, 9.0, 10.0, 11.0],
|
||||
[12.0, 13.0, 14.0, 15.0, 16.0, 17.0],
|
||||
[18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
|
||||
[24.0, 25.0, 26.0, 27.0, 28.0, 29.0],
|
||||
[30.0, 31.0, 32.0, 33.0, 34.0, 35.0],
|
||||
]]]);
|
||||
|
||||
// With ceil_mode=false (floor): output is 2x2
|
||||
// Window (0,0): avg(0,1,2,6,7,8,12,13,14) = avg(63) = 7
|
||||
// Window (0,1): avg(2,3,4,8,9,10,14,15,16) = avg(81) = 9
|
||||
// Window (1,0): avg(12,13,14,18,19,20,24,25,26) = avg(171) = 19
|
||||
// Window (1,1): avg(14,15,16,20,21,22,26,27,28) = avg(189) = 21
|
||||
let y_floor = TestTensor::<4>::from([[[[7.0, 9.0], [19.0, 21.0]]]]);
|
||||
|
||||
let output_floor = avg_pool2d(
|
||||
x.clone(),
|
||||
[3, 3],
|
||||
[2, 2],
|
||||
[0, 0],
|
||||
true, // count_include_pad
|
||||
false,
|
||||
);
|
||||
|
||||
y_floor.to_data().assert_approx_eq::<FloatElem>(
|
||||
&output_floor.into_data(),
|
||||
Tolerance::default().set_half_precision_relative(1e-3),
|
||||
);
|
||||
|
||||
// With ceil_mode=true: output is 3x3
|
||||
// The extra windows at the edge include partial/padded regions
|
||||
// When count_include_pad=false, only actual values are averaged
|
||||
// Window (0,2): positions (0:3, 4:6) -> values 4,5,10,11,16,17 -> avg = 10.5
|
||||
// Window (1,2): positions (2:5, 4:6) -> values 16,17,22,23,28,29 -> avg = 22.5
|
||||
// Window (2,0): positions (4:6, 0:3) -> values 24,25,26,30,31,32 -> avg = 28
|
||||
// Window (2,1): positions (4:6, 2:5) -> values 26,27,28,32,33,34 -> avg = 30
|
||||
// Window (2,2): positions (4:6, 4:6) -> values 28,29,34,35 -> avg = 31.5
|
||||
let y_ceil =
|
||||
TestTensor::<4>::from([[[[7.0, 9.0, 10.5], [19.0, 21.0, 22.5], [28.0, 30.0, 31.5]]]]);
|
||||
|
||||
let output_ceil = avg_pool2d(
|
||||
x,
|
||||
[3, 3],
|
||||
[2, 2],
|
||||
[0, 0],
|
||||
false, // count_include_pad=false to avoid dividing by full kernel size
|
||||
true,
|
||||
);
|
||||
|
||||
y_ceil.to_data().assert_approx_eq::<FloatElem>(
|
||||
&output_ceil.into_data(),
|
||||
Tolerance::default().set_half_precision_relative(1e-3),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool2d_ceil_mode_count_include_pad() {
|
||||
// Test count_include_pad=true + ceil_mode=true interaction
|
||||
// When ceil_mode creates windows that extend beyond the padded input:
|
||||
// - count_include_pad=true should count positions within padded bounds (not ceil_mode extensions)
|
||||
//
|
||||
// For input 6x6, kernel 3, stride 2, padding 1, ceil_mode=true:
|
||||
// - Output is 4x4
|
||||
// - Corner (3,3) window covers positions beyond even the user padding
|
||||
// - Expected: 35/4 = 8.75 (divides by count of positions within padded bounds)
|
||||
|
||||
let x = TestTensor::from([[[
|
||||
[0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0, 9.0, 10.0, 11.0],
|
||||
[12.0, 13.0, 14.0, 15.0, 16.0, 17.0],
|
||||
[18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
|
||||
[24.0, 25.0, 26.0, 27.0, 28.0, 29.0],
|
||||
[30.0, 31.0, 32.0, 33.0, 34.0, 35.0],
|
||||
]]]);
|
||||
|
||||
// Expected PyTorch output with padding=1, ceil_mode=true, count_include_pad=true
|
||||
// Note: corner (3,3) = 8.75 = 35/4, not 35/9
|
||||
let expected = TestTensor::<4>::from([[[
|
||||
[1.5556, 3.3333, 4.6667, 2.6667],
|
||||
[8.3333, 14.0000, 16.0000, 8.5000],
|
||||
[16.3333, 26.0000, 28.0000, 14.5000],
|
||||
[10.1667, 16.0000, 17.0000, 8.7500],
|
||||
]]]);
|
||||
|
||||
let output = avg_pool2d(
|
||||
x,
|
||||
[3, 3],
|
||||
[2, 2],
|
||||
[1, 1],
|
||||
true, // count_include_pad=true
|
||||
true, // ceil_mode=true
|
||||
);
|
||||
|
||||
expected.to_data().assert_approx_eq::<FloatElem>(
|
||||
&output.into_data(),
|
||||
Tolerance::default().set_half_precision_relative(1e-2),
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
use super::*;
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::interpolate;
|
||||
use burn_tensor::ops::{InterpolateMode, InterpolateOptions};
|
||||
|
||||
#[test]
|
||||
fn test_upsample_interpolation() {
|
||||
let test = InterpolateTestCase {
|
||||
batch_size: 2,
|
||||
channels: 1,
|
||||
height: 7,
|
||||
width: 5,
|
||||
height_out: 8,
|
||||
width_out: 7,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([
|
||||
[[
|
||||
[0.0000, 0.5741, 1.3704, 2.0000, 2.6296, 3.4259, 4.0000],
|
||||
[4.0015, 4.5755, 5.3718, 6.0015, 6.6311, 7.4274, 8.0015],
|
||||
[8.3528, 8.9268, 9.7231, 10.3528, 10.9824, 11.7787, 12.3528],
|
||||
[
|
||||
12.7697, 13.3438, 14.1400, 14.7697, 15.3993, 16.1956, 16.7697,
|
||||
],
|
||||
[
|
||||
17.2303, 17.8044, 18.6007, 19.2303, 19.8600, 20.6562, 21.2303,
|
||||
],
|
||||
[
|
||||
21.6472, 22.2213, 23.0176, 23.6472, 24.2769, 25.0731, 25.6472,
|
||||
],
|
||||
[
|
||||
25.9986, 26.5726, 27.3689, 27.9986, 28.6282, 29.4245, 29.9986,
|
||||
],
|
||||
[
|
||||
30.0000, 30.5741, 31.3704, 32.0000, 32.6296, 33.4259, 34.0000,
|
||||
],
|
||||
]],
|
||||
[[
|
||||
[
|
||||
35.0000, 35.5741, 36.3704, 37.0000, 37.6296, 38.4259, 39.0000,
|
||||
],
|
||||
[
|
||||
39.0015, 39.5755, 40.3718, 41.0015, 41.6311, 42.4274, 43.0015,
|
||||
],
|
||||
[
|
||||
43.3528, 43.9269, 44.7231, 45.3528, 45.9824, 46.7787, 47.3528,
|
||||
],
|
||||
[
|
||||
47.7697, 48.3438, 49.1400, 49.7697, 50.3993, 51.1956, 51.7697,
|
||||
],
|
||||
[
|
||||
52.2303, 52.8044, 53.6007, 54.2303, 54.8600, 55.6562, 56.2303,
|
||||
],
|
||||
[
|
||||
56.6472, 57.2213, 58.0176, 58.6472, 59.2769, 60.0731, 60.6472,
|
||||
],
|
||||
[
|
||||
60.9986, 61.5726, 62.3689, 62.9986, 63.6282, 64.4245, 64.9986,
|
||||
],
|
||||
[
|
||||
65.0000, 65.5741, 66.3704, 67.0000, 67.6296, 68.4259, 69.0000,
|
||||
],
|
||||
]],
|
||||
]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_downsample_interpolation() {
|
||||
let test = InterpolateTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
height: 45,
|
||||
width: 14,
|
||||
height_out: 4,
|
||||
width_out: 6,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[[
|
||||
[0.0000, 2.5760, 5.2480, 7.7520, 10.4240, 13.0000],
|
||||
[204.8148, 207.3908, 210.0628, 212.5668, 215.2388, 217.8148],
|
||||
[411.1852, 413.7612, 416.4331, 418.9371, 421.6091, 424.1852],
|
||||
[616.0000, 618.576, 621.2479, 623.7519, 626.4239, 629.0000],
|
||||
]]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_1d_bicubic() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
let device = Default::default();
|
||||
|
||||
// Run the model
|
||||
let input = TestTensor::<3>::from_floats(
|
||||
[[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let input = input.unsqueeze_dim(2);
|
||||
|
||||
let output = interpolate(
|
||||
input,
|
||||
[1, 9],
|
||||
InterpolateOptions::new(InterpolateMode::Bicubic),
|
||||
);
|
||||
|
||||
assert_eq!(output.dims(), [1, 1, 1, 9]);
|
||||
|
||||
// assert output data does not contain NaN
|
||||
assert!(
|
||||
!output
|
||||
.clone()
|
||||
.to_data()
|
||||
.as_slice::<FloatElem>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.any(|&x| x.is_nan()),
|
||||
"interpolate output contains NaN"
|
||||
);
|
||||
|
||||
TestTensor::<4>::from([[[[
|
||||
1.541, 0.5747652, -1.010614, -2.197787, -0.8269969, 0.59609234, -0.5803058, -1.3792794,
|
||||
-1.3986,
|
||||
]]]])
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
}
|
||||
struct InterpolateTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
height_out: usize,
|
||||
width_out: usize,
|
||||
}
|
||||
|
||||
impl InterpolateTestCase {
|
||||
fn assert_output(self, y: TestTensor<4>) {
|
||||
self.assert_output_with_align_corners(y, true);
|
||||
}
|
||||
|
||||
fn assert_output_with_align_corners(self, y: TestTensor<4>, align_corners: bool) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data(),
|
||||
);
|
||||
let output = interpolate(
|
||||
x,
|
||||
[self.height_out, self.width_out],
|
||||
InterpolateOptions::new(InterpolateMode::Bicubic).with_align_corners(align_corners),
|
||||
);
|
||||
|
||||
let tolerance = Tolerance::permissive();
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), tolerance);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_upsample_half_pixel() {
|
||||
let test = InterpolateTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
height_out: 8,
|
||||
width_out: 8,
|
||||
};
|
||||
|
||||
test.assert_output_with_align_corners(
|
||||
TestTensor::from([[[
|
||||
[
|
||||
-0.5273, -0.2305, 0.2461, 0.875, 1.2812, 1.9102, 2.3867, 2.6836,
|
||||
],
|
||||
[
|
||||
0.6602, 0.957, 1.4336, 2.0625, 2.4688, 3.0977, 3.5742, 3.8711,
|
||||
],
|
||||
[
|
||||
2.5664, 2.8633, 3.3398, 3.9688, 4.375, 5.0039, 5.4805, 5.7773,
|
||||
],
|
||||
[5.082, 5.3789, 5.8555, 6.4844, 6.8906, 7.5195, 7.9961, 8.293],
|
||||
[6.707, 7.0039, 7.4805, 8.1094, 8.5156, 9.1445, 9.6211, 9.918],
|
||||
[
|
||||
9.2227, 9.5195, 9.9961, 10.625, 11.0312, 11.6602, 12.1367, 12.4336,
|
||||
],
|
||||
[
|
||||
11.1289, 11.4258, 11.9023, 12.5312, 12.9375, 13.5664, 14.043, 14.3398,
|
||||
],
|
||||
[
|
||||
12.3164, 12.6133, 13.0898, 13.7188, 14.125, 14.7539, 15.2305, 15.5273,
|
||||
],
|
||||
]]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,270 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::interpolate;
|
||||
use burn_tensor::ops::{InterpolateMode, InterpolateOptions};
|
||||
use burn_tensor::{DType, Shape};
|
||||
|
||||
#[test]
|
||||
fn test_upsample_interpolation() {
|
||||
let test = InterpolateTestCase {
|
||||
batch_size: 2,
|
||||
channels: 1,
|
||||
height: 7,
|
||||
width: 5,
|
||||
height_out: 8,
|
||||
width_out: 7,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([
|
||||
[[
|
||||
[0.0000, 0.6667, 1.3333, 2.0000, 2.6667, 3.3333, 4.0000],
|
||||
[4.2857, 4.9524, 5.6190, 6.2857, 6.9524, 7.6190, 8.2857],
|
||||
[8.5714, 9.2381, 9.9048, 10.5714, 11.2381, 11.9048, 12.5714],
|
||||
[
|
||||
12.8571, 13.5238, 14.1905, 14.8571, 15.5238, 16.1905, 16.8571,
|
||||
],
|
||||
[
|
||||
17.1429, 17.8095, 18.4762, 19.1429, 19.8095, 20.4762, 21.1429,
|
||||
],
|
||||
[
|
||||
21.4286, 22.0952, 22.7619, 23.4286, 24.0952, 24.7619, 25.4286,
|
||||
],
|
||||
[
|
||||
25.7143, 26.3810, 27.0476, 27.7143, 28.3810, 29.0476, 29.7143,
|
||||
],
|
||||
[
|
||||
30.0000, 30.6667, 31.3333, 32.0000, 32.6667, 33.3333, 34.0000,
|
||||
],
|
||||
]],
|
||||
[[
|
||||
[
|
||||
35.0000, 35.6667, 36.3333, 37.0000, 37.6667, 38.3333, 39.0000,
|
||||
],
|
||||
[
|
||||
39.2857, 39.9524, 40.6190, 41.2857, 41.9524, 42.6190, 43.2857,
|
||||
],
|
||||
[
|
||||
43.5714, 44.2381, 44.9048, 45.5714, 46.2381, 46.9048, 47.5714,
|
||||
],
|
||||
[
|
||||
47.8571, 48.5238, 49.1905, 49.8571, 50.5238, 51.1905, 51.8571,
|
||||
],
|
||||
[
|
||||
52.1429, 52.8095, 53.4762, 54.1429, 54.8095, 55.4762, 56.1429,
|
||||
],
|
||||
[
|
||||
56.4286, 57.0952, 57.7619, 58.4286, 59.0952, 59.7619, 60.4286,
|
||||
],
|
||||
[
|
||||
60.7143, 61.3810, 62.0476, 62.7143, 63.3810, 64.0476, 64.7143,
|
||||
],
|
||||
[
|
||||
65.0000, 65.6667, 66.3333, 67.0000, 67.6667, 68.3333, 69.0000,
|
||||
],
|
||||
]],
|
||||
]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_downsample_interpolation() {
|
||||
let test = InterpolateTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
height: 45,
|
||||
width: 14,
|
||||
height_out: 4,
|
||||
width_out: 6,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[[
|
||||
[0.0, 2.6, 5.2, 7.8, 10.4, 13.],
|
||||
[205.3333, 207.9333, 210.5333, 213.1333, 215.7333, 218.3333],
|
||||
[410.6667, 413.2667, 415.8667, 418.4667, 421.0667, 423.6667],
|
||||
[616., 618.6, 621.2, 623.8, 626.4, 629.],
|
||||
]]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_1d_bilinear() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
let device = Default::default();
|
||||
|
||||
// Run the model
|
||||
let input = TestTensor::<3>::from_floats(
|
||||
[[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let input = input.unsqueeze_dim(2);
|
||||
|
||||
let output = interpolate(
|
||||
input,
|
||||
[1, 9],
|
||||
InterpolateOptions::new(InterpolateMode::Bilinear),
|
||||
);
|
||||
|
||||
assert_eq!(output.dims(), [1, 1, 1, 9]);
|
||||
|
||||
// assert output data does not contain NaN
|
||||
assert!(
|
||||
!output
|
||||
.clone()
|
||||
.to_data()
|
||||
.as_slice::<FloatElem>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.any(|&x| x.is_nan()),
|
||||
"interpolate output contains NaN"
|
||||
);
|
||||
|
||||
TestTensor::<4>::from([[[[
|
||||
1.541f32,
|
||||
0.39450002,
|
||||
-0.76475,
|
||||
-1.943125,
|
||||
-0.80520004,
|
||||
0.36178753,
|
||||
-0.671275,
|
||||
-1.2022874,
|
||||
-1.3986,
|
||||
]]]])
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_interpolate_coord_float_precision_boundary() {
|
||||
let test = InterpolateTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
height: 28,
|
||||
width: 4,
|
||||
height_out: 24,
|
||||
width_out: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[[
|
||||
[0.0, 3.0],
|
||||
[4.6956, 7.6956],
|
||||
[9.3913, 12.3913],
|
||||
[14.0869, 17.0869],
|
||||
[18.7826, 21.7826],
|
||||
[23.4782, 26.4782],
|
||||
[28.1739, 31.1739],
|
||||
[32.8695, 35.8695],
|
||||
[37.5652, 40.5652],
|
||||
[42.2608, 45.2608],
|
||||
[46.9565, 49.9565],
|
||||
[51.6521, 54.6521],
|
||||
[56.3478, 59.3478],
|
||||
[61.0434, 64.0434],
|
||||
[65.7391, 68.7391],
|
||||
[70.4347, 73.4347],
|
||||
[75.1304, 78.1304],
|
||||
[79.8260, 82.8260],
|
||||
[84.5217, 87.5217],
|
||||
[89.2173, 92.2173],
|
||||
[93.9130, 96.9130],
|
||||
[98.6086, 101.6086],
|
||||
[103.3043, 106.3043],
|
||||
[108.0, 111.0],
|
||||
]]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_interpolate_cast() {
|
||||
let device = Default::default();
|
||||
let shape_x = Shape::new([1, 1, 4, 4]);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data(),
|
||||
)
|
||||
.cast(DType::F32); // ok for f32 backends, casts dtype for f16 tests
|
||||
let output = interpolate(
|
||||
x,
|
||||
[8, 8],
|
||||
InterpolateOptions::new(InterpolateMode::Bilinear),
|
||||
);
|
||||
|
||||
let expected = TestTensor::<4>::from([[[
|
||||
[0.0, 0.42857, 0.8571, 1.2857, 1.7142, 2.1428, 2.5714, 3.0],
|
||||
[1.7142, 2.1428, 2.5714, 3.0, 3.4285, 3.8571, 4.2857, 4.7142],
|
||||
[3.4285, 3.8571, 4.2857, 4.7142, 5.1428, 5.5714, 6.0, 6.4285],
|
||||
[5.1428, 5.5714, 6.0, 6.4285, 6.8571, 7.2857, 7.7142, 8.1428],
|
||||
[6.8571, 7.2857, 7.7142, 8.1428, 8.5714, 9.0, 9.4285, 9.8571],
|
||||
[
|
||||
8.5714, 9.0, 9.4285, 9.8571, 10.2857, 10.7142, 11.1428, 11.5714,
|
||||
],
|
||||
[
|
||||
10.2857, 10.7142, 11.1428, 11.5714, 12.0, 12.4285, 12.8571, 13.2857,
|
||||
],
|
||||
[
|
||||
12.0, 12.4285, 12.8571, 13.2857, 13.7142, 14.1428, 14.5714, 15.0,
|
||||
],
|
||||
]]]);
|
||||
|
||||
let tolerance = Tolerance::permissive();
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected.into_data(), tolerance);
|
||||
}
|
||||
|
||||
struct InterpolateTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
height_out: usize,
|
||||
width_out: usize,
|
||||
}
|
||||
|
||||
impl InterpolateTestCase {
|
||||
fn assert_output(self, y: TestTensor<4>) {
|
||||
self.assert_output_with_align_corners(y, true);
|
||||
}
|
||||
|
||||
fn assert_output_with_align_corners(self, y: TestTensor<4>, align_corners: bool) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data(),
|
||||
);
|
||||
let output = interpolate(
|
||||
x,
|
||||
[self.height_out, self.width_out],
|
||||
InterpolateOptions::new(InterpolateMode::Bilinear).with_align_corners(align_corners),
|
||||
);
|
||||
|
||||
let tolerance = Tolerance::permissive();
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), tolerance);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_upsample_half_pixel() {
|
||||
let test = InterpolateTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
height_out: 8,
|
||||
width_out: 8,
|
||||
};
|
||||
|
||||
test.assert_output_with_align_corners(
|
||||
TestTensor::from([[[
|
||||
[0.0, 0.25, 0.75, 1.25, 1.75, 2.25, 2.75, 3.0],
|
||||
[1.0, 1.25, 1.75, 2.25, 2.75, 3.25, 3.75, 4.0],
|
||||
[3.0, 3.25, 3.75, 4.25, 4.75, 5.25, 5.75, 6.0],
|
||||
[5.0, 5.25, 5.75, 6.25, 6.75, 7.25, 7.75, 8.0],
|
||||
[7.0, 7.25, 7.75, 8.25, 8.75, 9.25, 9.75, 10.0],
|
||||
[9.0, 9.25, 9.75, 10.25, 10.75, 11.25, 11.75, 12.0],
|
||||
[11.0, 11.25, 11.75, 12.25, 12.75, 13.25, 13.75, 14.0],
|
||||
[12.0, 12.25, 12.75, 13.25, 13.75, 14.25, 14.75, 15.0],
|
||||
]]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
use super::*;
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::conv1d;
|
||||
use burn_tensor::ops::ConvOptions;
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_simple() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([
|
||||
[[43., 67., 82., 49.], [104., 176., 227., 158.]],
|
||||
[[139., 187., 202., 113.], [392., 584., 635., 414.]],
|
||||
]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_dilation() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 2,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([
|
||||
[[62., 38.], [159., 111.]],
|
||||
[[158., 102.], [447., 367.]],
|
||||
]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_groups() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 2,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([
|
||||
[[2., 5., 8., 3.], [42., 63., 75., 47.]],
|
||||
[[26., 29., 32., 11.], [114., 159., 171., 103.]],
|
||||
]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_complex() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 3,
|
||||
channels_out: 4,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[
|
||||
[[171., 294.], [415., 781.], [659., 1268.], [903., 1755.]],
|
||||
[[495., 726.], [1387., 2185.], [2279., 3644.], [3171., 5103.]],
|
||||
],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
struct Conv1dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
channels_out: usize,
|
||||
kernel_size: usize,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
length: usize,
|
||||
}
|
||||
|
||||
impl Conv1dTestCase {
|
||||
fn assert_output(self, y: TestTensor<3>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels_out,
|
||||
self.channels_in / self.groups,
|
||||
self.kernel_size,
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape::<3, _>(shape_weight)
|
||||
.into_data(),
|
||||
&device,
|
||||
);
|
||||
let bias = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
|
||||
&device,
|
||||
);
|
||||
let x = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<3, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
);
|
||||
let output = conv1d(
|
||||
x,
|
||||
weight,
|
||||
Some(bias),
|
||||
ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups),
|
||||
);
|
||||
|
||||
let tolerance = Tolerance::relative(1e-5).set_half_precision_relative(1e-3);
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), tolerance);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,652 @@
|
||||
use super::*;
|
||||
use alloc::{vec, vec::Vec};
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::activation::gelu;
|
||||
use burn_tensor::module::conv2d;
|
||||
use burn_tensor::ops::ConvOptions;
|
||||
use burn_tensor::{TensorData, Tolerance};
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_simple() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[
|
||||
[1196., 1796., 1916., 1264.],
|
||||
[1881., 2793., 2946., 1923.],
|
||||
[2313., 3405., 3558., 2307.],
|
||||
[1424., 2072., 2156., 1380.],
|
||||
],
|
||||
[
|
||||
[2709., 4173., 4509., 3065.],
|
||||
[4582., 7006., 7483., 5056.],
|
||||
[5878., 8914., 9391., 6304.],
|
||||
[4089., 6177., 6477., 4333.],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_simple_implicit() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 1,
|
||||
channels_out: 16,
|
||||
kernel_size_1: 4,
|
||||
kernel_size_2: 4,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 5,
|
||||
width: 5,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[
|
||||
[666., 916., 1030., 774.],
|
||||
[1124., 1500., 1620., 1190.],
|
||||
[1604., 2100., 2220., 1610.],
|
||||
[990., 1264., 1330., 936.],
|
||||
],
|
||||
[
|
||||
[1531., 2165., 2471., 1927.],
|
||||
[2757., 3805., 4181., 3207.],
|
||||
[4197., 5685., 6061., 4587.],
|
||||
[3295., 4433., 4691., 3529.],
|
||||
],
|
||||
[
|
||||
[2396., 3414., 3912., 3080.],
|
||||
[4390., 6110., 6742., 5224.],
|
||||
[6790., 9270., 9902., 7564.],
|
||||
[5600., 7602., 8052., 6122.],
|
||||
],
|
||||
[
|
||||
[3261., 4663., 5353., 4233.],
|
||||
[6023., 8415., 9303., 7241.],
|
||||
[9383., 12855., 13743., 10541.],
|
||||
[7905., 10771., 11413., 8715.],
|
||||
],
|
||||
[
|
||||
[4126., 5912., 6794., 5386.],
|
||||
[7656., 10720., 11864., 9258.],
|
||||
[11976., 16440., 17584., 13518.],
|
||||
[10210., 13940., 14774., 11308.],
|
||||
],
|
||||
[
|
||||
[4991., 7161., 8235., 6539.],
|
||||
[9289., 13025., 14425., 11275.],
|
||||
[14569., 20025., 21425., 16495.],
|
||||
[12515., 17109., 18135., 13901.],
|
||||
],
|
||||
[
|
||||
[5856., 8410., 9676., 7692.],
|
||||
[10922., 15330., 16986., 13292.],
|
||||
[17162., 23610., 25266., 19472.],
|
||||
[14820., 20278., 21496., 16494.],
|
||||
],
|
||||
[
|
||||
[6721., 9659., 11117., 8845.],
|
||||
[12555., 17635., 19547., 15309.],
|
||||
[19755., 27195., 29107., 22449.],
|
||||
[17125., 23447., 24857., 19087.],
|
||||
],
|
||||
[
|
||||
[7586., 10908., 12558., 9998.],
|
||||
[14188., 19940., 22108., 17326.],
|
||||
[22348., 30780., 32948., 25426.],
|
||||
[19430., 26616., 28218., 21680.],
|
||||
],
|
||||
[
|
||||
[8451., 12157., 13999., 11151.],
|
||||
[15821., 22245., 24669., 19343.],
|
||||
[24941., 34365., 36789., 28403.],
|
||||
[21735., 29785., 31579., 24273.],
|
||||
],
|
||||
[
|
||||
[9316., 13406., 15440., 12304.],
|
||||
[17454., 24550., 27230., 21360.],
|
||||
[27534., 37950., 40630., 31380.],
|
||||
[24040., 32954., 34940., 26866.],
|
||||
],
|
||||
[
|
||||
[10181., 14655., 16881., 13457.],
|
||||
[19087., 26855., 29791., 23377.],
|
||||
[30127., 41535., 44471., 34357.],
|
||||
[26345., 36123., 38301., 29459.],
|
||||
],
|
||||
[
|
||||
[11046., 15904., 18322., 14610.],
|
||||
[20720., 29160., 32352., 25394.],
|
||||
[32720., 45120., 48312., 37334.],
|
||||
[28650., 39292., 41662., 32052.],
|
||||
],
|
||||
[
|
||||
[11911., 17153., 19763., 15763.],
|
||||
[22353., 31465., 34913., 27411.],
|
||||
[35313., 48705., 52153., 40311.],
|
||||
[30955., 42461., 45023., 34645.],
|
||||
],
|
||||
[
|
||||
[12776., 18402., 21204., 16916.],
|
||||
[23986., 33770., 37474., 29428.],
|
||||
[37906., 52290., 55994., 43288.],
|
||||
[33260., 45630., 48384., 37238.],
|
||||
],
|
||||
[
|
||||
[13641., 19651., 22645., 18069.],
|
||||
[25619., 36075., 40035., 31445.],
|
||||
[40499., 55875., 59835., 46265.],
|
||||
[35565., 48799., 51745., 39831.],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_implicit_padded_in_channels() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 3,
|
||||
channels_out: 16,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[
|
||||
[4521., 6753., 7014., 4635.],
|
||||
[6858., 10197., 10548., 6939.],
|
||||
[7830., 11601., 11952., 7839.],
|
||||
[5007., 7383., 7590., 4953.],
|
||||
],
|
||||
[
|
||||
[10516., 15988., 16735., 11278.],
|
||||
[16822., 25507., 26587., 17875.],
|
||||
[19738., 29827., 30907., 20719.],
|
||||
[13594., 20506., 21199., 14188.],
|
||||
],
|
||||
[
|
||||
[16511., 25223., 26456., 17921.],
|
||||
[26786., 40817., 42626., 28811.],
|
||||
[31646., 48053., 49862., 33599.],
|
||||
[22181., 33629., 34808., 23423.],
|
||||
],
|
||||
[
|
||||
[22506., 34458., 36177., 24564.],
|
||||
[36750., 56127., 58665., 39747.],
|
||||
[43554., 66279., 68817., 46479.],
|
||||
[30768., 46752., 48417., 32658.],
|
||||
],
|
||||
[
|
||||
[28501., 43693., 45898., 31207.],
|
||||
[46714., 71437., 74704., 50683.],
|
||||
[55462., 84505., 87772., 59359.],
|
||||
[39355., 59875., 62026., 41893.],
|
||||
],
|
||||
[
|
||||
[34496., 52928., 55619., 37850.],
|
||||
[56678., 86747., 90743., 61619.],
|
||||
[67370., 102731., 106727., 72239.],
|
||||
[47942., 72998., 75635., 51128.],
|
||||
],
|
||||
[
|
||||
[40491., 62163., 65340., 44493.],
|
||||
[66642., 102057., 106782., 72555.],
|
||||
[79278., 120957., 125682., 85119.],
|
||||
[56529., 86121., 89244., 60363.],
|
||||
],
|
||||
[
|
||||
[46486., 71398., 75061., 51136.],
|
||||
[76606., 117367., 122821., 83491.],
|
||||
[91186., 139183., 144637., 97999.],
|
||||
[65116., 99244., 102853., 69598.],
|
||||
],
|
||||
[
|
||||
[52481., 80633., 84782., 57779.],
|
||||
[86570., 132677., 138860., 94427.],
|
||||
[103094., 157409., 163592., 110879.],
|
||||
[73703., 112367., 116462., 78833.],
|
||||
],
|
||||
[
|
||||
[58476., 89868., 94503., 64422.],
|
||||
[96534., 147987., 154899., 105363.],
|
||||
[115002., 175635., 182547., 123759.],
|
||||
[82290., 125490., 130071., 88068.],
|
||||
],
|
||||
[
|
||||
[64471., 99103., 104224., 71065.],
|
||||
[106498., 163297., 170938., 116299.],
|
||||
[126910., 193861., 201502., 136639.],
|
||||
[90877., 138613., 143680., 97303.],
|
||||
],
|
||||
[
|
||||
[70466., 108338., 113945., 77708.],
|
||||
[116462., 178607., 186977., 127235.],
|
||||
[138818., 212087., 220457., 149519.],
|
||||
[99464., 151736., 157289., 106538.],
|
||||
],
|
||||
[
|
||||
[76461., 117573., 123666., 84351.],
|
||||
[126426., 193917., 203016., 138171.],
|
||||
[150726., 230313., 239412., 162399.],
|
||||
[108051., 164859., 170898., 115773.],
|
||||
],
|
||||
[
|
||||
[82456., 126808., 133387., 90994.],
|
||||
[136390., 209227., 219055., 149107.],
|
||||
[162634., 248539., 258367., 175279.],
|
||||
[116638., 177982., 184507., 125008.],
|
||||
],
|
||||
[
|
||||
[88451., 136043., 143108., 97637.],
|
||||
[146354., 224537., 235094., 160043.],
|
||||
[174542., 266765., 277322., 188159.],
|
||||
[125225., 191105., 198116., 134243.],
|
||||
],
|
||||
[
|
||||
[94446., 145278., 152829., 104280.],
|
||||
[156318., 239847., 251133., 170979.],
|
||||
[186450., 284991., 296277., 201039.],
|
||||
[133812., 204228., 211725., 143478.],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_groups_channels_out() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 16,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 2,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[
|
||||
[73., 121., 154., 103.],
|
||||
[171., 258., 294., 186.],
|
||||
[279., 402., 438., 270.],
|
||||
[139., 187., 202., 113.],
|
||||
],
|
||||
[
|
||||
[164., 284., 371., 266.],
|
||||
[415., 664., 781., 538.],
|
||||
[739., 1132., 1249., 838.],
|
||||
[518., 782., 851., 564.],
|
||||
],
|
||||
[
|
||||
[255., 447., 588., 429.],
|
||||
[659., 1070., 1268., 890.],
|
||||
[1199., 1862., 2060., 1406.],
|
||||
[897., 1377., 1500., 1015.],
|
||||
],
|
||||
[
|
||||
[346., 610., 805., 592.],
|
||||
[903., 1476., 1755., 1242.],
|
||||
[1659., 2592., 2871., 1974.],
|
||||
[1276., 1972., 2149., 1466.],
|
||||
],
|
||||
[
|
||||
[437., 773., 1022., 755.],
|
||||
[1147., 1882., 2242., 1594.],
|
||||
[2119., 3322., 3682., 2542.],
|
||||
[1655., 2567., 2798., 1917.],
|
||||
],
|
||||
[
|
||||
[528., 936., 1239., 918.],
|
||||
[1391., 2288., 2729., 1946.],
|
||||
[2579., 4052., 4493., 3110.],
|
||||
[2034., 3162., 3447., 2368.],
|
||||
],
|
||||
[
|
||||
[619., 1099., 1456., 1081.],
|
||||
[1635., 2694., 3216., 2298.],
|
||||
[3039., 4782., 5304., 3678.],
|
||||
[2413., 3757., 4096., 2819.],
|
||||
],
|
||||
[
|
||||
[710., 1262., 1673., 1244.],
|
||||
[1879., 3100., 3703., 2650.],
|
||||
[3499., 5512., 6115., 4246.],
|
||||
[2792., 4352., 4745., 3270.],
|
||||
],
|
||||
[
|
||||
[5793., 8865., 9330., 6335.],
|
||||
[9467., 14450., 15134., 10250.],
|
||||
[11303., 17186., 17870., 12062.],
|
||||
[7971., 12099., 12546., 8457.],
|
||||
],
|
||||
[
|
||||
[6460., 9892., 10411., 7074.],
|
||||
[10575., 16152., 16917., 11466.],
|
||||
[12627., 19212., 19977., 13494.],
|
||||
[8926., 13558., 14059., 9484.],
|
||||
],
|
||||
[
|
||||
[7127., 10919., 11492., 7813.],
|
||||
[11683., 17854., 18700., 12682.],
|
||||
[13951., 21238., 22084., 14926.],
|
||||
[9881., 15017., 15572., 10511.],
|
||||
],
|
||||
[
|
||||
[7794., 11946., 12573., 8552.],
|
||||
[12791., 19556., 20483., 13898.],
|
||||
[15275., 23264., 24191., 16358.],
|
||||
[10836., 16476., 17085., 11538.],
|
||||
],
|
||||
[
|
||||
[8461., 12973., 13654., 9291.],
|
||||
[13899., 21258., 22266., 15114.],
|
||||
[16599., 25290., 26298., 17790.],
|
||||
[11791., 17935., 18598., 12565.],
|
||||
],
|
||||
[
|
||||
[9128., 14000., 14735., 10030.],
|
||||
[15007., 22960., 24049., 16330.],
|
||||
[17923., 27316., 28405., 19222.],
|
||||
[12746., 19394., 20111., 13592.],
|
||||
],
|
||||
[
|
||||
[9795., 15027., 15816., 10769.],
|
||||
[16115., 24662., 25832., 17546.],
|
||||
[19247., 29342., 30512., 20654.],
|
||||
[13701., 20853., 21624., 14619.],
|
||||
],
|
||||
[
|
||||
[10462., 16054., 16897., 11508.],
|
||||
[17223., 26364., 27615., 18762.],
|
||||
[20571., 31368., 32619., 22086.],
|
||||
[14656., 22312., 23137., 15646.],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_groups() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 2,
|
||||
height: 5,
|
||||
width: 5,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[[312., 348., 384.], [492., 528., 564.], [672., 708., 744.]],
|
||||
[
|
||||
[3724., 3841., 3958.],
|
||||
[4309., 4426., 4543.],
|
||||
[4894., 5011., 5128.],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_groups_multiple_channels() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 4,
|
||||
channels_out: 4,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 2,
|
||||
height: 5,
|
||||
width: 5,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[
|
||||
[4035., 4188., 4341.],
|
||||
[4800., 4953., 5106.],
|
||||
[5565., 5718., 5871.],
|
||||
],
|
||||
[
|
||||
[10030., 10507., 10984.],
|
||||
[12415., 12892., 13369.],
|
||||
[14800., 15277., 15754.],
|
||||
],
|
||||
[
|
||||
[56075., 56876., 57677.],
|
||||
[60080., 60881., 61682.],
|
||||
[64085., 64886., 65687.],
|
||||
],
|
||||
[
|
||||
[78270., 79395., 80520.],
|
||||
[83895., 85020., 86145.],
|
||||
[89520., 90645., 91770.],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_complex() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 3,
|
||||
channels_out: 4,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 2,
|
||||
padding_1: 1,
|
||||
padding_2: 2,
|
||||
stride_1: 2,
|
||||
stride_2: 3,
|
||||
dilation_1: 1,
|
||||
dilation_2: 2,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 5,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([
|
||||
[
|
||||
[[1845., 3789., 1926.], [3210., 6465., 3228.]],
|
||||
[[4276., 9082., 4789.], [8071., 16834., 8737.]],
|
||||
[[6707., 14375., 7652.], [12932., 27203., 14246.]],
|
||||
[[9138., 19668., 10515.], [17793., 37572., 19755.]],
|
||||
],
|
||||
[
|
||||
[[5445., 10629., 5166.], [8070., 15645., 7548.]],
|
||||
[[14356., 28882., 14509.], [22651., 45454., 22777.]],
|
||||
[[23267., 47135., 23852.], [37232., 75263., 38006.]],
|
||||
[[32178., 65388., 33195.], [51813., 105072., 53235.]],
|
||||
],
|
||||
]));
|
||||
}
|
||||
|
||||
struct Conv2dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
channels_out: usize,
|
||||
kernel_size_1: usize,
|
||||
kernel_size_2: usize,
|
||||
padding_1: usize,
|
||||
padding_2: usize,
|
||||
stride_1: usize,
|
||||
stride_2: usize,
|
||||
dilation_1: usize,
|
||||
dilation_2: usize,
|
||||
groups: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
}
|
||||
|
||||
impl Conv2dTestCase {
|
||||
fn assert_output(self, y: TestTensor<4>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels_out,
|
||||
self.channels_in / self.groups,
|
||||
self.kernel_size_1,
|
||||
self.kernel_size_2,
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_weight)
|
||||
.into_data(),
|
||||
);
|
||||
let bias = TestTensor::from(
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
|
||||
);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data(),
|
||||
);
|
||||
let output = conv2d(
|
||||
x,
|
||||
weight,
|
||||
Some(bias),
|
||||
ConvOptions::new(
|
||||
[self.stride_1, self.stride_2],
|
||||
[self.padding_1, self.padding_2],
|
||||
[self.dilation_1, self.dilation_2],
|
||||
self.groups,
|
||||
),
|
||||
);
|
||||
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
}
|
||||
}
|
||||
|
||||
#[rustfmt::skip] // param values are too long
|
||||
fn conv2d_weight() -> TensorData {
|
||||
TensorData::new(
|
||||
vec![0.048065186, -0.3059082, -0.10345459, -0.34643555, -0.20788574, -0.021072388, 0.13745117, -0.05102539, 0.024536133, -0.16479492, -0.19519043, 0.27270508, 0.17700195, -0.33764648, -0.08239746, -0.27929688, 0.17321777, -0.1315918, 0.04574585, -0.17980957, -0.33569336, 0.27612305, 0.30004883, -0.28979492, -0.17297363, -0.021759033, -0.27148438, 0.005657196, 0.29956055, -0.06958008, -0.29345703, -0.14440918, 0.10827637, -0.13305664, -0.20239258, 0.24890137, -0.1541748, -0.20019531, -0.2854004, 0.17016602, 0.07861328, -0.09075928, 0.30908203, -0.00013422966, 0.29589844, 0.15258789, -0.25708008, 0.20422363, -0.2529297, 0.07891846, -0.19506836, 0.23571777, 0.27124023, 0.17370605, -0.16992188, -0.23522949, 0.14648438, -0.09576416, -0.18310547, 0.21044922, -0.08911133, -0.2541504, -0.2775879, -0.2064209, -0.16271973, -0.048919678, -0.03555298, -0.11639404, 0.09661865, -0.10241699, 0.08929443, 0.2866211],
|
||||
[8, 1, 3, 3],
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_binary_broadcasted() {
|
||||
let device = Default::default();
|
||||
let x = TestTensor::<4>::full([1, 1, 28, 28], -0.42421296, &device);
|
||||
|
||||
// conv2d -> batchnorm -> activation
|
||||
let weight = TestTensor::from_data(conv2d_weight(), &device);
|
||||
let bias = TestTensor::from([
|
||||
0.082336426,
|
||||
-0.049591064,
|
||||
0.0031795502,
|
||||
0.00095653534,
|
||||
0.02357483,
|
||||
0.005569458,
|
||||
0.07525635,
|
||||
0.056396484,
|
||||
]);
|
||||
|
||||
// channels: [1, 8], kernel_size: [3, 3], stride: [1, 1], dilation: [1, 1], groups: 1, padding: [0, 0]
|
||||
let opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 1);
|
||||
let x = conv2d(x, weight, Some(bias), opt);
|
||||
|
||||
// simulate batchnorm binary ops with broadcasted params
|
||||
let gamma = TestTensor::<1>::from([
|
||||
1.0048828, 0.9902344, 1.0185547, 0.97558594, 1.0097656, 0.97802734, 1.0009766, 1.0146484,
|
||||
]);
|
||||
let beta = TestTensor::<1>::from([
|
||||
0.026290894,
|
||||
0.0007505417,
|
||||
0.006134033,
|
||||
0.02418518,
|
||||
0.07373047,
|
||||
0.020507813,
|
||||
0.01902771,
|
||||
0.02003479,
|
||||
]);
|
||||
let mean = TestTensor::<1>::from([
|
||||
0.029159546,
|
||||
-0.08673096,
|
||||
-0.03894043,
|
||||
-0.01108551,
|
||||
0.032440186,
|
||||
0.03237915,
|
||||
0.013839722,
|
||||
0.04397583,
|
||||
])
|
||||
.reshape([1, 8, 1, 1]);
|
||||
let var = TestTensor::<1>::from([
|
||||
0.67089844, 0.29956055, 0.5209961, 0.1862793, 0.30419922, 0.21313477, 0.7504883, 0.26342773,
|
||||
])
|
||||
.reshape([1, 8, 1, 1]);
|
||||
|
||||
let std = var.add_scalar(1e-5).sqrt();
|
||||
let x = x.sub(mean);
|
||||
let x = x.div(std);
|
||||
let x = x.mul(gamma.reshape([1, 8, 1, 1]));
|
||||
let x = x.add(beta.reshape([1, 8, 1, 1]));
|
||||
|
||||
let x = gelu(x);
|
||||
|
||||
let expected: Vec<f32> = [
|
||||
0.36432067f32,
|
||||
0.34909567,
|
||||
0.30684796,
|
||||
0.13217466,
|
||||
-0.018471397,
|
||||
-0.1389876,
|
||||
0.39402074,
|
||||
0.12394252,
|
||||
]
|
||||
.iter()
|
||||
.flat_map(|&v| core::iter::repeat_n(v, 676))
|
||||
.collect();
|
||||
let expected = TensorData::new(expected, [1, 8, 26, 26]);
|
||||
|
||||
x.into_data().assert_approx_eq::<FloatElem>(
|
||||
&expected,
|
||||
Tolerance::default().set_half_precision_absolute(1e-3),
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,299 @@
|
||||
use super::*;
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::conv3d;
|
||||
use burn_tensor::ops::ConvOptions;
|
||||
|
||||
#[test]
|
||||
fn test_conv3d_simple() {
|
||||
let test = Conv3dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
kernel_size_3: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
padding_3: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
stride_3: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
dilation_3: 1,
|
||||
groups: 1,
|
||||
depth: 4,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[
|
||||
[
|
||||
[29980.0, 44860.0, 45640.0, 30324.0],
|
||||
[45072.0, 67380.0, 68496.0, 45468.0],
|
||||
[48096.0, 71844.0, 72960.0, 48396.0],
|
||||
[31780.0, 47428.0, 48136.0, 31900.0],
|
||||
],
|
||||
[
|
||||
[47292.0, 70548.0, 71556.0, 47400.0],
|
||||
[70335.0, 104823.0, 106254.0, 70317.0],
|
||||
[74223.0, 110547.0, 111978.0, 74061.0],
|
||||
[48552.0, 72240.0, 73140.0, 48324.0],
|
||||
],
|
||||
[
|
||||
[58236.0, 86676.0, 87684.0, 57960.0],
|
||||
[85887.0, 127719.0, 129150.0, 85293.0],
|
||||
[89775.0, 133443.0, 134874.0, 89037.0],
|
||||
[58344.0, 86640.0, 87540.0, 57732.0],
|
||||
],
|
||||
[
|
||||
[36148.0, 53620.0, 54184.0, 35692.0],
|
||||
[52740.0, 78144.0, 78936.0, 51936.0],
|
||||
[54900.0, 81312.0, 82104.0, 54000.0],
|
||||
[35260.0, 52156.0, 52648.0, 34580.0],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[66701.0, 100589.0, 102665.0, 68773.0],
|
||||
[102745.0, 154861.0, 157921.0, 105733.0],
|
||||
[110953.0, 167101.0, 170161.0, 113845.0],
|
||||
[75413.0, 113525.0, 115529.0, 77261.0],
|
||||
],
|
||||
[
|
||||
[112741.0, 169693.0, 172645.0, 115441.0],
|
||||
[172396.0, 259372.0, 263719.0, 176266.0],
|
||||
[184060.0, 276760.0, 281107.0, 187786.0],
|
||||
[124369.0, 186937.0, 189781.0, 126733.0],
|
||||
],
|
||||
[
|
||||
[144421.0, 216925.0, 219877.0, 146737.0],
|
||||
[219052.0, 328924.0, 333271.0, 222346.0],
|
||||
[230716.0, 346312.0, 350659.0, 233866.0],
|
||||
[154897.0, 232441.0, 235285.0, 156877.0],
|
||||
],
|
||||
[
|
||||
[100517.0, 150821.0, 152681.0, 101789.0],
|
||||
[151885.0, 227833.0, 230569.0, 153673.0],
|
||||
[159229.0, 238777.0, 241513.0, 160921.0],
|
||||
[106541.0, 159725.0, 161513.0, 107589.0],
|
||||
],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv3d_groups() {
|
||||
let test = Conv3dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
kernel_size_3: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
padding_3: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
stride_3: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
dilation_3: 1,
|
||||
groups: 2,
|
||||
depth: 5,
|
||||
height: 5,
|
||||
width: 5,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[
|
||||
[
|
||||
[15219., 15570., 15921.],
|
||||
[16974., 17325., 17676.],
|
||||
[18729., 19080., 19431.],
|
||||
],
|
||||
[
|
||||
[23994., 24345., 24696.],
|
||||
[25749., 26100., 26451.],
|
||||
[27504., 27855., 28206.],
|
||||
],
|
||||
[
|
||||
[32769., 33120., 33471.],
|
||||
[34524., 34875., 35226.],
|
||||
[36279., 36630., 36981.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[172819., 173899., 174979.],
|
||||
[178219., 179299., 180379.],
|
||||
[183619., 184699., 185779.],
|
||||
],
|
||||
[
|
||||
[199819., 200899., 201979.],
|
||||
[205219., 206299., 207379.],
|
||||
[210619., 211699., 212779.],
|
||||
],
|
||||
[
|
||||
[226819., 227899., 228979.],
|
||||
[232219., 233299., 234379.],
|
||||
[237619., 238699., 239779.],
|
||||
],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv3d_complex() {
|
||||
let test = Conv3dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 3,
|
||||
channels_out: 4,
|
||||
kernel_size_1: 4,
|
||||
kernel_size_2: 3,
|
||||
kernel_size_3: 2,
|
||||
padding_1: 1,
|
||||
padding_2: 2,
|
||||
padding_3: 3,
|
||||
stride_1: 2,
|
||||
stride_2: 3,
|
||||
stride_3: 4,
|
||||
dilation_1: 1,
|
||||
dilation_2: 2,
|
||||
dilation_3: 3,
|
||||
groups: 1,
|
||||
depth: 4,
|
||||
height: 5,
|
||||
width: 6,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([
|
||||
[
|
||||
[
|
||||
[[149148., 299070., 149850.], [147636., 295758., 148050.]],
|
||||
[[150660., 301014., 150282.], [147420., 294246., 146754.]],
|
||||
],
|
||||
[
|
||||
[[351325., 709903., 358507.], [357589., 722143., 364483.]],
|
||||
[[391717., 789607., 397819.], [396253., 798391., 402067.]],
|
||||
],
|
||||
[
|
||||
[[553502., 1120736., 567164.], [567542., 1148528., 580916.]],
|
||||
[[632774., 1278200., 645356.], [645086., 1302536., 657380.]],
|
||||
],
|
||||
[
|
||||
[[755679., 1531569., 775821.], [777495., 1574913., 797349.]],
|
||||
[[873831., 1766793., 892893.], [893919., 1806681., 912693.]],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[[408348., 810990., 402570.], [393876., 781758., 387810.]],
|
||||
[[370980., 735174., 364122.], [354780., 702486., 347634.]],
|
||||
],
|
||||
[
|
||||
[
|
||||
[1077085., 2154943., 1077787.],
|
||||
[1070389., 2141263., 1070803.],
|
||||
],
|
||||
[
|
||||
[1078597., 2156887., 1078219.],
|
||||
[1070173., 2139751., 1069507.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[1745822., 3498896., 1753004.],
|
||||
[1746902., 3500768., 1753796.],
|
||||
],
|
||||
[
|
||||
[1786214., 3578600., 1792316.],
|
||||
[1785566., 3577016., 1791380.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[2414559., 4842849., 2428221.],
|
||||
[2423415., 4860273., 2436789.],
|
||||
],
|
||||
[
|
||||
[2493831., 5000313., 2506413.],
|
||||
[2500959., 5014281., 2513253.],
|
||||
],
|
||||
],
|
||||
],
|
||||
]));
|
||||
}
|
||||
|
||||
struct Conv3dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
channels_out: usize,
|
||||
kernel_size_1: usize,
|
||||
kernel_size_2: usize,
|
||||
kernel_size_3: usize,
|
||||
padding_1: usize,
|
||||
padding_2: usize,
|
||||
padding_3: usize,
|
||||
stride_1: usize,
|
||||
stride_2: usize,
|
||||
stride_3: usize,
|
||||
dilation_1: usize,
|
||||
dilation_2: usize,
|
||||
dilation_3: usize,
|
||||
groups: usize,
|
||||
depth: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
}
|
||||
|
||||
impl Conv3dTestCase {
|
||||
fn assert_output(self, y: TestTensor<5>) {
|
||||
let shape_x = Shape::new([
|
||||
self.batch_size,
|
||||
self.channels_in,
|
||||
self.depth,
|
||||
self.height,
|
||||
self.width,
|
||||
]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels_out,
|
||||
self.channels_in / self.groups,
|
||||
self.kernel_size_1,
|
||||
self.kernel_size_2,
|
||||
self.kernel_size_3,
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape::<5, _>(shape_weight)
|
||||
.into_data(),
|
||||
);
|
||||
let bias = TestTensor::from(
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
|
||||
);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<5, _>(shape_x)
|
||||
.into_data(),
|
||||
);
|
||||
let output = conv3d(
|
||||
x,
|
||||
weight,
|
||||
Some(bias),
|
||||
ConvOptions::new(
|
||||
[self.stride_1, self.stride_2, self.stride_3],
|
||||
[self.padding_1, self.padding_2, self.padding_3],
|
||||
[self.dilation_1, self.dilation_2, self.dilation_3],
|
||||
self.groups,
|
||||
),
|
||||
);
|
||||
|
||||
let tolerance = Tolerance::relative(1e-5).set_half_precision_relative(2e-3);
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), tolerance);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
use super::*;
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::conv_transpose1d;
|
||||
use burn_tensor::ops::ConvTransposeOptions;
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose1d_diff_channels() {
|
||||
let test = ConvTranspose1dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 3,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
padding_out: 0,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[270., 453., 516., 387.],
|
||||
[352., 589., 679., 505.],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose1d_stride() {
|
||||
let test = ConvTranspose1dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
padding_out: 1,
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[28., 62., 36., 78., 44., 94., 52., 62.],
|
||||
[41., 93., 55., 121., 69., 149., 83., 93.],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose1d_dilation() {
|
||||
let test = ConvTranspose1dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
padding_out: 0,
|
||||
stride: 1,
|
||||
dilation: 2,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[30., 64., 78., 76., 94., 52.],
|
||||
[49., 101., 127., 113., 143., 77.],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose1d_groups() {
|
||||
let test = ConvTranspose1dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
padding_out: 0,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 2,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats(
|
||||
[[[0., 1., 4., 7.], [32., 59., 71., 59.]]],
|
||||
&Default::default(),
|
||||
));
|
||||
}
|
||||
|
||||
struct ConvTranspose1dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
channels_out: usize,
|
||||
kernel_size: usize,
|
||||
padding: usize,
|
||||
padding_out: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
length: usize,
|
||||
}
|
||||
|
||||
impl ConvTranspose1dTestCase {
|
||||
fn assert_output(self, y: TestTensor<3>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]);
|
||||
let shape_weights = Shape::new([
|
||||
self.channels_in,
|
||||
self.channels_out / self.groups,
|
||||
self.kernel_size,
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weights = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weights.num_elements() as i64, &device)
|
||||
.reshape::<3, _>(shape_weights)
|
||||
.into_data(),
|
||||
&device,
|
||||
);
|
||||
let bias = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
|
||||
&device,
|
||||
);
|
||||
let x = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<3, _>(shape_x)
|
||||
.into_data(),
|
||||
&device,
|
||||
);
|
||||
let output = conv_transpose1d(
|
||||
x,
|
||||
weights,
|
||||
Some(bias),
|
||||
ConvTransposeOptions::new(
|
||||
[self.stride],
|
||||
[self.padding],
|
||||
[self.padding_out],
|
||||
[self.dilation],
|
||||
self.groups,
|
||||
),
|
||||
);
|
||||
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,361 @@
|
||||
use super::*;
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::conv_transpose2d;
|
||||
use burn_tensor::ops::ConvTransposeOptions;
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_simple_1() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 1,
|
||||
channels_out: 1,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
padding_out_1: 0,
|
||||
padding_out_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[[[5.0, 11.0], [23.0, 29.0]]]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_simple_2() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 3,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
padding_out_1: 0,
|
||||
padding_out_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[
|
||||
[9855., 15207., 15738., 10797.],
|
||||
[16290., 25119., 25956., 17793.],
|
||||
[18486., 28467., 29304., 20061.],
|
||||
[13593., 20913., 21498., 14703.],
|
||||
],
|
||||
[
|
||||
[11854., 18286., 18979., 13012.],
|
||||
[19612., 30223., 31303., 21439.],
|
||||
[22456., 34543., 35623., 24355.],
|
||||
[16456., 25288., 26035., 17782.],
|
||||
],
|
||||
[
|
||||
[13853., 21365., 22220., 15227.],
|
||||
[22934., 35327., 36650., 25085.],
|
||||
[26426., 40619., 41942., 28649.],
|
||||
[19319., 29663., 30572., 20861.],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_simple_3() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 1,
|
||||
channels_out: 1,
|
||||
kernel_size_1: 2,
|
||||
kernel_size_2: 2,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
padding_out_1: 0,
|
||||
padding_out_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[[
|
||||
[0.0, 0.0, 1.0],
|
||||
[0.0, 4.0, 6.0],
|
||||
[4.0, 12.0, 9.0],
|
||||
]]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_stride_2() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 1,
|
||||
channels_out: 1,
|
||||
kernel_size_1: 2,
|
||||
kernel_size_2: 2,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
padding_out_1: 0,
|
||||
padding_out_2: 0,
|
||||
stride_1: 2,
|
||||
stride_2: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[[
|
||||
[0.0, 0.0, 0.0, 1.0],
|
||||
[0.0, 0.0, 2.0, 3.0],
|
||||
[0.0, 2.0, 0.0, 3.0],
|
||||
[4.0, 6.0, 6.0, 9.0],
|
||||
]]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_dilation_2() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
padding_out_1: 1,
|
||||
padding_out_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 2,
|
||||
dilation_2: 2,
|
||||
groups: 1,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[
|
||||
[126., 116., 136., 124., 146.],
|
||||
[108., 88., 114., 92., 120.],
|
||||
[156., 140., 166., 148., 176.],
|
||||
[126., 100., 132., 104., 138.],
|
||||
[186., 164., 196., 172., 206.],
|
||||
],
|
||||
[
|
||||
[217., 189., 227., 197., 237.],
|
||||
[163., 125., 169., 129., 175.],
|
||||
[247., 213., 257., 221., 267.],
|
||||
[181., 137., 187., 141., 193.],
|
||||
[277., 237., 287., 245., 297.],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_stride2_out_padding() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
padding_out_1: 1,
|
||||
padding_out_2: 1,
|
||||
stride_1: 2,
|
||||
stride_2: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[
|
||||
[352., 728., 378., 780., 404., 832., 430., 452.],
|
||||
[784., 1616., 836., 1720., 888., 1824., 940., 992.],
|
||||
[456., 936., 482., 988., 508., 1040., 534., 564.],
|
||||
[992., 2032., 1044., 2136., 1096., 2240., 1148., 1216.],
|
||||
[560., 1144., 586., 1196., 612., 1248., 638., 676.],
|
||||
[1200., 2448., 1252., 2552., 1304., 2656., 1356., 1440.],
|
||||
[664., 1352., 690., 1404., 716., 1456., 742., 788.],
|
||||
[784., 1598., 816., 1662., 848., 1726., 880., 926.],
|
||||
],
|
||||
[
|
||||
[497., 1035., 541., 1123., 585., 1211., 629., 651.],
|
||||
[1145., 2373., 1233., 2549., 1321., 2725., 1409., 1461.],
|
||||
[673., 1387., 717., 1475., 761., 1563., 805., 835.],
|
||||
[1497., 3077., 1585., 3253., 1673., 3429., 1761., 1829.],
|
||||
[849., 1739., 893., 1827., 937., 1915., 981., 1019.],
|
||||
[1849., 3781., 1937., 3957., 2025., 4133., 2113., 2197.],
|
||||
[1025., 2091., 1069., 2179., 1113., 2267., 1157., 1203.],
|
||||
[1145., 2337., 1195., 2437., 1245., 2537., 1295., 1341.],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_groups_2() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
padding_out_1: 0,
|
||||
padding_out_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 2,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[[5., 11.], [23., 29.]],
|
||||
[[236., 258.], [302., 324.]],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_groups_different_channels() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 6,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
padding_out_1: 0,
|
||||
padding_out_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 2,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[
|
||||
[0.0000e+00, 0.0000e+00, 1.0000e+00, 2.0000e+00],
|
||||
[0.0000e+00, 5.0000e+00, 1.1000e+01, 1.1000e+01],
|
||||
[6.0000e+00, 2.3000e+01, 2.9000e+01, 2.3000e+01],
|
||||
[1.2000e+01, 3.2000e+01, 3.7000e+01, 2.4000e+01],
|
||||
],
|
||||
[
|
||||
[1.0000e+00, 1.0000e+01, 1.1000e+01, 1.2000e+01],
|
||||
[1.9000e+01, 6.0000e+01, 6.6000e+01, 4.8000e+01],
|
||||
[2.5000e+01, 7.8000e+01, 8.4000e+01, 6.0000e+01],
|
||||
[3.1000e+01, 7.8000e+01, 8.3000e+01, 5.2000e+01],
|
||||
],
|
||||
[
|
||||
[2.0000e+00, 2.0000e+01, 2.1000e+01, 2.2000e+01],
|
||||
[3.8000e+01, 1.1500e+02, 1.2100e+02, 8.5000e+01],
|
||||
[4.4000e+01, 1.3300e+02, 1.3900e+02, 9.7000e+01],
|
||||
[5.0000e+01, 1.2400e+02, 1.2900e+02, 8.0000e+01],
|
||||
],
|
||||
[
|
||||
[1.1100e+02, 2.5000e+02, 2.5900e+02, 1.4800e+02],
|
||||
[2.8500e+02, 6.3400e+02, 6.5600e+02, 3.6600e+02],
|
||||
[3.1500e+02, 7.0000e+02, 7.2200e+02, 4.0200e+02],
|
||||
[2.0100e+02, 4.3800e+02, 4.5100e+02, 2.4800e+02],
|
||||
],
|
||||
[
|
||||
[1.4800e+02, 3.3200e+02, 3.4100e+02, 1.9400e+02],
|
||||
[3.7600e+02, 8.3300e+02, 8.5500e+02, 4.7500e+02],
|
||||
[4.0600e+02, 8.9900e+02, 9.2100e+02, 5.1100e+02],
|
||||
[2.5600e+02, 5.5600e+02, 5.6900e+02, 3.1200e+02],
|
||||
],
|
||||
[
|
||||
[1.8500e+02, 4.1400e+02, 4.2300e+02, 2.4000e+02],
|
||||
[4.6700e+02, 1.0320e+03, 1.0540e+03, 5.8400e+02],
|
||||
[4.9700e+02, 1.0980e+03, 1.1200e+03, 6.2000e+02],
|
||||
[3.1100e+02, 6.7400e+02, 6.8700e+02, 3.7600e+02],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
struct ConvTranspose2dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
channels_out: usize,
|
||||
kernel_size_1: usize,
|
||||
kernel_size_2: usize,
|
||||
padding_1: usize,
|
||||
padding_2: usize,
|
||||
padding_out_1: usize,
|
||||
padding_out_2: usize,
|
||||
stride_1: usize,
|
||||
stride_2: usize,
|
||||
dilation_1: usize,
|
||||
dilation_2: usize,
|
||||
groups: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
}
|
||||
|
||||
impl ConvTranspose2dTestCase {
|
||||
fn assert_output(self, y: TestTensor<4>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
|
||||
let shape_weights = Shape::new([
|
||||
self.channels_in,
|
||||
self.channels_out / self.groups,
|
||||
self.kernel_size_1,
|
||||
self.kernel_size_2,
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weights = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_weights.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_weights)
|
||||
.into_data(),
|
||||
);
|
||||
let bias = TestTensor::from(
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
|
||||
);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data(),
|
||||
);
|
||||
let output = conv_transpose2d(
|
||||
x,
|
||||
weights,
|
||||
Some(bias),
|
||||
ConvTransposeOptions::new(
|
||||
[self.stride_1, self.stride_2],
|
||||
[self.padding_1, self.padding_2],
|
||||
[self.padding_out_1, self.padding_out_2],
|
||||
[self.dilation_1, self.dilation_2],
|
||||
self.groups,
|
||||
),
|
||||
);
|
||||
|
||||
y.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::rel_abs(1e-1, 0.01));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,749 @@
|
||||
use super::*;
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::conv_transpose3d;
|
||||
use burn_tensor::ops::ConvTransposeOptions;
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose3d_simple_1() {
|
||||
let test = ConvTranspose3dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 1,
|
||||
channels_out: 1,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
kernel_size_3: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
padding_3: 1,
|
||||
padding_out_1: 0,
|
||||
padding_out_2: 0,
|
||||
padding_out_3: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
stride_3: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
dilation_3: 1,
|
||||
groups: 1,
|
||||
depth: 2,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[[
|
||||
[[96., 124.], [180., 208.]],
|
||||
[[348., 376.], [432., 460.]],
|
||||
]]]));
|
||||
}
|
||||
#[test]
|
||||
fn test_conv_transpose3d_simple_2() {
|
||||
let test = ConvTranspose3dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 3,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
kernel_size_3: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
padding_3: 1,
|
||||
padding_out_1: 0,
|
||||
padding_out_2: 0,
|
||||
padding_out_3: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
stride_3: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
dilation_3: 1,
|
||||
groups: 1,
|
||||
depth: 4,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[
|
||||
[
|
||||
[238452., 360588., 363756., 244488.],
|
||||
[367929., 556353., 561186., 377163.],
|
||||
[380745., 575685., 580518., 390123.],
|
||||
[261192., 394896., 398172., 267564.],
|
||||
],
|
||||
[
|
||||
[394083., 595827., 600822., 403749.],
|
||||
[607635., 918648., 926262., 622404.],
|
||||
[627831., 949104., 956718., 642816.],
|
||||
[430353., 650529., 655686., 440523.],
|
||||
],
|
||||
[
|
||||
[447075., 675747., 680742., 457317.],
|
||||
[688419., 1040472., 1048086., 704052.],
|
||||
[708615., 1070928., 1078542., 724464.],
|
||||
[485073., 733041., 738198., 495819.],
|
||||
],
|
||||
[
|
||||
[328656., 496632., 500124., 335892.],
|
||||
[505611., 763983., 769302., 516645.],
|
||||
[519723., 785259., 790578., 530901.],
|
||||
[355428., 536988., 540588., 363000.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[286729., 433489., 437629., 294061.],
|
||||
[442288., 668620., 674911., 453466.],
|
||||
[458992., 693784., 700075., 470314.],
|
||||
[314653., 475573., 479821., 322321.],
|
||||
],
|
||||
[
|
||||
[474274., 716842., 723295., 485884.],
|
||||
[730837., 1104544., 1114345., 748522.],
|
||||
[756865., 1143748., 1153549., 774766.],
|
||||
[518320., 783208., 789823., 530434.],
|
||||
],
|
||||
[
|
||||
[542818., 820090., 826543., 555004.],
|
||||
[834949., 1261360., 1271161., 853498.],
|
||||
[860977., 1300564., 1310365., 879742.],
|
||||
[588592., 889048., 895663., 601282.],
|
||||
],
|
||||
[
|
||||
[397669., 600637., 605101., 406201.],
|
||||
[611074., 922906., 929683., 624052.],
|
||||
[629074., 950014., 956791., 642196.],
|
||||
[429625., 648769., 653341., 438493.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[335006., 506390., 511502., 343634.],
|
||||
[516647., 780887., 788636., 529769.],
|
||||
[537239., 811883., 819632., 550505.],
|
||||
[368114., 556250., 561470., 377078.],
|
||||
],
|
||||
[
|
||||
[554465., 837857., 845768., 568019.],
|
||||
[854039., 1290440., 1302428., 874640.],
|
||||
[885899., 1338392., 1350380., 906716.],
|
||||
[606287., 915887., 923960., 620345.],
|
||||
],
|
||||
[
|
||||
[638561., 964433., 972344., 652691.],
|
||||
[981479., 1482248., 1494236., 1002944.],
|
||||
[1013339., 1530200., 1542188., 1035020.],
|
||||
[692111., 1045055., 1053128., 706745.],
|
||||
],
|
||||
[
|
||||
[466682., 704642., 710078., 476510.],
|
||||
[716537., 1081829., 1090064., 731459.],
|
||||
[738425., 1114769., 1123004., 753491.],
|
||||
[503822., 760550., 766094., 513986.],
|
||||
],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose3d_stride_2() {
|
||||
let test = ConvTranspose3dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 1,
|
||||
channels_out: 1,
|
||||
kernel_size_1: 2,
|
||||
kernel_size_2: 2,
|
||||
kernel_size_3: 2,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
padding_3: 0,
|
||||
padding_out_1: 0,
|
||||
padding_out_2: 0,
|
||||
padding_out_3: 0,
|
||||
stride_1: 2,
|
||||
stride_2: 2,
|
||||
stride_3: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
dilation_3: 1,
|
||||
groups: 1,
|
||||
depth: 2,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[[
|
||||
[
|
||||
[0., 0., 0., 1.],
|
||||
[0., 0., 2., 3.],
|
||||
[0., 2., 0., 3.],
|
||||
[4., 6., 6., 9.],
|
||||
],
|
||||
[
|
||||
[0., 0., 4., 5.],
|
||||
[0., 0., 6., 7.],
|
||||
[8., 10., 12., 15.],
|
||||
[12., 14., 18., 21.],
|
||||
],
|
||||
[
|
||||
[0., 4., 0., 5.],
|
||||
[8., 12., 10., 15.],
|
||||
[0., 6., 0., 7.],
|
||||
[12., 18., 14., 21.],
|
||||
],
|
||||
[
|
||||
[16., 20., 20., 25.],
|
||||
[24., 28., 30., 35.],
|
||||
[24., 30., 28., 35.],
|
||||
[36., 42., 42., 49.],
|
||||
],
|
||||
]]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose3d_dilation_2() {
|
||||
let test = ConvTranspose3dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
kernel_size_3: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
padding_3: 1,
|
||||
padding_out_1: 1,
|
||||
padding_out_2: 1,
|
||||
padding_out_3: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
stride_3: 1,
|
||||
dilation_1: 2,
|
||||
dilation_2: 2,
|
||||
dilation_3: 2,
|
||||
groups: 1,
|
||||
depth: 2,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[
|
||||
[
|
||||
[810., 776., 832., 796., 854.],
|
||||
[756., 712., 774., 728., 792.],
|
||||
[876., 836., 898., 856., 920.],
|
||||
[810., 760., 828., 776., 846.],
|
||||
[942., 896., 964., 916., 986.],
|
||||
],
|
||||
[
|
||||
[720., 660., 734., 672., 748.],
|
||||
[606., 536., 616., 544., 626.],
|
||||
[762., 696., 776., 708., 790.],
|
||||
[636., 560., 646., 568., 656.],
|
||||
[804., 732., 818., 744., 832.],
|
||||
],
|
||||
[
|
||||
[1008., 956., 1030., 976., 1052.],
|
||||
[918., 856., 936., 872., 954.],
|
||||
[1074., 1016., 1096., 1036., 1118.],
|
||||
[972., 904., 990., 920., 1008.],
|
||||
[1140., 1076., 1162., 1096., 1184.],
|
||||
],
|
||||
[
|
||||
[846., 768., 860., 780., 874.],
|
||||
[696., 608., 706., 616., 716.],
|
||||
[888., 804., 902., 816., 916.],
|
||||
[726., 632., 736., 640., 746.],
|
||||
[930., 840., 944., 852., 958.],
|
||||
],
|
||||
[
|
||||
[1206., 1136., 1228., 1156., 1250.],
|
||||
[1080., 1000., 1098., 1016., 1116.],
|
||||
[1272., 1196., 1294., 1216., 1316.],
|
||||
[1134., 1048., 1152., 1064., 1170.],
|
||||
[1338., 1256., 1360., 1276., 1382.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[1405., 1317., 1427., 1337., 1449.],
|
||||
[1243., 1145., 1261., 1161., 1279.],
|
||||
[1471., 1377., 1493., 1397., 1515.],
|
||||
[1297., 1193., 1315., 1209., 1333.],
|
||||
[1537., 1437., 1559., 1457., 1581.],
|
||||
],
|
||||
[
|
||||
[1099., 985., 1113., 997., 1127.],
|
||||
[877., 753., 887., 761., 897.],
|
||||
[1141., 1021., 1155., 1033., 1169.],
|
||||
[907., 777., 917., 785., 927.],
|
||||
[1183., 1057., 1197., 1069., 1211.],
|
||||
],
|
||||
[
|
||||
[1603., 1497., 1625., 1517., 1647.],
|
||||
[1405., 1289., 1423., 1305., 1441.],
|
||||
[1669., 1557., 1691., 1577., 1713.],
|
||||
[1459., 1337., 1477., 1353., 1495.],
|
||||
[1735., 1617., 1757., 1637., 1779.],
|
||||
],
|
||||
[
|
||||
[1225., 1093., 1239., 1105., 1253.],
|
||||
[967., 825., 977., 833., 987.],
|
||||
[1267., 1129., 1281., 1141., 1295.],
|
||||
[997., 849., 1007., 857., 1017.],
|
||||
[1309., 1165., 1323., 1177., 1337.],
|
||||
],
|
||||
[
|
||||
[1801., 1677., 1823., 1697., 1845.],
|
||||
[1567., 1433., 1585., 1449., 1603.],
|
||||
[1867., 1737., 1889., 1757., 1911.],
|
||||
[1621., 1481., 1639., 1497., 1657.],
|
||||
[1933., 1797., 1955., 1817., 1977.],
|
||||
],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose3d_stride2_out_padding() {
|
||||
let test = ConvTranspose3dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
kernel_size_3: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
padding_3: 1,
|
||||
padding_out_1: 1,
|
||||
padding_out_2: 1,
|
||||
padding_out_3: 1,
|
||||
stride_1: 2,
|
||||
stride_2: 2,
|
||||
stride_3: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
dilation_3: 1,
|
||||
groups: 1,
|
||||
depth: 2,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[
|
||||
[
|
||||
[2144., 4366., 2224., 4526., 2304., 4686., 2384., 2422.],
|
||||
[4584., 9324., 4744., 9644., 4904., 9964., 5064., 5148.],
|
||||
[2464., 5006., 2544., 5166., 2624., 5326., 2704., 2750.],
|
||||
[5224., 10604., 5384., 10924., 5544., 11244., 5704., 5804.],
|
||||
[2784., 5646., 2864., 5806., 2944., 5966., 3024., 3078.],
|
||||
[5864., 11884., 6024., 12204., 6184., 12524., 6344., 6460.],
|
||||
[3104., 6286., 3184., 6446., 3264., 6606., 3344., 3406.],
|
||||
[3272., 6628., 3358., 6800., 3444., 6972., 3530., 3592.],
|
||||
],
|
||||
[
|
||||
[5280., 10716., 5440., 11036., 5600., 11356., 5760., 5868.],
|
||||
[
|
||||
11152., 22616., 11472., 23256., 11792., 23896., 12112., 12344.,
|
||||
],
|
||||
[5920., 11996., 6080., 12316., 6240., 12636., 6400., 6524.],
|
||||
[
|
||||
12432., 25176., 12752., 25816., 13072., 26456., 13392., 13656.,
|
||||
],
|
||||
[6560., 13276., 6720., 13596., 6880., 13916., 7040., 7180.],
|
||||
[
|
||||
13712., 27736., 14032., 28376., 14352., 29016., 14672., 14968.,
|
||||
],
|
||||
[7200., 14556., 7360., 14876., 7520., 15196., 7680., 7836.],
|
||||
[7632., 15432., 7804., 15776., 7976., 16120., 8148., 8304.],
|
||||
],
|
||||
[
|
||||
[3424., 6926., 3504., 7086., 3584., 7246., 3664., 3734.],
|
||||
[7144., 14444., 7304., 14764., 7464., 15084., 7624., 7772.],
|
||||
[3744., 7566., 3824., 7726., 3904., 7886., 3984., 4062.],
|
||||
[7784., 15724., 7944., 16044., 8104., 16364., 8264., 8428.],
|
||||
[4064., 8206., 4144., 8366., 4224., 8526., 4304., 4390.],
|
||||
[8424., 17004., 8584., 17324., 8744., 17644., 8904., 9084.],
|
||||
[4384., 8846., 4464., 9006., 4544., 9166., 4624., 4718.],
|
||||
[4648., 9380., 4734., 9552., 4820., 9724., 4906., 5000.],
|
||||
],
|
||||
[
|
||||
[4000., 8096., 4098., 8292., 4196., 8488., 4294., 4364.],
|
||||
[8368., 16928., 8564., 17320., 8760., 17712., 8956., 9104.],
|
||||
[4392., 8880., 4490., 9076., 4588., 9272., 4686., 4764.],
|
||||
[9152., 18496., 9348., 18888., 9544., 19280., 9740., 9904.],
|
||||
[4784., 9664., 4882., 9860., 4980., 10056., 5078., 5164.],
|
||||
[
|
||||
9936., 20064., 10132., 20456., 10328., 20848., 10524., 10704.,
|
||||
],
|
||||
[5176., 10448., 5274., 10644., 5372., 10840., 5470., 5564.],
|
||||
[5440., 10982., 5544., 11190., 5648., 11398., 5752., 5846.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[3009., 6149., 3143., 6417., 3277., 6685., 3411., 3449.],
|
||||
[6529., 13321., 6797., 13857., 7065., 14393., 7333., 7417.],
|
||||
[3545., 7221., 3679., 7489., 3813., 7757., 3947., 3993.],
|
||||
[7601., 15465., 7869., 16001., 8137., 16537., 8405., 8505.],
|
||||
[4081., 8293., 4215., 8561., 4349., 8829., 4483., 4537.],
|
||||
[8673., 17609., 8941., 18145., 9209., 18681., 9477., 9593.],
|
||||
[4617., 9365., 4751., 9633., 4885., 9901., 5019., 5081.],
|
||||
[4785., 9707., 4925., 9987., 5065., 10267., 5205., 5267.],
|
||||
],
|
||||
[
|
||||
[7873., 16009., 8141., 16545., 8409., 17081., 8677., 8785.],
|
||||
[
|
||||
16769., 34065., 17305., 35137., 17841., 36209., 18377., 18609.,
|
||||
],
|
||||
[8945., 18153., 9213., 18689., 9481., 19225., 9749., 9873.],
|
||||
[
|
||||
18913., 38353., 19449., 39425., 19985., 40497., 20521., 20785.,
|
||||
],
|
||||
[
|
||||
10017., 20297., 10285., 20833., 10553., 21369., 10821., 10961.,
|
||||
],
|
||||
[
|
||||
21057., 42641., 21593., 43713., 22129., 44785., 22665., 22961.,
|
||||
],
|
||||
[
|
||||
11089., 22441., 11357., 22977., 11625., 23513., 11893., 12049.,
|
||||
],
|
||||
[
|
||||
11521., 23317., 11801., 23877., 12081., 24437., 12361., 12517.,
|
||||
],
|
||||
],
|
||||
[
|
||||
[5153., 10437., 5287., 10705., 5421., 10973., 5555., 5625.],
|
||||
[
|
||||
10817., 21897., 11085., 22433., 11353., 22969., 11621., 11769.,
|
||||
],
|
||||
[5689., 11509., 5823., 11777., 5957., 12045., 6091., 6169.],
|
||||
[
|
||||
11889., 24041., 12157., 24577., 12425., 25113., 12693., 12857.,
|
||||
],
|
||||
[6225., 12581., 6359., 12849., 6493., 13117., 6627., 6713.],
|
||||
[
|
||||
12961., 26185., 13229., 26721., 13497., 27257., 13765., 13945.,
|
||||
],
|
||||
[6761., 13653., 6895., 13921., 7029., 14189., 7163., 7257.],
|
||||
[7025., 14187., 7165., 14467., 7305., 14747., 7445., 7539.],
|
||||
],
|
||||
[
|
||||
[5729., 11607., 5881., 11911., 6033., 12215., 6185., 6255.],
|
||||
[
|
||||
12041., 24381., 12345., 24989., 12649., 25597., 12953., 13101.,
|
||||
],
|
||||
[6337., 12823., 6489., 13127., 6641., 13431., 6793., 6871.],
|
||||
[
|
||||
13257., 26813., 13561., 27421., 13865., 28029., 14169., 14333.,
|
||||
],
|
||||
[6945., 14039., 7097., 14343., 7249., 14647., 7401., 7487.],
|
||||
[
|
||||
14473., 29245., 14777., 29853., 15081., 30461., 15385., 15565.,
|
||||
],
|
||||
[7553., 15255., 7705., 15559., 7857., 15863., 8009., 8103.],
|
||||
[7817., 15789., 7975., 16105., 8133., 16421., 8291., 8385.],
|
||||
],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose3d_groups_2() {
|
||||
let test = ConvTranspose3dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
kernel_size_3: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
padding_3: 1,
|
||||
padding_out_1: 0,
|
||||
padding_out_2: 0,
|
||||
padding_out_3: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
stride_3: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
dilation_3: 1,
|
||||
groups: 2,
|
||||
depth: 2,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[[[96., 124.], [180., 208.]], [[348., 376.], [432., 460.]]],
|
||||
[
|
||||
[[2997., 3089.], [3273., 3365.]],
|
||||
[[3825., 3917.], [4101., 4193.]],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose3d_groups_different_channels() {
|
||||
let test = ConvTranspose3dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 6,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
kernel_size_3: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
padding_3: 0,
|
||||
padding_out_1: 0,
|
||||
padding_out_2: 0,
|
||||
padding_out_3: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
stride_3: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
dilation_3: 1,
|
||||
groups: 2,
|
||||
depth: 2,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[
|
||||
[
|
||||
[0., 0., 1., 2.],
|
||||
[0., 5., 11., 11.],
|
||||
[6., 23., 29., 23.],
|
||||
[12., 32., 37., 24.],
|
||||
],
|
||||
[
|
||||
[0., 13., 23., 21.],
|
||||
[30., 96., 124., 86.],
|
||||
[66., 180., 208., 134.],
|
||||
[66., 161., 179., 107.],
|
||||
],
|
||||
[
|
||||
[36., 103., 113., 75.],
|
||||
[138., 348., 376., 230.],
|
||||
[174., 432., 460., 278.],
|
||||
[138., 323., 341., 197.],
|
||||
],
|
||||
[
|
||||
[72., 166., 175., 100.],
|
||||
[192., 433., 455., 255.],
|
||||
[222., 499., 521., 291.],
|
||||
[144., 318., 331., 182.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[1., 28., 29., 30.],
|
||||
[55., 168., 174., 120.],
|
||||
[61., 186., 192., 132.],
|
||||
[67., 168., 173., 106.],
|
||||
],
|
||||
[
|
||||
[109., 284., 294., 184.],
|
||||
[355., 853., 881., 519.],
|
||||
[391., 937., 965., 567.],
|
||||
[283., 648., 666., 378.],
|
||||
],
|
||||
[
|
||||
[145., 374., 384., 238.],
|
||||
[463., 1105., 1133., 663.],
|
||||
[499., 1189., 1217., 711.],
|
||||
[355., 810., 828., 468.],
|
||||
],
|
||||
[
|
||||
[181., 410., 419., 236.],
|
||||
[463., 1028., 1050., 580.],
|
||||
[493., 1094., 1116., 616.],
|
||||
[307., 670., 683., 372.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[2., 56., 57., 58.],
|
||||
[110., 331., 337., 229.],
|
||||
[116., 349., 355., 241.],
|
||||
[122., 304., 309., 188.],
|
||||
],
|
||||
[
|
||||
[218., 555., 565., 347.],
|
||||
[680., 1610., 1638., 952.],
|
||||
[716., 1694., 1722., 1000.],
|
||||
[500., 1135., 1153., 649.],
|
||||
],
|
||||
[
|
||||
[254., 645., 655., 401.],
|
||||
[788., 1862., 1890., 1096.],
|
||||
[824., 1946., 1974., 1144.],
|
||||
[572., 1297., 1315., 739.],
|
||||
],
|
||||
[
|
||||
[290., 654., 663., 372.],
|
||||
[734., 1623., 1645., 905.],
|
||||
[764., 1689., 1711., 941.],
|
||||
[470., 1022., 1035., 562.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[651., 1388., 1405., 750.],
|
||||
[1485., 3150., 3188., 1690.],
|
||||
[1539., 3264., 3302., 1750.],
|
||||
[873., 1840., 1861., 982.],
|
||||
],
|
||||
[
|
||||
[1695., 3578., 3620., 1910.],
|
||||
[3789., 7967., 8059., 4233.],
|
||||
[3921., 8243., 8335., 4377.],
|
||||
[2181., 4566., 4616., 2416.],
|
||||
],
|
||||
[
|
||||
[1875., 3956., 3998., 2108.],
|
||||
[4185., 8795., 8887., 4665.],
|
||||
[4317., 9071., 9163., 4809.],
|
||||
[2397., 5016., 5066., 2650.],
|
||||
],
|
||||
[
|
||||
[1191., 2490., 2515., 1316.],
|
||||
[2613., 5450., 5504., 2870.],
|
||||
[2691., 5612., 5666., 2954.],
|
||||
[1473., 3062., 3091., 1608.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[868., 1848., 1865., 994.],
|
||||
[1972., 4177., 4215., 2231.],
|
||||
[2026., 4291., 4329., 2291.],
|
||||
[1144., 2408., 2429., 1280.],
|
||||
],
|
||||
[
|
||||
[2236., 4713., 4755., 2505.],
|
||||
[4978., 10452., 10544., 5530.],
|
||||
[5110., 10728., 10820., 5674.],
|
||||
[2830., 5917., 5967., 3119.],
|
||||
],
|
||||
[
|
||||
[2416., 5091., 5133., 2703.],
|
||||
[5374., 11280., 11372., 5962.],
|
||||
[5506., 11556., 11648., 6106.],
|
||||
[3046., 6367., 6417., 3353.],
|
||||
],
|
||||
[
|
||||
[1516., 3166., 3191., 1668.],
|
||||
[3316., 6909., 6963., 3627.],
|
||||
[3394., 7071., 7125., 3711.],
|
||||
[1852., 3846., 3875., 2014.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[1085., 2308., 2325., 1238.],
|
||||
[2459., 5204., 5242., 2772.],
|
||||
[2513., 5318., 5356., 2832.],
|
||||
[1415., 2976., 2997., 1578.],
|
||||
],
|
||||
[
|
||||
[2777., 5848., 5890., 3100.],
|
||||
[6167., 12937., 13029., 6827.],
|
||||
[6299., 13213., 13305., 6971.],
|
||||
[3479., 7268., 7318., 3822.],
|
||||
],
|
||||
[
|
||||
[2957., 6226., 6268., 3298.],
|
||||
[6563., 13765., 13857., 7259.],
|
||||
[6695., 14041., 14133., 7403.],
|
||||
[3695., 7718., 7768., 4056.],
|
||||
],
|
||||
[
|
||||
[1841., 3842., 3867., 2020.],
|
||||
[4019., 8368., 8422., 4384.],
|
||||
[4097., 8530., 8584., 4468.],
|
||||
[2231., 4630., 4659., 2420.],
|
||||
],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
struct ConvTranspose3dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
channels_out: usize,
|
||||
kernel_size_1: usize,
|
||||
kernel_size_2: usize,
|
||||
kernel_size_3: usize,
|
||||
padding_1: usize,
|
||||
padding_2: usize,
|
||||
padding_3: usize,
|
||||
padding_out_1: usize,
|
||||
padding_out_2: usize,
|
||||
padding_out_3: usize,
|
||||
stride_1: usize,
|
||||
stride_2: usize,
|
||||
stride_3: usize,
|
||||
dilation_1: usize,
|
||||
dilation_2: usize,
|
||||
dilation_3: usize,
|
||||
groups: usize,
|
||||
depth: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
}
|
||||
|
||||
impl ConvTranspose3dTestCase {
|
||||
fn assert_output(self, y: TestTensor<5>) {
|
||||
let shape_x = Shape::new([
|
||||
self.batch_size,
|
||||
self.channels_in,
|
||||
self.depth,
|
||||
self.height,
|
||||
self.width,
|
||||
]);
|
||||
let shape_weights = Shape::new([
|
||||
self.channels_in,
|
||||
self.channels_out / self.groups,
|
||||
self.kernel_size_1,
|
||||
self.kernel_size_2,
|
||||
self.kernel_size_3,
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weights = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_weights.num_elements() as i64, &device)
|
||||
.reshape::<5, _>(shape_weights)
|
||||
.into_data(),
|
||||
);
|
||||
let bias = TestTensor::from(
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
|
||||
);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<5, _>(shape_x)
|
||||
.into_data(),
|
||||
);
|
||||
let output = conv_transpose3d(
|
||||
x,
|
||||
weights,
|
||||
Some(bias),
|
||||
ConvTransposeOptions::new(
|
||||
[self.stride_1, self.stride_2, self.stride_3],
|
||||
[self.padding_1, self.padding_2, self.padding_3],
|
||||
[self.padding_out_1, self.padding_out_2, self.padding_out_3],
|
||||
[self.dilation_1, self.dilation_2, self.dilation_3],
|
||||
self.groups,
|
||||
),
|
||||
);
|
||||
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,438 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::deform_conv2d;
|
||||
use burn_tensor::ops::DeformConvOptions;
|
||||
use burn_tensor::{Shape, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_deform_conv2d_simple() {
|
||||
let test = DeformConv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 3,
|
||||
channels_out: 5,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
weight_groups: 1,
|
||||
offset_groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::<4>::from([[
|
||||
[[0.9074, 0.6387], [0.5160, 0.4196]],
|
||||
[[2.4259, 1.8008], [1.5449, 1.3112]],
|
||||
[[3.9444, 2.9629], [2.5738, 2.2027]],
|
||||
[[5.4629, 4.1250], [3.6027, 3.0943]],
|
||||
[[6.9814, 5.2871], [4.6316, 3.9859]],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deform_conv2d_batched() {
|
||||
let test = DeformConv2dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 3,
|
||||
channels_out: 5,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
weight_groups: 1,
|
||||
offset_groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::<4>::from([
|
||||
[
|
||||
[[0.215466, 0.192846], [0.193407, 0.175496]],
|
||||
[[0.725073, 0.675926], [0.687746, 0.648506]],
|
||||
[[1.234679, 1.159006], [1.182085, 1.121516]],
|
||||
[[1.744286, 1.642086], [1.676423, 1.594526]],
|
||||
[[2.253892, 2.125167], [2.170762, 2.067536]],
|
||||
],
|
||||
[
|
||||
[[1.652976, 1.136937], [0.984030, 0.718403]],
|
||||
[[4.836801, 3.472453], [3.177263, 2.418021]],
|
||||
[[8.020626, 5.807969], [5.370497, 4.117639]],
|
||||
[[11.204453, 8.143486], [7.563731, 5.817256]],
|
||||
[[14.388277, 10.479003], [9.756965, 7.516875]],
|
||||
],
|
||||
]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deform_conv2d_weight_groups() {
|
||||
let test = DeformConv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 3,
|
||||
channels_out: 6,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
weight_groups: 3,
|
||||
offset_groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::<4>::from([[
|
||||
[[0.101823, 0.065756], [0.046691, 0.036233]],
|
||||
[[0.412523, 0.336674], [0.306863, 0.282386]],
|
||||
[[1.307585, 1.024152], [0.902454, 0.800008]],
|
||||
[[1.840507, 1.458072], [1.299371, 1.158781]],
|
||||
[[3.402235, 2.634555], [2.305198, 2.014265]],
|
||||
[[4.157379, 3.231476], [2.838861, 2.485659]],
|
||||
]]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deform_conv2d_offset_groups() {
|
||||
let test = DeformConv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 3,
|
||||
channels_out: 6,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
weight_groups: 1,
|
||||
offset_groups: 3,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::<4>::from([[
|
||||
[[1.0794, 0.7676], [0.7209, 0.5337]],
|
||||
[[2.7059, 2.0216], [1.9740, 1.5419]],
|
||||
[[4.3325, 3.2755], [3.2271, 2.5501]],
|
||||
[[5.9590, 4.5295], [4.4802, 3.5582]],
|
||||
[[7.5855, 5.7835], [5.7333, 4.5664]],
|
||||
[[9.2120, 7.0375], [6.9864, 5.5746]],
|
||||
]]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deform_conv2d_different_kernel_size() {
|
||||
let test = DeformConv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 4,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
weight_groups: 1,
|
||||
offset_groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::<4>::from([[
|
||||
[[1.0669], [0.6329]],
|
||||
[[2.9741], [2.0383]],
|
||||
[[4.8812], [3.4437]],
|
||||
]]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deform_conv2d_different_padding_size() {
|
||||
let test = DeformConv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 2,
|
||||
padding_2: 3,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
weight_groups: 1,
|
||||
offset_groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::<4>::from([[
|
||||
[
|
||||
[
|
||||
0.199779, 0.376176, 0.528501, 0.605256, 0.384365, 0.198675, 0.048145, 0.000000,
|
||||
],
|
||||
[
|
||||
0.287923, 0.551719, 0.777562, 0.890479, 0.580469, 0.304325, 0.079554, 0.000000,
|
||||
],
|
||||
[
|
||||
0.372947, 0.721405, 1.013668, 1.151988, 0.756444, 0.393098, 0.101582, 0.000000,
|
||||
],
|
||||
[
|
||||
0.132138, 0.324872, 0.495372, 0.584617, 0.453122, 0.250084, 0.075703, 0.000000,
|
||||
],
|
||||
[
|
||||
0.059332, 0.160658, 0.244789, 0.297057, 0.239464, 0.132701, 0.047114, 0.000000,
|
||||
],
|
||||
[
|
||||
0.014338, 0.051338, 0.078303, 0.094190, 0.081278, 0.041954, 0.014506, 0.000000,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
0.766652, 1.164805, 1.521938, 1.711110, 1.230500, 0.807579, 0.450423, 0.333333,
|
||||
],
|
||||
[
|
||||
0.981162, 1.601005, 2.152534, 2.440920, 1.745547, 1.091843, 0.536749, 0.333333,
|
||||
],
|
||||
[
|
||||
1.196386, 2.044845, 2.785330, 3.152243, 2.242613, 1.351308, 0.604905, 0.333333,
|
||||
],
|
||||
[
|
||||
0.669465, 1.178133, 1.644096, 1.902188, 1.573183, 1.033924, 0.553577, 0.333333,
|
||||
],
|
||||
[
|
||||
0.495048, 0.786124, 1.039796, 1.204721, 1.052342, 0.743887, 0.483380, 0.333333,
|
||||
],
|
||||
[
|
||||
0.378767, 0.498209, 0.592867, 0.654230, 0.615487, 0.488202, 0.390890, 0.333333,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
1.333524, 1.953435, 2.515375, 2.816964, 2.076636, 1.416483, 0.852701, 0.666667,
|
||||
],
|
||||
[
|
||||
1.674402, 2.650291, 3.527507, 3.991360, 2.910625, 1.879361, 0.993943, 0.666667,
|
||||
],
|
||||
[
|
||||
2.019825, 3.368286, 4.556992, 5.152499, 3.728782, 2.309520, 1.108229, 0.666667,
|
||||
],
|
||||
[
|
||||
1.206791, 2.031395, 2.792820, 3.219759, 2.693245, 1.817763, 1.031452, 0.666667,
|
||||
],
|
||||
[
|
||||
0.930765, 1.411590, 1.834802, 2.112385, 1.865221, 1.355072, 0.919646, 0.666667,
|
||||
],
|
||||
[
|
||||
0.743195, 0.945081, 1.107431, 1.214270, 1.149695, 0.934451, 0.767274, 0.666667,
|
||||
],
|
||||
],
|
||||
]]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deform_conv2d_different_stride() {
|
||||
let test = DeformConv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
weight_groups: 1,
|
||||
offset_groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::<4>::from([[
|
||||
[[1.0647], [0.5783]],
|
||||
[[2.9289], [1.8829]],
|
||||
[[4.7931], [3.1875]],
|
||||
]]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deform_conv2d_different_dilation() {
|
||||
let test = DeformConv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 2,
|
||||
weight_groups: 1,
|
||||
offset_groups: 1,
|
||||
height: 5,
|
||||
width: 5,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::<4>::from([[
|
||||
[[0.6162], [0.7611], [0.4666]],
|
||||
[[1.8578], [2.2684], [1.6208]],
|
||||
[[3.0994], [3.7757], [2.7749]],
|
||||
]]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deform_conv2d_different_width() {
|
||||
let test = DeformConv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
weight_groups: 1,
|
||||
offset_groups: 1,
|
||||
height: 6,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::<4>::from([[
|
||||
[
|
||||
[0.8909, 0.6016],
|
||||
[1.0697, 0.7186],
|
||||
[1.2618, 0.8433],
|
||||
[0.6424, 0.5032],
|
||||
],
|
||||
[
|
||||
[2.4670, 1.8168],
|
||||
[2.9529, 2.1497],
|
||||
[3.4805, 2.5090],
|
||||
[2.0925, 1.7411],
|
||||
],
|
||||
[
|
||||
[4.0432, 3.0321],
|
||||
[4.8362, 3.5809],
|
||||
[5.6992, 4.1746],
|
||||
[3.5425, 2.9790],
|
||||
],
|
||||
]]))
|
||||
}
|
||||
|
||||
struct DeformConv2dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
channels_out: usize,
|
||||
kernel_size_1: usize,
|
||||
kernel_size_2: usize,
|
||||
padding_1: usize,
|
||||
padding_2: usize,
|
||||
stride_1: usize,
|
||||
stride_2: usize,
|
||||
dilation_1: usize,
|
||||
dilation_2: usize,
|
||||
weight_groups: usize,
|
||||
offset_groups: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
}
|
||||
|
||||
impl DeformConv2dTestCase {
|
||||
fn assert_output(self, y: Tensor<TestBackend, 4>) {
|
||||
let out_height =
|
||||
(self.height + 2 * self.padding_1 - self.dilation_1 * (self.kernel_size_1 - 1) - 1)
|
||||
/ self.stride_1
|
||||
+ 1;
|
||||
let out_width =
|
||||
(self.width + 2 * self.padding_2 - self.dilation_2 * (self.kernel_size_2 - 1) - 1)
|
||||
/ self.stride_2
|
||||
+ 1;
|
||||
|
||||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels_out,
|
||||
self.channels_in / self.weight_groups,
|
||||
self.kernel_size_1,
|
||||
self.kernel_size_2,
|
||||
]);
|
||||
let shape_offset = Shape::new([
|
||||
self.batch_size,
|
||||
self.kernel_size_1 * self.kernel_size_2 * self.offset_groups * 2,
|
||||
out_height,
|
||||
out_width,
|
||||
]);
|
||||
let shape_mask = Shape::new([
|
||||
self.batch_size,
|
||||
self.kernel_size_1 * self.kernel_size_2 * self.offset_groups,
|
||||
out_height,
|
||||
out_width,
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestTensor::<4>::from(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_weight.clone())
|
||||
.into_data(),
|
||||
)
|
||||
.div_scalar(shape_weight.num_elements() as f32);
|
||||
let bias = TestTensor::<1>::from(
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
|
||||
)
|
||||
.div_scalar(self.channels_out as f32);
|
||||
let x = TestTensor::<4>::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_x.clone())
|
||||
.into_data(),
|
||||
)
|
||||
.div_scalar(shape_x.num_elements() as f32);
|
||||
let offset = TestTensor::<4>::from(
|
||||
TestTensorInt::arange(0..shape_offset.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_offset.clone())
|
||||
.into_data(),
|
||||
)
|
||||
.div_scalar(shape_offset.num_elements() as f32);
|
||||
let mask = TestTensor::<4>::from(
|
||||
TestTensorInt::arange(0..shape_mask.num_elements() as i64, &device)
|
||||
.reshape::<4, _>(shape_mask.clone())
|
||||
.into_data(),
|
||||
)
|
||||
.div_scalar(shape_mask.num_elements() as f32);
|
||||
|
||||
let output = deform_conv2d(
|
||||
x,
|
||||
offset,
|
||||
weight,
|
||||
Some(mask),
|
||||
Some(bias),
|
||||
DeformConvOptions::new(
|
||||
[self.stride_1, self.stride_2],
|
||||
[self.padding_1, self.padding_2],
|
||||
[self.dilation_1, self.dilation_2],
|
||||
self.weight_groups,
|
||||
self.offset_groups,
|
||||
),
|
||||
);
|
||||
|
||||
let tolerance = Tolerance::permissive();
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), tolerance);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, module::embedding};
|
||||
|
||||
#[test]
|
||||
fn test_embedding_forward() {
|
||||
let weights = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let indices = TensorData::from([[0, 1], [1, 1]]);
|
||||
let weights = TestTensor::<2>::from(weights);
|
||||
let indices = TestTensorInt::<2>::from(indices);
|
||||
|
||||
let output = embedding(weights, indices);
|
||||
let expected = TensorData::from([
|
||||
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
|
||||
[[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]],
|
||||
]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::linear;
|
||||
|
||||
#[test]
|
||||
fn test_linear_1d() {
|
||||
let weight = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
|
||||
let x = TestTensor::<1>::from([1.0, 2.0]);
|
||||
let output = linear(x, weight, None);
|
||||
|
||||
let expected = TensorData::from([7.0, 10.0]);
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::relative(1e-5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linear_1d_one_element_output() {
|
||||
let weight = TestTensor::<2>::from([[3.0], [4.0]]);
|
||||
|
||||
let x = TestTensor::<1>::from([1.0, 2.0]);
|
||||
let output = linear(x, weight, None);
|
||||
|
||||
let expected = TensorData::from([11.0]);
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::relative(1e-5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linear_forward_no_bias() {
|
||||
let weight = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
|
||||
let x = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[-1.0, -2.0], [-3.0, -4.0]]]);
|
||||
|
||||
let output = linear(x, weight, None);
|
||||
|
||||
let expected = TensorData::from([[[7.0, 10.0], [15.0, 22.0]], [[-7.0, -10.0], [-15.0, -22.0]]]);
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::relative(1e-5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linear_forward_with_bias() {
|
||||
let weight = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let bias = Some(TestTensor::<1>::from([1.0, -1.0]));
|
||||
|
||||
let x = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[-1.0, -2.0], [-3.0, -4.0]]]);
|
||||
|
||||
let output = linear(x, weight, bias);
|
||||
|
||||
let expected = TensorData::from([[[8.0, 9.0], [16.0, 21.0]], [[-6.0, -11.0], [-14.0, -23.0]]]);
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::relative(1e-5));
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::{max_pool1d, max_pool1d_with_indices};
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_simple() {
|
||||
let kernel_size = 3;
|
||||
let padding = 0;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
let x = TestTensor::from([[
|
||||
[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221],
|
||||
[0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689],
|
||||
]]);
|
||||
let y = TestTensor::<3>::from([[
|
||||
[0.9861, 0.5474, 0.4477, 0.8221],
|
||||
[0.949, 0.949, 0.949, 0.789],
|
||||
]]);
|
||||
|
||||
let output = max_pool1d(x, kernel_size, stride, padding, dilation, false);
|
||||
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_different_padding_stride_kernel() {
|
||||
let kernel_size = 3;
|
||||
let padding = 1;
|
||||
let stride = 2;
|
||||
let dilation = 1;
|
||||
|
||||
let x = TestTensor::from([[[0.6309, 0.6112, 0.6998, 0.4708]]]);
|
||||
let y = TestTensor::<3>::from([[[0.6309, 0.6998]]]);
|
||||
|
||||
let output = max_pool1d(x, kernel_size, stride, padding, dilation, false);
|
||||
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_with_neg() {
|
||||
let kernel_size = 3;
|
||||
let padding = 1;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
let x = TestTensor::from([[[-0.6309, -0.6112, -0.6998, -0.4708]]]);
|
||||
let y = TestTensor::<3>::from([[[-0.6112, -0.6112, -0.4708, -0.4708]]]);
|
||||
|
||||
let output = max_pool1d(x, kernel_size, stride, padding, dilation, false);
|
||||
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_with_dilation() {
|
||||
let kernel_size = 2;
|
||||
let padding = 1;
|
||||
let stride = 1;
|
||||
let dilation = 2;
|
||||
|
||||
let x = TestTensor::from([[
|
||||
[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221],
|
||||
[0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689],
|
||||
]]);
|
||||
let y = TestTensor::<3>::from([[
|
||||
[0.5474, 0.9861, 0.5474, 0.4477, 0.8221, 0.3548],
|
||||
[0.5474, 0.9490, 0.7890, 0.9490, 0.7890, 0.5537],
|
||||
]]);
|
||||
|
||||
let output = max_pool1d(x, kernel_size, stride, padding, dilation, false);
|
||||
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_with_indices() {
|
||||
let kernel_size = 2;
|
||||
let padding = 0;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
let x = TestTensor::from([[[0.2479, 0.6386, 0.3166, 0.5742]]]);
|
||||
let indices = TensorData::from([[[1, 1, 3]]]);
|
||||
let y = TestTensor::<3>::from([[[0.6386, 0.6386, 0.5742]]]);
|
||||
|
||||
let (output, output_indices) =
|
||||
max_pool1d_with_indices(x, kernel_size, stride, padding, dilation, false);
|
||||
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
output_indices.into_data().assert_eq(&indices, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_complex() {
|
||||
let kernel_size = 4;
|
||||
let padding = 2;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
let x = TestTensor::from([[[0.5388, 0.0676, 0.7122, 0.8316, 0.0653]]]);
|
||||
let indices = TensorData::from([[[0, 2, 3, 3, 3, 3]]]);
|
||||
let y = TestTensor::<3>::from([[[0.5388, 0.7122, 0.8316, 0.8316, 0.8316, 0.8316]]]);
|
||||
|
||||
let (output, output_indices) =
|
||||
max_pool1d_with_indices(x, kernel_size, stride, padding, dilation, false);
|
||||
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
output_indices.into_data().assert_eq(&indices, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_ceil_mode() {
|
||||
// Test ceil_mode=true produces larger output when input doesn't divide evenly by stride
|
||||
// Input: 1x1x6, kernel: 3, stride: 2, padding: 0
|
||||
// Floor mode: output = (6-3)/2+1 = 2 elements
|
||||
// Ceil mode: output = ceil((6-3)/2)+1 = ceil(1.5)+1 = 3 elements
|
||||
let kernel_size = 3;
|
||||
let padding = 0;
|
||||
let stride = 2;
|
||||
let dilation = 1;
|
||||
|
||||
let x = TestTensor::from([[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]]);
|
||||
|
||||
// With ceil_mode=false (floor): output is 2 elements
|
||||
// Window 0: positions [0:3] -> max(1,2,3) = 3
|
||||
// Window 1: positions [2:5] -> max(3,4,5) = 5
|
||||
let y_floor = TestTensor::<3>::from([[[3.0, 5.0]]]);
|
||||
|
||||
let output_floor = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);
|
||||
|
||||
y_floor
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output_floor.into_data(), Tolerance::default());
|
||||
|
||||
// With ceil_mode=true: output is 3 elements
|
||||
// Window 0: positions [0:3] -> max(1,2,3) = 3
|
||||
// Window 1: positions [2:5] -> max(3,4,5) = 5
|
||||
// Window 2: positions [4:7] -> max(5,6) = 6 (partial window)
|
||||
let y_ceil = TestTensor::<3>::from([[[3.0, 5.0, 6.0]]]);
|
||||
|
||||
let output_ceil = max_pool1d(x, kernel_size, stride, padding, dilation, true);
|
||||
|
||||
y_ceil
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output_ceil.into_data(), Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,523 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::{max_pool2d, max_pool2d_with_indices};
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_simple() {
|
||||
let kernel_size_1 = 3;
|
||||
let kernel_size_2 = 3;
|
||||
let padding_1 = 1;
|
||||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let x = TestTensor::from([
|
||||
[
|
||||
[
|
||||
[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221],
|
||||
[0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689],
|
||||
[0.5986, 0.2059, 0.4897, 0.6136, 0.2965, 0.6182],
|
||||
[0.1485, 0.9540, 0.4023, 0.6176, 0.7111, 0.3392],
|
||||
[0.3703, 0.0472, 0.2771, 0.1868, 0.8855, 0.5605],
|
||||
[0.5063, 0.1638, 0.9432, 0.7836, 0.8696, 0.1068],
|
||||
],
|
||||
[
|
||||
[0.8872, 0.0137, 0.1652, 0.5505, 0.6127, 0.6473],
|
||||
[0.1128, 0.0888, 0.1152, 0.5456, 0.6199, 0.7947],
|
||||
[0.5911, 0.7781, 0.7256, 0.6578, 0.0989, 0.9149],
|
||||
[0.5879, 0.5189, 0.6561, 0.0578, 0.7025, 0.6426],
|
||||
[0.9590, 0.0325, 0.6455, 0.6248, 0.2009, 0.1544],
|
||||
[0.7339, 0.1369, 0.6598, 0.5528, 0.6775, 0.1572],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[0.6853, 0.6439, 0.4639, 0.5573, 0.2723, 0.5910],
|
||||
[0.5419, 0.7729, 0.6743, 0.8956, 0.2997, 0.9546],
|
||||
[0.0334, 0.2178, 0.6917, 0.4958, 0.3357, 0.6584],
|
||||
[0.7358, 0.9074, 0.2462, 0.5159, 0.6420, 0.2441],
|
||||
[0.7602, 0.6297, 0.6073, 0.5937, 0.8037, 0.4881],
|
||||
[0.8859, 0.0974, 0.3954, 0.6763, 0.1078, 0.7467],
|
||||
],
|
||||
[
|
||||
[0.2991, 0.5012, 0.8024, 0.7653, 0.9378, 0.7952],
|
||||
[0.7393, 0.2336, 0.9521, 0.2719, 0.8445, 0.0454],
|
||||
[0.6479, 0.9822, 0.7905, 0.0318, 0.2474, 0.0628],
|
||||
[0.9955, 0.7591, 0.4140, 0.3215, 0.4349, 0.1527],
|
||||
[0.8064, 0.0164, 0.4002, 0.2024, 0.6128, 0.5827],
|
||||
[0.5368, 0.7895, 0.8727, 0.7793, 0.0910, 0.3421],
|
||||
],
|
||||
],
|
||||
]);
|
||||
let y = TestTensor::<4>::from([
|
||||
[
|
||||
[
|
||||
[0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],
|
||||
[0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],
|
||||
[0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111],
|
||||
[0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855],
|
||||
[0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855],
|
||||
[0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855],
|
||||
],
|
||||
[
|
||||
[0.8872, 0.8872, 0.5505, 0.6199, 0.7947, 0.7947],
|
||||
[0.8872, 0.8872, 0.7781, 0.7256, 0.9149, 0.9149],
|
||||
[0.7781, 0.7781, 0.7781, 0.7256, 0.9149, 0.9149],
|
||||
[0.9590, 0.9590, 0.7781, 0.7256, 0.9149, 0.9149],
|
||||
[0.9590, 0.9590, 0.6598, 0.7025, 0.7025, 0.7025],
|
||||
[0.9590, 0.9590, 0.6598, 0.6775, 0.6775, 0.6775],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546],
|
||||
[0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546],
|
||||
[0.9074, 0.9074, 0.9074, 0.8956, 0.9546, 0.9546],
|
||||
[0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037],
|
||||
[0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037],
|
||||
[0.8859, 0.8859, 0.6763, 0.8037, 0.8037, 0.8037],
|
||||
],
|
||||
[
|
||||
[0.7393, 0.9521, 0.9521, 0.9521, 0.9378, 0.9378],
|
||||
[0.9822, 0.9822, 0.9822, 0.9521, 0.9378, 0.9378],
|
||||
[0.9955, 0.9955, 0.9822, 0.9521, 0.8445, 0.8445],
|
||||
[0.9955, 0.9955, 0.9822, 0.7905, 0.6128, 0.6128],
|
||||
[0.9955, 0.9955, 0.8727, 0.8727, 0.7793, 0.6128],
|
||||
[0.8064, 0.8727, 0.8727, 0.8727, 0.7793, 0.6128],
|
||||
],
|
||||
],
|
||||
]);
|
||||
|
||||
let output = max_pool2d(
|
||||
x,
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
false,
|
||||
);
|
||||
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_different_padding_stride_kernel() {
|
||||
let kernel_size_1 = 3;
|
||||
let kernel_size_2 = 1;
|
||||
let padding_1 = 1;
|
||||
let padding_2 = 0;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 2;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let x = TestTensor::from([[[
|
||||
[0.6309, 0.6112, 0.6998],
|
||||
[0.4708, 0.9161, 0.5402],
|
||||
[0.4577, 0.7397, 0.9870],
|
||||
[0.6380, 0.4352, 0.5884],
|
||||
[0.6277, 0.5139, 0.4525],
|
||||
[0.9333, 0.9846, 0.5006],
|
||||
]]]);
|
||||
let y = TestTensor::<4>::from([[[
|
||||
[0.6309, 0.6998],
|
||||
[0.6309, 0.9870],
|
||||
[0.6380, 0.9870],
|
||||
[0.6380, 0.9870],
|
||||
[0.9333, 0.5884],
|
||||
[0.9333, 0.5006],
|
||||
]]]);
|
||||
|
||||
let output = max_pool2d(
|
||||
x,
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
false,
|
||||
);
|
||||
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_with_neg() {
|
||||
let kernel_size_1 = 3;
|
||||
let kernel_size_2 = 3;
|
||||
let padding_1 = 1;
|
||||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let x = TestTensor::from([[[
|
||||
[0.6309, 0.6112, 0.6998],
|
||||
[0.4708, 0.9161, 0.5402],
|
||||
[0.4577, 0.7397, 0.9870],
|
||||
[0.6380, 0.4352, 0.5884],
|
||||
[0.6277, 0.5139, 0.4525],
|
||||
[0.9333, 0.9846, 0.5006],
|
||||
]]])
|
||||
.neg();
|
||||
let y = TestTensor::<4>::from([[[
|
||||
[-0.4708, -0.4708, -0.5402],
|
||||
[-0.4577, -0.4577, -0.5402],
|
||||
[-0.4352, -0.4352, -0.4352],
|
||||
[-0.4352, -0.4352, -0.4352],
|
||||
[-0.4352, -0.4352, -0.4352],
|
||||
[-0.5139, -0.4525, -0.4525],
|
||||
]]]);
|
||||
|
||||
let output = max_pool2d(
|
||||
x,
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
false,
|
||||
);
|
||||
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_with_dilation() {
|
||||
let kernel_size_1 = 2;
|
||||
let kernel_size_2 = 2;
|
||||
let padding_1 = 0;
|
||||
let padding_2 = 0;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
let dilation_1 = 2;
|
||||
let dilation_2 = 2;
|
||||
|
||||
let x = TestTensor::from([[[
|
||||
[0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],
|
||||
[0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],
|
||||
[0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111],
|
||||
[0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855],
|
||||
[0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855],
|
||||
[0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855],
|
||||
]]]);
|
||||
let y = TestTensor::<4>::from([[[
|
||||
[0.9861, 0.9861, 0.9540, 0.9490],
|
||||
[0.9861, 0.9861, 0.9540, 0.9490],
|
||||
[0.9540, 0.9540, 0.9540, 0.9490],
|
||||
[0.9540, 0.9540, 0.9540, 0.9432],
|
||||
]]]);
|
||||
|
||||
let output = max_pool2d(
|
||||
x,
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
false,
|
||||
);
|
||||
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_with_indices() {
|
||||
let kernel_size_1 = 2;
|
||||
let kernel_size_2 = 2;
|
||||
let padding_1 = 1;
|
||||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let x = TestTensor::from([[[
|
||||
[0.2479, 0.6386, 0.3166, 0.5742],
|
||||
[0.7065, 0.1940, 0.6305, 0.8959],
|
||||
[0.5416, 0.8602, 0.8129, 0.1662],
|
||||
[0.3358, 0.3059, 0.8293, 0.0990],
|
||||
]]]);
|
||||
let indices = TensorData::from([[[
|
||||
[0, 1, 1, 3, 3],
|
||||
[4, 4, 1, 7, 7],
|
||||
[4, 9, 9, 7, 7],
|
||||
[8, 9, 9, 14, 11],
|
||||
[12, 12, 14, 14, 15],
|
||||
]]]);
|
||||
let y = TestTensor::<4>::from([[[
|
||||
[0.2479, 0.6386, 0.6386, 0.5742, 0.5742],
|
||||
[0.7065, 0.7065, 0.6386, 0.8959, 0.8959],
|
||||
[0.7065, 0.8602, 0.8602, 0.8959, 0.8959],
|
||||
[0.5416, 0.8602, 0.8602, 0.8293, 0.1662],
|
||||
[0.3358, 0.3358, 0.8293, 0.8293, 0.0990],
|
||||
]]]);
|
||||
|
||||
let (output, output_indices) = max_pool2d_with_indices(
|
||||
x,
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
false,
|
||||
);
|
||||
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
output_indices.into_data().assert_eq(&indices, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_complex() {
|
||||
let kernel_size_1 = 4;
|
||||
let kernel_size_2 = 2;
|
||||
let padding_1 = 2;
|
||||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 2;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let x = TestTensor::from([[[
|
||||
[0.5388, 0.0676, 0.7122, 0.8316, 0.0653],
|
||||
[0.9154, 0.1536, 0.9089, 0.8016, 0.7518],
|
||||
[0.2073, 0.0501, 0.8811, 0.5604, 0.5075],
|
||||
[0.4384, 0.9963, 0.9698, 0.4988, 0.2609],
|
||||
[0.3391, 0.2230, 0.4610, 0.5365, 0.6880],
|
||||
]]]);
|
||||
let indices = TensorData::from([[[
|
||||
[5, 7, 3],
|
||||
[5, 7, 3],
|
||||
[5, 16, 3],
|
||||
[5, 16, 8],
|
||||
[15, 16, 24],
|
||||
[15, 16, 24],
|
||||
]]]);
|
||||
let y = TestTensor::<4>::from([[[
|
||||
[0.9154, 0.9089, 0.8316],
|
||||
[0.9154, 0.9089, 0.8316],
|
||||
[0.9154, 0.9963, 0.8316],
|
||||
[0.9154, 0.9963, 0.8016],
|
||||
[0.4384, 0.9963, 0.688],
|
||||
[0.4384, 0.9963, 0.688],
|
||||
]]]);
|
||||
let (output, output_indices) = max_pool2d_with_indices(
|
||||
x,
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
false,
|
||||
);
|
||||
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
output_indices.into_data().assert_eq(&indices, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_ceil_mode() {
|
||||
// Test ceil_mode=true which produces larger output when input doesn't divide evenly by stride
|
||||
// Using 1x1x6x6 with kernel 3x3, stride 2x2, padding 0:
|
||||
// Floor mode: output = (6+0-1*(3-1)-1)/2+1 = 3/2+1 = 2 x 2
|
||||
// Ceil mode: output = ceil(3/2)+1 = 2+1 = 3 x 3
|
||||
let kernel_size_1 = 3;
|
||||
let kernel_size_2 = 3;
|
||||
let padding_1 = 0;
|
||||
let padding_2 = 0;
|
||||
let stride_1 = 2;
|
||||
let stride_2 = 2;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
// Input (values 1-36 arranged row by row):
|
||||
// col: 0 1 2 3 4 5
|
||||
// row 0: 1 2 3 4 5 6
|
||||
// row 1: 7 8 9 10 11 12
|
||||
// row 2: 13 14 15 16 17 18
|
||||
// row 3: 19 20 21 22 23 24
|
||||
// row 4: 25 26 27 28 29 30
|
||||
// row 5: 31 32 33 34 35 36
|
||||
let x = TestTensor::from([[[
|
||||
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
|
||||
[7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
|
||||
[13.0, 14.0, 15.0, 16.0, 17.0, 18.0],
|
||||
[19.0, 20.0, 21.0, 22.0, 23.0, 24.0],
|
||||
[25.0, 26.0, 27.0, 28.0, 29.0, 30.0],
|
||||
[31.0, 32.0, 33.0, 34.0, 35.0, 36.0],
|
||||
]]]);
|
||||
|
||||
// With ceil_mode=false (floor): output is 2x2
|
||||
// (0,0): rows 0-2, cols 0-2 -> max(1,2,3,7,8,9,13,14,15) = 15
|
||||
// (0,1): rows 0-2, cols 2-4 -> max(3,4,5,9,10,11,15,16,17) = 17
|
||||
// (1,0): rows 2-4, cols 0-2 -> max(13,14,15,19,20,21,25,26,27) = 27
|
||||
// (1,1): rows 2-4, cols 2-4 -> max(15,16,17,21,22,23,27,28,29) = 29
|
||||
let y_floor = TestTensor::<4>::from([[[[15.0, 17.0], [27.0, 29.0]]]]);
|
||||
|
||||
let output_floor = max_pool2d(
|
||||
x.clone(),
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
false,
|
||||
);
|
||||
|
||||
y_floor
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output_floor.into_data(), Tolerance::default());
|
||||
|
||||
// With ceil_mode=true: output is 3x3
|
||||
// Extra windows at edges use only available input values (padded with -inf for max pooling)
|
||||
// (0,0): rows 0-2, cols 0-2 -> max = 15
|
||||
// (0,1): rows 0-2, cols 2-4 -> max = 17
|
||||
// (0,2): rows 0-2, cols 4-5 -> max(5,6,11,12,17,18) = 18
|
||||
// (1,0): rows 2-4, cols 0-2 -> max = 27
|
||||
// (1,1): rows 2-4, cols 2-4 -> max = 29
|
||||
// (1,2): rows 2-4, cols 4-5 -> max(17,18,23,24,29,30) = 30
|
||||
// (2,0): rows 4-5, cols 0-2 -> max(25,26,27,31,32,33) = 33
|
||||
// (2,1): rows 4-5, cols 2-4 -> max(27,28,29,33,34,35) = 35
|
||||
// (2,2): rows 4-5, cols 4-5 -> max(29,30,35,36) = 36
|
||||
let y_ceil =
|
||||
TestTensor::<4>::from([[[[15.0, 17.0, 18.0], [27.0, 29.0, 30.0], [33.0, 35.0, 36.0]]]]);
|
||||
|
||||
let output_ceil = max_pool2d(
|
||||
x,
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
true,
|
||||
);
|
||||
|
||||
y_ceil
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output_ceil.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_ceil_mode_with_indices() {
|
||||
// Test ceil_mode=true with indices to verify correct index calculation
|
||||
// when pooling windows extend beyond original input bounds
|
||||
let kernel_size_1 = 3;
|
||||
let kernel_size_2 = 3;
|
||||
let padding_1 = 0;
|
||||
let padding_2 = 0;
|
||||
let stride_1 = 2;
|
||||
let stride_2 = 2;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
// Input 6x6 (indices 0-35 in row-major order):
|
||||
// row 0: 0 1 2 3 4 5
|
||||
// row 1: 6 7 8 9 10 11
|
||||
// row 2: 12 13 14 15 16 17
|
||||
// row 3: 18 19 20 21 22 23
|
||||
// row 4: 24 25 26 27 28 29
|
||||
// row 5: 30 31 32 33 34 35
|
||||
let x = TestTensor::from([[[
|
||||
[0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0, 9.0, 10.0, 11.0],
|
||||
[12.0, 13.0, 14.0, 15.0, 16.0, 17.0],
|
||||
[18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
|
||||
[24.0, 25.0, 26.0, 27.0, 28.0, 29.0],
|
||||
[30.0, 31.0, 32.0, 33.0, 34.0, 35.0],
|
||||
]]]);
|
||||
|
||||
// With ceil_mode=true: output is 3x3
|
||||
// (0,0): rows 0-2, cols 0-2 -> max at index 14
|
||||
// (0,1): rows 0-2, cols 2-4 -> max at index 16
|
||||
// (0,2): rows 0-2, cols 4-5 -> max at index 17
|
||||
// (1,0): rows 2-4, cols 0-2 -> max at index 26
|
||||
// (1,1): rows 2-4, cols 2-4 -> max at index 28
|
||||
// (1,2): rows 2-4, cols 4-5 -> max at index 29
|
||||
// (2,0): rows 4-5, cols 0-2 -> max at index 32
|
||||
// (2,1): rows 4-5, cols 2-4 -> max at index 34
|
||||
// (2,2): rows 4-5, cols 4-5 -> max at index 35
|
||||
let expected_values =
|
||||
TestTensor::<4>::from([[[[14.0, 16.0, 17.0], [26.0, 28.0, 29.0], [32.0, 34.0, 35.0]]]]);
|
||||
let expected_indices = TensorData::from([[[[14i64, 16, 17], [26, 28, 29], [32, 34, 35]]]]);
|
||||
|
||||
let (output, output_indices) = max_pool2d_with_indices(
|
||||
x,
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
true,
|
||||
);
|
||||
|
||||
expected_values
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
output_indices
|
||||
.into_data()
|
||||
.assert_eq(&expected_indices, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_ceil_mode_with_indices_and_padding() {
|
||||
// Test ceil_mode=true with padding and indices to verify correct index calculation
|
||||
// This exercises the case where both user padding and ceil_mode extra padding apply
|
||||
let kernel_size_1 = 3;
|
||||
let kernel_size_2 = 3;
|
||||
let padding_1 = 1;
|
||||
let padding_2 = 1;
|
||||
let stride_1 = 2;
|
||||
let stride_2 = 2;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
// Input 5x5 (indices 0-24 in row-major order):
|
||||
// row 0: 0 1 2 3 4
|
||||
// row 1: 5 6 7 8 9
|
||||
// row 2: 10 11 12 13 14
|
||||
// row 3: 15 16 17 18 19
|
||||
// row 4: 20 21 22 23 24
|
||||
let x = TestTensor::from([[[
|
||||
[0.0, 1.0, 2.0, 3.0, 4.0],
|
||||
[5.0, 6.0, 7.0, 8.0, 9.0],
|
||||
[10.0, 11.0, 12.0, 13.0, 14.0],
|
||||
[15.0, 16.0, 17.0, 18.0, 19.0],
|
||||
[20.0, 21.0, 22.0, 23.0, 24.0],
|
||||
]]]);
|
||||
|
||||
// With padding=1, ceil_mode=true:
|
||||
// Effective input is 7x7 (5 + 2*1)
|
||||
// Output size: ceil((5 + 2*1 - 3) / 2) + 1 = ceil(4/2) + 1 = 3
|
||||
//
|
||||
// Windows (with -inf padding at boundaries):
|
||||
// (0,0): rows -1 to 1, cols -1 to 1 -> valid: (0,0) to (1,1), max at (1,1)=6
|
||||
// (0,1): rows -1 to 1, cols 1 to 3 -> max at (1,3)=8
|
||||
// (0,2): rows -1 to 1, cols 3 to 5 -> max at (1,4)=9
|
||||
// (1,0): rows 1 to 3, cols -1 to 1 -> max at (3,1)=16
|
||||
// (1,1): rows 1 to 3, cols 1 to 3 -> max at (3,3)=18
|
||||
// (1,2): rows 1 to 3, cols 3 to 5 -> max at (3,4)=19
|
||||
// (2,0): rows 3 to 5, cols -1 to 1 -> max at (4,1)=21
|
||||
// (2,1): rows 3 to 5, cols 1 to 3 -> max at (4,3)=23
|
||||
// (2,2): rows 3 to 5, cols 3 to 5 -> max at (4,4)=24
|
||||
let expected_values =
|
||||
TestTensor::<4>::from([[[[6.0, 8.0, 9.0], [16.0, 18.0, 19.0], [21.0, 23.0, 24.0]]]]);
|
||||
let expected_indices = TensorData::from([[[[6i64, 8, 9], [16, 18, 19], [21, 23, 24]]]]);
|
||||
|
||||
let (output, output_indices) = max_pool2d_with_indices(
|
||||
x,
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
true,
|
||||
);
|
||||
|
||||
expected_values
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
output_indices
|
||||
.into_data()
|
||||
.assert_eq(&expected_indices, false);
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
use super::*;
|
||||
|
||||
mod adaptive_avgpool1d;
|
||||
mod adaptive_avgpool2d;
|
||||
mod attention;
|
||||
mod avgpool1d;
|
||||
mod avgpool2d;
|
||||
mod bicubic_interpolate;
|
||||
mod bilinear_interpolate;
|
||||
mod conv1d;
|
||||
mod conv2d;
|
||||
mod conv3d;
|
||||
mod conv_transpose1d;
|
||||
mod conv_transpose2d;
|
||||
mod conv_transpose3d;
|
||||
mod deform_conv2d;
|
||||
mod forward;
|
||||
mod linear;
|
||||
mod maxpool1d;
|
||||
mod maxpool2d;
|
||||
mod nearest_interpolate;
|
||||
mod unfold4d;
|
||||
@@ -0,0 +1,127 @@
|
||||
use super::*;
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::interpolate;
|
||||
use burn_tensor::ops::{InterpolateMode, InterpolateOptions};
|
||||
|
||||
#[test]
|
||||
fn test_upsample_interpolation() {
|
||||
let test = InterpolateTestCase {
|
||||
batch_size: 2,
|
||||
channels: 1,
|
||||
height: 7,
|
||||
width: 5,
|
||||
height_out: 8,
|
||||
width_out: 7,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([
|
||||
[[
|
||||
[0., 0., 1., 2., 2., 3., 4.],
|
||||
[0., 0., 1., 2., 2., 3., 4.],
|
||||
[5., 5., 6., 7., 7., 8., 9.],
|
||||
[10., 10., 11., 12., 12., 13., 14.],
|
||||
[15., 15., 16., 17., 17., 18., 19.],
|
||||
[20., 20., 21., 22., 22., 23., 24.],
|
||||
[25., 25., 26., 27., 27., 28., 29.],
|
||||
[30., 30., 31., 32., 32., 33., 34.],
|
||||
]],
|
||||
[[
|
||||
[35., 35., 36., 37., 37., 38., 39.],
|
||||
[35., 35., 36., 37., 37., 38., 39.],
|
||||
[40., 40., 41., 42., 42., 43., 44.],
|
||||
[45., 45., 46., 47., 47., 48., 49.],
|
||||
[50., 50., 51., 52., 52., 53., 54.],
|
||||
[55., 55., 56., 57., 57., 58., 59.],
|
||||
[60., 60., 61., 62., 62., 63., 64.],
|
||||
[65., 65., 66., 67., 67., 68., 69.],
|
||||
]],
|
||||
]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_downsample_interpolation() {
|
||||
let test = InterpolateTestCase {
|
||||
batch_size: 1,
|
||||
channels: 1,
|
||||
height: 45,
|
||||
width: 14,
|
||||
height_out: 4,
|
||||
width_out: 6,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[[
|
||||
[0., 2., 4., 7., 9., 11.],
|
||||
[154., 156., 158., 161., 163., 165.],
|
||||
[308., 310., 312., 315., 317., 319.],
|
||||
[462., 464., 466., 469., 471., 473.],
|
||||
]]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_1d_nearest() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
let device = Default::default();
|
||||
|
||||
// Run the model
|
||||
let input = TestTensor::<3>::from_floats(
|
||||
[[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let input = input.unsqueeze_dim(2);
|
||||
|
||||
let output = interpolate(
|
||||
input,
|
||||
[1, 9],
|
||||
InterpolateOptions::new(InterpolateMode::Nearest),
|
||||
);
|
||||
assert_eq!(output.dims(), [1, 1, 1, 9]);
|
||||
|
||||
// assert output data does not contain NaN
|
||||
assert!(
|
||||
!output
|
||||
.clone()
|
||||
.to_data()
|
||||
.as_slice::<FloatElem>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.any(|&x| x.is_nan()),
|
||||
"interpolate output contains NaN"
|
||||
);
|
||||
|
||||
TestTensor::<4>::from([[[[
|
||||
1.541, 1.541, -0.2934, -2.1788, -2.1788, 0.5684, -1.0845, -1.0845, -1.3986,
|
||||
]]]])
|
||||
.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
struct InterpolateTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
height_out: usize,
|
||||
width_out: usize,
|
||||
}
|
||||
|
||||
impl InterpolateTestCase {
|
||||
fn assert_output(self, y: TestTensor<4>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data()
|
||||
.convert::<f32>(),
|
||||
);
|
||||
let output = interpolate(
|
||||
x,
|
||||
[self.height_out, self.width_out],
|
||||
InterpolateOptions::new(InterpolateMode::Nearest),
|
||||
);
|
||||
|
||||
y.to_data()
|
||||
.assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
use super::*;
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::module::unfold4d;
|
||||
use burn_tensor::ops::UnfoldOptions;
|
||||
|
||||
#[test]
|
||||
fn test_unfold4d_shape() {
|
||||
let test = Unfold4dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 5,
|
||||
kernel_size: [2, 3],
|
||||
padding: [0, 0],
|
||||
stride: [1, 1],
|
||||
dilation: [1, 1],
|
||||
height: 3,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_shape([2, 30, 4]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unfold4d_simple() {
|
||||
let test = Unfold4dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
kernel_size: [2, 2],
|
||||
padding: [0, 0],
|
||||
stride: [1, 1],
|
||||
dilation: [1, 1],
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[0., 1., 2., 4., 5., 6., 8., 9., 10.],
|
||||
[1., 2., 3., 5., 6., 7., 9., 10., 11.],
|
||||
[4., 5., 6., 8., 9., 10., 12., 13., 14.],
|
||||
[5., 6., 7., 9., 10., 11., 13., 14., 15.],
|
||||
[16., 17., 18., 20., 21., 22., 24., 25., 26.],
|
||||
[17., 18., 19., 21., 22., 23., 25., 26., 27.],
|
||||
[20., 21., 22., 24., 25., 26., 28., 29., 30.],
|
||||
[21., 22., 23., 25., 26., 27., 29., 30., 31.],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unfold4d_complex() {
|
||||
let test = Unfold4dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
kernel_size: [2, 3],
|
||||
padding: [0, 1],
|
||||
stride: [1, 2],
|
||||
dilation: [1, 2],
|
||||
height: 3,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[0., 0.],
|
||||
[1., 5.],
|
||||
[3., 7.],
|
||||
[0., 0.],
|
||||
[5., 9.],
|
||||
[7., 11.],
|
||||
[0., 0.],
|
||||
[13., 17.],
|
||||
[15., 19.],
|
||||
[0., 0.],
|
||||
[17., 21.],
|
||||
[19., 23.],
|
||||
]]));
|
||||
}
|
||||
|
||||
struct Unfold4dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
kernel_size: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
height: usize,
|
||||
width: usize,
|
||||
}
|
||||
|
||||
impl Unfold4dTestCase {
|
||||
fn assert_shape(self, expected_shape: [usize; 3]) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &Default::default())
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data()
|
||||
.convert::<f32>(),
|
||||
);
|
||||
|
||||
let output = unfold4d(
|
||||
x,
|
||||
self.kernel_size,
|
||||
UnfoldOptions::new(self.stride, self.padding, self.dilation),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
output.shape().as_slice(),
|
||||
expected_shape,
|
||||
"Expected shape doesn't match the actual shape"
|
||||
);
|
||||
}
|
||||
|
||||
fn assert_output(self, expected: TestTensor<3>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &Default::default())
|
||||
.reshape::<4, _>(shape_x)
|
||||
.into_data(),
|
||||
);
|
||||
|
||||
let output = unfold4d(
|
||||
x,
|
||||
self.kernel_size,
|
||||
UnfoldOptions::new(self.stride, self.padding, self.dilation),
|
||||
);
|
||||
|
||||
let tolerance = Tolerance::default()
|
||||
.set_half_precision_relative(2e-3)
|
||||
.set_half_precision_absolute(2e-3);
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected.into_data(), tolerance);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn should_support_abs_ops_float() {
|
||||
let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, 4.0, -5.0]]);
|
||||
|
||||
let output = tensor.abs();
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]), false);
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
use super::*;
|
||||
use burn_tensor::{TensorData, backend::Backend};
|
||||
|
||||
#[test]
|
||||
fn test_add_d2() {
|
||||
let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor_2 = TestTensor::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]);
|
||||
|
||||
let output = tensor_1 + tensor_2;
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_broadcast() {
|
||||
let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0]]);
|
||||
let tensor_2 = TestTensor::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]);
|
||||
|
||||
let output = tensor_1 + tensor_2;
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[3.0, 5.0, 7.0], [6.0, 8.0, 10.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_different_strides_rhs() {
|
||||
// We need to execute an operation after `from data` to trigger inplace in some backends.
|
||||
// Which is the operation that might be problematic in this case.
|
||||
let tensor_1 = TestTensor::<2>::from([[0.0, 1.0], [2.0, 3.0]]) * 1;
|
||||
let tensor_2 = TestTensor::from([[4.0, 5.0], [6.0, 7.0]]) * 1;
|
||||
|
||||
let output = tensor_1 + tensor_2.transpose();
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[4.0, 7.0], [7.0, 10.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_different_strides_lhs() {
|
||||
// We need to execute an operation after `from data` to trigger inplace in some backends.
|
||||
// Which is the operation that might be problematic in this case.
|
||||
let tensor_1 = TestTensor::<2>::from([[0.0, 1.0], [2.0, 3.0]]) * 1;
|
||||
let tensor_2 = TestTensor::from([[4.0, 5.0], [6.0, 7.0]]) * 1;
|
||||
|
||||
let output = tensor_1.transpose() + tensor_2;
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[4.0, 7.0], [7.0, 10.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_different_strides_broadcast() {
|
||||
// We need to execute an operation after `from data` to trigger inplace in some backends.
|
||||
// Which is the operation that might be problematic in this case.
|
||||
let tensor_1 = TestTensor::<2>::from([[0.0, 1.0], [2.0, 3.0]]) * 1;
|
||||
let tensor_2 = TestTensor::from([[4.0, 5.0]]) * 1;
|
||||
|
||||
let output = tensor_1.transpose() + tensor_2;
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[4.0, 7.0], [5.0, 8.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_add_scalar_ops() {
|
||||
let scalar = 2.0;
|
||||
let tensor = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
||||
let output = tensor + scalar;
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_maybe_fused_not_contiguous() {
|
||||
let tensor1 = TestTensorInt::arange(0..8, &Default::default()).float();
|
||||
let tensor2 = TestTensorInt::arange(8..16, &Default::default()).float();
|
||||
let tensor1 = tensor1.reshape([2, 4]);
|
||||
let tensor2 = tensor2.reshape([4, 2]);
|
||||
let tensor2 = tensor2.swap_dims(0, 1);
|
||||
|
||||
TestBackend::sync(&tensor2.device()).unwrap();
|
||||
|
||||
let output = tensor1 + tensor2;
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[8.0, 11.0, 14.0, 17.0], [13.0, 16.0, 19.0, 22.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_maybe_fused_not_contiguous_broadcasted() {
|
||||
let tensor1 = TestTensorInt::arange(0..8, &Default::default()).float();
|
||||
let tensor2 = TestTensorInt::arange(8..10, &Default::default()).float();
|
||||
let tensor1 = tensor1.reshape([2, 4]);
|
||||
let tensor2 = tensor2.reshape([1, 2]);
|
||||
let tensor2 = tensor2.swap_dims(0, 1);
|
||||
|
||||
TestBackend::sync(&tensor2.device()).unwrap();
|
||||
|
||||
let output = tensor2 + tensor1;
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[8.0, 9.0, 10.0, 11.0], [13.0, 14.0, 15.0, 16.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,460 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
#[test]
|
||||
fn test_should_mean() {
|
||||
let tensor = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
||||
let output = tensor.mean();
|
||||
let expected = TensorData::from([15.0 / 6.0]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_sum() {
|
||||
let tensor = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
||||
let output = tensor.sum();
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([15.0]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_sum_dim_maybe_fused() {
|
||||
let tensor = TestTensor::<2>::from([[5.0], [-12.0]]);
|
||||
let tensor1 = TestTensor::<2>::from([[2.0, 3.0], [-1.0, -5.0]]);
|
||||
let ones = TestTensor::<2>::ones([2, 2], &Default::default());
|
||||
let _x = ones.clone() * tensor;
|
||||
let y = ones * tensor1;
|
||||
|
||||
let output = y.clone().sum_dim(1);
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[5.0], [-6.0]]), false);
|
||||
|
||||
// Negative Indexing.
|
||||
let output = y.clone().sum_dim(-1);
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[5.0], [-6.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_mean_last_dim() {
|
||||
let tensor = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
||||
let output = tensor.clone().mean_dim(1);
|
||||
let expected = TensorData::from([[3.0 / 3.0], [12.0 / 3.0]]);
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
// Negative Indexing.
|
||||
let output = tensor.clone().mean_dim(-1);
|
||||
let expected = TensorData::from([[3.0 / 3.0], [12.0 / 3.0]]);
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_sum_last_dim() {
|
||||
let tensor = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
||||
let output = tensor.sum_dim(1);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[3.0], [12.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_sum_first_dim() {
|
||||
let tensor = TestTensor::<2>::from([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]);
|
||||
|
||||
let output = tensor.sum_dim(0);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[7.0, 3.0, 5.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_mean_first_dim() {
|
||||
let tensor = TestTensor::<2>::from([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]);
|
||||
|
||||
let output = tensor.mean_dim(0);
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[7.0 / 2.0, 3.0 / 2.0, 5.0 / 2.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_sum_mid_dim_3d_non_contiguous_1() {
|
||||
let tensor = TestTensor::<3>::from([
|
||||
[[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]],
|
||||
[[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]],
|
||||
]);
|
||||
|
||||
let output = tensor.swap_dims(0, 2).sum_dim(1);
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::new(vec![9.0, 7.0, -1.0, 3.0, 4.0, 5.0], [3, 1, 2]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_sum_mid_dim_3d_non_contiguous_2() {
|
||||
let tensor = TestTensor::<3>::from([
|
||||
[[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]],
|
||||
[[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]],
|
||||
]);
|
||||
|
||||
let output = tensor.swap_dims(0, 1).sum_dim(1);
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::new(vec![5.0, 5.0, 3.0, 11.0, -3.0, 6.0], [2, 1, 3]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prod_float() {
|
||||
let tensor = TestTensor::<2>::from([[2.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let output = tensor.prod();
|
||||
|
||||
// 2 * 1 * 2 * 3 * 4 * 5 = 240 but we need to check the precision because of the float
|
||||
let expected = TensorData::from([240.0]);
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let tensor_with_zero = TestTensor::<2>::from([[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let output = tensor_with_zero.prod();
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([0.0]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prod_dim_float() {
|
||||
let tensor = TestTensor::<2>::from([[2.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let output = tensor.prod_dim(1);
|
||||
let expected = TensorData::from([[4.0], [60.0]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let tensor_with_zero = TestTensor::<2>::from([[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let output = tensor_with_zero.prod_dim(1);
|
||||
let expected = TensorData::from([[0.0], [60.0]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sum_dim_2d() {
|
||||
let tensor =
|
||||
TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default());
|
||||
|
||||
let output = tensor.clone().sum_dim(1);
|
||||
let expected = TensorData::from([[3.], [12.]]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
|
||||
let output = tensor.sum_dim(0);
|
||||
let expected = TensorData::from([[3., 5., 7.]]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sum_dims_2d() {
|
||||
let tensor =
|
||||
TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default());
|
||||
|
||||
tensor
|
||||
.clone()
|
||||
.sum_dims(&[1])
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[3.], [12.]]), false);
|
||||
|
||||
tensor
|
||||
.clone()
|
||||
.sum_dims(&[-1])
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[3.], [12.]]), false);
|
||||
|
||||
tensor
|
||||
.clone()
|
||||
.sum_dims(&[0, 1])
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[15.]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sum_and_squeeze_dims() {
|
||||
let tensor = TestTensor::<3>::from_floats(
|
||||
[
|
||||
[[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]],
|
||||
[[9.0, 2.0, 5.0], [5.0, 7.0, 7.0]],
|
||||
],
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
tensor
|
||||
.sum_dims_squeeze::<1, _>(&[0, 1])
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([20., 16., 21.]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sum_dim_1_reshape_maybe_fused() {
|
||||
let tensor = TestTensorInt::arange(0..9, &Default::default()).float();
|
||||
TestBackend::sync(&tensor.device()).unwrap();
|
||||
|
||||
let output = tensor.reshape([3, 3]) + 2;
|
||||
let output = output.sum_dim(1);
|
||||
let expected = TensorData::from([[9.0], [18.0], [27.0]]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sum_dim_1_swap_dims_maybe_fused() {
|
||||
let tensor = TestTensorInt::arange(0..9, &Default::default()).float();
|
||||
let tensor = tensor.reshape([3, 3]);
|
||||
TestBackend::sync(&tensor.device()).unwrap();
|
||||
|
||||
let output = tensor.swap_dims(0, 1) + 2;
|
||||
let output = output.sum_dim(1);
|
||||
let expected = TensorData::from([[15.0], [18.0], [21.0]]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sum_dim_2_reshape_maybe_fused_broadcast() {
|
||||
let tensor = TestTensorInt::arange(0..9, &Default::default()).float();
|
||||
TestBackend::sync(&tensor.device()).unwrap();
|
||||
|
||||
let output = tensor.reshape([1, 3, 3]) + 2;
|
||||
let output = output.sum_dim(2);
|
||||
let expected = TensorData::from([[[9.0], [18.0], [27.0]]]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sum_dim_2_maybe_fused_on_write() {
|
||||
let tensor_1 = TestTensorInt::arange(0..8, &Default::default()).float();
|
||||
let tensor_2 = TestTensorInt::arange(10..12, &Default::default()).float();
|
||||
let tensor_1 = tensor_1.reshape([1, 2, 4]);
|
||||
let tensor_2 = tensor_2.reshape([1, 2, 1]);
|
||||
TestBackend::sync(&tensor_1.device()).unwrap();
|
||||
|
||||
let output = (tensor_1 + tensor_2.clone()).sum_dim(2) + tensor_2;
|
||||
TestBackend::sync(&output.device()).unwrap();
|
||||
let expected = TensorData::from([[[56.0], [77.0]]]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sum_dim_3_maybe_fused_on_read_not_contiguous() {
|
||||
let tensor_1 = TestTensorInt::arange(0..8, &Default::default()).float();
|
||||
let tensor_2 = TestTensorInt::arange(16..24, &Default::default()).float();
|
||||
|
||||
let tensor_1 = tensor_1.reshape([4, 2, 1]);
|
||||
let tensor_1 = tensor_1.swap_dims(0, 2);
|
||||
|
||||
let tensor_2 = tensor_2.reshape([1, 4, 2]);
|
||||
let tensor_2 = tensor_2.swap_dims(1, 2);
|
||||
TestBackend::sync(&tensor_1.device()).unwrap();
|
||||
|
||||
let output = (tensor_1 + tensor_2).sum_dim(2);
|
||||
let expected = TensorData::from([[[88.0], [96.0]]]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sum_dim_4_maybe_fused_on_read_not_contiguous_mixed() {
|
||||
let tensor_1 = TestTensorInt::arange(0..8, &Default::default()).float();
|
||||
let tensor_2 = TestTensorInt::arange(16..24, &Default::default()).float();
|
||||
let tensor_3 = TestTensorInt::arange(32..40, &Default::default()).float();
|
||||
|
||||
let tensor_1 = tensor_1.reshape([4, 2, 1]);
|
||||
let tensor_3 = tensor_3.reshape([1, 2, 4]);
|
||||
let tensor_1 = tensor_1.swap_dims(0, 2);
|
||||
|
||||
let tensor_2 = tensor_2.reshape([1, 4, 2]);
|
||||
let tensor_2 = tensor_2.swap_dims(1, 2);
|
||||
TestBackend::sync(&tensor_1.device()).unwrap();
|
||||
|
||||
let output = (tensor_3 + tensor_1 + tensor_2).sum_dim(2);
|
||||
let expected = TensorData::from([[[222.0], [246.0]]]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sum_dim_5_maybe_fused_on_read_not_contiguous_mixed() {
|
||||
let tensor_1 = TestTensorInt::arange(0..8, &Default::default()).float();
|
||||
let tensor_2 = TestTensorInt::arange(16..24, &Default::default()).float();
|
||||
let tensor_3 = TestTensorInt::arange(32..40, &Default::default()).float();
|
||||
|
||||
let tensor_1 = tensor_1.reshape([4, 2, 1]);
|
||||
let tensor_3 = tensor_3.reshape([1, 2, 4]);
|
||||
let tensor_1 = tensor_1.swap_dims(0, 2);
|
||||
|
||||
let tensor_2 = tensor_2.reshape([1, 4, 2]);
|
||||
let tensor_2 = tensor_2.swap_dims(1, 2);
|
||||
TestBackend::sync(&tensor_1.device()).unwrap();
|
||||
|
||||
let output = (tensor_3 + tensor_1 + tensor_2).sum_dim(1);
|
||||
let expected = TensorData::from([[[102.0, 112.0, 122.0, 132.0]]]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sum_dim_6_maybe_fused_on_read_not_contiguous_broadcasted() {
|
||||
let tensor_1 = TestTensorInt::arange(0..32, &Default::default()).float();
|
||||
let tensor_2 = TestTensorInt::arange(0..8, &Default::default()).float();
|
||||
|
||||
let tensor_1 = tensor_1.reshape([4, 2, 2, 2]);
|
||||
let tensor_1 = tensor_1.swap_dims(3, 2);
|
||||
let tensor_1 = tensor_1.swap_dims(1, 2);
|
||||
|
||||
let tensor_2 = tensor_2.reshape([1, 2, 2, 2]);
|
||||
|
||||
TestBackend::sync(&tensor_1.device()).unwrap();
|
||||
let sum = tensor_2.clone().sum_dim(0);
|
||||
let sum = sum.sum_dim(1);
|
||||
let sum = sum.sum_dim(2);
|
||||
|
||||
TestBackend::sync(&tensor_1.device()).unwrap();
|
||||
|
||||
let _tmp = sum.clone() + 2;
|
||||
let output = (tensor_1 + tensor_2 + sum).sum_dim(1);
|
||||
let expected = TensorData::from([
|
||||
[[[29.0, 43.0], [41.0, 55.0]]],
|
||||
[[[45.0, 59.0], [57.0, 71.0]]],
|
||||
[[[61.0, 75.0], [73.0, 87.0]]],
|
||||
[[[77.0, 91.0], [89.0, 103.0]]],
|
||||
]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sum_dim_7_maybe_fused_on_read_reshaped() {
|
||||
let tensor_1 = TestTensorInt::arange(0..16, &Default::default()).float();
|
||||
|
||||
let tensor_1 = tensor_1.reshape([4, 4]);
|
||||
|
||||
TestBackend::sync(&tensor_1.device()).unwrap();
|
||||
|
||||
let reshaped = tensor_1.reshape([1, 4, 4]);
|
||||
let tmp = reshaped + 5.0;
|
||||
let output = tmp.sum_dim(2);
|
||||
let expected = TensorData::from([[[26.0], [42.0], [58.0], [74.0]]]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mean_dim_fused_on_read_on_write() {
|
||||
// https://github.com/tracel-ai/burn/issues/3987
|
||||
let device = Default::default();
|
||||
let x = TestTensor::ones([128, 32, 1], &device);
|
||||
|
||||
let weight = TestTensor::ones([1, 32, 1], &device);
|
||||
let options = burn_tensor::ops::ConvOptions::new([1], [0], [1], 1);
|
||||
let x = burn_tensor::module::conv1d(x, weight, None, options);
|
||||
let global = x.clone().powi_scalar(2).sum_dim(2).add_scalar(1e-5).sqrt();
|
||||
let norm = global.clone().div(global.mean_dim(1));
|
||||
let x = x.clone().mul(norm).add(x);
|
||||
|
||||
let out = x.sum();
|
||||
|
||||
out.into_data()
|
||||
.assert_eq(&TensorData::from([8192.0]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mean_dim_2d() {
|
||||
let tensor =
|
||||
TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default());
|
||||
|
||||
let output = tensor.clone().mean_dim(1);
|
||||
let expected = TensorData::from([[1.], [4.]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
|
||||
let output = tensor.mean_dim(0);
|
||||
let expected = TensorData::from([[1.5, 2.5, 3.5]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mean_dims_2d() {
|
||||
let tensor =
|
||||
TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default());
|
||||
|
||||
tensor
|
||||
.clone()
|
||||
.mean_dims(&[1])
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[1.], [4.]]), false);
|
||||
|
||||
tensor
|
||||
.clone()
|
||||
.mean_dims(&[-1])
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[1.], [4.]]), false);
|
||||
|
||||
tensor
|
||||
.clone()
|
||||
.mean_dims(&[0, 1])
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[2.5]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_reduce_dims_permuted() {
|
||||
// Regression test for https://github.com/tracel-ai/burn/issues/4461
|
||||
let tensor = TestTensorInt::arange(0..2 * 2 * 256, &Default::default())
|
||||
.float()
|
||||
.reshape([2, 2, 256]);
|
||||
|
||||
let output = tensor
|
||||
.permute([1, 2, 0])
|
||||
.mean_dim(0)
|
||||
.mean_dim(1)
|
||||
.squeeze_dims::<1>(&[0, 1]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&TensorData::from([255.5, 767.5]), Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn test_all() {
|
||||
let tensor = TestTensor::<2>::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]);
|
||||
let data_actual = tensor.all().into_data();
|
||||
let data_expected = TensorData::from([false]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_dim() {
|
||||
let tensor = TestTensor::<2>::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]);
|
||||
let data_actual = tensor.all_dim(1).into_data();
|
||||
let data_expected = TensorData::from([[false], [true]]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn test_any() {
|
||||
// test float tensor
|
||||
let tensor = TestTensor::<2>::from([[0.0, 0.0, 0.0], [1.0, -1.0, 0.0]]);
|
||||
let data_actual = tensor.any().into_data();
|
||||
let data_expected = TensorData::from([true]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
|
||||
let tensor = TestTensor::<2>::from([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]);
|
||||
let data_actual = tensor.any().into_data();
|
||||
let data_expected = TensorData::from([false]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
|
||||
// test int tensor
|
||||
let tensor = TestTensorInt::<2>::from([[0, 0, 0], [1, -1, 0]]);
|
||||
let data_actual = tensor.any().into_data();
|
||||
let data_expected = TensorData::from([true]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
|
||||
let tensor = TestTensorInt::<2>::from([[0, 0, 0], [0, 0, 0]]);
|
||||
let data_actual = tensor.any().into_data();
|
||||
let data_expected = TensorData::from([false]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
|
||||
// test bool tensor
|
||||
let tensor = TestTensorBool::<2>::from([[false, false, false], [true, true, false]]);
|
||||
let data_actual = tensor.any().into_data();
|
||||
let data_expected = TensorData::from([true]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
|
||||
let tensor = TestTensorBool::<2>::from([[false, false, false], [false, false, false]]);
|
||||
let data_actual = tensor.any().into_data();
|
||||
let data_expected = TensorData::from([false]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_any_dim() {
|
||||
let tensor = TestTensor::<2>::from([[0.0, 0.0, 0.0], [1.0, -1.0, 0.0]]);
|
||||
let data_actual = tensor.any_dim(1).into_data();
|
||||
let data_expected = TensorData::from([[false], [true]]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
|
||||
// test int tensor
|
||||
let tensor = TestTensorInt::<2>::from([[0, 0, 0], [1, -1, 0]]);
|
||||
let data_actual = tensor.any_dim(1).into_data();
|
||||
let data_expected = TensorData::from([[false], [true]]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
|
||||
// test bool tensor
|
||||
let tensor = TestTensorBool::<2>::from([[false, false, false], [true, true, false]]);
|
||||
let data_actual = tensor.any_dim(1).into_data();
|
||||
let data_expected = TensorData::from([[false], [true]]);
|
||||
data_expected.assert_eq(&data_actual, false);
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn test_argmax_2d_dim0() {
|
||||
let tensor = TestTensor::<2>::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
||||
let output = tensor.argmax(0);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[0, 0, 1]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_argmin_2d_dim0() {
|
||||
let tensor = TestTensor::<2>::from([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]);
|
||||
|
||||
let output = tensor.argmin(0);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[0, 1, 0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_argmax_2d_dim1() {
|
||||
let tensor = TestTensor::<2>::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
||||
let output = tensor.argmax(1);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[1], [2]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_argmin_2d_dim1() {
|
||||
let tensor = TestTensor::<2>::from([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]);
|
||||
|
||||
let output = tensor.argmin(1);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[2], [1]]), false);
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{DType, TensorData};
|
||||
|
||||
#[test]
|
||||
fn cast_float_to_bool() {
|
||||
let tensor1 = TestTensor::<2>::from([[0.0, 43.0, 0.0], [2.0, -4.2, 31.33]]);
|
||||
let data_actual = tensor1.bool().into_data();
|
||||
let data_expected = TensorData::from([[false, true, false], [true, true, true]]);
|
||||
data_actual.assert_eq(&data_expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_float_to_int() {
|
||||
let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.4, 5.5, 6.6]]).int();
|
||||
let expected = TensorData::from([[1, 2, 3], [4, 5, 6]]);
|
||||
|
||||
tensor.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_int_to_float_tensor() {
|
||||
let tensor = TestTensorInt::<2>::from([[1, 2, 3], [4, 5, 6]]).float();
|
||||
|
||||
let expected = TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
|
||||
|
||||
tensor.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_bool_to_float_tensor() {
|
||||
let tensor = TestTensorBool::<2>::from([[true, false, true], [false, false, true]]).float();
|
||||
|
||||
let expected = TensorData::from([[1., 0., 1.], [0., 0., 1.]]);
|
||||
|
||||
tensor.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_float_precision() {
|
||||
let data = TensorData::from([[1.0, 2.0, 3.0], [4.4, 5.5, 6.6]]);
|
||||
let tensor = TestTensor::<2>::from(data.clone());
|
||||
|
||||
let output = tensor.cast(DType::F32);
|
||||
|
||||
assert_eq!(output.dtype(), DType::F32);
|
||||
// Use precision 2 for parameterized tests in f16 and bf16
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&data, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
use super::*;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{DType, TensorData};
|
||||
|
||||
#[test]
|
||||
fn should_support_cat_ops_2d_dim0() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device);
|
||||
let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]], &device);
|
||||
|
||||
let output = TestTensor::cat(vec![tensor_1, tensor_2], 0);
|
||||
let expected = TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_cat_ops_2d_dim1() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device);
|
||||
let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]], &device);
|
||||
|
||||
let output = TestTensor::cat(vec![tensor_1, tensor_2], 1);
|
||||
let expected = TensorData::from([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_cat_ops_3d() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensor::<3>::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]], &device);
|
||||
let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]], &device);
|
||||
|
||||
let output = TestTensor::cat(vec![tensor_1, tensor_2], 0);
|
||||
let expected = TensorData::from([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]], [[4.0, 5.0, 6.0]]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_panic_when_dimensions_are_not_the_same() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], &device);
|
||||
let tensor_2 = TestTensor::from_data([[4.0, 5.0]], &device);
|
||||
|
||||
TestTensor::cat(vec![tensor_1, tensor_2], 0).into_data();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_panic_when_list_of_vectors_is_empty() {
|
||||
let tensor: Vec<TestTensor<2>> = vec![];
|
||||
TestTensor::cat(tensor, 0).into_data();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_panic_when_cat_exceeds_dimension() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensor::<3>::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]], &device);
|
||||
let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]], &device);
|
||||
|
||||
TestTensor::cat(vec![tensor_1, tensor_2], 3).into_data();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_cat_ops_cast_dtype() {
|
||||
let device = Default::default();
|
||||
// ok for f32 backends, casts dtype for f16 tests
|
||||
let tensor_1 = TestTensor::<3>::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]], &device)
|
||||
.cast(DType::F32);
|
||||
let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]], &device).cast(DType::F32);
|
||||
|
||||
let output = TestTensor::cat(vec![tensor_1, tensor_2], 0);
|
||||
let expected = TensorData::from([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]], [[4.0, 5.0, 6.0]]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_cat_with_empty_tensor() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device);
|
||||
let tensor_2: TestTensor<2> = TestTensor::empty([1, 0], &device); // Empty tensor with size 0 on dim 1
|
||||
|
||||
// Concatenating with an empty tensor should just return the non-empty tensor
|
||||
let output = TestTensor::cat(vec![tensor_1.clone(), tensor_2], 1);
|
||||
let expected = TensorData::from([[1.0, 2.0, 3.0]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_cat_with_empty_tensor_first() {
|
||||
let device = Default::default();
|
||||
let tensor_1: TestTensor<2> = TestTensor::empty([1, 0], &device); // Empty tensor
|
||||
let tensor_2 = TestTensor::<2>::from_data([[4.0, 5.0, 6.0]], &device);
|
||||
|
||||
// Empty tensor first, then non-empty
|
||||
let output = TestTensor::cat(vec![tensor_1, tensor_2.clone()], 1);
|
||||
let expected = TensorData::from([[4.0, 5.0, 6.0]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_cat_with_multiple_empty_tensors() {
|
||||
let device = Default::default();
|
||||
let tensor_1: TestTensor<2> = TestTensor::empty([2, 0], &device);
|
||||
let tensor_2 = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
|
||||
let tensor_3: TestTensor<2> = TestTensor::empty([2, 0], &device);
|
||||
let tensor_4 = TestTensor::<2>::from_data([[5.0], [6.0]], &device);
|
||||
|
||||
// Mix of empty and non-empty tensors
|
||||
let output = TestTensor::cat(vec![tensor_1, tensor_2, tensor_3, tensor_4], 1);
|
||||
let expected = TensorData::from([[1.0, 2.0, 5.0], [3.0, 4.0, 6.0]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_cat_all_empty_tensors() {
|
||||
let device = Default::default();
|
||||
let tensor_1: TestTensor<2> = TestTensor::empty([2, 0], &device);
|
||||
let tensor_2: TestTensor<2> = TestTensor::empty([2, 0], &device);
|
||||
|
||||
// All empty tensors should produce an empty tensor
|
||||
let output = TestTensor::cat(vec![tensor_1, tensor_2], 1);
|
||||
|
||||
assert_eq!(output.shape().as_slice(), [2, 0]);
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_support_ceil_ops() {
|
||||
let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]);
|
||||
let tensor = TestTensor::<2>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.ceil();
|
||||
let expected = TensorData::from([[25., 88., 77.], [60., 44., 95.]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn test_chunk_evenly_divisible() {
|
||||
let tensors = TestTensorInt::arange(0..12, &Default::default())
|
||||
.float()
|
||||
.chunk(6, 0);
|
||||
assert_eq!(tensors.len(), 6);
|
||||
|
||||
let expected = [
|
||||
TensorData::from([0, 1]),
|
||||
TensorData::from([2, 3]),
|
||||
TensorData::from([4, 5]),
|
||||
TensorData::from([6, 7]),
|
||||
TensorData::from([8, 9]),
|
||||
TensorData::from([10, 11]),
|
||||
];
|
||||
|
||||
for (index, tensor) in tensors.iter().enumerate() {
|
||||
tensor.to_data().assert_eq(&expected[index], false);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_not_evenly_divisible() {
|
||||
let tensors = TestTensorInt::arange(0..11, &Default::default())
|
||||
.float()
|
||||
.chunk(6, 0);
|
||||
assert_eq!(tensors.len(), 6);
|
||||
|
||||
let expected = [
|
||||
TensorData::from([0, 1]),
|
||||
TensorData::from([2, 3]),
|
||||
TensorData::from([4, 5]),
|
||||
TensorData::from([6, 7]),
|
||||
TensorData::from([8, 9]),
|
||||
TensorData::from([10]),
|
||||
];
|
||||
|
||||
for (index, tensor) in tensors.iter().enumerate() {
|
||||
tensor.to_data().assert_eq(&expected[index], false);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_not_evenly_divisible_remains_several() {
|
||||
let tensors = TestTensorInt::arange(0..100, &Default::default())
|
||||
.float()
|
||||
.chunk(8, 0);
|
||||
assert_eq!(tensors.len(), 8);
|
||||
|
||||
let expected = [13, 13, 13, 13, 13, 13, 13, 9];
|
||||
|
||||
for (index, tensor) in tensors.iter().enumerate() {
|
||||
assert_eq!(tensor.shape()[0], expected[index]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_not_divisible() {
|
||||
let tensors = TestTensorInt::arange(0..6, &Default::default())
|
||||
.float()
|
||||
.chunk(7, 0);
|
||||
assert_eq!(tensors.len(), 6);
|
||||
|
||||
let expected = [
|
||||
TensorData::from([0]),
|
||||
TensorData::from([1]),
|
||||
TensorData::from([2]),
|
||||
TensorData::from([3]),
|
||||
TensorData::from([4]),
|
||||
TensorData::from([5]),
|
||||
];
|
||||
|
||||
for (index, tensor) in tensors.iter().enumerate() {
|
||||
tensor.to_data().assert_eq(&expected[index], false);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_invalid_dim() {
|
||||
let _tensors = TestTensorInt::arange(0..12, &Default::default()).chunk(6, 1);
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn clamp_min() {
|
||||
let device = Default::default();
|
||||
// test float tensor
|
||||
let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = TestTensor::<2>::from_data(data, &device);
|
||||
|
||||
let output = tensor.clamp_min(2.0);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[2.0, 2.0, 2.0], [3.0, 4.0, 5.0]]), false);
|
||||
|
||||
// test int tensor
|
||||
let data = TensorData::from([[0, 1, 2], [3, 4, 5]]);
|
||||
let tensor = TestTensorInt::<2>::from_data(data, &device);
|
||||
let output = tensor.clamp_min(2);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[2, 2, 2], [3, 4, 5]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_max() {
|
||||
let device = Default::default();
|
||||
// test float tensor
|
||||
let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = TestTensor::<2>::from_data(data, &device);
|
||||
|
||||
let output = tensor.clamp_max(2.0);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[0.0, 1.0, 2.0], [2.0, 2.0, 2.0]]), false);
|
||||
|
||||
// test int tensor
|
||||
let data = TensorData::from([[0, 1, 2], [3, 4, 5]]);
|
||||
let tensor = TestTensorInt::<2>::from_data(data, &device);
|
||||
let output = tensor.clamp_max(4);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[0, 1, 2], [3, 4, 4]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_min_max() {
|
||||
let device = Default::default();
|
||||
// test float tensor
|
||||
let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = TestTensor::<2>::from_data(data, &device);
|
||||
let output = tensor.clamp(1.0, 4.0);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 4.0]]), false);
|
||||
|
||||
// test int tensor
|
||||
let data = TensorData::from([[0, 1, 2], [3, 4, 5]]);
|
||||
let tensor = TestTensorInt::<2>::from_data(data, &device);
|
||||
let output = tensor.clamp(1, 4);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[1, 1, 2], [3, 4, 4]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_min_max_vec_should_compile() {
|
||||
let input = TestTensor::<2>::ones([2, 4], &Default::default());
|
||||
let output = input.clamp(0., 0.5);
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
use super::*;
|
||||
use burn_tensor::{DEFAULT_ATOL, DEFAULT_RTOL, TensorData};
|
||||
|
||||
#[test]
|
||||
fn test_is_close() {
|
||||
let tensor1 = TestTensor::<2>::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]);
|
||||
let tensor2 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 3.0]]) + 1e-9;
|
||||
|
||||
let data_actual = tensor1
|
||||
.clone()
|
||||
.is_close(tensor2.clone(), None, None)
|
||||
.into_data();
|
||||
let defaults_expected = TensorData::from([[true, true, true], [true, true, false]]);
|
||||
defaults_expected.assert_eq(&data_actual, false);
|
||||
|
||||
// Using the defaults.
|
||||
let data_actual = tensor1
|
||||
.is_close(tensor2, Some(DEFAULT_RTOL), Some(DEFAULT_ATOL))
|
||||
.into_data();
|
||||
defaults_expected.assert_eq(&data_actual, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_close() {
|
||||
let tensor1 = TestTensor::<2>::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]);
|
||||
let tensor2 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 3.0]]) + 1e-9;
|
||||
assert!(!tensor1.clone().all_close(tensor2.clone(), None, None));
|
||||
|
||||
let tensor2 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]) + 1e-9;
|
||||
assert!(tensor1.all_close(tensor2, None, None));
|
||||
|
||||
// non finite values
|
||||
let inf_plus = TestTensor::<2>::from([[f32::INFINITY]]);
|
||||
let one = TestTensor::<2>::from([[1.]]);
|
||||
let inf_minus = TestTensor::<2>::from([[-f32::INFINITY]]);
|
||||
assert!(!inf_plus.clone().all_close(inf_minus.clone(), None, None));
|
||||
assert!(!one.clone().all_close(inf_minus.clone(), None, None));
|
||||
assert!(!one.all_close(inf_plus.clone(), None, None));
|
||||
assert!(inf_plus.clone().all_close(inf_plus, None, None));
|
||||
}
|
||||
@@ -0,0 +1,303 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn test_equal_inf() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0, 2.0], [f32::INFINITY, 4.0, f32::NEG_INFINITY]]);
|
||||
let data_2 = TensorData::from([[1.0, 1.0, 1.0], [f32::INFINITY, 3.0, f32::NEG_INFINITY]]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensor::<2>::from_data(data_1, &device);
|
||||
let tensor_2 = TestTensor::<2>::from_data(data_2, &device);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone());
|
||||
let data_actual_inplace = tensor_1.equal(tensor_2);
|
||||
|
||||
let data_expected = TensorData::from([[false, true, false], [true, false, true]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_not_equal_inf() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0, 2.0], [3.0, f32::INFINITY, 5.0]]);
|
||||
let data_2 = TensorData::from([[1.0, 1.0, 1.0], [f32::INFINITY, 3.0, f32::NEG_INFINITY]]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensor::<2>::from_data(data_1, &device);
|
||||
let tensor_2 = TestTensor::<2>::from_data(data_2, &device);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().not_equal(tensor_2.clone());
|
||||
let data_actual_inplace = tensor_1.not_equal(tensor_2);
|
||||
|
||||
let data_expected = TensorData::from([[true, false, true], [true, true, true]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_equal() {
|
||||
let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor_2 = TestTensor::<2>::from([[1.0, 1.0, 1.0], [4.0, 3.0, 5.0]]);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone());
|
||||
let data_actual_inplace = tensor_1.equal(tensor_2);
|
||||
|
||||
let data_expected = TensorData::from([[false, true, false], [false, false, true]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_not_equal() {
|
||||
let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor_2 = TestTensor::<2>::from([[1.0, 1.0, 1.0], [4.0, 3.0, 5.0]]);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().not_equal(tensor_2.clone());
|
||||
let data_actual_inplace = tensor_1.not_equal(tensor_2);
|
||||
|
||||
let data_expected = TensorData::from([[true, false, true], [true, true, false]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_equal_elem() {
|
||||
let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 2.0, 5.0]]);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().equal_elem(2);
|
||||
let data_actual_inplace = tensor_1.equal_elem(2);
|
||||
|
||||
let data_expected = TensorData::from([[false, false, true], [false, true, false]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_not_equal_elem() {
|
||||
let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 2.0, 5.0]]);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().not_equal_elem(2);
|
||||
let data_actual_inplace = tensor_1.not_equal_elem(2);
|
||||
|
||||
let data_expected = TensorData::from([[true, true, false], [true, false, true]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn greater_elem() {
|
||||
let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().greater_elem(4);
|
||||
let data_actual_inplace = tensor_1.greater_elem(4);
|
||||
|
||||
let data_expected = TensorData::from([[false, false, false], [false, false, true]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_greater_equal_elem() {
|
||||
let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().greater_equal_elem(4.0);
|
||||
let data_actual_inplace = tensor_1.greater_equal_elem(4.0);
|
||||
|
||||
let data_expected = TensorData::from([[false, false, false], [false, true, true]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_greater() {
|
||||
let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor_2 = TestTensor::<2>::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().greater(tensor_2.clone());
|
||||
let data_actual_inplace = tensor_1.greater(tensor_2);
|
||||
|
||||
let data_expected = TensorData::from([[false, false, true], [false, true, false]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_greater_equal() {
|
||||
let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor_2 = TestTensor::<2>::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().greater_equal(tensor_2.clone());
|
||||
let data_actual_inplace = tensor_1.greater_equal(tensor_2);
|
||||
|
||||
let data_expected = TensorData::from([[false, true, true], [false, true, false]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lower_elem() {
|
||||
let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().lower_elem(4.0);
|
||||
let data_actual_inplace = tensor_1.lower_elem(4.0);
|
||||
|
||||
let data_expected = TensorData::from([[true, true, true], [true, false, false]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lower_equal_elem() {
|
||||
let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().lower_equal_elem(4.0);
|
||||
let data_actual_inplace = tensor_1.lower_equal_elem(4.0);
|
||||
|
||||
let data_expected = TensorData::from([[true, true, true], [true, true, false]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lower() {
|
||||
let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor_2 = TestTensor::<2>::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().lower(tensor_2.clone());
|
||||
let data_actual_inplace = tensor_1.lower(tensor_2);
|
||||
|
||||
let data_expected = TensorData::from([[true, false, false], [true, false, true]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lower_equal() {
|
||||
let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor_2 = TestTensor::<2>::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().lower_equal(tensor_2.clone());
|
||||
let data_actual_inplace = tensor_1.lower_equal(tensor_2);
|
||||
|
||||
let data_expected = TensorData::from([[true, true, false], [true, false, true]]);
|
||||
data_expected.assert_eq(&data_actual_cloned.into_data(), false);
|
||||
data_expected.assert_eq(&data_actual_inplace.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_greater_broadcast() {
|
||||
// Test broadcasting with shape [1, 4] vs [4, 4]
|
||||
let device = Default::default();
|
||||
let data_1 = TensorData::from([[1.0, 2.0, 3.0, 4.0]]);
|
||||
let data_2 = TensorData::from([
|
||||
[0.5, 1.5, 2.5, 3.5],
|
||||
[1.5, 2.5, 3.5, 4.5],
|
||||
[2.5, 3.5, 4.5, 5.5],
|
||||
[3.5, 4.5, 5.5, 6.5],
|
||||
]);
|
||||
let tensor_1 = TestTensor::<2>::from_data(data_1, &device);
|
||||
let tensor_2 = TestTensor::<2>::from_data(data_2, &device);
|
||||
|
||||
let result = tensor_1.greater(tensor_2);
|
||||
|
||||
let expected = TensorData::from([
|
||||
[true, true, true, true],
|
||||
[false, false, false, false],
|
||||
[false, false, false, false],
|
||||
[false, false, false, false],
|
||||
]);
|
||||
expected.assert_eq(&result.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_greater_equal_broadcast() {
|
||||
// Test broadcasting with shape [4, 1] vs [1, 4]
|
||||
let device = Default::default();
|
||||
let data_1 = TensorData::from([[1.0], [2.0], [3.0], [4.0]]);
|
||||
let data_2 = TensorData::from([[1.0, 2.0, 3.0, 4.0]]);
|
||||
let tensor_1 = TestTensor::<2>::from_data(data_1, &device);
|
||||
let tensor_2 = TestTensor::<2>::from_data(data_2, &device);
|
||||
|
||||
let result = tensor_1.greater_equal(tensor_2);
|
||||
|
||||
let expected = TensorData::from([
|
||||
[true, false, false, false],
|
||||
[true, true, false, false],
|
||||
[true, true, true, false],
|
||||
[true, true, true, true],
|
||||
]);
|
||||
expected.assert_eq(&result.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lower_broadcast() {
|
||||
// Test broadcasting mimicking CLIP pattern: [1, 5] vs [5, 1]
|
||||
let device = Default::default();
|
||||
let data_1 = TensorData::from([[0.0, 1.0, -1.0, 2.0, -2.0]]);
|
||||
let data_2 = TensorData::from([[0.5], [1.5], [-0.5], [-1.5], [2.5]]);
|
||||
let tensor_1 = TestTensor::<2>::from_data(data_1, &device);
|
||||
let tensor_2 = TestTensor::<2>::from_data(data_2, &device);
|
||||
|
||||
let result = tensor_1.lower(tensor_2);
|
||||
|
||||
let expected = TensorData::from([
|
||||
[true, false, true, false, true],
|
||||
[true, true, true, false, true],
|
||||
[false, false, true, false, true],
|
||||
[false, false, false, false, true],
|
||||
[true, true, true, true, true],
|
||||
]);
|
||||
expected.assert_eq(&result.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lower_equal_broadcast() {
|
||||
// Test broadcasting with shape [1, 1] vs [2, 4]
|
||||
let device = Default::default();
|
||||
let data_1 = TensorData::from([[2.5]]);
|
||||
let data_2 = TensorData::from([[1.0, 2.0, 3.0, 4.0], [2.0, 2.5, 3.0, 3.5]]);
|
||||
let tensor_1 = TestTensor::<2>::from_data(data_1, &device);
|
||||
let tensor_2 = TestTensor::<2>::from_data(data_2, &device);
|
||||
|
||||
let result = tensor_1.lower_equal(tensor_2);
|
||||
|
||||
let expected = TensorData::from([[false, false, true, true], [false, true, true, true]]);
|
||||
expected.assert_eq(&result.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_equal_broadcast() {
|
||||
// Test broadcasting with different ranks
|
||||
let device = Default::default();
|
||||
let data_1 = TensorData::from([[2.0], [3.0], [4.0]]);
|
||||
let data_2 = TensorData::from([[2.0, 3.0, 4.0, 2.0]]);
|
||||
let tensor_1 = TestTensor::<2>::from_data(data_1, &device);
|
||||
let tensor_2 = TestTensor::<2>::from_data(data_2, &device);
|
||||
|
||||
let result = tensor_1.equal(tensor_2);
|
||||
|
||||
let expected = TensorData::from([
|
||||
[true, false, false, true],
|
||||
[false, true, false, false],
|
||||
[false, false, true, false],
|
||||
]);
|
||||
expected.assert_eq(&result.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_not_equal_broadcast() {
|
||||
// Test broadcasting with shape [3, 1] vs [1, 3]
|
||||
let device = Default::default();
|
||||
let data_1 = TensorData::from([[1.0], [2.0], [3.0]]);
|
||||
let data_2 = TensorData::from([[1.0, 2.0, 3.0]]);
|
||||
let tensor_1 = TestTensor::<2>::from_data(data_1, &device);
|
||||
let tensor_2 = TestTensor::<2>::from_data(data_2, &device);
|
||||
|
||||
let result = tensor_1.not_equal(tensor_2);
|
||||
|
||||
let expected = TensorData::from([
|
||||
[false, true, true],
|
||||
[true, false, true],
|
||||
[true, true, false],
|
||||
]);
|
||||
expected.assert_eq(&result.into_data(), false);
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
use super::*;
|
||||
use burn_tensor::Tolerance;
|
||||
use burn_tensor::{Distribution, TensorData};
|
||||
|
||||
#[test]
|
||||
fn should_support_zeros_like() {
|
||||
let tensor = TestTensor::<3>::from_floats(
|
||||
[
|
||||
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
|
||||
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
|
||||
],
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let tensor = tensor.zeros_like();
|
||||
let expected = TensorData::from([[[0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.]]]);
|
||||
|
||||
tensor
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_ones_like() {
|
||||
let tensor = TestTensor::<3>::from_floats(
|
||||
[
|
||||
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
|
||||
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
|
||||
],
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let tensor = tensor.ones_like();
|
||||
let expected = TensorData::from([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]);
|
||||
|
||||
tensor
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_randoms_like() {
|
||||
let tensor = TestTensor::<3>::from_floats(
|
||||
[
|
||||
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
|
||||
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
|
||||
],
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let tensor = tensor.random_like(Distribution::Uniform(0.99999, 1.));
|
||||
let expected = TensorData::from([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]);
|
||||
|
||||
tensor
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use burn_backend_tests::might_panic;
|
||||
|
||||
#[test]
|
||||
fn test_cross_3d_last_dim() {
|
||||
let tensor_1 = TestTensor::<2>::from([[1.0, 3.0, -5.0], [2.0, -1.0, 4.0]]);
|
||||
let tensor_2 = TestTensor::from([[4.0, -2.0, 1.0], [3.0, 5.0, -2.0]]);
|
||||
|
||||
let output = tensor_1.cross(tensor_2, -1);
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[-7.0, -21.0, -14.0], [-18.0, 16.0, 13.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_3d_non_contiguous_last_dim() {
|
||||
let tensor_1 = TestTensor::<2>::from([[1.0, 3.0, -5.0], [2.0, -1.0, 4.0]]);
|
||||
let tensor_2 = TestTensor::from([[4.0, 3.0], [-2.0, 5.0], [1.0, -2.0]]);
|
||||
|
||||
let output = tensor_1.cross(tensor_2.permute([1, 0]), -1);
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[-7.0, -21.0, -14.0], [-18.0, 16.0, 13.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[might_panic(reason = "not implemented: Cross product on non-last dimension")]
|
||||
#[test]
|
||||
fn test_cross_3d_dim0() {
|
||||
let tensor_1 = TestTensor::<2>::from([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]]);
|
||||
let tensor_2 = TestTensor::from([[0.0, 1.0], [0.0, 0.0], [1.0, 0.0]]);
|
||||
|
||||
let output = tensor_1.cross(tensor_2, 0);
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[0.0, 0.0], [-1.0, 0.0], [0.0, -1.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_3d_broadcast() {
|
||||
let tensor_1 = TestTensor::<2>::from([[1.0, 3.0, -5.0]]);
|
||||
let tensor_2 = TestTensor::from([[4.0, -2.0, 1.0], [3.0, 5.0, -2.0]]);
|
||||
|
||||
let output = tensor_1.cross(tensor_2, -1);
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[-7.0, -21.0, -14.0], [19.0, -13.0, -4.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_4d_last_dim() {
|
||||
let tensor_1 = TestTensor::<3>::from([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]]);
|
||||
let tensor_2 = TestTensor::from([[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]);
|
||||
|
||||
let output = tensor_1.cross(tensor_2, -1);
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
// Helper to compute expected cross product for 2-D (N × 3) tensors.
|
||||
fn manual_cross(a: &[[f32; 3]], b: &[[f32; 3]]) -> Vec<[f32; 3]> {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| {
|
||||
[
|
||||
x[1] * y[2] - x[2] * y[1],
|
||||
x[2] * y[0] - x[0] * y[2],
|
||||
x[0] * y[1] - x[1] * y[0],
|
||||
]
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_matches_manual_cross() {
|
||||
let a_raw = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
|
||||
let b_raw = [[7.0, 8.0, 9.0], [1.0, 0.0, -1.0]];
|
||||
let a = TestTensor::<2>::from(a_raw);
|
||||
let b = TestTensor::<2>::from(b_raw);
|
||||
|
||||
let out = a.cross(b.clone(), 1);
|
||||
let expected_vec = manual_cross(&a_raw, &b_raw);
|
||||
let expected: [[f32; 3]; 2] = [expected_vec[0], expected_vec[1]];
|
||||
|
||||
out.into_data()
|
||||
.assert_eq(&TensorData::from(expected), false);
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn test_cumsum_float_dim_0() {
|
||||
let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
|
||||
|
||||
let output = tensor.cumsum(0);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[1.0, 2.0, 3.0], [5.0, 7.0, 9.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cumsum_float_dim_1() {
|
||||
let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
|
||||
|
||||
let output = tensor.cumsum(1);
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[1.0, 3.0, 6.0], [4.0, 9.0, 15.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cumsum_non_contiguous() {
|
||||
let tensor = TestTensor::<2>::from([[1., 2.], [3., 4.]]).swap_dims(0, 1);
|
||||
|
||||
let output = tensor.cumsum(1);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[1., 4.], [2., 6.]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cumsum_float_3d() {
|
||||
let tensor = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]);
|
||||
|
||||
let output = tensor.cumsum(2);
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[[1.0, 3.0], [3.0, 7.0]], [[5.0, 11.0], [7.0, 15.0]]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cumprod_float_dim_0() {
|
||||
let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
|
||||
|
||||
let output = tensor.cumprod(0);
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[1.0, 2.0, 3.0], [4.0, 10.0, 18.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cumprod_float_dim_1() {
|
||||
let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
|
||||
|
||||
let output = tensor.cumprod(1);
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[1.0, 2.0, 6.0], [4.0, 20.0, 120.0]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cumprod_float_3d() {
|
||||
let tensor = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]);
|
||||
|
||||
let output = tensor.cumprod(2);
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[[1.0, 2.0], [3.0, 12.0]], [[5.0, 30.0], [7.0, 56.0]]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cummin_float_dim_0() {
|
||||
let tensor = TestTensor::<2>::from([[3.0, 1.0, 4.0], [2.0, 5.0, 1.0]]);
|
||||
|
||||
let output = tensor.cummin(0);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[3.0, 1.0, 4.0], [2.0, 1.0, 1.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cummin_float_dim_1() {
|
||||
let tensor = TestTensor::<2>::from([[3.0, 1.0, 4.0], [2.0, 5.0, 1.0]]);
|
||||
|
||||
let output = tensor.cummin(1);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[3.0, 1.0, 1.0], [2.0, 2.0, 1.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cummin_float_3d() {
|
||||
let tensor = TestTensor::<3>::from([[[4.0, 2.0], [3.0, 1.0]], [[5.0, 6.0], [7.0, 8.0]]]);
|
||||
|
||||
let output = tensor.cummin(2);
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[[4.0, 2.0], [3.0, 1.0]], [[5.0, 5.0], [7.0, 7.0]]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cummax_float_dim_0() {
|
||||
let tensor = TestTensor::<2>::from([[3.0, 1.0, 4.0], [1.0, 5.0, 2.0]]);
|
||||
|
||||
let output = tensor.cummax(0);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[3.0, 1.0, 4.0], [3.0, 5.0, 4.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cummax_float_dim_1() {
|
||||
let tensor = TestTensor::<2>::from([[3.0, 1.0, 4.0], [1.0, 5.0, 2.0]]);
|
||||
|
||||
let output = tensor.cummax(1);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[3.0, 3.0, 4.0], [1.0, 5.0, 5.0]]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cummax_float_3d() {
|
||||
let tensor = TestTensor::<3>::from([[[1.0, 3.0], [2.0, 4.0]], [[5.0, 2.0], [6.0, 1.0]]]);
|
||||
|
||||
let output = tensor.cummax(2);
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[[1.0, 3.0], [2.0, 4.0]], [[5.0, 5.0], [6.0, 6.0]]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_support_div_ops() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let data_2 = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensor::<2>::from_data(data_1, &device);
|
||||
let tensor_2 = TestTensor::<2>::from_data(data_2, &device);
|
||||
|
||||
let output = tensor_1 / tensor_2;
|
||||
let expected = TensorData::from([[0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_div_broadcast() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0, 2.0]]);
|
||||
let data_2 = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensor::<2>::from_data(data_1, &device);
|
||||
let tensor_2 = TestTensor::<2>::from_data(data_2, &device);
|
||||
|
||||
let output = tensor_1 / tensor_2;
|
||||
|
||||
output.into_data().assert_eq(
|
||||
&TensorData::from([[0.0, 1.0, 1.0], [0.0, 0.25, 0.4]]),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_div_scalar_ops() {
|
||||
let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let scalar = 2.0;
|
||||
let device = Default::default();
|
||||
let tensor = TestTensor::<2>::from_data(data, &device);
|
||||
|
||||
let output = tensor / scalar;
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([[0.0, 0.5, 1.0], [1.5, 2.0, 2.5]]), false);
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
#[test]
|
||||
fn test_float() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensor::<1>::from_data([1.0, 2.0, 3.0], &device);
|
||||
let tensor_2 = TestTensor::<1>::from_data([0.0, -1.0, 4.0], &device);
|
||||
|
||||
let output = tensor_1.dot(tensor_2);
|
||||
let expected = TensorData::from([10.0]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_int() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensor::<1>::from_data([1, 2, 3], &device);
|
||||
let tensor_2 = TestTensor::<1>::from_data([0, -1, 4], &device);
|
||||
|
||||
let output = tensor_1.dot(tensor_2);
|
||||
let expected = TensorData::from([10]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_panics_for_different_sizes() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestTensor::<1>::from_data([1, 2], &device);
|
||||
let tensor_2 = TestTensor::<1>::from_data([1, 2, 3], &device);
|
||||
let _output = tensor_1.dot(tensor_2);
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
use super::*;
|
||||
use burn_tensor::TensorData;
|
||||
use burn_tensor::Tolerance;
|
||||
|
||||
#[test]
|
||||
fn should_support_erf_ops() {
|
||||
let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = TestTensor::<2>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.erf();
|
||||
let expected = TensorData::from([[0.0000, 0.8427, 0.99532], [0.99998, 1.0000, 1.0000]]);
|
||||
|
||||
output.into_data().assert_approx_eq::<FloatElem>(
|
||||
&expected,
|
||||
Tolerance::default().set_half_precision_absolute(2e-3),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_erf_ops_with_negative_number() {
|
||||
let data = TensorData::from([[-0.056, -0.043, -0.089], [3.0, 4.0, 5.0]]);
|
||||
let tensor = TestTensor::<2>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.erf();
|
||||
let expected = TensorData::from([
|
||||
[-0.06312324, -0.048490416, -0.10016122],
|
||||
[0.99998, 1.0000, 1.0000],
|
||||
]);
|
||||
|
||||
output.into_data().assert_approx_eq::<FloatElem>(
|
||||
&expected,
|
||||
Tolerance::default().set_half_precision_absolute(3e-3),
|
||||
);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user