use alloc::{vec, vec::Vec}; use burn_backend::element::{Element, ElementConversion}; #[cfg(feature = "simd")] use burn_backend::{DType, quantization::QuantValue}; use core::fmt::Debug; use core::marker::PhantomData; use ndarray::IntoDimension; use ndarray::SliceInfo; use ndarray::Zip; use ndarray::s; use ndarray::{Array2, ArrayD}; use num_traits::Signed; #[cfg(feature = "simd")] use paste::paste; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float; #[cfg(feature = "simd")] use crate::ops::simd::{ binary::try_binary_simd, binary_elemwise::{ VecAdd, VecBitAnd, VecBitOr, VecBitXor, VecClamp, VecDiv, VecMax, VecMin, VecMul, VecSub, try_binary_scalar_simd, }, cmp::{ VecEquals, VecGreater, VecGreaterEq, VecLower, VecLowerEq, try_cmp_scalar_simd, try_cmp_simd, }, unary::{RecipVec, VecAbs, VecBitNot, try_unary_simd}, }; use crate::reshape; use crate::{ IntNdArrayElement, ShapeOps, ops::macros::{ cummax_dim, cummin_dim, cumprod_dim, cumsum_dim, keepdim, mean_dim, prod_dim, sum_dim, }, }; use crate::{SharedArray, element::NdArrayElement}; use burn_backend::ops::unfold::calculate_unfold_shape; use burn_backend::{Shape, Slice}; use ndarray::ArrayView; use ndarray::Axis; use ndarray::Dim; use ndarray::IxDyn; use ndarray::SliceInfoElem; pub struct NdArrayOps { e: PhantomData, } pub(crate) struct NdArrayMathOps { e: PhantomData, } impl NdArrayOps where E: Copy + Debug + Element + crate::AddAssignElement, { pub fn slice(tensor: ArrayView, slices: &[Slice]) -> SharedArray { let slices = Self::to_slice_args_with_steps(slices, tensor.shape().num_dims()); tensor.slice_move(slices.as_slice()).to_shared() } pub fn slice_assign( tensor: SharedArray, slices: &[Slice], value: SharedArray, ) -> SharedArray { let slices = Self::to_slice_args_with_steps(slices, tensor.shape().num_dims()); let mut array = tensor.into_owned(); array.slice_mut(slices.as_slice()).assign(&value); array.into_shared() } pub fn mask_where( tensor: SharedArray, mask: SharedArray, source: SharedArray, ) -> SharedArray { let tensor = tensor.broadcast(mask.dim()).unwrap(); let source = source.broadcast(mask.dim()).unwrap(); Zip::from(&tensor) .and(&mask) .and(&source) .map_collect(|&x, &mask_val, &y| if mask_val { y } else { x }) .into_shared() } pub fn mask_fill(tensor: SharedArray, mask: SharedArray, value: E) -> SharedArray { // Use into_owned() instead of clone() - only copies if shared, avoids copy if unique let mut output = tensor.into_owned(); let broadcast_mask = mask.broadcast(output.dim()).unwrap(); Zip::from(&mut output) .and(&broadcast_mask) .for_each(|out, &mask_val| { if mask_val { *out = value; } }); output.into_shared() } pub fn gather( dim: usize, mut tensor: SharedArray, mut indices: SharedArray, ) -> SharedArray { let ndims = tensor.shape().num_dims(); if dim != ndims - 1 { tensor.swap_axes(ndims - 1, dim); indices.swap_axes(ndims - 1, dim); } let (shape_tensor, shape_indices) = (tensor.shape(), indices.shape().into_shape()); let (size_tensor, size_index) = (shape_tensor[ndims - 1], shape_indices[ndims - 1]); let batch_size = Self::gather_batch_size(shape_tensor, &shape_indices); let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])); let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])); let mut output = Array2::from_elem((batch_size, size_index), 0.elem::()); for b in 0..batch_size { let indices = indices.slice(s!(b, ..)); for (i, index) in indices.iter().enumerate() { output[[b, i]] = tensor[[b, index.elem::() as usize]]; } } let mut output = NdArrayOps::reshape(output.into_shared().into_dyn(), shape_indices); if dim != ndims - 1 { output.swap_axes(ndims - 1, dim); } output } pub fn scatter( dim: usize, mut tensor: SharedArray, mut indices: SharedArray, mut value: SharedArray, ) -> SharedArray { let ndims = tensor.shape().num_dims(); if dim != ndims - 1 { tensor.swap_axes(ndims - 1, dim); indices.swap_axes(ndims - 1, dim); value.swap_axes(ndims - 1, dim); } let (shape_tensor, shape_indices, shape_value) = (tensor.shape().into_shape(), indices.shape(), value.shape()); let (size_tensor, size_index, size_value) = ( shape_tensor[ndims - 1], shape_indices[ndims - 1], shape_value[ndims - 1], ); let batch_size = Self::gather_batch_size(&shape_tensor, shape_indices); if shape_value != shape_indices { panic!( "Invalid dimension: the shape of the index tensor should be the same as the value \ tensor: Index {:?} value {:?}", shape_indices, shape_value ); } let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])); let value = NdArrayOps::reshape(value, Shape::new([batch_size, size_value])); let mut tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])); for b in 0..batch_size { let indices = indices.slice(s!(b, ..)); for (i, index) in indices.iter().enumerate() { let index = index.elem::() as usize; tensor[[b, index]].add_assign(value[[b, i]]); } } let mut output = NdArrayOps::reshape(tensor.into_shared().into_dyn(), shape_tensor); if dim != ndims - 1 { output.swap_axes(ndims - 1, dim); } output } fn gather_batch_size(shape_tensor: &[usize], shape_indices: &[usize]) -> usize { let ndims = shape_tensor.num_dims(); let mut batch_size = 1; for i in 0..ndims - 1 { if shape_tensor[i] != shape_indices[i] { panic!( "Unsupported dimension, only the last dimension can differ: Tensor {:?} Index \ {:?}", shape_tensor, shape_indices ); } batch_size *= shape_indices[i]; } batch_size } pub fn reshape(tensor: SharedArray, shape: Shape) -> SharedArray { reshape!( ty E, shape shape, array tensor, d shape.num_dims() ) } pub(crate) fn concatenate( arrays: &[ndarray::ArrayView], dim: usize, ) -> SharedArray { let array = ndarray::concatenate(Axis(dim), arrays) .unwrap() .into_shared(); // Transform column-major layout into row-major (standard) layout. (fix #1053) // Get shape first (via reference), then pass ownership to avoid clone let shape = array.shape().into_shape(); Self::reshape(array, shape) } pub fn cat(tensors: Vec>, dim: usize) -> SharedArray { let arrays: Vec<_> = tensors.iter().map(|t| t.view()).collect(); Self::concatenate(&arrays, dim) } #[allow(clippy::wrong_self_convention)] fn to_slice_args_with_steps( burn_slices: &[burn_backend::Slice], ndims: usize, ) -> Vec { let mut slices = vec![SliceInfoElem::NewAxis; ndims]; for i in 0..ndims { slices[i] = if i < burn_slices.len() { let slice = &burn_slices[i]; // Check for empty range (would result in no elements) if let Some(end) = slice.end && slice.start == end { SliceInfoElem::Slice { start: 0, end: Some(0), step: 1, } } else { // Pass slice parameters directly to ndarray // ndarray handles both positive and negative steps correctly: // - Positive step: iterates forward from start // - Negative step: iterates backward from the last element in range SliceInfoElem::Slice { start: slice.start, end: slice.end, step: slice.step, } } } else { // Dimension not specified in slices - use full range SliceInfoElem::Slice { start: 0, end: None, step: 1, } } } slices } pub fn swap_dims(mut tensor: SharedArray, dim1: usize, dim2: usize) -> SharedArray { tensor.swap_axes(dim1, dim2); tensor } pub fn permute(tensor: SharedArray, axes: &[usize]) -> SharedArray { tensor.permuted_axes(axes.into_dimension()) } /// Broadcasts the tensor to the given shape pub(crate) fn expand(tensor: SharedArray, shape: Shape) -> SharedArray { tensor .broadcast(shape.into_dimension()) .expect("The shapes should be broadcastable") // need to convert view to owned array because NdArrayTensor expects owned array // and try_into_owned_nocopy() panics for broadcasted arrays (zero strides) .into_owned() .into_shared() } pub fn flip(tensor: SharedArray, axes: &[usize]) -> SharedArray { let slice_items: Vec<_> = (0..tensor.shape().num_dims()) .map(|i| { if axes.contains(&i) { SliceInfoElem::Slice { start: 0, end: None, step: -1, } } else { SliceInfoElem::Slice { start: 0, end: None, step: 1, } } }) .collect(); let slice_info = SliceInfo::, IxDyn, IxDyn>::try_from(slice_items).unwrap(); tensor.slice(slice_info).into_owned().into_shared() } /// Unfold windows along a dimension. /// /// # Warning /// /// This is a copy impl; `ndarray` doesn't expose the layout machinery /// necessary to build the stride view. /// /// Returns a copy of the tensor with all complete windows of size `size` in dimension `dim`; /// where windows are advanced by `step` at each index. /// /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`. /// /// # Arguments /// /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]`` /// * `dim` - the dimension to unfold. /// * `size` - the size of each unfolded window. /// * `step` - the step between each window. /// /// # Returns /// /// A tensor view with shape ``[pre=..., windows, post=..., size]``. #[allow(unused)] pub(crate) fn unfold( tensor: SharedArray, dim: usize, size: usize, step: usize, ) -> SharedArray { let result_shape = calculate_unfold_shape(tensor.shape(), dim, size, step); let windows = result_shape[dim]; let mut slices = vec![Slice::new(0, None, 1); tensor.shape().len()]; let new_axis = slices.len(); let mut stack = Vec::with_capacity(windows); for widx in 0..windows { let start = widx * step; let end = start + size; slices[dim] = Slice::new(start as isize, Some(end as isize), 1); let mut window_slice = tensor.slice(Self::to_slice_args_with_steps(&slices, slices.len()).as_slice()); window_slice.insert_axis_inplace(Axis(new_axis)); window_slice.swap_axes(dim, new_axis); stack.push(window_slice); } Self::concatenate(&stack, dim) } } #[cfg(feature = "simd")] macro_rules! dispatch_binary_simd { (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ paste! { let simd = match $elem::dtype() { $(DType::[<$ty:upper>] => try_binary_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)* _ => Err(($lhs, $rhs)), }; match simd { Ok(out) => return out, Err(args) => args, } } }}; ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ paste! { let simd = match $elem::dtype() { $(DType::[<$ty:upper>] => try_binary_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)* DType::QFloat(strategy) => match strategy.value { QuantValue::Q8F | QuantValue::Q8S => try_binary_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs), _ => Err(($lhs, $rhs)), }, _ => Err(($lhs, $rhs)), }; match simd { Ok(out) => return out, Err(args) => args, } } }}; } #[cfg(not(feature = "simd"))] macro_rules! dispatch_binary_simd { (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }}; ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }}; } #[cfg(feature = "simd")] macro_rules! dispatch_binary_scalar_simd { (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ paste! { let simd = match $elem::dtype() { $(DType::[<$ty:upper>] => try_binary_scalar_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)* _ => Err($lhs), }; match simd { Ok(out) => return out, Err(args) => args, } } }}; ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ paste! { let simd = match $elem::dtype() { $(DType::[<$ty:upper>] => try_binary_scalar_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)* DType::QFloat(strategy) => match strategy.value { QuantValue::Q8F | QuantValue::Q8S => try_binary_scalar_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs), QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => Err($lhs) }, _ => Err($lhs), }; match simd { Ok(out) => return out, Err(args) => args, } } }}; } #[cfg(not(feature = "simd"))] macro_rules! dispatch_binary_scalar_simd { (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }}; ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }}; } #[cfg(feature = "simd")] macro_rules! dispatch_cmp_simd { ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ paste! { let simd = match $elem::dtype() { $(DType::[<$ty:upper>] => try_cmp_simd::<$elem, $ty, $op>($lhs, $rhs),)* DType::QFloat(strategy) => match strategy.value { QuantValue::Q8F | QuantValue::Q8S => try_cmp_simd::<$elem, i8, $op>($lhs, $rhs), QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => Err(($lhs, $rhs)) }, _ => Err(($lhs, $rhs)), }; match simd { Ok(out) => return out, Err(args) => args, } } }}; } #[cfg(not(feature = "simd"))] macro_rules! dispatch_cmp_simd { ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }}; } #[cfg(feature = "simd")] macro_rules! dispatch_cmp_scalar_simd { ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ paste! { let simd = match $elem::dtype() { $(DType::[<$ty:upper>] => try_cmp_scalar_simd::<$elem, $ty, $op>($lhs, $rhs),)* DType::QFloat(strategy) => match strategy.value { QuantValue::Q8F | QuantValue::Q8S => try_cmp_scalar_simd::<$elem, i8, $op>($lhs, $rhs), QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => Err($lhs) }, _ => Err($lhs), }; match simd { Ok(out) => return out, Err(args) => args, } } }}; } #[cfg(not(feature = "simd"))] macro_rules! dispatch_cmp_scalar_simd { ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }}; } #[cfg(feature = "simd")] macro_rules! dispatch_unary_simd { ($elem: ty, $op: ty, $lhs: expr, $($ty: ty),*) => {{ paste! { let simd = match $elem::dtype() { $(DType::[<$ty:upper>] => try_unary_simd::<$elem, $elem, $ty, $ty, $op>($lhs),)* _ => Err($lhs), }; match simd { Ok(out) => return out, Err(args) => args, } } }}; } #[cfg(not(feature = "simd"))] macro_rules! dispatch_unary_simd { ($elem: ty, $op: ty, $lhs: expr, $($ty: ty),*) => {{ $lhs }}; } // Helper function to broadcast two tensors to a common shape for comparison operations // Returns broadcasted views that can be safely zipped fn broadcast_for_comparison<'a, E: Copy, S1, S2>( lhs: &'a ndarray::ArrayBase, rhs: &'a ndarray::ArrayBase, ) -> ( ndarray::ArrayView<'a, E, ndarray::IxDyn>, ndarray::ArrayView<'a, E, ndarray::IxDyn>, ) where S1: ndarray::Data, S2: ndarray::Data, { // Get shapes let lhs_shape = lhs.shape(); let rhs_shape = rhs.shape(); // Compute broadcast shape using ndarray's broadcast compatibility rules let ndims = lhs_shape.len().max(rhs_shape.len()); let mut broadcast_shape = vec![1; ndims]; for i in 0..ndims { let lhs_dim = if i < lhs_shape.len() { lhs_shape[lhs_shape.len() - 1 - i] } else { 1 }; let rhs_dim = if i < rhs_shape.len() { rhs_shape[rhs_shape.len() - 1 - i] } else { 1 }; if lhs_dim == rhs_dim { broadcast_shape[ndims - 1 - i] = lhs_dim; } else if lhs_dim == 1 { broadcast_shape[ndims - 1 - i] = rhs_dim; } else if rhs_dim == 1 { broadcast_shape[ndims - 1 - i] = lhs_dim; } else { panic!( "Incompatible shapes for broadcasting: {:?} and {:?}", lhs_shape, rhs_shape ); } } // Create IxDyn from broadcast shape let broadcast_dim = ndarray::IxDyn(&broadcast_shape); // Broadcast both arrays let lhs_broadcast = lhs .broadcast(broadcast_dim.clone()) .expect("Failed to broadcast lhs"); let rhs_broadcast = rhs .broadcast(broadcast_dim) .expect("Failed to broadcast rhs"); (lhs_broadcast, rhs_broadcast) } impl NdArrayMathOps where E: Copy + NdArrayElement, { pub fn add(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_binary_simd!( E, VecAdd, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 ); let array = &lhs + &rhs; array.into_shared() } pub fn add_scalar(lhs: SharedArray, rhs: E) -> SharedArray { let lhs = dispatch_binary_scalar_simd!( E, VecAdd, lhs, rhs.elem(), u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 ); let array = lhs + rhs; array.into_shared() } pub fn sub(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_binary_simd!( E, VecSub, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 ); let array = lhs - rhs; array.into_shared() } pub fn sub_scalar(lhs: SharedArray, rhs: E) -> SharedArray { let lhs = dispatch_binary_scalar_simd!( E, VecSub, lhs, rhs.elem(), u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 ); let array = lhs - rhs; array.into_shared() } pub fn mul(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_binary_simd!(noq, E, VecMul, lhs, rhs, u16, i16, u32, i32, f32, f64); let array = lhs * rhs; array.into_shared() } pub fn mul_scalar(lhs: SharedArray, rhs: E) -> SharedArray { let lhs = dispatch_binary_scalar_simd!( noq, E, VecMul, lhs, rhs.elem(), u16, i16, u32, i32, f32, f64 ); let array = lhs * rhs; array.into_shared() } pub fn div(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_binary_simd!(noq, E, VecDiv, lhs, rhs, f32, f64); let array = lhs / rhs; array.into_shared() } pub fn div_scalar(lhs: SharedArray, rhs: E) -> SharedArray { let lhs = dispatch_binary_scalar_simd!(noq, E, VecDiv, lhs, rhs.elem(), f32, f64); let array = lhs / rhs; array.into_shared() } pub fn remainder(lhs: SharedArray, rhs: SharedArray) -> SharedArray { // Use into_owned() instead of clone() - only copies if shared, avoids copy if unique let mut out = lhs.into_owned(); Zip::from(&mut out).and(&rhs).for_each(|out_elem, &b| { // out_elem holds lhs value; read it before overwriting with remainder let a_f = (*out_elem).to_f64(); let b_f = b.to_f64(); let r = a_f - b_f * (a_f / b_f).floor(); *out_elem = r.elem(); }); out.into_shared() } pub fn remainder_scalar(lhs: SharedArray, rhs: E) -> SharedArray where E: core::ops::Rem, { let array = lhs.mapv(|x| ((x % rhs) + rhs) % rhs); array.into_shared() } pub fn recip(tensor: SharedArray) -> SharedArray { let tensor = dispatch_unary_simd!(E, RecipVec, tensor, f32); let array = tensor.map(|x| 1.elem::() / *x); array.into_shared() } /// Sum all elements - zero-copy for borrowed storage. pub fn sum_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { let sum = view.sum(); ArrayD::from_elem(IxDyn(&[1]), sum).into_shared() } /// Mean of all elements - zero-copy for borrowed storage. pub fn mean_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { let mean = view.mean().unwrap(); ArrayD::from_elem(IxDyn(&[1]), mean).into_shared() } /// Product of all elements - zero-copy for borrowed storage. pub fn prod_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { let prod = view.iter().fold(E::one(), |acc, &x| acc * x); ArrayD::from_elem(IxDyn(&[1]), prod).into_shared() } pub fn mean_dim(tensor: SharedArray, dim: usize) -> SharedArray { let ndims = tensor.shape().num_dims(); match ndims { d if (1..=6).contains(&d) => keepdim!(dim, tensor, mean), _ => panic!("Dim not supported {ndims}"), } } pub fn sum_dim(tensor: SharedArray, dim: usize) -> SharedArray { let ndims = tensor.shape().num_dims(); match ndims { d if (1..=6).contains(&d) => keepdim!(dim, tensor, sum), _ => panic!("Dim not supported {ndims}"), } } pub fn prod_dim(tensor: SharedArray, dim: usize) -> SharedArray { let ndims = tensor.shape().num_dims(); match ndims { d if (1..=6).contains(&d) => keepdim!(dim, tensor, prod), _ => panic!("Dim not supported {ndims}"), } } pub fn cumsum(tensor: SharedArray, dim: usize) -> SharedArray { cumsum_dim(tensor, dim) } pub fn cumprod(tensor: SharedArray, dim: usize) -> SharedArray { cumprod_dim(tensor, dim) } pub fn select( tensor: SharedArray, dim: usize, indices: SharedArray, ) -> SharedArray { let array = tensor.select( Axis(dim), &indices .into_iter() .map(|i| i.elem::() as usize) .collect::>(), ); array.into_shared() } pub fn select_assign( tensor: SharedArray, dim: usize, indices: SharedArray, value: SharedArray, ) -> SharedArray { let mut output_array = tensor.into_owned(); for (index_value, index) in indices.into_iter().enumerate() { let mut view = output_array.index_axis_mut(Axis(dim), index.elem::() as usize); let value = value.index_axis(Axis(dim), index_value); view.zip_mut_with(&value, |a, b| *a += *b); } output_array.into_shared() } pub(crate) fn elementwise_op( lhs: SharedArray, rhs: SharedArray, var_name: impl FnMut(&E, &OtherE) -> E, ) -> SharedArray { let lhs = lhs.broadcast(rhs.dim()).unwrap_or(lhs.view()); let rhs = rhs.broadcast(lhs.dim()).unwrap_or(rhs.view()); Zip::from(lhs).and(rhs).map_collect(var_name).into_shared() } pub(crate) fn elementwise_op_scalar( lhs: SharedArray, var_name: impl FnMut(E) -> E, ) -> SharedArray { lhs.mapv(var_name).into_shared() } pub(crate) fn abs(tensor: SharedArray) -> SharedArray { let tensor = dispatch_unary_simd!(E, VecAbs, tensor, i8, i16, i32, f32, f64); tensor.mapv_into(|a| a.abs_elem()).into_shared() } pub(crate) fn equal(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_cmp_simd!( E, VecEquals, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); // Use the helper to broadcast both arrays to a common shape let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); // Now we can safely zip and compare Zip::from(&lhs_broadcast) .and(&rhs_broadcast) .map_collect(|&lhs, &rhs| lhs == rhs) .into_shared() } pub(crate) fn equal_elem(lhs: SharedArray, rhs: E) -> SharedArray { let lhs = dispatch_cmp_scalar_simd!( E, VecEquals, lhs, rhs.elem(), u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); lhs.mapv(|a| a == rhs).into_shared() } pub(crate) fn sign_op(tensor: SharedArray) -> SharedArray where E: Signed, { let zero = 0.elem(); let one = 1.elem::(); tensor .mapv(|x| { if x == zero { zero } else { match x.is_positive() { true => one, false => -one, } } }) .into_shared() } } impl NdArrayMathOps where E: Copy + NdArrayElement + PartialOrd, { /// Max of all elements - zero-copy for borrowed storage. pub fn max_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { let max = view .iter() .copied() .reduce(|a, b| if a > b { a } else { b }) .expect("Cannot compute max of empty tensor"); ArrayD::from_elem(IxDyn(&[1]), max).into_shared() } /// Min of all elements - zero-copy for borrowed storage. pub fn min_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { let min = view .iter() .copied() .reduce(|a, b| if a < b { a } else { b }) .expect("Cannot compute min of empty tensor"); ArrayD::from_elem(IxDyn(&[1]), min).into_shared() } /// Argmax along dimension - zero-copy for borrowed storage. pub fn argmax_view( view: ArrayView<'_, E, IxDyn>, dim: usize, ) -> SharedArray { arg_view(view, dim, CmpType::Max) } /// Argmin along dimension - zero-copy for borrowed storage. pub fn argmin_view( view: ArrayView<'_, E, IxDyn>, dim: usize, ) -> SharedArray { arg_view(view, dim, CmpType::Min) } pub fn cummin(tensor: SharedArray, dim: usize) -> SharedArray { cummin_dim(tensor, dim) } pub fn cummax(tensor: SharedArray, dim: usize) -> SharedArray { cummax_dim(tensor, dim) } pub fn argmax( tensor: SharedArray, dim: usize, ) -> SharedArray { arg(tensor, dim, CmpType::Max) } pub fn argmin( tensor: SharedArray, dim: usize, ) -> SharedArray { arg(tensor, dim, CmpType::Min) } pub fn clamp_min(tensor: SharedArray, min: E) -> SharedArray { let mut tensor = dispatch_binary_scalar_simd!( E, VecMax, tensor, min.elem(), u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 ); tensor.mapv_inplace(|x| match x < min { true => min, false => x, }); tensor } pub fn clamp_max(tensor: SharedArray, max: E) -> SharedArray { let mut tensor = dispatch_binary_scalar_simd!( E, VecMin, tensor, max.elem(), u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 ); tensor.mapv_inplace(|x| match x > max { true => max, false => x, }); tensor } pub fn clamp(tensor: SharedArray, min: E, max: E) -> SharedArray { let mut tensor = dispatch_binary_scalar_simd!( E, VecClamp, tensor, (min.elem(), max.elem()), u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 ); tensor.mapv_inplace(|x| match x < min { true => min, false => match x > max { true => max, false => x, }, }); tensor } pub(crate) fn greater(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_cmp_simd!( E, VecGreater, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); // Use the helper to broadcast both arrays to a common shape let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); // Now we can safely zip and compare Zip::from(&lhs_broadcast) .and(&rhs_broadcast) .map_collect(|&lhs, &rhs| lhs > rhs) .into_shared() } pub(crate) fn greater_elem(lhs: SharedArray, rhs: E) -> SharedArray { let lhs = dispatch_cmp_scalar_simd!( E, VecGreater, lhs, rhs.elem(), u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); lhs.mapv(|a| a > rhs).into_shared() } pub(crate) fn greater_equal(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_cmp_simd!( E, VecGreaterEq, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); // Use the helper to broadcast both arrays to a common shape let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); // Now we can safely zip and compare Zip::from(&lhs_broadcast) .and(&rhs_broadcast) .map_collect(|&lhs, &rhs| lhs >= rhs) .into_shared() } pub(crate) fn greater_equal_elem(lhs: SharedArray, rhs: E) -> SharedArray { let lhs = dispatch_cmp_scalar_simd!( E, VecGreaterEq, lhs, rhs.elem(), u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); lhs.mapv(|a| a >= rhs).into_shared() } pub(crate) fn lower_equal(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_cmp_simd!( E, VecLowerEq, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); // Use the helper to broadcast both arrays to a common shape let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); // Now we can safely zip and compare Zip::from(&lhs_broadcast) .and(&rhs_broadcast) .map_collect(|&lhs, &rhs| lhs <= rhs) .into_shared() } pub(crate) fn lower_equal_elem(lhs: SharedArray, rhs: E) -> SharedArray { let lhs = dispatch_cmp_scalar_simd!( E, VecLowerEq, lhs, rhs.elem(), u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); lhs.mapv(|a| a <= rhs).into_shared() } pub(crate) fn lower(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_cmp_simd!( E, VecLower, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); // Use the helper to broadcast both arrays to a common shape let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); // Now we can safely zip and compare Zip::from(&lhs_broadcast) .and(&rhs_broadcast) .map_collect(|&lhs, &rhs| lhs < rhs) .into_shared() } pub(crate) fn lower_elem(lhs: SharedArray, rhs: E) -> SharedArray { let lhs = dispatch_cmp_scalar_simd!( E, VecLower, lhs, rhs.elem(), u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); lhs.mapv(|a| a < rhs).into_shared() } } pub struct NdArrayBitOps(PhantomData); impl NdArrayBitOps { pub(crate) fn bitand(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_binary_simd!(I, VecBitAnd, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64); NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { (a.elem::() & (b.elem::())).elem() }) } pub(crate) fn bitand_scalar(lhs: SharedArray, rhs: I) -> SharedArray { let lhs = dispatch_binary_scalar_simd!( I, VecBitAnd, lhs, rhs.elem(), i8, u8, i16, u16, i32, u32, i64, u64 ); NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { (a.elem::() & rhs.elem::()).elem() }) } pub(crate) fn bitor(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_binary_simd!(I, VecBitOr, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64); NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { (a.elem::() | (b.elem::())).elem() }) } pub(crate) fn bitor_scalar(lhs: SharedArray, rhs: I) -> SharedArray { let lhs = dispatch_binary_scalar_simd!( I, VecBitOr, lhs, rhs.elem(), i8, u8, i16, u16, i32, u32, i64, u64 ); NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { (a.elem::() | rhs.elem::()).elem() }) } pub(crate) fn bitxor(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_binary_simd!(I, VecBitXor, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64); NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { (a.elem::() ^ (b.elem::())).elem() }) } pub(crate) fn bitxor_scalar(lhs: SharedArray, rhs: I) -> SharedArray { let lhs = dispatch_binary_scalar_simd!( I, VecBitXor, lhs, rhs.elem(), i8, u8, i16, u16, i32, u32, i64, u64 ); NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { (a.elem::() ^ rhs.elem::()).elem() }) } pub(crate) fn bitnot(tensor: SharedArray) -> SharedArray { let tensor = dispatch_unary_simd!(I, VecBitNot, tensor, i8, u8, i16, u16, i32, u32, i64, u64); NdArrayMathOps::elementwise_op_scalar(tensor, |a: I| (!a.elem::()).elem()) } } pub struct NdArrayBoolOps; // Rust booleans are either `00000000` or `00000001`, so bitwise and/or is fine, but bitwise not would // produce invalid values. impl NdArrayBoolOps { pub(crate) fn equal(lhs: SharedArray, rhs: SharedArray) -> SharedArray { #[cfg(feature = "simd")] let (lhs, rhs) = match try_cmp_simd::(lhs, rhs) { Ok(out) => return out, Err(args) => args, }; // Use the helper to broadcast both arrays to a common shape let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); // Now we can safely zip and compare Zip::from(&lhs_broadcast) .and(&rhs_broadcast) .map_collect(|&lhs, &rhs| lhs == rhs) .into_shared() } pub(crate) fn equal_elem(lhs: SharedArray, rhs: bool) -> SharedArray { #[cfg(feature = "simd")] let lhs = match try_cmp_scalar_simd::(lhs, rhs.elem()) { Ok(out) => return out, Err(args) => args, }; lhs.mapv(|a| a == rhs).into_shared() } pub(crate) fn and(lhs: SharedArray, rhs: SharedArray) -> SharedArray { #[cfg(feature = "simd")] let (lhs, rhs) = match try_binary_simd::(lhs, rhs) { Ok(out) => return out, Err(args) => args, }; // Use the helper to broadcast both arrays to a common shape let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); // Now we can safely zip and compare Zip::from(&lhs_broadcast) .and(&rhs_broadcast) .map_collect(|&lhs, &rhs| lhs && rhs) .into_shared() } pub(crate) fn or(lhs: SharedArray, rhs: SharedArray) -> SharedArray { #[cfg(feature = "simd")] let (lhs, rhs) = match try_binary_simd::(lhs, rhs) { Ok(out) => return out, Err(args) => args, }; // Use the helper to broadcast both arrays to a common shape let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); // Now we can safely zip and compare Zip::from(&lhs_broadcast) .and(&rhs_broadcast) .map_collect(|&lhs, &rhs| lhs || rhs) .into_shared() } /// Any element is true - zero-copy for borrowed storage. pub fn any_view(view: ArrayView<'_, bool, IxDyn>) -> bool { view.iter().any(|&x| x) } /// All elements are true - zero-copy for borrowed storage. pub fn all_view(view: ArrayView<'_, bool, IxDyn>) -> bool { view.iter().all(|&x| x) } } enum CmpType { Min, Max, } fn arg( tensor: SharedArray, dim: usize, cmp: CmpType, ) -> SharedArray { arg_view(tensor.view(), dim, cmp) } /// View-based argmax/argmin - zero-copy for borrowed storage. fn arg_view( view: ArrayView<'_, E, IxDyn>, dim: usize, cmp: CmpType, ) -> SharedArray { let mut reshape = view.shape().to_vec(); reshape[dim] = 1; let output = view.map_axis(Axis(dim), |arr| { // Find the min/max value in the array, and return its index. let (_e, idx) = arr.indexed_iter().fold((arr[0], 0usize), |acc, (idx, e)| { let cmp = match cmp { CmpType::Min => e < &acc.0, CmpType::Max => e > &acc.0, }; if cmp { (*e, idx) } else { acc } }); (idx as i64).elem() }); let output = output.to_shape(Dim(reshape.as_slice())).unwrap(); output.into_shared() } #[cfg(test)] mod tests { use burn_backend::TensorData; use crate::NdArrayTensor; use super::*; #[test] fn should_generate_row_major_layout_for_cat() { let expected_shape: &[usize] = &[4, 6, 2]; let expected_strides: &[isize] = &[12, 2, 1]; let NdArrayTensor::I32(expected_storage) = NdArrayTensor::from_data(TensorData::from([ [[1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [6, 0]], [[7, 0], [8, 0], [9, 0], [10, 0], [11, 0], [12, 0]], [[13, 0], [14, 0], [15, 0], [16, 0], [17, 0], [18, 0]], [[19, 0], [20, 0], [21, 0], [22, 0], [23, 0], [24, 0]], ])) else { panic!() }; let expected_array = expected_storage.into_shared(); let NdArrayTensor::I32(tensor_storage) = NdArrayTensor::from_data(TensorData::from([ [1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18], [19, 20, 21, 22, 23, 24], ])) else { panic!() }; let tensor = tensor_storage.into_shared(); // unsqueeze dim on the outermost axis let array = NdArrayOps::reshape(tensor, Shape::from([4, 6, 1])); let NdArrayTensor::I32(zeros_storage) = NdArrayTensor::from_data(TensorData::zeros::([4, 6, 1])) else { panic!() }; let zeros = zeros_storage.into_shared(); // make `ndarray` concatenates array on the outermost axis let array = NdArrayOps::cat([array, zeros].to_vec(), 2); assert!(array.is_standard_layout()); assert_eq!(array.shape(), expected_shape); assert_eq!(array.strides(), expected_strides); assert_eq!( array.into_iter().collect::>(), expected_array.into_iter().collect::>(), ); } }