use super::*; use burn_tensor::{Distribution, Int, Tensor, backend::Backend}; use burn_tensor::{IndexingUpdateOp, Tolerance}; #[test] fn scatter_should_work_with_multiple_workgroups_2d_dim0() { same_as_reference_same_shape(0, [256, 32]); } #[test] fn scatter_should_work_with_multiple_workgroups_2d_dim1() { same_as_reference_same_shape(1, [32, 256]); } #[test] fn scatter_should_work_with_multiple_workgroups_3d_dim0() { same_as_reference_same_shape(0, [256, 6, 6]); } #[test] fn scatter_should_work_with_multiple_workgroups_3d_dim1() { same_as_reference_same_shape(1, [6, 256, 6]); } #[test] fn scatter_should_work_with_multiple_workgroups_3d_dim2() { same_as_reference_same_shape(2, [6, 6, 256]); } #[test] fn scatter_should_work_with_multiple_workgroups_diff_shapes() { same_as_reference_diff_shape(1, [32, 128], [32, 1]); } fn same_as_reference_diff_shape( dim: usize, shape1: [usize; D], shape2: [usize; D], ) { let test_device = Default::default(); TestBackend::seed(&test_device, 0); let tensor = Tensor::::random(shape1, Distribution::Default, &test_device); let value = Tensor::::random(shape2, Distribution::Default, &test_device); let indices = Tensor::::random( [shape2.iter().product::()], Distribution::Uniform(0., shape2[dim] as f64), &test_device, ) .reshape(shape2); let ref_device = Default::default(); let tensor_ref = Tensor::::from_data(tensor.to_data(), &ref_device); let value_ref = Tensor::::from_data(value.to_data(), &ref_device); let indices_ref = Tensor::::from_data(indices.to_data(), &ref_device); let actual = tensor.scatter(dim, indices, value, IndexingUpdateOp::Add); let expected = tensor_ref.scatter(dim, indices_ref, value_ref, IndexingUpdateOp::Add); expected .into_data() .assert_approx_eq::(&actual.into_data(), Tolerance::default()); } fn same_as_reference_same_shape(dim: usize, shape: [usize; D]) { same_as_reference_diff_shape(dim, shape, shape); }