use super::*; use burn_tensor::Tolerance; use burn_tensor::{Distribution, Int, Tensor}; #[test] fn select_should_work_with_multiple_workgroups() { let tensor = Tensor::::random([6, 256], Distribution::Default, &Default::default()); let indices = Tensor::::arange(0..100, &Default::default()); let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); let indices_ref = Tensor::::from_data(indices.to_data(), &Default::default()); let actual = tensor.select(1, indices); let expected = tensor_ref.select(1, indices_ref); expected .into_data() .assert_approx_eq::(&actual.into_data(), Tolerance::default()); }