use burn::{ tensor::{ backend::Backend, activation::relu, Tensor, Int, Bool, Float, TensorKind, BasicOps, Numeric, Element, }, }; use num_traits::ToPrimitive; pub fn tensor_max_scalar(x: Tensor, max: f64) -> Tensor { relu(x.sub_scalar(max)).add_scalar(max) } pub fn tensor_min_scalar(x: Tensor, min: f64) -> Tensor { -tensor_max_scalar(-x, -min) } pub fn tensor_max(x: Tensor, max: Tensor) -> Tensor { relu(x - max.clone()) + max } pub fn tensor_min(x: Tensor, min: Tensor) -> Tensor { -tensor_max(-x, -min) } pub fn tensor_log10(x: Tensor) -> Tensor { let ln10 = (10.0f64).ln(); x.log() / ln10 } pub fn tensor_max_element(x: Tensor) -> f64 { let flat: Tensor = x.flatten(0, D - 1); let max_index = flat.clone().argmax(0); flat.select(0, max_index).into_scalar().to_f64().unwrap() } pub fn all_zeros(x: Tensor) -> bool { x.powf(2.0).sum().into_scalar().to_f64().unwrap() == 0.0 } pub fn max_dim(x: Tensor, dim: usize) -> Tensor { let indices = x.clone().argmax(dim).flatten(0, 1); x.select(dim, indices) } pub fn _10pow(x: Tensor) -> Tensor { let log10 = (10.0f64).ln(); (x * log10).exp() } pub fn to_float(x: Tensor) -> Tensor { let device = x.device(); Tensor::from_data( x .into_data() .convert() ).to_device(&device) } pub fn to_float_bool(x: Tensor) -> Tensor { let device = x.device(); Tensor::from_data( x .into_data() .convert() ).to_device(&device) } pub fn reverse + BasicOps + Numeric>(x: Tensor, dim: usize) -> Tensor where >::Elem: Element { let len = x.dims()[dim]; let indices = -Tensor::arange_device(0..len, &x.device()) + (len - 1) as i64; x.select(dim, indices) } pub fn div_roundup(x: usize, y: usize) -> usize { (x + y - 1) / y }