use super::*; use burn_tensor::Tensor; #[test] fn tanh_should_not_have_numerical_bugs_on_macos() { fn tanh_one_value(input: f32) -> f32 { let tensor = Tensor::::ones([1], &Default::default()) * input; let output = tensor.tanh().into_primitive(); Tensor::::from_primitive(output) .into_data() .as_slice() .unwrap()[0] } let ok = tanh_one_value(43.0); // metal tanh gives 1.0 which is the right answer let zero = tanh_one_value(44.0); // metal tanh gives zero when within 43.67..44.36 let nan = tanh_one_value(45.0); // metal tanh gives nan when over 44.36 let neg = tanh_one_value(-45.0); // metal works correctly here assert!(!ok.is_nan() && ok == 1.0); assert!(!zero.is_nan() && zero == 1.0); assert!(!nan.is_nan() && nan == 1.0); assert!(!neg.is_nan() && neg == -1.0); }