feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,221 @@
use crate::rand::NdArrayRng;
use crate::{NdArrayQTensor, NdArrayTensor};
use crate::{
SharedArray,
element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
};
use alloc::string::String;
use burn_backend::quantization::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue};
use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
use burn_backend::{Backend, DType, DeviceId, DeviceOps};
use burn_ir::{BackendIr, HandleKind, TensorHandle};
use burn_std::stub::Mutex;
use core::marker::PhantomData;
use rand::SeedableRng;
pub(crate) static SEED: Mutex<Option<NdArrayRng>> = Mutex::new(None);
/// The device type for the ndarray backend.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum NdArrayDevice {
/// The CPU device.
#[default]
Cpu,
}
impl DeviceOps for NdArrayDevice {}
impl burn_backend::Device for NdArrayDevice {
fn from_id(_device_id: DeviceId) -> Self {
Self::Cpu
}
fn to_id(&self) -> DeviceId {
DeviceId {
type_id: 0,
index_id: 0,
}
}
fn device_count(_type_id: u16) -> usize {
1
}
}
/// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations.
///
/// This backend is compatible with CPUs and can be compiled for almost any platform, including
/// `wasm`, `arm`, and `x86`.
#[derive(Clone, Copy, Default, Debug)]
pub struct NdArray<E = f32, I = i64, Q = i8>
where
NdArrayTensor: From<SharedArray<E>>,
NdArrayTensor: From<SharedArray<I>>,
{
_e: PhantomData<E>,
_i: PhantomData<I>,
_q: PhantomData<Q>,
}
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> Backend for NdArray<E, I, Q>
where
NdArrayTensor: From<SharedArray<E>>,
NdArrayTensor: From<SharedArray<I>>,
{
type Device = NdArrayDevice;
type FloatTensorPrimitive = NdArrayTensor;
type FloatElem = E;
type IntTensorPrimitive = NdArrayTensor;
type IntElem = I;
type BoolTensorPrimitive = NdArrayTensor;
type BoolElem = bool;
type QuantizedTensorPrimitive = NdArrayQTensor;
fn ad_enabled(_device: &Self::Device) -> bool {
false
}
fn name(_device: &Self::Device) -> String {
String::from("ndarray")
}
fn seed(_device: &Self::Device, seed: u64) {
let rng = NdArrayRng::seed_from_u64(seed);
let mut seed = SEED.lock().unwrap();
*seed = Some(rng);
}
fn dtype_usage(_device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
match dtype {
DType::F64
| DType::F32
| DType::Flex32
| DType::I64
| DType::I32
| DType::I16
| DType::I8
| DType::U64
| DType::U32
| DType::U16
| DType::U8
| DType::Bool => burn_backend::DTypeUsage::general(),
DType::F16 | DType::BF16 => burn_backend::DTypeUsageSet::empty(),
DType::QFloat(scheme) => {
match scheme {
QuantScheme {
level: QuantLevel::Tensor | QuantLevel::Block(_),
mode: QuantMode::Symmetric,
#[cfg(not(feature = "export_tests"))]
value: QuantValue::Q8F | QuantValue::Q8S,
// For tests, "native" sub-byte quant serves as a reference for value equality.
// Values are stored as i8 regardless.
#[cfg(feature = "export_tests")]
value:
QuantValue::Q8F
| QuantValue::Q8S
| QuantValue::Q4F
| QuantValue::Q4S
| QuantValue::Q2F
| QuantValue::Q2S,
store: QuantStore::Native,
..
} => burn_backend::DTypeUsage::general(),
_scheme => burn_backend::DTypeUsageSet::empty(),
}
}
}
}
}
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BackendIr for NdArray<E, I, Q>
where
NdArrayTensor: From<SharedArray<E>>,
NdArrayTensor: From<SharedArray<I>>,
{
type Handle = HandleKind<Self>;
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
match handle.handle {
HandleKind::Float(handle) => handle,
_ => panic!("Expected float handle, got {}", handle.handle.name()),
}
}
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
match handle.handle {
HandleKind::Int(handle) => handle,
_ => panic!("Expected int handle, got {}", handle.handle.name()),
}
}
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
match handle.handle {
HandleKind::Bool(handle) => handle,
_ => panic!("Expected bool handle, got {}", handle.handle.name()),
}
}
fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
match handle.handle {
HandleKind::Quantized(handle) => handle,
_ => panic!("Expected quantized handle, got {}", handle.handle.name()),
}
}
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
HandleKind::Float(tensor)
}
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
HandleKind::Int(tensor)
}
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
HandleKind::Bool(tensor)
}
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
HandleKind::Quantized(tensor)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn_backend::QTensorPrimitive;
#[test]
fn should_support_dtypes() {
type B = NdArray<f32>;
let device = Default::default();
assert!(B::supports_dtype(&device, DType::F64));
assert!(B::supports_dtype(&device, DType::F32));
assert!(B::supports_dtype(&device, DType::Flex32));
assert!(B::supports_dtype(&device, DType::I64));
assert!(B::supports_dtype(&device, DType::I32));
assert!(B::supports_dtype(&device, DType::I16));
assert!(B::supports_dtype(&device, DType::I8));
assert!(B::supports_dtype(&device, DType::U64));
assert!(B::supports_dtype(&device, DType::U32));
assert!(B::supports_dtype(&device, DType::U16));
assert!(B::supports_dtype(&device, DType::U8));
assert!(B::supports_dtype(&device, DType::Bool));
assert!(B::supports_dtype(
&device,
DType::QFloat(NdArrayQTensor::default_scheme())
));
assert!(!B::supports_dtype(&device, DType::F16));
assert!(!B::supports_dtype(&device, DType::BF16));
// QuantStore::U32 not supported
assert!(!B::supports_dtype(
&device,
DType::QFloat(QuantScheme::default())
));
}
}

View File

@@ -0,0 +1,207 @@
use burn_backend::Element;
use num_traits::Signed;
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float;
use num_traits::Pow;
use libm::{log1p, log1pf};
/// A float element for ndarray backend.
pub trait FloatNdArrayElement: NdArrayElement + Signed + core::cmp::PartialOrd<Self>
where
Self: Sized,
{
}
/// An int element for ndarray backend.
pub trait IntNdArrayElement: NdArrayElement + core::cmp::PartialOrd<Self> {}
/// A general element for ndarray backend.
pub trait NdArrayElement:
Element
+ ndarray::LinalgScalar
+ ndarray::ScalarOperand
+ ExpElement
+ AddAssignElement
+ num_traits::FromPrimitive
+ core::ops::AddAssign
+ core::cmp::PartialEq
+ core::ops::Rem<Output = Self>
{
}
/// A element for ndarray backend that supports exp ops.
pub trait ExpElement {
/// Exponent
fn exp_elem(self) -> Self;
/// Log
fn log_elem(self) -> Self;
/// Log1p
fn log1p_elem(self) -> Self;
/// Powf
fn powf_elem(self, value: f32) -> Self;
/// Powi
fn powi_elem(self, value: i32) -> Self;
/// Sqrt
fn sqrt_elem(self) -> Self;
/// Abs
fn abs_elem(self) -> Self;
}
/// The addition assignment operator implemented for ndarray elements.
pub trait AddAssignElement<Rhs = Self> {
/// Performs the addition assignment operation.
///
/// For `bool`, this corresponds to logical OR assignment.
fn add_assign(&mut self, rhs: Rhs);
}
impl<E: NdArrayElement> AddAssignElement for E {
fn add_assign(&mut self, rhs: Self) {
*self += rhs;
}
}
impl AddAssignElement for bool {
fn add_assign(&mut self, rhs: Self) {
*self = *self || rhs; // logical OR for bool
}
}
/// A quantized element for the ndarray backend.
pub trait QuantElement: NdArrayElement {}
impl QuantElement for i8 {}
impl FloatNdArrayElement for f64 {}
impl FloatNdArrayElement for f32 {}
impl IntNdArrayElement for i64 {}
impl IntNdArrayElement for i32 {}
impl IntNdArrayElement for i16 {}
impl IntNdArrayElement for i8 {}
impl IntNdArrayElement for u64 {}
impl IntNdArrayElement for u32 {}
impl IntNdArrayElement for u16 {}
impl IntNdArrayElement for u8 {}
macro_rules! make_float {
(
$ty:ty,
$log1p:expr
) => {
impl NdArrayElement for $ty {}
#[allow(clippy::cast_abs_to_unsigned)]
impl ExpElement for $ty {
#[inline(always)]
fn exp_elem(self) -> Self {
self.exp()
}
#[inline(always)]
fn log_elem(self) -> Self {
self.ln()
}
#[inline(always)]
fn log1p_elem(self) -> Self {
$log1p(self)
}
#[inline(always)]
fn powf_elem(self, value: f32) -> Self {
self.pow(value)
}
#[inline(always)]
fn powi_elem(self, value: i32) -> Self {
#[cfg(feature = "std")]
let val = self.powi(value);
#[cfg(not(feature = "std"))]
let val = Self::powf_elem(self, value as f32);
val
}
#[inline(always)]
fn sqrt_elem(self) -> Self {
self.sqrt()
}
#[inline(always)]
fn abs_elem(self) -> Self {
self.abs()
}
}
};
}
macro_rules! make_int {
(
$ty:ty,
$abs:expr
) => {
impl NdArrayElement for $ty {}
#[allow(clippy::cast_abs_to_unsigned)]
impl ExpElement for $ty {
#[inline(always)]
fn exp_elem(self) -> Self {
(self as f32).exp() as $ty
}
#[inline(always)]
fn log_elem(self) -> Self {
(self as f32).ln() as $ty
}
#[inline(always)]
fn log1p_elem(self) -> Self {
log1pf(self as f32) as $ty
}
#[inline(always)]
fn powf_elem(self, value: f32) -> Self {
(self as f32).pow(value) as $ty
}
#[inline(always)]
fn powi_elem(self, value: i32) -> Self {
#[cfg(feature = "std")]
let val = f32::powi(self as f32, value) as $ty;
#[cfg(not(feature = "std"))]
let val = Self::powf_elem(self, value as f32);
val
}
#[inline(always)]
fn sqrt_elem(self) -> Self {
(self as f32).sqrt() as $ty
}
#[inline(always)]
fn abs_elem(self) -> Self {
$abs(self)
}
}
};
}
make_float!(f64, log1p);
make_float!(f32, log1pf);
make_int!(i64, i64::wrapping_abs);
make_int!(i32, i32::wrapping_abs);
make_int!(i16, i16::wrapping_abs);
make_int!(i8, i8::wrapping_abs);
make_int!(u64, |x| x);
make_int!(u32, |x| x);
make_int!(u16, |x| x);
make_int!(u8, |x| x);

View File

@@ -0,0 +1,29 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
//! Burn ndarray backend.
#[cfg(any(
feature = "blas-netlib",
feature = "blas-openblas",
feature = "blas-openblas-system",
))]
extern crate blas_src;
mod backend;
mod element;
mod ops;
mod parallel;
mod rand;
mod sharing;
mod storage;
mod tensor;
pub use backend::*;
pub use element::*;
pub(crate) use sharing::*;
pub(crate) use storage::*;
pub use tensor::*;
extern crate alloc;

View File

@@ -0,0 +1,18 @@
use crate::{
NdArray, NdArrayTensor, SharedArray,
element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
execute_with_numeric_dtype,
ops::NdArrayMathOps,
};
use burn_backend::{ElementConversion, TensorMetadata, ops::ActivationOps, tensor::FloatTensor};
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ActivationOps<Self>
for NdArray<E, I, Q>
where
NdArrayTensor: From<SharedArray<E>>,
NdArrayTensor: From<SharedArray<I>>,
{
fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_numeric_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, 0.elem()))
}
}

View File

@@ -0,0 +1,103 @@
use crate::{
SharedArray, element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef,
};
use burn_backend::ElementConversion;
use ndarray::Array4;
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float;
pub(crate) fn adaptive_avg_pool2d<E: FloatNdArrayElement>(
x: SharedArray<E>,
output_size: [usize; 2],
) -> SharedArray<E> {
let [batch_size, channels, input_height, input_width] = x.shape().try_into().unwrap();
let mut output = Array4::from_elem(
(batch_size, channels, output_size[0], output_size[1]),
0.elem(),
);
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;
let output = unsafe_shared_out.get();
for h in 0..output_size[0] {
for w in 0..output_size[1] {
let ih_start = start_index(h, output_size[0], input_height);
let ih_end = end_index(h, output_size[0], input_height);
let iw_start = start_index(w, output_size[1], input_width);
let iw_end = end_index(w, output_size[1], input_width);
let mut sum_val: E = 0.elem();
for ih in ih_start..ih_end {
for iw in iw_start..iw_end {
sum_val += x[[b, c, ih, iw]];
}
}
let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem();
output[[b, c, h, w]] = sum_val / count.elem();
}
}
})
});
output.into_dyn().into_shared()
}
pub(crate) fn adaptive_avg_pool2d_backward<E: FloatNdArrayElement>(
x: SharedArray<E>,
grad: SharedArray<E>,
) -> SharedArray<E> {
let [_, _, input_height, input_width] = x.shape().try_into().unwrap();
let [batch_size, channels, output_height, output_width] = grad.shape().try_into().unwrap();
let mut output_grad =
Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem());
let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad);
run_par!(|| {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;
let output_grad = unsafe_shared_out.get();
for oh in 0..output_height {
for ow in 0..output_width {
let ih_start = start_index(oh, output_height, input_height);
let ih_end = end_index(oh, output_height, input_height);
let iw_start = start_index(ow, output_width, input_width);
let iw_end = end_index(ow, output_width, input_width);
let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem();
for ih in ih_start..ih_end {
for iw in iw_start..iw_end {
output_grad[[b, c, ih, iw]] += grad[[b, c, oh, ow]] / count.elem();
}
}
}
}
})
});
output_grad.into_dyn().into_shared()
}
fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
((output_size_index as f32 * input_size as f32) / output_size as f32).floor() as usize
}
fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
let index =
(((output_size_index + 1) as f32 * input_size as f32) / output_size as f32).ceil() as usize;
usize::min(index, input_size)
}

View File

@@ -0,0 +1,172 @@
use crate::{
SharedArray, element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef,
};
use burn_backend::ElementConversion;
use burn_backend::ops::conv::calculate_pool_output_size;
use ndarray::Array4;
pub(crate) fn avg_pool2d<E: FloatNdArrayElement>(
x: SharedArray<E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
ceil_mode: bool,
) -> SharedArray<E> {
let [kernel_height, kernel_width] = kernel_size;
let [padding_height, padding_width] = padding;
let [stride_height, stride_width] = stride;
let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap();
let out_height = calculate_pool_output_size(
kernel_height,
stride_height,
padding_height,
1,
x_height,
ceil_mode,
);
let out_width = calculate_pool_output_size(
kernel_width,
stride_width,
padding_width,
1,
x_width,
ceil_mode,
);
// Padded input bounds (for count_include_pad calculation)
let padded_height = x_height + 2 * padding_height;
let padded_width = x_width + 2 * padding_width;
let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), 0.elem());
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;
let output = unsafe_shared_out.get();
for oh in 0..out_height {
for ow in 0..out_width {
let mut sum_val: E = 0.elem();
let mut valid_count = 0usize;
let mut padded_count = 0usize;
for kh in 0..kernel_height {
let ih = oh * stride_height + kh;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw;
// Check if within padded bounds (excludes ceil_mode extensions)
if ih < padded_height && iw < padded_width {
padded_count += 1;
// Check if within valid (non-padding) input bounds
if ih >= padding_height
&& ih < x_height + padding_height
&& iw >= padding_width
&& iw < x_width + padding_width
{
let ih_valid = ih - padding_height;
let iw_valid = iw - padding_width;
sum_val += x[[b, c, ih_valid, iw_valid]];
valid_count += 1;
}
}
}
}
// count_include_pad: count positions within padded bounds (not ceil_mode extensions)
// !count_include_pad: count only valid (non-padding) positions
let count: E = if count_include_pad {
(padded_count as i32).elem()
} else {
(valid_count as i32).elem()
};
output[[b, c, oh, ow]] = sum_val / count;
}
}
})
});
output.into_dyn().into_shared()
}
pub(crate) fn avg_pool2d_backward<E: FloatNdArrayElement>(
x: SharedArray<E>,
grad: SharedArray<E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
_ceil_mode: bool,
) -> SharedArray<E> {
let [kernel_height, kernel_width] = kernel_size;
let [stride_height, stride_width] = stride;
let [padding_height, padding_width] = padding;
let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap();
let [_batch_size, _channels, out_height, out_width] = grad.shape().try_into().unwrap();
// Padded input bounds (for count_include_pad calculation)
let padded_height = x_height + 2 * padding_height;
let padded_width = x_width + 2 * padding_width;
let mut output_grad = Array4::from_elem((batch_size, channels, x_height, x_width), 0.elem());
let unsafe_shared_grad = UnsafeSharedRef::new(&mut output_grad);
run_par!(|| {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;
let output_grad = unsafe_shared_grad.get();
for oh in 0..out_height {
for ow in 0..out_width {
let ih_start_kernel = oh * stride_height;
let iw_start_kernel = ow * stride_width;
let ih_end_kernel = ih_start_kernel + kernel_height;
let iw_end_kernel = iw_start_kernel + kernel_width;
// Clip to valid input bounds (for gradient distribution)
let ih_start = usize::max(ih_start_kernel, padding_height);
let iw_start = usize::max(iw_start_kernel, padding_width);
let ih_end = usize::min(ih_end_kernel, x_height + padding_height);
let iw_end = usize::min(iw_end_kernel, x_width + padding_width);
// Calculate count based on count_include_pad
let count = if count_include_pad {
// Count positions within padded bounds (not ceil_mode extensions)
let ih_start_padded = ih_start_kernel;
let iw_start_padded = iw_start_kernel;
let ih_end_padded = usize::min(ih_end_kernel, padded_height);
let iw_end_padded = usize::min(iw_end_kernel, padded_width);
(ih_end_padded - ih_start_padded) * (iw_end_padded - iw_start_padded)
} else {
// Count only valid (non-padding) positions
(ih_end - ih_start) * (iw_end - iw_start)
};
for ih in ih_start..ih_end {
for iw in iw_start..iw_end {
let ih = ih - padding_height;
let iw = iw - padding_width;
output_grad[[b, c, ih, iw]] +=
grad[[b, c, oh, ow]] / (count as i32).elem();
}
}
}
}
})
});
output_grad.into_dyn().into_shared()
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,220 @@
// Language
use alloc::vec;
use alloc::vec::Vec;
use burn_backend::Scalar;
use burn_backend::{ElementConversion, TensorMetadata, tensor::FloatTensor};
use burn_backend::{
backend::ExecutionError,
ops::BoolTensorOps,
tensor::{BoolTensor, IntTensor},
};
use ndarray::IntoDimension;
// Current crate
use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
use crate::{NdArray, execute_with_int_dtype, tensor::NdArrayTensor};
use crate::{NdArrayDevice, SharedArray, slice};
// Workspace crates
use burn_backend::{Shape, TensorData, backend::Backend};
use super::{NdArrayBoolOps, NdArrayOps};
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolTensorOps<Self>
for NdArray<E, I, Q>
where
NdArrayTensor: From<SharedArray<E>>,
NdArrayTensor: From<SharedArray<I>>,
{
fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
if !data.dtype.is_bool() {
unimplemented!("Unsupported dtype for `bool_from_data`")
}
NdArrayTensor::from_data(data)
}
async fn bool_into_data(tensor: NdArrayTensor) -> Result<TensorData, ExecutionError> {
Ok(tensor.into_data())
}
fn bool_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
tensor
}
fn bool_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
NdArrayOps::reshape(tensor.bool(), shape).into()
}
fn bool_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor {
slice!(tensor, slices)
}
fn bool_into_int(tensor: NdArrayTensor) -> NdArrayTensor {
// Use mapv directly instead of collecting to Vec and going through TensorData
let int_array: SharedArray<I> = tensor.bool().mapv(|b| b.elem()).into_shared();
int_array.into()
}
fn bool_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {
NdArrayDevice::Cpu
}
fn bool_empty(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {
Self::bool_zeros(shape, _device)
}
fn bool_zeros(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {
let values = vec![false; shape.num_elements()];
NdArrayTensor::from_data(TensorData::new(values, shape))
}
fn bool_ones(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {
let values = vec![true; shape.num_elements()];
NdArrayTensor::from_data(TensorData::new(values, shape))
}
fn bool_slice_assign(
tensor: NdArrayTensor,
slices: &[burn_backend::Slice],
value: NdArrayTensor,
) -> NdArrayTensor {
NdArrayOps::slice_assign(tensor.bool(), slices, value.bool()).into()
}
fn bool_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
NdArrayOps::cat(tensors.into_iter().map(|it| it.bool()).collect(), dim).into()
}
fn bool_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
NdArrayBoolOps::equal(lhs.bool(), rhs.bool()).into()
}
fn bool_not(tensor: NdArrayTensor) -> NdArrayTensor {
tensor.bool().mapv(|a| !a).into_shared().into()
}
fn bool_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
NdArrayBoolOps::and(lhs.bool(), rhs.bool()).into()
}
fn bool_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
NdArrayBoolOps::or(lhs.bool(), rhs.bool()).into()
}
fn bool_into_float(tensor: NdArrayTensor) -> FloatTensor<Self> {
let arr: SharedArray<E> = tensor.bool().mapv(|b| b.elem()).into_shared();
arr.into()
}
fn bool_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
NdArrayOps::swap_dims(tensor.bool(), dim1, dim2).into()
}
fn bool_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
tensor.bool().permuted_axes(axes.into_dimension()).into()
}
fn bool_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
NdArrayOps::expand(tensor.bool(), shape).into()
}
fn bool_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
let tensor_bool = tensor.bool();
let indices_vec: Vec<usize> = indices
.into_iter()
.map(|i| i.elem::<i64>() as usize)
.collect();
let selected = tensor_bool.select(ndarray::Axis(dim), &indices_vec);
selected.into_shared().into()
})
}
fn bool_select_or(
tensor: NdArrayTensor,
dim: usize,
indices: NdArrayTensor,
value: NdArrayTensor,
) -> NdArrayTensor {
execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
let mut output_array = tensor.bool().into_owned();
let value_bool = value.bool();
for (index_value, index) in indices.into_iter().enumerate() {
let index_usize = index.elem::<i64>() as usize;
let mut view = output_array.index_axis_mut(ndarray::Axis(dim), index_usize);
let value_slice = value_bool.index_axis(ndarray::Axis(dim), index_value);
// For boolean tensors, select_assign should use logical OR operation
view.zip_mut_with(&value_slice, |a, b| *a = *a || *b);
}
output_array.into_shared().into()
})
}
fn bool_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
NdArrayOps::flip(tensor.bool(), axes).into()
}
fn bool_unfold(tensor: NdArrayTensor, dim: usize, size: usize, step: usize) -> NdArrayTensor {
NdArrayOps::unfold(tensor.bool(), dim, size, step).into()
}
fn bool_mask_where(
tensor: BoolTensor<Self>,
mask: BoolTensor<Self>,
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
NdArrayOps::mask_where(tensor.bool(), mask.bool(), value.bool()).into()
}
fn bool_mask_fill(
tensor: BoolTensor<Self>,
mask: BoolTensor<Self>,
value: Scalar,
) -> BoolTensor<Self> {
NdArrayOps::mask_fill(tensor.bool(), mask.bool(), value.elem()).into()
}
fn bool_gather(
dim: usize,
tensor: BoolTensor<Self>,
indices: IntTensor<Self>,
) -> BoolTensor<Self> {
execute_with_int_dtype!(indices, |indices| NdArrayOps::gather(
dim,
tensor.bool(),
indices
))
}
fn bool_scatter_or(
dim: usize,
tensor: BoolTensor<Self>,
indices: IntTensor<Self>,
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
execute_with_int_dtype!(indices, |indices| NdArrayOps::scatter(
dim,
tensor.bool(),
indices,
value.bool()
))
}
fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
NdArrayBoolOps::equal_elem(lhs.bool(), rhs.elem()).into()
}
fn bool_any(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
// Use view() for zero-copy on borrowed storage with short-circuit evaluation
let result = NdArrayBoolOps::any_view(tensor.bool().view());
NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1])))
}
fn bool_all(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
// Use view() for zero-copy on borrowed storage with short-circuit evaluation
let result = NdArrayBoolOps::all_view(tensor.bool().view());
NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1])))
}
}

View File

@@ -0,0 +1,574 @@
use burn_backend::{
ElementConversion,
ops::{
ConvOptions, ConvTransposeOptions,
conv::{calculate_conv_output_size, calculate_conv_transpose_output_size},
},
};
use ndarray::{
Array3, Array4, Array5, ArrayView2, ArrayView3, ArrayViewMut2, ArrayViewMut3, Axis, Dim, s,
};
use crate::{
NdArrayElement, SharedArray, iter_par, iter_range_par,
ops::padding::{apply_padding_4d, apply_padding_5d},
run_par,
sharing::UnsafeSharedRef,
tensor::NdArrayTensor,
};
#[inline(always)]
fn conv2d_mad_inner<E: NdArrayElement>(
mut output: ArrayViewMut2<E>,
x: ArrayView2<E>,
k: E,
k_xy: (usize, usize),
out_xy: (usize, usize),
stride: (usize, usize),
dilation: (usize, usize),
) {
let (kh, kw) = k_xy;
let (out_width, out_height) = out_xy;
let (stride_width, stride_height) = stride;
let (dilation_width, dilation_height) = dilation;
for oh in 0..out_height {
// Construct a sub-slice view of the input row.
// This is done upfront so that rustc does not have to emit bounds checks
// in the hot loop below.
let ir = x
.row(oh * stride_height + kh * dilation_height)
.to_slice()
.unwrap();
// Ditto. Construct a sub-slice view of the output row, and explicitly specify
// the bounds upfront as 0..out_width so that rustc can make the assumption
// that all accesses are in-bounds in the below loop.
let mut or = output.row_mut(oh);
let or = &mut or.as_slice_mut().unwrap()[0..out_width];
#[allow(clippy::needless_range_loop)]
for ow in 0..out_width {
let iw = ow * stride_width + kw * dilation_width;
or[ow] += ir[iw] * k;
}
}
}
#[inline(always)]
fn conv3d_mad_inner<E: NdArrayElement>(
mut output: ArrayViewMut3<E>,
x: ArrayView3<E>,
k: E,
k_xyz: (usize, usize, usize),
out_xyz: (usize, usize, usize),
stride: (usize, usize, usize),
dilation: (usize, usize, usize),
) {
let (kd, kh, kw) = k_xyz;
let (out_width, out_height, out_depth) = out_xyz;
let (stride_width, stride_height, stride_depth) = stride;
let (dilation_width, dilation_height, dilation_depth) = dilation;
for od in 0..out_depth {
let id = od * stride_depth + kd * dilation_depth;
for oh in 0..out_height {
let ih = oh * stride_height + kh * dilation_height;
// Construct a sub-slice view of the input row.
// This is done upfront so that rustc does not have to emit bounds checks
// in the hot loop below.
let ir = x.slice(s![id, ih, ..]).to_slice().unwrap();
// Ditto. Construct a sub-slice view of the output row, and explicitly specify
// the bounds upfront as 0..out_width so that rustc can make the assumption
// that all accesses are in-bounds in the below loop.
let or = &mut output
.slice_mut(s![od, oh, 0..out_width])
.into_slice()
.unwrap()[0..out_width];
#[allow(clippy::needless_range_loop)]
for ow in 0..out_width {
let iw = ow * stride_width + kw * dilation_width;
or[ow] += ir[iw] * k;
}
}
}
}
pub(crate) fn conv2d<E: NdArrayElement>(
x: SharedArray<E>,
weight: SharedArray<E>,
bias: Option<SharedArray<E>>,
options: ConvOptions<2>,
) -> SharedArray<E>
where
NdArrayTensor: From<SharedArray<E>>,
{
let [dilation_height, dilation_width] = options.dilation;
let [padding_height, padding_width] = options.padding;
let [stride_height, stride_width] = options.stride;
let [batch_size, _in_channels, in_height, in_width] = x.shape().try_into().unwrap();
let [out_channels, in_channels, kernel_height, kernel_width] =
weight.shape().try_into().unwrap();
let channels_per_group = out_channels / options.groups;
let out_height = calculate_conv_output_size(
kernel_height,
stride_height,
padding_height,
dilation_height,
in_height,
);
let out_width = calculate_conv_output_size(
kernel_width,
stride_width,
padding_width,
dilation_width,
in_width,
);
let x = apply_padding_4d::<E>(x, options.padding, 0i32.elem());
// Convert inputs from dynamic indexes to static to improve perf.
let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();
let weights = weight.into_dimensionality::<ndarray::Ix4>().unwrap();
let mut output = Array3::zeros(Dim([batch_size * out_channels, out_height, out_width]));
run_par!(|| {
iter_par!(output.axis_iter_mut(Axis(0)))
.enumerate()
.for_each(
#[inline(never)]
|(k, mut output)| {
let b = k / out_channels;
let oc = k % out_channels;
let g = oc / channels_per_group;
for ic in (in_channels * g)..(in_channels * (g + 1)) {
let weight_ic = ic - (g * in_channels);
let x = x.slice(s![b, ic, .., ..]);
let k = weights.slice(s![oc, weight_ic, .., ..]);
for kh in 0..kernel_height {
for kw in 0..kernel_width {
let k = k[[kh, kw]];
// NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization
// in the case that the stride/dilation is 1.
#[allow(clippy::if_same_then_else)]
if (1, 1, 1, 1)
== (
stride_width,
stride_height,
dilation_width,
dilation_height,
)
{
conv2d_mad_inner(
output.view_mut(),
x.view(),
k,
(kh, kw),
(out_width, out_height),
(stride_width, stride_height),
(dilation_width, dilation_height),
);
} else {
conv2d_mad_inner(
output.view_mut(),
x.view(),
k,
(kh, kw),
(out_width, out_height),
(stride_width, stride_height),
(dilation_width, dilation_height),
);
}
}
}
}
if let Some(bias) = &bias {
let bias = bias[oc];
for oh in 0..out_height {
// Get a mutable slice reference to the row we're looping over.
// We explicitly define the bounds to 0..out_width so that rustc can make
// the assumption that all accesses are in-bounds.
let mut or = output.row_mut(oh);
let or = &mut or.as_slice_mut().unwrap()[0..out_width];
#[allow(clippy::needless_range_loop)]
for ow in 0..out_width {
or[ow] += bias;
}
}
}
},
);
});
output
.to_shape([batch_size, out_channels, out_height, out_width])
.unwrap()
.into_dyn()
.into_shared()
}
pub(crate) fn conv_transpose2d<E: NdArrayElement>(
x: SharedArray<E>,
weight: SharedArray<E>,
bias: Option<SharedArray<E>>,
options: ConvTransposeOptions<2>,
) -> SharedArray<E> {
let [dilation_height, dilation_width] = options.dilation;
let [padding_height, padding_width] = options.padding;
let [stride_height, stride_width] = options.stride;
let [out_padding_height, out_padding_width] = options.padding_out;
let [batch_size, _in_channels, in_height, in_width] = x.shape().try_into().unwrap();
let [in_channels, out_channels, kernel_height, kernel_width] =
weight.shape().try_into().unwrap();
let out_height = calculate_conv_transpose_output_size(
kernel_height,
stride_height,
padding_height,
out_padding_height,
dilation_height,
in_height,
);
let out_width = calculate_conv_transpose_output_size(
kernel_width,
stride_width,
padding_width,
out_padding_width,
dilation_width,
in_width,
);
let x = x;
let mut output = Array4::zeros(Dim([
batch_size,
out_channels * options.groups,
out_height,
out_width,
]));
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| {
iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe {
let b = k / (out_channels * options.groups);
let oc = k % out_channels;
let g = (k / out_channels) % options.groups;
let output = unsafe_shared_out.get();
let oc_out = oc + (out_channels * g);
let ic_start = g * (in_channels / options.groups);
let ic_end = ic_start + in_channels / options.groups;
for ic in ic_start..ic_end {
for ih in 0..in_height {
for iw in 0..in_width {
for kh in 0..kernel_height {
for kw in 0..kernel_width {
let oh = ih * stride_height + kh * dilation_height;
let ow = iw * stride_width + kw * dilation_width;
if oh >= out_height + padding_height
|| ow >= out_width + padding_width
|| oh < padding_height
|| ow < padding_width
{
continue;
}
let oh = oh - padding_height;
let ow = ow - padding_width;
output[[b, oc_out, oh, ow]] +=
x[[b, ic, ih, iw]] * weight[[ic, oc, kh, kw]];
}
}
}
}
}
if let Some(bias) = &bias {
for oh in 0..out_height {
for ow in 0..out_width {
output[[b, oc_out, oh, ow]] += bias[oc_out];
}
}
}
});
});
output.into_dyn().into_shared()
}
pub(crate) fn conv3d<E: NdArrayElement>(
x: SharedArray<E>,
weight: SharedArray<E>,
bias: Option<SharedArray<E>>,
options: ConvOptions<3>,
) -> SharedArray<E>
where
NdArrayTensor: From<SharedArray<E>>,
{
let [dilation_depth, dilation_height, dilation_width] = options.dilation;
let [padding_depth, padding_height, padding_width] = options.padding;
let [stride_depth, stride_height, stride_width] = options.stride;
let [batch_size, _in_channels, in_depth, in_height, in_width] = x.shape().try_into().unwrap();
let [
out_channels,
in_channels,
kernel_depth,
kernel_height,
kernel_width,
] = weight.shape().try_into().unwrap();
let out_c_per_group = out_channels / options.groups;
let out_depth = calculate_conv_output_size(
kernel_depth,
stride_depth,
padding_depth,
dilation_depth,
in_depth,
);
let out_height = calculate_conv_output_size(
kernel_height,
stride_height,
padding_height,
dilation_height,
in_height,
);
let out_width = calculate_conv_output_size(
kernel_width,
stride_width,
padding_width,
dilation_width,
in_width,
);
let x = apply_padding_5d::<E>(x, options.padding, 0i32.elem());
// Convert inputs from dynamic indexes to static to improve perf.
let x = x.into_dimensionality::<ndarray::Ix5>().unwrap();
let weights = weight.into_dimensionality::<ndarray::Ix5>().unwrap();
let mut output = Array4::zeros(Dim([
batch_size * out_channels,
out_depth,
out_height,
out_width,
]));
run_par!(|| {
iter_par!(output.axis_iter_mut(Axis(0)))
.enumerate()
.for_each(
#[inline(never)]
|(k, mut output)| {
let b = k / out_channels;
let oc = k % out_channels;
let g = oc / out_c_per_group;
for ic in (in_channels * g)..(in_channels * (g + 1)) {
let weight_ic = ic - (g * in_channels);
let x = x.slice(s![b, ic, .., .., ..]);
let k = weights.slice(s![oc, weight_ic, .., .., ..]);
for kd in 0..kernel_depth {
for kh in 0..kernel_height {
for kw in 0..kernel_width {
let k = k[[kd, kh, kw]];
// NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization
// in the case that the stride/dilation is 1.
#[allow(clippy::if_same_then_else)]
if (1, 1, 1, 1, 1, 1)
== (
stride_width,
stride_height,
stride_depth,
dilation_width,
dilation_height,
dilation_depth,
)
{
conv3d_mad_inner(
output.view_mut(),
x.view(),
k,
(kd, kh, kw),
(out_width, out_height, out_depth),
(stride_width, stride_height, stride_depth),
(dilation_width, dilation_height, dilation_depth),
);
} else {
conv3d_mad_inner(
output.view_mut(),
x.view(),
k,
(kd, kh, kw),
(out_width, out_height, out_depth),
(stride_width, stride_height, stride_depth),
(dilation_width, dilation_height, dilation_depth),
);
}
}
}
}
}
if let Some(bias) = &bias {
let bias = bias[oc];
// Get a mutable iterator to the row we're looping over.
let orows = output.rows_mut();
for mut or in orows {
// We explicitly define the bounds to 0..out_width so that rustc can make
// the assumption that all accesses are in-bounds.
let or = &mut or.as_slice_mut().unwrap()[0..out_width];
#[allow(clippy::needless_range_loop)]
for ow in 0..out_width {
or[ow] += bias;
}
}
}
},
);
});
output
.to_shape([batch_size, out_channels, out_depth, out_height, out_width])
.unwrap()
.into_dyn()
.into_shared()
}
pub(crate) fn conv_transpose3d<E: NdArrayElement>(
x: SharedArray<E>,
weight: SharedArray<E>,
bias: Option<SharedArray<E>>,
options: ConvTransposeOptions<3>,
) -> SharedArray<E> {
let [dilation_depth, dilation_height, dilation_width] = options.dilation;
let [padding_depth, padding_height, padding_width] = options.padding;
let [stride_depth, stride_height, stride_width] = options.stride;
let [out_padding_depth, out_padding_height, out_padding_width] = options.padding_out;
let [batch_size, _in_channels, in_depth, in_height, in_width] = x.shape().try_into().unwrap();
let [
in_channels,
out_channels,
kernel_depth,
kernel_height,
kernel_width,
] = weight.shape().try_into().unwrap();
let out_depth = calculate_conv_transpose_output_size(
kernel_depth,
stride_depth,
padding_depth,
out_padding_depth,
dilation_depth,
in_depth,
);
let out_height = calculate_conv_transpose_output_size(
kernel_height,
stride_height,
padding_height,
out_padding_height,
dilation_height,
in_height,
);
let out_width = calculate_conv_transpose_output_size(
kernel_width,
stride_width,
padding_width,
out_padding_width,
dilation_width,
in_width,
);
let x = x;
let mut output = Array5::zeros(Dim([
batch_size,
out_channels * options.groups,
out_depth,
out_height,
out_width,
]));
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| {
iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe {
let b = k / (out_channels * options.groups);
let oc = k % out_channels;
let g = (k / out_channels) % options.groups;
let output = unsafe_shared_out.get();
let oc_out = oc + (out_channels * g);
let ic_start = g * (in_channels / options.groups);
let ic_end = ic_start + in_channels / options.groups;
for ic in ic_start..ic_end {
for id in 0..in_depth {
for ih in 0..in_height {
for iw in 0..in_width {
for kd in 0..kernel_depth {
for kh in 0..kernel_height {
for kw in 0..kernel_width {
let od = id * stride_depth + kd * dilation_depth;
let oh = ih * stride_height + kh * dilation_height;
let ow = iw * stride_width + kw * dilation_width;
if od >= out_depth + padding_depth
|| oh >= out_height + padding_height
|| ow >= out_width + padding_width
|| od < padding_depth
|| oh < padding_height
|| ow < padding_width
{
continue;
}
let od = od - padding_depth;
let oh = oh - padding_height;
let ow = ow - padding_width;
output[[b, oc_out, od, oh, ow]] +=
x[[b, ic, id, ih, iw]] * weight[[ic, oc, kd, kh, kw]];
}
}
}
}
}
}
}
if let Some(bias) = &bias {
for od in 0..out_depth {
for oh in 0..out_height {
for ow in 0..out_width {
output[[b, oc_out, od, oh, ow]] += bias[oc_out];
}
}
}
}
});
});
output.into_dyn().into_shared()
}

View File

@@ -0,0 +1,662 @@
use burn_backend::ops::{DeformConvOptions, conv::calculate_conv_output_size};
use core::ops::AddAssign;
use ndarray::{
Array2, Array4, ArrayView2, ArrayView3, ArrayView4, ArrayView6, ArrayViewMut2, Axis, Dim, Ix4,
Zip, s,
};
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float;
use crate::{FloatNdArrayElement, NdArrayTensor, ShapeOps, SharedArray, iter_par, run_par};
use super::matmul::matmul;
#[inline(always)]
#[allow(clippy::too_many_arguments)]
fn deform_im2col_kernel<F: FloatNdArrayElement>(
out_y: usize,
out_x: usize,
input: ArrayView2<F>,
offset: ArrayView3<F>,
mask: Option<ArrayView2<F>>,
mut columns: ArrayViewMut2<F>,
args: DeformConvOptions<2>,
(kernel_h, kernel_w): (usize, usize),
) {
// position shape: [in_channels, batch_size, out_h, out_w]
// columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]
let (height, width) = input.dim();
for kernel_y in 0..kernel_h {
for kernel_x in 0..kernel_w {
let mask_value = mask
.map(|it| it[[kernel_y, kernel_x]])
.unwrap_or_else(|| F::from_elem(1.0));
let offset = offset.slice(s![kernel_y, kernel_x, ..]);
let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])
- F::from_elem(args.padding[0])
+ offset[0];
let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])
- F::from_elem(args.padding[1])
+ offset[1];
let interpolated = bilinear_interpolate(input, height, width, y, x);
columns[[kernel_y, kernel_x]] = mask_value * interpolated;
}
}
}
fn bilinear_interpolate<F: FloatNdArrayElement>(
input: ArrayView2<F>,
height: usize,
width: usize,
y: F,
x: F,
) -> F {
// To simplify code
let y = y.to_f32();
let x = x.to_f32();
let mut result = F::from_elem(0.0);
if y > -1.0 && height as f32 > y && x > -1.0 && width as f32 > x {
let y_low = f32::floor(y);
let x_low = f32::floor(x);
let y_high = (y_low + 1.) as usize;
let x_high = (x_low + 1.) as usize;
let zero = F::from_elem(0.0);
let v1: F = if y_low >= 0. && x_low >= 0. {
input[[y_low as usize, x_low as usize]]
} else {
zero
};
let v2: F = if y_low >= 0. && x_high < width {
input[[y_low as usize, x_high]]
} else {
zero
};
let v3: F = if y_high < height && x_low >= 0. {
input[[y_high, x_low as usize]]
} else {
zero
};
let v4: F = if y_high < height && x_high < width {
input[[y_high, x_high]]
} else {
zero
};
let l_y = y - y_low;
let l_x = x - x_low;
let h_y = 1.0 - l_y;
let h_x = 1.0 - l_x;
let w1 = F::from_elem(h_y * h_x);
let w2 = F::from_elem(h_y * l_x);
let w3 = F::from_elem(l_y * h_x);
let w4 = F::from_elem(l_y * l_x);
result = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
}
result
}
pub(crate) fn deform_conv2d<F: FloatNdArrayElement>(
input: SharedArray<F>,
offset: SharedArray<F>,
weight: SharedArray<F>,
mask: Option<SharedArray<F>>,
bias: Option<SharedArray<F>>,
args: DeformConvOptions<2>,
) -> SharedArray<F>
where
NdArrayTensor: From<SharedArray<F>>,
{
let [batch_size, _, in_height, in_width] = input.shape().dims();
let [out_channels, _, kernel_h, kernel_w] = weight.shape().dims();
let groups = args.weight_groups;
let weight = weight.as_standard_layout();
let out_h = calculate_conv_output_size(
kernel_h,
args.stride[0],
args.padding[0],
args.dilation[0],
in_height,
);
let out_w = calculate_conv_output_size(
kernel_w,
args.stride[1],
args.padding[1],
args.dilation[1],
in_width,
);
let out_dims = (out_h, out_w);
let input = input.into_dimensionality::<Ix4>().unwrap();
let offset = offset.into_dimensionality::<Ix4>().unwrap();
let mask = mask.as_ref().map(|it| {
it.to_shape((
batch_size,
args.offset_groups,
kernel_h,
kernel_w,
out_h,
out_w,
))
.unwrap()
});
let columns = deform_im2col(
input.view(),
offset.view(),
mask.as_ref().map(|it| it.view()),
args,
out_dims,
(kernel_h, kernel_w),
);
let (col_size_0, col_size_1) = columns.dim();
let col_size_0 = col_size_0 / groups;
let out_c_per_group = out_channels / groups;
let weight = weight
.to_shape((groups, out_c_per_group, col_size_0))
.unwrap();
let columns = columns.to_shape((groups, col_size_0, col_size_1)).unwrap();
let out = matmul(
weight.to_owned().into_dyn().into_shared(),
columns.to_owned().into_dyn().into_shared(),
);
let mut out = out
.into_shape_with_order((out_channels, batch_size, out_h, out_w))
.unwrap();
out.swap_axes(0, 1);
if let Some(bias) = bias {
let bias = bias.to_shape((1, out_channels, 1, 1)).unwrap();
out.add_assign(&bias);
}
out.into_dyn().into_shared()
}
pub(crate) fn deform_im2col<F: FloatNdArrayElement>(
input: ArrayView4<F>,
offset: ArrayView4<F>,
mask: Option<ArrayView6<F>>,
args: DeformConvOptions<2>,
out_dims: (usize, usize),
kernel_dims: (usize, usize),
) -> Array2<F> {
let (batch_size, in_channels, _, _) = input.dim();
let (kernel_h, kernel_w) = kernel_dims;
let (out_h, out_w) = out_dims;
let channels_per_offset_group = in_channels / args.offset_groups;
let mut columns = Array4::zeros(Dim([
in_channels,
kernel_h,
kernel_w,
batch_size * out_h * out_w,
]));
let groups = args.offset_groups;
run_par!(|| {
iter_par!(columns.axis_iter_mut(Axis(3)))
.enumerate()
.for_each(|(index, mut columns)| {
let out_x = index % out_w;
let out_y = (index / out_w) % out_h;
let batch = (index / (out_w * out_h)) % batch_size;
let offset = offset.slice(s![batch, .., out_y, out_x]);
let offset = offset.to_shape((groups, kernel_h, kernel_w, 2)).unwrap();
let mask = mask
.as_ref()
.map(|it| it.slice(s![batch, .., .., .., out_y, out_x]));
columns
.axis_iter_mut(Axis(0))
.enumerate()
.for_each(|(in_channel, mut columns)| {
let group_index = in_channel / channels_per_offset_group;
deform_im2col_kernel(
out_y,
out_x,
input.slice(s![batch, in_channel, .., ..]),
offset.slice(s![group_index, .., .., ..]),
mask.as_ref().map(|it| it.slice(s![group_index, .., ..])),
columns.view_mut(),
args.clone(),
kernel_dims,
);
});
});
});
columns
// Columns is created here, so we know it's contiguous
.into_shape_with_order((
in_channels * kernel_h * kernel_w,
batch_size * out_h * out_w,
))
.unwrap()
}
pub mod backward {
#[cfg(target_has_atomic = "32")]
use core::sync::atomic::Ordering;
use atomic_float::AtomicF32;
use ndarray::{Array1, Array5, ArrayView4, ArrayView6, Ix4};
use super::*;
pub(crate) type DeformConv2dBackward<F> = (
SharedArray<F>,
SharedArray<F>,
SharedArray<F>,
Option<SharedArray<F>>,
Option<SharedArray<F>>,
);
/// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions.
pub(crate) fn deform_conv2d_backward<F: FloatNdArrayElement>(
input: SharedArray<F>,
offset: SharedArray<F>,
weight: SharedArray<F>,
mask: Option<SharedArray<F>>,
bias: Option<SharedArray<F>>,
out_grad: SharedArray<F>,
args: DeformConvOptions<2>,
) -> DeformConv2dBackward<F> {
let [batch_size, out_channels, out_h, out_w] = out_grad.shape().dims();
let [_, _, kernel_h, kernel_w] = weight.shape().dims();
let groups = args.weight_groups;
let out_c_per_group = out_channels / groups;
let col_shape_1 = batch_size * out_h * out_w;
let mut out_grad = out_grad.into_dimensionality::<Ix4>().unwrap();
let gradient_bias = bias.map(|_| {
let out_grad = out_grad
.clone()
.sum_axis(Axis(0))
.sum_axis(Axis(1))
.sum_axis(Axis(1));
out_grad.into_dyn().into_shared()
});
out_grad.swap_axes(0, 1);
let out_grad = out_grad
.to_shape((groups, out_c_per_group, col_shape_1))
.unwrap();
let input = input.into_dimensionality::<Ix4>().unwrap();
let offset = offset.into_dimensionality::<Ix4>().unwrap();
let mask = mask.map(|it| {
it.into_shape_with_order((
batch_size,
args.offset_groups,
kernel_h,
kernel_w,
out_h,
out_w,
))
.unwrap()
});
let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs(
input.view(),
weight,
offset.view(),
mask.as_ref().map(|it| it.view()),
out_grad.view(),
&args,
(kernel_h, kernel_w),
);
let weight_grad = compute_weight_grad(
input.view(),
offset.view(),
mask.as_ref().map(|it| it.view()),
out_grad.view(),
args,
(kernel_h, kernel_w),
(out_h, out_w),
);
(
input_gradient,
offset_gradient,
weight_grad,
mask_gradient,
gradient_bias,
)
}
fn compute_weight_grad<F: FloatNdArrayElement>(
input: ArrayView4<F>,
offset: ArrayView4<F>,
mask: Option<ArrayView6<F>>,
out_grad: ArrayView3<F>,
options: DeformConvOptions<2>,
kernel_dims: (usize, usize),
out_dims: (usize, usize),
) -> SharedArray<F> {
let in_channels = input.dim().1;
let (groups, out_c_per_group, _) = out_grad.dim();
let (kernel_h, kernel_w) = kernel_dims;
let in_c_per_group = in_channels / groups;
let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims);
let (col_size_0, col_size_1) = columns.dim();
let col_size_0 = col_size_0 / groups;
let mut columns = columns.to_shape((groups, col_size_0, col_size_1)).unwrap();
columns.swap_axes(1, 2);
let grad_weight = matmul(
out_grad.to_owned().into_dyn().into_shared(),
columns.to_owned().into_dyn().into_shared(),
);
let grad_weight = grad_weight
.into_shape_with_order((out_c_per_group * groups, in_c_per_group, kernel_h, kernel_w))
.unwrap();
grad_weight.into_dyn().into_shared()
}
type InputGradients<F> = (SharedArray<F>, SharedArray<F>, Option<SharedArray<F>>);
fn backward_gradient_inputs<F: FloatNdArrayElement>(
image: ArrayView4<F>,
weight: SharedArray<F>,
offset: ArrayView4<F>,
mask: Option<ArrayView6<F>>,
out_grad: ArrayView3<F>,
args: &DeformConvOptions<2>,
kernel_dims: (usize, usize),
) -> InputGradients<F> {
let input_shape = image.dim();
let in_channels = input_shape.1;
let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.shape().dims();
let (batch_size, _, out_h, out_w) = offset.dim();
let groups = args.weight_groups;
let out_c_per_group = out_channels / groups;
let col_shape_0 = in_c_per_group * kernel_h * kernel_w;
let mut weight = weight
.to_shape((groups, out_c_per_group, col_shape_0))
.unwrap();
weight.swap_axes(1, 2);
let columns = matmul(
weight.to_owned().into_dyn().into_shared(),
out_grad.to_owned().into_dyn().into_shared(),
);
let columns = columns
.to_shape((in_channels, kernel_h, kernel_w, batch_size, out_h, out_w))
.unwrap();
let (offset_gradient, mask_gradient) = compute_offset_and_mask_gradient(
columns.view(),
image.view(),
offset,
mask,
args,
kernel_dims,
);
let input_gradient =
compute_input_grad(columns.view(), offset, mask, args, kernel_dims, input_shape);
(input_gradient, offset_gradient, mask_gradient)
}
fn compute_offset_and_mask_gradient<F: FloatNdArrayElement>(
columns: ArrayView6<F>,
image: ArrayView4<F>,
offset: ArrayView4<F>,
mask: Option<ArrayView6<F>>,
args: &DeformConvOptions<2>,
kernel_dims: (usize, usize),
) -> (SharedArray<F>, Option<SharedArray<F>>) {
let (kernel_h, kernel_w) = kernel_dims;
let (_, in_channels, height, width) = image.dim();
let (batch_size, offset_channels, out_h, out_w) = offset.dim();
let offs_groups = args.offset_groups;
let channels_per_offset_group = in_channels / args.offset_groups;
let mut grad_offset = Array5::zeros((
offs_groups,
kernel_h,
kernel_w,
2,
batch_size * out_h * out_w,
));
let mut grad_mask =
Array4::zeros((offs_groups, kernel_h, kernel_w, batch_size * out_h * out_w));
grad_mask
.axis_iter_mut(Axis(3))
.zip(grad_offset.axis_iter_mut(Axis(4)))
.enumerate()
.for_each(|(index, (mut grad_mask, mut grad_offset))| {
let out_x = index % out_w;
let out_y = (index / out_w) % out_h;
let batch = index / (out_w * out_h);
let offset = offset.slice(s![batch, .., out_y, out_x]);
let offset = offset
.to_shape((offs_groups, kernel_h, kernel_w, 2))
.unwrap();
let mask: Option<ArrayView3<F>> = mask
.as_ref()
.map(|mask| mask.slice(s![batch, .., .., .., out_y, out_x]));
let columns = columns.slice(s![.., .., .., batch, out_y, out_x]);
let image = image.slice(s![batch, .., .., ..]);
for ((group, kernel_y, kernel_x), grad_mask) in grad_mask.indexed_iter_mut() {
let grad_mask: &mut F = grad_mask;
let mut grad_offset = grad_offset.slice_mut(s![group, kernel_y, kernel_x, ..]);
let offset = offset.slice(s![group, kernel_y, kernel_x, ..]);
let mask = mask.map(|it| it[[group, kernel_y, kernel_x]]);
let columns = columns.slice(s![.., kernel_y, kernel_x]);
let group_offset = group * channels_per_offset_group;
let image = image.slice(s![group_offset.., .., ..]);
let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])
- F::from_elem(args.padding[0])
+ offset[0];
let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])
- F::from_elem(args.padding[1])
+ offset[1];
for (i, grad_offset) in grad_offset.iter_mut().enumerate() {
let is_y_direction = i % 2 == 0;
let use_mask = mask.is_some();
for channel in 0..channels_per_offset_group {
let mask = mask.unwrap_or_else(|| F::one());
let image = image.index_axis(Axis(0), channel);
let weight =
get_coordinate_weight(image, height, width, y, x, is_y_direction);
*grad_offset += mask * weight * columns[channel];
if use_mask && is_y_direction {
*grad_mask += columns[channel]
* bilinear_interpolate(image, height, width, y, x);
}
}
}
}
});
let mask_gradient = mask.map(|_| {
let mut grad_mask = grad_mask
.into_shape_with_order((offset_channels / 2, batch_size, out_h, out_w))
.unwrap();
grad_mask.swap_axes(0, 1);
grad_mask.into_dyn().into_shared()
});
let mut grad_offset = grad_offset
.into_shape_with_order((offset_channels, batch_size, out_h, out_w))
.unwrap();
grad_offset.swap_axes(0, 1);
let offset_gradient = grad_offset.into_dyn().into_shared();
(offset_gradient, mask_gradient)
}
fn get_coordinate_weight<F: FloatNdArrayElement>(
input: ArrayView2<F>,
height: usize,
width: usize,
y: F,
x: F,
is_y_direction: bool,
) -> F {
let y = y.to_f32();
let x = x.to_f32();
let y_low = f32::floor(y);
let x_low = f32::floor(x);
let y_high = y_low + 1.;
let x_high = x_low + 1.;
let valid_y_low = y_low >= 0. && y_low < height as f32;
let valid_y_high = y_high >= 0. && y_high < height as f32;
let valid_x_low = x_low >= 0. && x_low < width as f32;
let valid_x_high = x_high >= 0. && x_high < width as f32;
let bottom_left = if valid_y_low && valid_x_low {
input[[y_low as usize, x_low as usize]]
} else {
F::zero()
};
let bottom_right = if valid_y_low && valid_x_high {
input[[y_low as usize, x_high as usize]]
} else {
F::zero()
};
let top_left = if valid_y_high && valid_x_low {
input[[y_high as usize, x_low as usize]]
} else {
F::zero()
};
let top_right = if valid_y_high && valid_x_high {
input[[y_high as usize, x_high as usize]]
} else {
F::zero()
};
if is_y_direction {
let delta_x = F::from_elem(x - x_low);
delta_x * (top_right - bottom_right) + (F::one() - delta_x) * (top_left - bottom_left)
} else {
let delta_y = F::from_elem(y - y_low);
delta_y * (top_right - top_left) + (F::one() - delta_y) * (bottom_right - bottom_left)
}
}
fn compute_input_grad<F: FloatNdArrayElement>(
columns: ArrayView6<F>,
offset: ArrayView4<F>,
mask: Option<ArrayView6<F>>,
args: &DeformConvOptions<2>,
kernel_dims: (usize, usize),
input_shape: (usize, usize, usize, usize),
) -> SharedArray<F> {
let (batch_size, in_channels, height, width) = input_shape;
let (kernel_h, kernel_w) = kernel_dims;
let offs_groups = args.offset_groups;
let channels_per_offset_group = in_channels / offs_groups;
let grad_in =
Array4::from_shape_simple_fn((batch_size, in_channels, height, width), || {
AtomicF32::new(0.0)
});
let compute_for_each = |(in_channel, kernel_y, kernel_x, batch, out_y, out_x), col: &F| {
let group = in_channel / channels_per_offset_group;
let offset = offset.slice(s![batch, .., out_y, out_x]);
let offset = offset
.to_shape((offs_groups, kernel_h, kernel_w, 2))
.unwrap();
let offset = offset.slice(s![group, kernel_y, kernel_x, ..]);
let offset = [offset[0], offset[1]];
let mask = mask
.as_ref()
.map(|it| it[[batch, group, kernel_y, kernel_x, out_y, out_x]].to_f32());
let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])
- F::from_elem(args.padding[0])
+ offset[0];
let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])
- F::from_elem(args.padding[1])
+ offset[1];
let grad_in = grad_in.slice(s![batch, in_channel, .., ..]);
deform_col2img_kernel(y.to_f32(), x.to_f32(), mask, col.to_f32(), grad_in);
};
// `for_each` expects a 2-tuple argument with `.into_par_iter()`, but 2 separate arguments otherwise
#[cfg(feature = "multi-threads")]
run_par!(|| {
iter_par!(Zip::indexed(columns))
.for_each(|(args0, args1)| compute_for_each(args0, args1))
});
#[cfg(not(feature = "multi-threads"))]
run_par!(|| { iter_par!(Zip::indexed(columns)).for_each(&compute_for_each) });
let grad_in: Array1<F> = grad_in
.into_iter()
.map(|it| F::from_elem(it.into_inner()))
.collect();
let grad_in = grad_in
.into_shape_with_order((batch_size, in_channels, height, width))
.unwrap();
grad_in.into_dyn().into_shared()
}
fn deform_col2img_kernel(
y: f32,
x: f32,
mask: Option<f32>,
col: f32,
grad_input: ArrayView2<AtomicF32>,
) {
let (height, width) = grad_input.dim();
let mask_value = mask.unwrap_or(1.0);
for dy in -1..=1 {
for dx in -1..=1 {
let yp = f32::floor(y) + dy as f32;
let xp = f32::floor(x) + dx as f32;
if yp >= 0.0
&& yp < height as f32
&& xp >= 0.0
&& xp < width as f32
&& f32::abs(y - yp) < 1.0
&& f32::abs(x - xp) < 1.0
{
let weight = (1.0 - f32::abs(y - yp)) * (1.0 - f32::abs(x - xp));
#[cfg_attr(not(target_has_atomic = "32"), allow(unused))]
let value = mask_value * weight * col;
#[cfg(target_has_atomic = "32")]
grad_input[[yp as usize, xp as usize]].fetch_add(value, Ordering::AcqRel);
#[cfg(not(target_has_atomic = "32"))]
panic!("Can't use deformable convolution backwards pass without atomics");
}
}
}
}
}

View File

@@ -0,0 +1,214 @@
use burn_backend::ElementConversion;
use burn_backend::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float;
use ndarray::Array4;
use crate::SharedArray;
use crate::{FloatNdArrayElement, UnsafeSharedRef, iter_range_par, run_par};
/// Sample a tensor using grid-based sampling.
///
/// # Arguments
///
/// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in)
/// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
/// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
/// * `options` - Grid sampling options (mode, padding_mode, align_corners)
///
/// # Returns
///
/// A tensor with shape (N, C, H_out, W_out)
pub(crate) fn grid_sample_2d<E: FloatNdArrayElement>(
tensor: SharedArray<E>,
grid: SharedArray<E>,
options: GridSampleOptions,
) -> SharedArray<E> {
match options.mode {
InterpolateMode::Bilinear => (),
_ => todo!(
"grid_sample_2d with {:?} mode is not implemented",
options.mode
),
}
let tensor = tensor.into_dimensionality::<ndarray::Ix4>().unwrap();
let grid = grid.into_dimensionality::<ndarray::Ix4>().unwrap();
let (batch_size, channels, height_in, width_in) = tensor.dim();
let (b, height_out, width_out, d) = grid.dim();
assert!(batch_size == b);
assert!(2 == d);
let mut output = Array4::zeros((batch_size, channels, height_out, width_out));
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
let sample_count = batch_size * channels * height_out * width_out;
let strides = (
channels * height_out * width_out,
height_out * width_out,
width_out,
);
let align = options.align_corners;
let pad_mode = options.padding_mode;
run_par!(|| {
iter_range_par!(0, sample_count).for_each(|id| {
let (b, c, y, x) = (
id / strides.0,
id % strides.0 / strides.1,
id % strides.1 / strides.2,
id % strides.2,
);
let sample_x = grid[(b, y, x, 0)].elem::<f64>();
let sample_y = grid[(b, y, x, 1)].elem::<f64>();
// Convert normalized grid coordinates [-1, 1] to pixel coordinates
let (px, py) = if align {
// align_corners=true: x_pixel = (x_norm + 1) * (width - 1) / 2
// Maps -1 to 0 and 1 to width - 1
let px = (sample_x + 1.0) * ((width_in - 1) as f64) / 2.0;
let py = (sample_y + 1.0) * ((height_in - 1) as f64) / 2.0;
(px, py)
} else {
// align_corners=false: x_pixel = (x_norm + 1) * width / 2 - 0.5
// Maps -1 to -0.5 and 1 to width - 0.5
let px = (sample_x + 1.0) * (width_in as f64) / 2.0 - 0.5;
let py = (sample_y + 1.0) * (height_in as f64) / 2.0 - 0.5;
(px, py)
};
// Bilinear interpolation with the specified padding mode
let val =
bilinear_interpolate(&tensor, b, c, px, py, width_in, height_in, pad_mode, align);
unsafe {
let output = unsafe_shared_out.get();
output[(b, c, y, x)] = val.elem();
}
});
});
output.into_dyn().into_shared()
}
/// Bilinear interpolation at a point with configurable padding mode.
#[allow(clippy::too_many_arguments)]
fn bilinear_interpolate<E, S>(
source: &ndarray::ArrayBase<S, ndarray::Dim<[usize; 4]>>,
b: usize,
c: usize,
x: f64,
y: f64,
width: usize,
height: usize,
padding_mode: GridSamplePaddingMode,
align_corners: bool,
) -> f64
where
E: FloatNdArrayElement,
S: ndarray::Data<Elem = E>,
{
// Handle inf/nan coordinates
if !x.is_finite() || !y.is_finite() {
return match padding_mode {
GridSamplePaddingMode::Zeros => 0.0,
GridSamplePaddingMode::Border => {
// Clamp to center of image for inf/nan
let cx = ((width - 1) as f64 / 2.0).clamp(0.0, (width - 1) as f64);
let cy = ((height - 1) as f64 / 2.0).clamp(0.0, (height - 1) as f64);
source[(b, c, cy as usize, cx as usize)].elem::<f64>()
}
GridSamplePaddingMode::Reflection => 0.0, // Simplified: treat as zeros for inf/nan
};
}
// Apply padding mode to get actual sampling coordinates
let (x, y) = match padding_mode {
GridSamplePaddingMode::Border => {
// Clamp coordinates to valid range [0, size-1]
let x = x.clamp(0.0, (width - 1) as f64);
let y = y.clamp(0.0, (height - 1) as f64);
(x, y)
}
GridSamplePaddingMode::Reflection => {
// Reflect coordinates at boundaries
let x = reflect_coordinate(x, width, align_corners);
let y = reflect_coordinate(y, height, align_corners);
(x, y)
}
GridSamplePaddingMode::Zeros => (x, y), // Keep as-is, handle out-of-bounds in read
};
// Get the four corner indices
let x0 = x.floor() as i64;
let y0 = y.floor() as i64;
let x1 = x0.saturating_add(1);
let y1 = y0.saturating_add(1);
// Compute interpolation weights (fractional part)
let x_frac = x - x.floor();
let y_frac = y - y.floor();
// Helper to read a value based on padding mode
let read_value = |xi: i64, yi: i64| -> f64 {
match padding_mode {
GridSamplePaddingMode::Zeros => {
// Return 0 for out-of-bounds
if xi >= 0 && xi < width as i64 && yi >= 0 && yi < height as i64 {
source[(b, c, yi as usize, xi as usize)].elem::<f64>()
} else {
0.0
}
}
GridSamplePaddingMode::Border | GridSamplePaddingMode::Reflection => {
// Coordinates should already be in valid range after clamping/reflection
let xi = xi.clamp(0, (width - 1) as i64) as usize;
let yi = yi.clamp(0, (height - 1) as i64) as usize;
source[(b, c, yi, xi)].elem::<f64>()
}
}
};
// Read the four corners
let v00 = read_value(x0, y0);
let v01 = read_value(x0, y1);
let v10 = read_value(x1, y0);
let v11 = read_value(x1, y1);
// Bilinear interpolation weights
let w00 = (1.0 - x_frac) * (1.0 - y_frac);
let w01 = (1.0 - x_frac) * y_frac;
let w10 = x_frac * (1.0 - y_frac);
let w11 = x_frac * y_frac;
v00 * w00 + v01 * w01 + v10 * w10 + v11 * w11
}
/// Reflect a coordinate at the boundaries using a triangle wave pattern.
///
/// For align_corners=true: reflects within [0, size-1]
/// For align_corners=false: reflects within [-0.5, size-0.5]
fn reflect_coordinate(coord: f64, size: usize, align_corners: bool) -> f64 {
let size_f = size as f64;
let (min_val, max_val) = if align_corners {
(0.0, size_f - 1.0)
} else {
(-0.5, size_f - 0.5)
};
let span = max_val - min_val;
if span <= 0.0 {
return min_val;
}
// Triangle wave formula: span - |((x mod 2*span) - span)|
let period = 2.0 * span;
let x = (coord - min_val).abs();
let x_mod = x - (x / period).floor() * period;
span - (x_mod - span).abs() + min_val
}

View File

@@ -0,0 +1,497 @@
// Language
use crate::rand::get_seeded_rng;
use alloc::vec::Vec;
use burn_backend::backend::ExecutionError;
use burn_backend::ops::IntTensorOps;
use burn_backend::tensor::{FloatTensor, IntTensor};
use burn_backend::{Distribution, IntDType, Scalar, TensorMetadata};
use burn_backend::ElementConversion;
// Current crate
use crate::{NdArray, cast_to_dtype, execute_with_dtype, tensor::NdArrayTensor};
use crate::{NdArrayDevice, SEED, slice};
use crate::{SharedArray, element::QuantElement};
use crate::{cat_with_dtype, execute_with_float_dtype};
use crate::{element::FloatNdArrayElement, ops::matmul::matmul};
use crate::{element::IntNdArrayElement, execute_with_int_dtype};
// Workspace crates
use super::{NdArrayBitOps, NdArrayMathOps, NdArrayOps};
use burn_backend::{DType, Shape, TensorData, backend::Backend};
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps<Self>
for NdArray<E, I, Q>
where
NdArrayTensor: From<SharedArray<E>>,
NdArrayTensor: From<SharedArray<I>>,
{
fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
if data.dtype.is_int() || data.dtype.is_uint() {
NdArrayTensor::from_data(data)
} else {
unimplemented!("Unsupported dtype for `int_from_data`: {:?}", data.dtype)
}
}
async fn int_into_data(tensor: NdArrayTensor) -> Result<TensorData, ExecutionError> {
Ok(tensor.into_data())
}
fn int_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
tensor
}
fn int_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
execute_with_int_dtype!(tensor, |array| NdArrayOps::reshape(array, shape))
}
fn int_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor {
slice!(tensor, slices)
}
fn int_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {
NdArrayDevice::Cpu
}
fn int_empty(
shape: Shape,
device: &<NdArray<E> as Backend>::Device,
dtype: IntDType,
) -> NdArrayTensor {
Self::int_zeros(shape, device, dtype)
}
fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
execute_with_int_dtype!((lhs, rhs), matmul)
}
fn int_mask_where(
tensor: NdArrayTensor,
mask: NdArrayTensor,
source: NdArrayTensor,
) -> NdArrayTensor {
execute_with_int_dtype!((tensor, source), |tensor, source| {
NdArrayOps::mask_where(tensor, mask.bool(), source)
})
}
fn int_mask_fill(tensor: NdArrayTensor, mask: NdArrayTensor, value: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(tensor, |array| NdArrayOps::mask_fill(
array,
mask.bool(),
value.elem()
))
}
fn int_slice_assign(
tensor: NdArrayTensor,
slices: &[burn_backend::Slice],
value: NdArrayTensor,
) -> NdArrayTensor {
execute_with_int_dtype!((tensor, value), |tensor, value| NdArrayOps::slice_assign(
tensor, slices, value
))
}
fn int_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
cat_with_dtype!(tensors, dim, [I64, I32, I16, I8, U64, U32, U16, U8])
}
fn int_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::equal)
}
fn int_equal_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::equal_elem(array, rhs.elem()))
}
fn int_greater(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater)
}
fn int_greater_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_elem(array, rhs.elem()))
}
fn int_greater_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater_equal)
}
fn int_greater_equal_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_equal_elem(
array,
rhs.elem()
))
}
fn int_lower(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower)
}
fn int_lower_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_elem(array, rhs.elem()))
}
fn int_lower_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower_equal)
}
fn int_lower_equal_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_equal_elem(
array,
rhs.elem()
))
}
fn int_add(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::add)
}
fn int_add_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::add_scalar(array, rhs.elem()))
}
fn int_sub(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::sub)
}
fn int_sub_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::sub_scalar(array, rhs.elem()))
}
fn int_mul(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::mul)
}
fn int_mul_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::mul_scalar(array, rhs.elem()))
}
fn int_div(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::div)
}
fn int_div_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::div_scalar(array, rhs.elem()))
}
fn int_remainder(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::remainder)
}
fn int_remainder_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::remainder_scalar(
array,
rhs.elem()
))
}
fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor {
// Use view() for zero-copy on borrowed storage
execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::sum_view(
array.view()
))
}
fn int_sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::sum_dim(array, dim))
}
fn int_prod(tensor: NdArrayTensor) -> NdArrayTensor {
// Use view() for zero-copy on borrowed storage
execute_with_int_dtype!(
tensor,
E,
|array: SharedArray<E>| NdArrayMathOps::prod_view(array.view())
)
}
fn int_prod_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::prod_dim(array, dim))
}
fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor {
// Use view() for zero-copy on borrowed storage
execute_with_int_dtype!(
tensor,
E,
|array: SharedArray<E>| NdArrayMathOps::mean_view(array.view())
)
}
fn int_mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::mean_dim(array, dim))
}
fn int_max(tensor: NdArrayTensor) -> NdArrayTensor {
// Use view() for zero-copy on borrowed storage
execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::max_view(
array.view()
))
}
fn int_min(tensor: NdArrayTensor) -> NdArrayTensor {
// Use view() for zero-copy on borrowed storage
execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::min_view(
array.view()
))
}
fn int_cumsum(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumsum(array, dim))
}
fn int_cumprod(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumprod(array, dim))
}
fn int_cummin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummin(array, dim))
}
fn int_cummax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummax(array, dim))
}
fn int_gather(dim: usize, tensor: NdArrayTensor, indices: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor {
execute_with_int_dtype!(indices, |idx_array| NdArrayOps::gather(
dim, array, idx_array
))
})
}
fn int_scatter_add(
dim: usize,
tensor: NdArrayTensor,
indices: NdArrayTensor,
value: NdArrayTensor,
) -> NdArrayTensor {
execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
execute_with_int_dtype!(indices, |idx_array| NdArrayOps::<I>::scatter(
dim, tensor, idx_array, value
))
})
}
fn int_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor {
execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::select(
array, dim, idx_array
))
})
}
fn int_select_add(
tensor: NdArrayTensor,
dim: usize,
indices: NdArrayTensor,
value: NdArrayTensor,
) -> NdArrayTensor {
execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::<I>::select_assign(
tensor, dim, idx_array, value
))
})
}
fn int_argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
// Use view() for zero-copy on borrowed storage
execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {
NdArrayMathOps::argmax_view::<I>(array.view(), dim)
})
}
fn int_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
// Use view() for zero-copy on borrowed storage
execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {
NdArrayMathOps::argmin_view::<I>(array.view(), dim)
})
}
fn int_clamp_min(tensor: NdArrayTensor, min: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, min.elem()))
}
fn int_clamp_max(tensor: NdArrayTensor, max: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_max(array, max.elem()))
}
fn int_clamp(tensor: NdArrayTensor, min: Scalar, max: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp(
array,
min.elem(),
max.elem()
))
}
fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor {
match tensor.dtype() {
DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
execute_with_dtype!(tensor, I, NdArrayMathOps::abs, [
I64 => i64, I32 => i32, I16 => i16, I8 => i8
])
}
// Already unsigned
DType::U64 | DType::U32 | DType::U16 | DType::U8 => tensor,
other => panic!("Unsupported dtype: {other:?}"),
}
}
fn int_into_float(tensor: NdArrayTensor) -> FloatTensor<Self> {
execute_with_int_dtype!(tensor, IntElem, |array: SharedArray<IntElem>| array
.mapv(|a: IntElem| a.elem::<E>())
.into_shared())
}
fn int_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
execute_with_int_dtype!(tensor, |array| NdArrayOps::swap_dims(array, dim1, dim2))
}
fn int_random(
shape: Shape,
distribution: Distribution,
device: &NdArrayDevice,
) -> NdArrayTensor {
let mut seed = SEED.lock().unwrap();
let mut rng = seed.take().unwrap_or_else(get_seeded_rng);
let effective_distribution = if distribution == Distribution::Default {
Distribution::Uniform(0.0, 255.0) // Assuming UniformInt is the integer variant
} else {
distribution
};
let tensor = Self::int_from_data(
TensorData::random::<I, _, _>(shape, effective_distribution, &mut rng),
device,
);
*seed = Some(rng);
tensor
}
fn int_powi(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| NdArrayMathOps::elementwise_op(
lhs,
rhs,
|a: &I, b: &I| { (a.elem::<i64>().pow(b.elem::<u32>())).elem() }
))
}
fn int_powf(lhs: NdArrayTensor, rhs: FloatTensor<Self>) -> NdArrayTensor {
execute_with_int_dtype!(lhs, I, |array| -> NdArrayTensor {
execute_with_float_dtype!(rhs, E, |rhs_array| {
NdArrayMathOps::elementwise_op(array, rhs_array, |a: &I, b: &E| {
(a.elem::<i64>().pow(*b as u32)).elem()
})
})
})
}
fn int_powf_scalar_impl(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(lhs, I, |array| {
NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
(a.elem::<i64>().pow(rhs.elem())).elem()
})
})
}
fn int_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
execute_with_int_dtype!(tensor, |array| NdArrayOps::permute(array, axes))
}
fn int_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
execute_with_int_dtype!(tensor, |array| NdArrayOps::flip(array, axes))
}
fn int_sign(tensor: NdArrayTensor) -> NdArrayTensor {
match tensor.dtype() {
DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
execute_with_dtype!(tensor, I, NdArrayMathOps::sign_op, [
I64 => i64, I32 => i32, I16 => i16, I8 => i8
])
}
DType::U64 | DType::U32 | DType::U16 | DType::U8 => {
Self::int_greater_elem(tensor, 0.into())
}
other => panic!("Unsupported dtype: {other:?}"),
}
}
fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
execute_with_int_dtype!(tensor, |array| NdArrayOps::expand(array, shape))
}
fn bitwise_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitand)
}
fn bitwise_and_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitand_scalar(array, rhs.elem()))
}
fn bitwise_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitor)
}
fn bitwise_or_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitor_scalar(array, rhs.elem()))
}
fn bitwise_xor(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitxor)
}
fn bitwise_xor_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitxor_scalar(array, rhs.elem()))
}
fn bitwise_not(tensor: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!(tensor, NdArrayBitOps::bitnot)
}
fn bitwise_left_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
(a.elem::<i64>() << (b.elem::<u32>())).elem()
})
})
}
fn bitwise_left_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(lhs, I, |array| {
NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
(a.elem::<i64>() << rhs.elem::<u32>()).elem()
})
})
}
fn bitwise_right_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
(a.elem::<i64>() >> (b.elem::<u32>())).elem()
})
})
}
fn bitwise_right_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
execute_with_int_dtype!(lhs, I, |array| {
NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
(a.elem::<i64>() >> rhs.elem::<u32>()).elem()
})
})
}
fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
execute_with_int_dtype!(tensor, |array| cast_to_dtype(array, dtype.into()))
}
fn int_unfold(
tensor: IntTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> IntTensor<Self> {
execute_with_int_dtype!(tensor, |array| NdArrayOps::unfold(array, dim, size, step))
}
}

View File

@@ -0,0 +1,302 @@
use burn_backend::ElementConversion;
use ndarray::{Array4, ArrayBase, DataOwned};
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float;
use crate::{FloatNdArrayElement, ShapeOps, SharedArray, UnsafeSharedRef, iter_range_par, run_par};
pub(crate) fn nearest_interpolate<E: FloatNdArrayElement>(
x: SharedArray<E>,
output_size: [usize; 2],
) -> SharedArray<E> {
let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();
let (batch_size, channels, in_height, in_width) = x.dim();
let [out_height, out_width] = output_size;
let y_ratio = (in_height as f64) / (out_height as f64);
let x_ratio = (in_width as f64) / (out_width as f64);
let out_element_num = batch_size * channels * out_height * out_width;
let strides = (
channels * out_height * out_width,
out_height * out_width,
out_width,
);
let mut output = Array4::zeros((batch_size, channels, out_height, out_width));
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| {
iter_range_par!(0, out_element_num).for_each(|id| {
let (b, c, h, w) = (
id / strides.0,
id % strides.0 / strides.1,
id % strides.1 / strides.2,
id % strides.2,
);
let y_in = (y_ratio * h as f64).floor() as usize;
let x_in = (x_ratio * w as f64).floor() as usize;
unsafe {
let output = unsafe_shared_out.get();
output[(b, c, h, w)] = x[(b, c, y_in, x_in)];
}
});
});
output.into_dyn().into_shared()
}
pub(crate) fn nearest_interpolate_backward<E: FloatNdArrayElement>(
x: SharedArray<E>,
grad: SharedArray<E>,
output_size: [usize; 2],
) -> SharedArray<E> {
let [batch_size, channels, input_height, input_width] = x.shape().dims();
let [output_height, output_width] = output_size;
let mut output_grad =
Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem());
let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad);
run_par!(|| {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;
let output_grad = unsafe_shared_out.get();
for oh in 0..output_height {
for ow in 0..output_width {
let ih = start_index(oh, output_height, input_height);
let iw = start_index(ow, output_width, input_width);
output_grad[[b, c, ih, iw]] += grad[[b, c, oh, ow]]
}
}
})
});
output_grad.into_dyn().into_shared()
}
fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
((output_size_index as f32 * input_size as f32) / output_size as f32).floor() as usize
}
// clamp ceil(frac) to stay within bounds in case of floating-point imprecision
pub(crate) fn ceil_clamp(frac: f64, max: usize) -> f64 {
frac.ceil().min(max as f64)
}
pub(crate) fn bilinear_interpolate<E: FloatNdArrayElement>(
x: SharedArray<E>,
output_size: [usize; 2],
align_corners: bool,
) -> SharedArray<E> {
let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();
let (batch_size, channels, in_height, in_width) = x.dim();
let [out_height, out_width] = output_size;
let out_element_num = batch_size * channels * out_height * out_width;
let strides = (
channels * out_height * out_width,
out_height * out_width,
out_width,
);
let mut output = Array4::zeros((batch_size, channels, out_height, out_width));
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| {
iter_range_par!(0, out_element_num).for_each(|id| {
let (b, c, h, w) = (
id / strides.0,
id % strides.0 / strides.1,
id % strides.1 / strides.2,
id % strides.2,
);
let (y_frac, x_frac) = if align_corners {
let y_ratio = ((in_height - 1) as f64) / (core::cmp::max(out_height - 1, 1) as f64);
let x_ratio = ((in_width - 1) as f64) / (core::cmp::max(out_width - 1, 1) as f64);
(y_ratio * h as f64, x_ratio * w as f64)
} else {
let y_frac = (h as f64 + 0.5) * (in_height as f64 / out_height as f64) - 0.5;
let x_frac = (w as f64 + 0.5) * (in_width as f64 / out_width as f64) - 0.5;
(
y_frac.clamp(0.0, (in_height - 1) as f64),
x_frac.clamp(0.0, (in_width - 1) as f64),
)
};
let val =
bilinear_interpolate_single(&x, b, c, x_frac, y_frac, in_width - 1, in_height - 1);
unsafe {
let output = unsafe_shared_out.get();
output[(b, c, h, w)] = val.elem();
}
});
});
output.into_dyn().into_shared()
}
pub(crate) fn bicubic_interpolate<E: FloatNdArrayElement>(
x: SharedArray<E>,
output_size: [usize; 2],
align_corners: bool,
) -> SharedArray<E> {
fn cubic_interp1d(x0: f64, x1: f64, x2: f64, x3: f64, t: f64) -> f64 {
fn cubic_convolution1(x: f64, a: f64) -> f64 {
((a + 2.0) * x - (a + 3.0)) * x * x + 1.0
}
fn cubic_convolution2(x: f64, a: f64) -> f64 {
((a * x - 5.0 * a) * x + 8.0 * a) * x - 4.0 * a
}
let coeffs = [
cubic_convolution2(t + 1.0, -0.75),
cubic_convolution1(t, -0.75),
cubic_convolution1(1.0 - t, -0.75),
cubic_convolution2(2.0 - t, -0.75),
];
x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]
}
let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();
let (batch_size, channels, in_height, in_width) = x.dim();
let [out_height, out_width] = output_size;
let out_element_num = batch_size * channels * out_height * out_width;
let strides = (
channels * out_height * out_width,
out_height * out_width,
out_width,
);
let mut output = Array4::zeros((batch_size, channels, out_height, out_width));
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| {
iter_range_par!(0, out_element_num).for_each(|id| {
let (b, c, h, w) = (
id / strides.0,
id % strides.0 / strides.1,
id % strides.1 / strides.2,
id % strides.2,
);
let (y_frac, x_frac) = if align_corners {
let y_ratio = ((in_height - 1) as f64) / (core::cmp::max(out_height - 1, 1) as f64);
let x_ratio = ((in_width - 1) as f64) / (core::cmp::max(out_width - 1, 1) as f64);
(y_ratio * h as f64, x_ratio * w as f64)
} else {
let y_frac = (h as f64 + 0.5) * (in_height as f64 / out_height as f64) - 0.5;
let x_frac = (w as f64 + 0.5) * (in_width as f64 / out_width as f64) - 0.5;
(y_frac, x_frac)
};
let y0 = y_frac.floor();
let yw = y_frac - y0;
let y_in = y0 as isize;
let x0 = x_frac.floor();
let xw = x_frac - x0;
let x_in = x0 as isize;
let max_h = (in_height - 1) as isize;
let max_w = (in_width - 1) as isize;
let ys_in = [
(y_in - 1).clamp(0, max_h) as usize,
y_in.clamp(0, max_h) as usize,
(y_in + 1).clamp(0, max_h) as usize,
(y_in + 2).clamp(0, max_h) as usize,
];
let xs_in = [
(x_in - 1).clamp(0, max_w) as usize,
x_in.clamp(0, max_w) as usize,
(x_in + 1).clamp(0, max_w) as usize,
(x_in + 2).clamp(0, max_w) as usize,
];
let coefficients = ys_in.map(|y| {
cubic_interp1d(
x[(b, c, y, xs_in[0])].elem(),
x[(b, c, y, xs_in[1])].elem(),
x[(b, c, y, xs_in[2])].elem(),
x[(b, c, y, xs_in[3])].elem(),
xw,
)
});
let result = cubic_interp1d(
coefficients[0],
coefficients[1],
coefficients[2],
coefficients[3],
yw,
)
.elem();
unsafe {
let output = unsafe_shared_out.get();
output[(b, c, h, w)] = result;
}
});
});
output.into_dyn().into_shared()
}
/// Sample an element of the source array with bilinear interpolation
///
/// * `source` - The tensor to read from. Has shape (batch_size, channels, height, width)
/// * `b` - The batch to read from
/// * `c` - The channel to read from
/// * `x` - The x position to read in the array
/// * `y` - The y position to read in the array
/// * `x_max` - The max x position (inclusive)
/// * `y_max` - The max y position (inclusive)
///
/// # Returns
///
/// The interpolated value read from the array
pub(crate) fn bilinear_interpolate_single<E, S>(
source: &ArrayBase<S, ndarray::Dim<[usize; 4]>>,
b: usize,
c: usize,
x: f64,
y: f64,
x_max: usize,
y_max: usize,
) -> f64
where
E: FloatNdArrayElement,
S: DataOwned<Elem = E>,
{
let y0 = y.floor();
let y1 = ceil_clamp(y, y_max);
let yw = y - y0;
let x0 = x.floor();
let x1 = ceil_clamp(x, x_max);
let xw = x - x0;
let (x0, x1, y0, y1) = (x0 as usize, x1 as usize, y0 as usize, y1 as usize);
let p_a = source[(b, c, y0, x0)].elem::<f64>() * (1.0 - xw) * (1.0 - yw);
let p_b = source[(b, c, y0, x1)].elem::<f64>() * xw * (1.0 - yw);
let p_c = source[(b, c, y1, x0)].elem::<f64>() * (1.0 - xw) * yw;
let p_d = source[(b, c, y1, x1)].elem::<f64>() * xw * yw;
p_a + p_b + p_c + p_d
}

View File

@@ -0,0 +1,107 @@
macro_rules! keepdim {
(
$dim:expr,
$self:expr,
mean
) => {{
// Get shape first (via reference), then pass ownership to avoid clone
let mut shape = $self.shape().into_shape();
shape[$dim] = 1;
let tensor: SharedArray<E> = mean_dim($self, $dim);
NdArrayOps::reshape(tensor, shape)
}};
(
$dim:expr,
$self:expr,
sum
) => {{
// Get shape first (via reference), then pass ownership to avoid clone
let mut shape = $self.shape().into_shape();
shape[$dim] = 1;
let tensor: SharedArray<E> = sum_dim($self, $dim);
NdArrayOps::reshape(tensor, shape)
}};
(
$dim:expr,
$self:expr,
prod
) => {{
// Get shape first (via reference), then pass ownership to avoid clone
let mut shape = $self.shape().into_shape();
shape[$dim] = 1;
let tensor: SharedArray<E> = prod_dim($self, $dim);
NdArrayOps::reshape(tensor, shape)
}};
}
use burn_backend::ElementConversion;
pub(crate) use keepdim;
use ndarray::{Axis, Zip};
use crate::{SharedArray, element::NdArrayElement};
pub(crate) fn mean_dim<E: NdArrayElement>(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
tensor.mean_axis(Axis(dim)).unwrap().into_shared()
}
pub(crate) fn sum_dim<E: NdArrayElement>(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
tensor.sum_axis(Axis(dim)).into_shared()
}
pub(crate) fn prod_dim<E: NdArrayElement>(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
tensor
.fold_axis(Axis(dim), 1.elem::<E>(), |acc, &x| acc.mul(x.elem()))
.into_shared()
}
/// Generic cumulative operation function with closure-based operation.
pub(crate) fn cumulative_with_op<E, F>(tensor: SharedArray<E>, dim: usize, op: F) -> SharedArray<E>
where
E: NdArrayElement,
F: Fn(&mut E, &E),
{
let axis = Axis(dim);
let shape = tensor.shape().to_vec();
// Use into_owned() instead of to_owned() - only copies if shared, avoids copy if unique
let mut result = tensor.into_owned();
let dim_size = shape[dim];
for i in 1..dim_size {
let prev = result.index_axis(axis, i - 1).to_owned();
let mut current = result.index_axis_mut(axis, i);
Zip::from(&mut current).and(&prev).for_each(&op);
}
result.into_shared()
}
// Define all cumulative operation functions using the generic function
pub(crate) fn cumsum_dim<E: NdArrayElement>(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
cumulative_with_op(tensor, dim, |c, &p| *c = c.add(p.elem()))
}
pub(crate) fn cumprod_dim<E: NdArrayElement>(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
cumulative_with_op(tensor, dim, |c, &p| *c = c.mul(p.elem()))
}
pub(crate) fn cummin_dim<E: NdArrayElement + core::cmp::PartialOrd<E>>(
tensor: SharedArray<E>,
dim: usize,
) -> SharedArray<E> {
cumulative_with_op(tensor, dim, |c, &p| {
if p < *c {
*c = p;
}
})
}
pub(crate) fn cummax_dim<E: NdArrayElement + core::cmp::PartialOrd<E>>(
tensor: SharedArray<E>,
dim: usize,
) -> SharedArray<E> {
cumulative_with_op(tensor, dim, |c, &p| {
if p > *c {
*c = p;
}
})
}

View File

@@ -0,0 +1,362 @@
use crate::UnsafeSharedRef;
use crate::{NdArrayElement, ShapeOps, SharedArray, iter_range_par, ops::NdArrayOps, run_par};
use alloc::{vec, vec::Vec};
use burn_backend::ElementConversion;
use burn_backend::Shape;
use ndarray::{IxDyn, s};
pub(crate) fn matmul<E: NdArrayElement>(
lhs: SharedArray<E>,
rhs: SharedArray<E>,
) -> SharedArray<E> {
let shape_lhs = lhs.shape();
let shape_rhs = rhs.shape();
let ndims = shape_lhs.num_dims();
let m = shape_lhs[ndims - 2]; // # of left rows
let k = shape_rhs[ndims - 2]; // # of left cols and right rows
let n = shape_rhs[ndims - 1]; // # of right cols
let (out_shape, strides_lhs, strides_rhs, strides_out) = output_shape(shape_lhs, shape_rhs);
let l_mat_size = m * k; // size of matrix component of left array
let r_mat_size = k * n; // size of matrix component of right array
let out_mat_size = m * n; // size of matrix component of output array
let num_l_batches = shape_lhs.num_elements() / l_mat_size;
let num_r_batches = shape_rhs.num_elements() / r_mat_size;
let num_out_batches = out_shape.num_elements() / out_mat_size;
let lhs_array = NdArrayOps::reshape(lhs, Shape::new([num_l_batches, m, k]));
let rhs_array = NdArrayOps::reshape(rhs, Shape::new([num_r_batches, k, n]));
let alpha: E = 1.0.elem();
let beta: E = 0.0.elem();
let out = run_par!(|| {
let mut out_array = ndarray::Array3::<E>::zeros((num_out_batches, m, n));
let unsafe_shared_out_array = UnsafeSharedRef::new(&mut out_array);
iter_range_par!(0, num_out_batches).for_each(|out_batch| {
// Here, we:
// 1. Un-flatten the output batch into a component-based batch index.
// 2. Use the strides for left and right batch indices to convert it to a flattened
// batch for left and right.
let out_index = strides_out.unflatten(out_batch);
let l_batch = strides_lhs.flatten(&out_index);
let r_batch = strides_rhs.flatten(&out_index);
let lhs_slice = lhs_array.slice(s!(l_batch, .., ..));
let rhs_slice = rhs_array.slice(s!(r_batch, .., ..));
unsafe {
let mut out_slice = unsafe_shared_out_array
.get()
.slice_mut(s!(out_batch, .., ..));
ndarray::linalg::general_mat_mul(
alpha,
&lhs_slice,
&rhs_slice,
beta,
&mut out_slice,
)
}
});
out_array.into_shared().into_dyn()
});
NdArrayOps::reshape(out, out_shape)
}
#[derive(Debug, PartialEq)]
struct Strides {
strides: Vec<usize>,
}
impl Strides {
fn new(strides: Vec<usize>) -> Self {
Strides { strides }
}
fn unflatten(&self, linear_index: usize) -> Vec<usize> {
let mut coord = Vec::with_capacity(self.strides.len());
let mut rem = linear_index;
for stride in self.strides.iter() {
coord.push(rem / stride);
rem %= stride;
}
coord
}
fn flatten(&self, index: &Vec<usize>) -> usize {
assert_eq!(self.strides.len(), index.len());
self.strides
.iter()
.zip(index)
.map(|(stride, index)| stride * index)
.sum()
}
}
/// Compute the (broadcasted) output shape of matrix multiplication, along with strides for
/// the non-matrix dimensions of all arrays.
///
/// # Arguments
/// * `lsh`: Shape of the first (left-hand) matrix multiplication argument.
/// * `rsh`: Shape of the second (right-hand) matrix multiplication argument.
///
/// # Panics
/// * If `D` is not at least 2.
/// * If the matrix multiplication dimensions (last 2) are incompatible.
/// * If any other dimension is not the same for both tensors, or equal to 1. (Any dimension where
/// one dim is equal to 1 is broadcast.)
fn output_shape(lsh: &[usize], rsh: &[usize]) -> (Shape, Strides, Strides, Strides) {
let ndims = lsh.num_dims();
if ndims < 2 {
panic!("Matrix multiplication requires an array with at least 2 dimensions.");
}
// Fetch matrix dimensions and check compatibility.
let l_rows = lsh[ndims - 2];
let l_cols = lsh[ndims - 1];
let r_rows = rsh[ndims - 2];
let r_cols = rsh[ndims - 1];
if l_cols != r_rows {
panic!("Dimensions are incompatible for matrix multiplication.");
}
// Set matrix dimensions of the output shape.
let mut osh = vec![0; ndims];
osh[ndims - 2] = l_rows;
osh[ndims - 1] = r_cols;
// Set other array dimensions, broadcasting as necessary.
// Compute the strides inline.
let mut cur_l_stride: usize = 1;
let mut cur_r_stride: usize = 1;
let mut cur_o_stride: usize = 1;
let mut l_strides = Vec::with_capacity(ndims - 2);
let mut r_strides = Vec::with_capacity(ndims - 2);
let mut o_strides = Vec::with_capacity(ndims - 2);
for i in (0..ndims - 2).rev() {
let l_dim = lsh[i];
let r_dim = rsh[i];
// Compatible dimensions are:
// 1. Both dimensions are equal.
// 2. One of the dimensions is equal to 1.
let o_dim: usize;
if l_dim == r_dim {
o_dim = l_dim; // both dimensions are equal
l_strides.push(cur_l_stride);
r_strides.push(cur_r_stride);
} else if l_dim == 1 {
o_dim = r_dim; // broadcast the left
l_strides.push(0);
r_strides.push(cur_r_stride);
} else if r_dim == 1 {
o_dim = l_dim; // broadcast the right
l_strides.push(cur_l_stride);
r_strides.push(0);
} else {
panic!("Dimensions differ and cannot be broadcasted.");
}
osh[i] = o_dim;
o_strides.push(cur_o_stride);
cur_o_stride *= o_dim;
cur_l_stride *= l_dim;
cur_r_stride *= r_dim;
}
l_strides.reverse();
r_strides.reverse();
o_strides.reverse();
(
Shape::from(osh),
Strides::new(l_strides),
Strides::new(r_strides),
Strides::new(o_strides),
)
}
pub(crate) fn cross<E: NdArrayElement>(
lhs: SharedArray<E>,
rhs: SharedArray<E>,
dim: usize,
) -> SharedArray<E> {
let shape_lhs = lhs.shape();
let shape_rhs = rhs.shape();
let ndims = shape_lhs.num_dims();
// Broadcast the shapes except along dim
let mut broadcast_shape = vec![0; ndims];
for i in 0..ndims {
if i == dim {
broadcast_shape[i] = shape_lhs[i]; // already checked to be 3
} else {
let l = shape_lhs[i];
let r = shape_rhs[i];
if l == r {
broadcast_shape[i] = l;
} else if l == 1 {
broadcast_shape[i] = r;
} else if r == 1 {
broadcast_shape[i] = l;
} else {
panic!("Tensors are not broadcastable along dimension {}", i);
}
}
}
// Broadcast lhs and rhs
let lhs_broadcast = if shape_lhs == broadcast_shape.as_slice() {
lhs
} else {
NdArrayOps::expand(lhs, Shape::from(broadcast_shape.clone()))
};
let rhs_broadcast = if shape_rhs == broadcast_shape.as_slice() {
rhs
} else {
NdArrayOps::expand(rhs, Shape::from(broadcast_shape.clone()))
};
// Now, move dim to the last dimension
let mut perm = (0..ndims).collect::<Vec<_>>();
perm.remove(dim);
perm.push(dim);
let lhs_permuted = NdArrayOps::permute(lhs_broadcast, &perm);
let rhs_permuted = NdArrayOps::permute(rhs_broadcast, &perm);
// Reshape to (*, 3)
let total_elements = lhs_permuted.shape().num_elements();
let batch_size = total_elements / 3;
let lhs_reshaped = NdArrayOps::reshape(lhs_permuted, Shape::new([batch_size, 3]));
let rhs_reshaped = NdArrayOps::reshape(rhs_permuted, Shape::new([batch_size, 3]));
// Compute cross product
let mut result = ndarray::ArrayD::<E>::zeros(IxDyn(&[batch_size, 3]));
for i in 0..batch_size {
let a1 = lhs_reshaped[IxDyn(&[i, 0])];
let a2 = lhs_reshaped[IxDyn(&[i, 1])];
let a3 = lhs_reshaped[IxDyn(&[i, 2])];
let b1 = rhs_reshaped[IxDyn(&[i, 0])];
let b2 = rhs_reshaped[IxDyn(&[i, 1])];
let b3 = rhs_reshaped[IxDyn(&[i, 2])];
result[IxDyn(&[i, 0])] = a2.mul(b3).sub(a3.mul(b2));
result[IxDyn(&[i, 1])] = a3.mul(b1).sub(a1.mul(b3));
result[IxDyn(&[i, 2])] = a1.mul(b2).sub(a2.mul(b1));
}
let result_shared = result.into_shared();
// Reshape back to the broadcast shape with dim at the end
let mut result_shape = broadcast_shape;
result_shape.remove(dim);
result_shape.push(3);
let result_reshaped = NdArrayOps::reshape(result_shared, Shape::from(result_shape));
// Permute back
let mut inv_perm = vec![0; ndims];
for (i, &p) in perm.iter().enumerate() {
inv_perm[p] = i;
}
NdArrayOps::permute(result_reshaped, &inv_perm)
}
#[cfg(test)]
mod tests {
use super::*;
impl Strides {
fn empty() -> Self {
Strides {
strides: Vec::with_capacity(0),
}
}
}
#[test]
fn test_output_shape() {
// plain matrix multiply
assert_eq!(
output_shape(&[5, 3], &[3, 7]),
(
Shape::from([5, 7]),
Strides::empty(),
Strides::empty(),
Strides::empty()
)
);
// matrix multiply with one extra stack dimension
assert_eq!(
output_shape(&[4, 5, 3], &[4, 3, 7]),
(
Shape::from([4, 5, 7]),
Strides::new(vec![1]),
Strides::new(vec![1]),
Strides::new(vec![1])
)
);
// rank 3, broadcast left
assert_eq!(
output_shape(&[1, 5, 3], &[4, 3, 7]),
(
Shape::from([4, 5, 7]),
Strides::new(vec![0]),
Strides::new(vec![1]),
Strides::new(vec![1])
)
);
// rank 3, broadcast right
assert_eq!(
output_shape(&[4, 5, 3], &[1, 3, 7]),
(
Shape::from([4, 5, 7]),
Strides::new(vec![1]),
Strides::new(vec![0]),
Strides::new(vec![1])
)
);
// rank 4, multi broadcast
assert_eq!(
output_shape(&[1, 4, 5, 3], &[8, 1, 3, 7]),
(
Shape::from([8, 4, 5, 7]),
Strides::new(vec![0, 1]),
Strides::new(vec![1, 0]),
Strides::new(vec![4, 1])
)
);
// rank 5, multi-broadcast
assert_eq!(
output_shape(&[1, 3, 4, 5, 3], &[8, 3, 1, 3, 7]),
(
Shape::from([8, 3, 4, 5, 7]),
Strides::new(vec![0, 4, 1]),
Strides::new(vec![3, 1, 0]),
Strides::new(vec![12, 4, 1])
)
)
}
#[test]
#[should_panic(
expected = "Matrix multiplication requires an array with at least 2 dimensions."
)]
fn test_output_shape_too_small() {
output_shape(&[4], &[4]);
}
#[test]
#[should_panic(expected = "Dimensions are incompatible for matrix multiplication.")]
fn test_output_shape_bad_matrix_dims() {
output_shape(&[5, 3], &[4, 7]);
}
#[test]
#[should_panic(expected = "Dimensions differ and cannot be broadcasted.")]
fn test_output_shape_non_broadcast() {
output_shape(&[4, 5, 3], &[2, 3, 7]);
}
}

View File

@@ -0,0 +1,247 @@
use crate::{
ShapeOps, SharedArray,
element::{FloatNdArrayElement, IntNdArrayElement},
iter_range_par,
ops::padding::apply_padding_4d,
run_par,
sharing::UnsafeSharedRef,
};
use burn_backend::ElementConversion;
use burn_backend::ops::conv::calculate_pool_output_size;
use ndarray::Array4;
pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
x: SharedArray<E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
) -> SharedArray<E> {
let [kernel_height, kernel_width] = kernel_size;
let [padding_height, padding_width] = padding;
let [stride_height, stride_width] = stride;
let [dilation_height, dilation_width] = dilation;
let [batch_size, channels, x_height, x_width] = x.shape().dims();
let inf = (-f32::INFINITY).elem::<E>();
let out_height = calculate_pool_output_size(
kernel_height,
stride_height,
padding_height,
dilation_height,
x_height,
ceil_mode,
);
let out_width = calculate_pool_output_size(
kernel_width,
stride_width,
padding_width,
dilation_width,
x_width,
ceil_mode,
);
// Calculate extra padding needed for ceil_mode
// The maximum input position accessed is: (out_size - 1) * stride + (kernel_size - 1) * dilation
// This must be < input_size + 2 * total_padding
let max_ih =
(out_height.saturating_sub(1)) * stride_height + (kernel_height - 1) * dilation_height;
let max_iw = (out_width.saturating_sub(1)) * stride_width + (kernel_width - 1) * dilation_width;
let padded_height = x_height + 2 * padding_height;
let padded_width = x_width + 2 * padding_width;
let extra_pad_h = max_ih.saturating_sub(padded_height.saturating_sub(1));
let extra_pad_w = max_iw.saturating_sub(padded_width.saturating_sub(1));
let total_padding = [padding_height + extra_pad_h, padding_width + extra_pad_w];
let x = apply_padding_4d::<E>(x, total_padding, inf);
// Offset to account for extra padding (extra_pad is added on both sides by apply_padding_4d)
let offset_h = extra_pad_h;
let offset_w = extra_pad_w;
let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf);
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;
let output = unsafe_shared_out.get();
for oh in 0..out_height {
for ow in 0..out_width {
let mut max_val = inf;
for kh in 0..kernel_height {
let ih = offset_h + oh * stride_height + kh * dilation_height;
for kw in 0..kernel_width {
let iw = offset_w + ow * stride_width + kw * dilation_width;
let val = x[[b, c, ih, iw]];
if val > max_val {
max_val = val;
}
}
}
output[[b, c, oh, ow]] = max_val;
}
}
})
});
output.into_dyn().into_shared()
}
pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement, I: IntNdArrayElement>(
x: SharedArray<E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
) -> (SharedArray<E>, SharedArray<I>) {
let [kernel_height, kernel_width] = kernel_size;
let [padding_height, padding_width] = padding;
let [stride_height, stride_width] = stride;
let [dilation_height, dilation_width] = dilation;
let [batch_size, channels, x_height, x_width] = x.shape().dims();
let inf = (-f32::INFINITY).elem::<E>();
let out_height = calculate_pool_output_size(
kernel_height,
stride_height,
padding_height,
dilation_height,
x_height,
ceil_mode,
);
let out_width = calculate_pool_output_size(
kernel_width,
stride_width,
padding_width,
dilation_width,
x_width,
ceil_mode,
);
// Calculate extra padding needed for ceil_mode
let max_ih =
(out_height.saturating_sub(1)) * stride_height + (kernel_height - 1) * dilation_height;
let max_iw = (out_width.saturating_sub(1)) * stride_width + (kernel_width - 1) * dilation_width;
let padded_height = x_height + 2 * padding_height;
let padded_width = x_width + 2 * padding_width;
let extra_pad_h = max_ih.saturating_sub(padded_height.saturating_sub(1));
let extra_pad_w = max_iw.saturating_sub(padded_width.saturating_sub(1));
let total_padding = [padding_height + extra_pad_h, padding_width + extra_pad_w];
let x = apply_padding_4d::<E>(x, total_padding, inf);
// Offset to account for extra padding
let offset_h = extra_pad_h;
let offset_w = extra_pad_w;
let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf);
let mut indices = Array4::<I>::zeros((batch_size, channels, out_height, out_width));
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices);
run_par!(|| {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;
let output = unsafe_shared_out.get();
let indices = unsafe_shared_indices.get();
for oh in 0..out_height {
for ow in 0..out_width {
let mut max_val = inf;
let mut index = 0;
for kh in 0..kernel_height {
let ih = offset_h + oh * stride_height + kh * dilation_height;
for kw in 0..kernel_width {
let iw = offset_w + ow * stride_width + kw * dilation_width;
let val = x[[b, c, ih, iw]];
if val > max_val {
max_val = val;
// Calculate index in original (unpadded) input
let ih_orig = ih as i64 - (total_padding[0]) as i64;
let iw_orig = iw as i64 - (total_padding[1]) as i64;
// Clamp to valid range for index calculation
let ih_clamped = ih_orig.max(0).min(x_height as i64 - 1);
let iw_clamped = iw_orig.max(0).min(x_width as i64 - 1);
index = ih_clamped * x_width as i64 + iw_clamped;
}
}
}
output[[b, c, oh, ow]] = max_val;
indices[[b, c, oh, ow]] = index.elem();
}
}
})
});
let output = output.into_dyn().into_shared();
let indices = indices.into_dyn().into_shared();
(output, indices)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement, I: IntNdArrayElement>(
x: SharedArray<E>,
_kernel_size: [usize; 2],
_stride: [usize; 2],
_padding: [usize; 2],
_dilation: [usize; 2],
_ceil_mode: bool,
output_grad: SharedArray<E>,
indices: SharedArray<I>,
) -> SharedArray<E> {
let [_batch_size, _channels, height, width] = output_grad.shape().dims();
let [batch_size, channels, height_x, width_x] = x.shape().dims();
let output_grad = output_grad;
let indices = indices;
let mut output = Array4::zeros((batch_size, channels, height_x, width_x));
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;
let output = unsafe_shared_out.get();
for h in 0..height {
for w in 0..width {
let index = indices[[b, c, h, w]].elem::<i64>();
let grad = output_grad[[b, c, h, w]];
let index_h = index as usize / width_x;
let index_w = index as usize % width_x;
output[[b, c, index_h, index_w]] += grad;
}
}
});
});
output.into_dyn().into_shared()
}

View File

@@ -0,0 +1,24 @@
mod activation;
mod base;
mod bool_tensor;
mod int_tensor;
mod module;
mod qtensor;
#[cfg(feature = "simd")]
mod simd;
mod tensor;
mod transaction;
pub(crate) mod adaptive_avgpool;
pub(crate) mod avgpool;
pub(crate) mod conv;
pub(crate) mod deform_conv;
pub(crate) mod grid_sample;
pub(crate) mod interpolate;
pub(crate) mod macros;
pub(crate) mod matmul;
pub(crate) mod maxpool;
pub(crate) mod padding;
pub(crate) mod quantization;
pub(crate) use base::*;

View File

@@ -0,0 +1,367 @@
use super::{
adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
avgpool::{avg_pool2d, avg_pool2d_backward},
conv::{conv_transpose2d, conv_transpose3d, conv2d, conv3d},
deform_conv::{backward::deform_conv2d_backward, deform_conv2d},
interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
};
#[cfg(feature = "simd")]
use crate::ops::simd::{
avgpool::try_avg_pool2d_simd, conv::try_conv2d_simd, maxpool::try_max_pool2d_simd,
};
use crate::{
NdArray, SharedArray, element::FloatNdArrayElement, execute_with_int_dtype,
tensor::NdArrayTensor,
};
use crate::{
element::{IntNdArrayElement, QuantElement},
ops::interpolate::nearest_interpolate_backward,
};
use burn_backend::{
ElementConversion, TensorMetadata,
ops::{attention::attention_fallback, *},
tensor::FloatTensor,
};
macro_rules! module_op {
// Module op with inputs (inp), optional (opt) and arguments (args).
// Converts NdArrayStorage to SharedArray for compatibility with existing operations.
(inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{
#[allow(unused_parens, unreachable_patterns)]
match ($($x),+) {
($(NdArrayTensor::F32($x)),+) => {
type $element = f32;
$op(
$($x.into_shared()),+
$(, $opt.map(|o| match o { NdArrayTensor::F32(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))*
)
}
($(NdArrayTensor::F64($x)),+) => {
type $element = f64;
$op(
$($x.into_shared()),+
$(, $opt.map(|o| match o { NdArrayTensor::F64(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))*
)
}
_ => panic!("Data type mismatch"),
}
}};
}
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ModuleOps<Self>
for NdArray<E, I, Q>
where
NdArrayTensor: From<SharedArray<E>>,
NdArrayTensor: From<SharedArray<I>>,
{
fn conv2d(
x: NdArrayTensor,
weight: NdArrayTensor,
bias: Option<NdArrayTensor>,
options: ConvOptions<2>,
) -> NdArrayTensor {
module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
#[cfg(feature = "simd")]
let (x, weight, bias) = match try_conv2d_simd(x, weight, bias, options.clone()) {
Ok(out) => return out.into(),
Err(args) => args,
};
conv2d::<E>(x, weight, bias, options).into()
})
}
fn deform_conv2d(
x: FloatTensor<Self>,
offset: FloatTensor<Self>,
weight: FloatTensor<Self>,
mask: Option<FloatTensor<Self>>,
bias: Option<FloatTensor<Self>>,
options: DeformConvOptions<2>,
) -> FloatTensor<Self> {
module_op!(
inp(x, offset, weight),
opt(mask, bias),
E,
|x, offset, weight, mask, bias| deform_conv2d::<E>(
x, offset, weight, mask, bias, options
)
.into()
)
}
fn deform_conv2d_backward(
x: FloatTensor<Self>,
offset: FloatTensor<Self>,
weight: FloatTensor<Self>,
mask: Option<FloatTensor<Self>>,
bias: Option<FloatTensor<Self>>,
output_grad: FloatTensor<Self>,
options: DeformConvOptions<2>,
) -> DeformConv2dBackward<Self> {
module_op!(
inp(x, offset, weight, output_grad),
opt(mask, bias),
E,
|x, offset, weight, output_grad, mask, bias| {
let (x, offset, weight, mask, bias) = deform_conv2d_backward::<E>(
x,
offset,
weight,
mask,
bias,
output_grad,
options,
);
DeformConv2dBackward::new(
x.into(),
offset.into(),
weight.into(),
mask.map(|m| m.into()),
bias.map(|b| b.into()),
)
}
)
}
fn conv_transpose2d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvTransposeOptions<2>,
) -> FloatTensor<Self> {
module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
conv_transpose2d::<E>(x, weight, bias, options).into()
})
}
fn avg_pool2d(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
ceil_mode: bool,
) -> FloatTensor<Self> {
module_op!(inp(x), opt(), E, |x| {
#[cfg(feature = "simd")]
let x = match if ceil_mode {
// SIMD path doesn't support ceil_mode yet, skip it
Err(x)
} else {
try_avg_pool2d_simd(x, kernel_size, stride, padding, count_include_pad)
} {
Ok(out) => return out.into(),
Err(x) => x,
};
avg_pool2d::<E>(
x,
kernel_size,
stride,
padding,
count_include_pad,
ceil_mode,
)
.into()
})
}
fn avg_pool2d_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
ceil_mode: bool,
) -> FloatTensor<Self> {
module_op!(inp(x, grad), opt(), E, |x, grad| avg_pool2d_backward::<E>(
x,
grad,
kernel_size,
stride,
padding,
count_include_pad,
ceil_mode
)
.into())
}
fn max_pool2d(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
) -> FloatTensor<Self> {
module_op!(inp(x), opt(), E, |x| {
#[cfg(feature = "simd")]
let x = match if ceil_mode {
// SIMD path doesn't support ceil_mode yet, skip it
Err(x)
} else {
try_max_pool2d_simd(x, kernel_size, stride, padding, dilation)
} {
Ok(out) => return out.into(),
Err(x) => x,
};
max_pool2d::<E>(x, kernel_size, stride, padding, dilation, ceil_mode).into()
})
}
fn max_pool2d_with_indices(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {
module_op!(inp(x), opt(), E, |x| {
let (output, indices) = max_pool2d_with_indices::<E, I>(
x,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
);
MaxPool2dWithIndices::new(output.into(), indices.into())
})
}
fn max_pool2d_with_indices_backward(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
output_grad: FloatTensor<Self>,
indices: NdArrayTensor,
) -> MaxPool2dBackward<NdArray<E, I, Q>> {
execute_with_int_dtype!(indices, IntElem, |idx_s: SharedArray<IntElem>| {
// Convert indices from runtime dtype to the expected I type
// (pool indices are bounded by tensor dimensions, so conversion is safe)
let indices: SharedArray<I> = idx_s.mapv(|x| x.elem()).into_shared();
module_op!(inp(x, output_grad), opt(), E, |x, output_grad| {
let output = max_pool2d_backward::<E, I>(
x,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
output_grad,
indices,
);
MaxPool2dBackward::new(output.into())
})
})
}
fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
module_op!(inp(x), opt(), E, |x| adaptive_avg_pool2d::<E>(
x,
output_size
)
.into())
}
fn adaptive_avg_pool2d_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
) -> FloatTensor<Self> {
module_op!(inp(x, grad), opt(), E, |x, grad| {
adaptive_avg_pool2d_backward::<E>(x, grad).into()
})
}
fn interpolate(
x: FloatTensor<Self>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<Self> {
match options.mode {
InterpolateMode::Nearest => {
module_op!(inp(x), opt(), E, |x| nearest_interpolate::<E>(
x,
output_size
)
.into())
}
InterpolateMode::Bilinear => {
let align_corners = options.align_corners;
module_op!(inp(x), opt(), E, |x| bilinear_interpolate::<E>(
x,
output_size,
align_corners
)
.into())
}
InterpolateMode::Bicubic => {
let align_corners = options.align_corners;
module_op!(inp(x), opt(), E, |x| bicubic_interpolate::<E>(
x,
output_size,
align_corners
)
.into())
}
}
}
fn interpolate_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<Self> {
match options.mode {
InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| {
nearest_interpolate_backward::<E>(x, grad, output_size).into()
}),
InterpolateMode::Bilinear => {
panic!("bilinear interpolation backward is not supported for ndarray backend")
}
InterpolateMode::Bicubic => {
panic!("bicubic interpolation backward is not supported for ndarray backend")
}
}
}
fn conv3d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvOptions<3>,
) -> FloatTensor<Self> {
module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::<E>(
x, weight, bias, options
)
.into())
}
fn conv_transpose3d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvTransposeOptions<3>,
) -> FloatTensor<Self> {
module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
conv_transpose3d::<E>(x, weight, bias, options).into()
})
}
fn attention(
query: FloatTensor<Self>,
key: FloatTensor<Self>,
value: FloatTensor<Self>,
mask: Option<burn_backend::tensor::BoolTensor<Self>>,
attn_bias: Option<FloatTensor<Self>>,
options: AttentionModuleOptions,
) -> FloatTensor<Self> {
attention_fallback::<Self>(query, key, value, mask, attn_bias, options)
}
}

View File

@@ -0,0 +1,72 @@
use crate::{NdArrayElement, SharedArray};
use ndarray::{Array4, Array5};
use super::NdArrayOps;
pub(crate) fn apply_padding_4d<E: NdArrayElement>(
x: SharedArray<E>,
padding: [usize; 2],
elem: E,
) -> SharedArray<E> {
let [batch_size, input_channels, height, width] = x.shape().try_into().unwrap();
let [padding_height, padding_width] = padding;
let padded_height = height + 2 * padding_height;
let padded_width = width + 2 * padding_width;
let x_new = Array4::from_elem(
(batch_size, input_channels, padded_height, padded_width),
elem,
);
let mut x_new = x_new.into_shared().into_dyn();
x_new = NdArrayOps::slice_assign(
x_new,
&[
burn_backend::Slice::from(0..batch_size),
burn_backend::Slice::from(0..input_channels),
burn_backend::Slice::from(padding_height..height + padding_height),
burn_backend::Slice::from(padding_width..width + padding_width),
],
x,
);
x_new
}
pub(crate) fn apply_padding_5d<E: NdArrayElement>(
x: SharedArray<E>,
padding: [usize; 3],
elem: E,
) -> SharedArray<E> {
let [batch_size, input_channels, depth, height, width] = x.shape().try_into().unwrap();
let [padding_depth, padding_height, padding_width] = padding;
let padded_depth = depth + 2 * padding_depth;
let padded_height = height + 2 * padding_height;
let padded_width = width + 2 * padding_width;
let x_new = Array5::from_elem(
(
batch_size,
input_channels,
padded_depth,
padded_height,
padded_width,
),
elem,
);
let mut x_new = x_new.into_shared().into_dyn();
x_new = NdArrayOps::slice_assign(
x_new,
&[
burn_backend::Slice::from(0..batch_size),
burn_backend::Slice::from(0..input_channels),
burn_backend::Slice::from(padding_depth..depth + padding_depth),
burn_backend::Slice::from(padding_height..height + padding_height),
burn_backend::Slice::from(padding_width..width + padding_width),
],
x,
);
x_new
}

View File

@@ -0,0 +1,346 @@
use alloc::{vec, vec::Vec};
use burn_backend::{
DType, ExecutionError, Shape, TensorData, TensorMetadata,
ops::QTensorOps,
quantization::{
QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue,
QuantizationParametersPrimitive, QuantizedBytes,
},
tensor::{FloatTensor, IntTensor, QuantizedTensor},
};
use crate::{
FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, SharedArray,
element::{IntNdArrayElement, QuantElement},
execute_with_dtype, execute_with_int_dtype, execute_with_numeric_dtype, slice,
};
use super::quantization::{QuantizationStrategy, SymmetricQuantization};
use super::{NdArrayMathOps, NdArrayOps};
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
for NdArray<E, I, Q>
where
NdArrayTensor: From<SharedArray<E>>,
NdArrayTensor: From<SharedArray<I>>,
{
fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
match data.dtype {
DType::QFloat(scheme) => {
let shape = data.shape.clone();
let num_elements = data.num_elements();
let q_bytes = QuantizedBytes {
bytes: data.into_bytes(),
scheme,
num_elements,
};
match scheme {
QuantScheme {
level: QuantLevel::Tensor | QuantLevel::Block(_),
mode: QuantMode::Symmetric,
value: QuantValue::Q8F | QuantValue::Q8S,
..
} => {
// We can load QuantStore::U32 w/ QuantizedBytes impl
let (values, qparams) = q_bytes.into_vec_i8();
let data = TensorData::new(values, shape);
// Overwrite storage
let scheme = scheme.with_store(QuantStore::Native);
let qparams = qparams
.scales
.into_iter()
.map(|scales| QParams { scales })
.collect();
NdArrayQTensor {
qtensor: NdArrayTensor::from_data(data),
scheme,
qparams,
}
}
QuantScheme {
value:
QuantValue::Q4F
| QuantValue::Q4S
| QuantValue::Q2F
| QuantValue::Q2S
| QuantValue::E2M1
| QuantValue::E4M3
| QuantValue::E5M2,
..
} => unimplemented!("from_data not supported for scheme {scheme:?}"),
}
}
_ => panic!(
"Invalid dtype (expected DType::QFloat, got {:?})",
data.dtype
),
}
}
fn quantize(
tensor: FloatTensor<Self>,
scheme: &QuantScheme,
qparams: QuantizationParametersPrimitive<Self>,
) -> QuantizedTensor<Self> {
let shape = tensor.shape();
let data_f = tensor.into_data();
let scales = qparams.scales.into_data().convert::<f32>();
// Implement with ndarray instead of QuantizationStrategy?
let (data, qparams) = match scheme {
QuantScheme {
level: QuantLevel::Tensor,
mode: QuantMode::Symmetric,
#[cfg(not(feature = "export_tests"))]
value: QuantValue::Q8F | QuantValue::Q8S,
// For tests, "native" sub-byte quant serves as a reference for value equality.
// Values are stored as i8 regardless.
#[cfg(feature = "export_tests")]
value:
QuantValue::Q8F
| QuantValue::Q8S
| QuantValue::Q4F
| QuantValue::Q4S
| QuantValue::Q2F
| QuantValue::Q2S,
store: QuantStore::Native,
..
} => {
let scales = scales.iter().next().unwrap();
let strategy = QuantizationStrategy::PerTensorSymmetric(
SymmetricQuantization::init(scales, scheme.value),
);
let values = strategy.quantize(data_f.as_slice().unwrap());
(
TensorData::quantized(values, shape.clone(), *scheme, &[scales]),
vec![QParams { scales }],
)
}
QuantScheme {
level: QuantLevel::Block(block_size),
mode: QuantMode::Symmetric,
#[cfg(not(feature = "export_tests"))]
value: QuantValue::Q8F | QuantValue::Q8S,
#[cfg(feature = "export_tests")]
value:
QuantValue::Q8F
| QuantValue::Q8S
| QuantValue::Q4F
| QuantValue::Q4S
| QuantValue::Q2F
| QuantValue::Q2S,
store: QuantStore::Native,
..
} => {
let scales = scales.as_slice().unwrap();
let (strategy, qparams) = scales
.iter()
.map(|&s| {
(
SymmetricQuantization::init(s, scheme.value),
QParams { scales: s },
)
})
.unzip();
let strategy = QuantizationStrategy::PerBlockSymmetric(strategy, *block_size);
let values = strategy.quantize(data_f.as_slice().unwrap());
(
TensorData::quantized(values, shape.clone(), *scheme, scales),
qparams,
)
}
scheme => unimplemented!("Quantization not supported for scheme {scheme:?}"),
};
let num_elements = data.num_elements();
let q_bytes = QuantizedBytes {
bytes: data.into_bytes(),
scheme: *scheme,
num_elements,
};
let (values, _) = q_bytes.into_vec_i8();
let data = TensorData::new(values, shape).convert::<Q>();
NdArrayQTensor {
qtensor: NdArrayTensor::from_data(data),
scheme: *scheme,
qparams,
}
}
fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
let strategy = tensor.strategy();
let scheme = tensor.scheme;
let shape = tensor.shape();
let data = match tensor.qtensor {
NdArrayTensor::I8(storage) => {
let data = storage.into_shared().into_iter().collect();
dequantize(data, shape, scheme, &strategy)
}
_ => unreachable!(),
};
NdArrayTensor::from_data(data)
}
fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
NdArrayDevice::Cpu
}
fn q_to_device(
tensor: QuantizedTensor<Self>,
_device: &NdArrayDevice,
) -> QuantizedTensor<Self> {
tensor
}
fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
NdArrayQTensor {
qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
NdArrayOps::reshape(array, shape)
}),
scheme: tensor.scheme,
qparams: tensor.qparams,
}
}
async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
let shape = tensor.qtensor.shape();
let scales = tensor.qparams.iter().map(|q| q.scales).collect::<Vec<_>>();
Ok(execute_with_numeric_dtype!(
tensor.qtensor,
E,
|array: SharedArray<E>| {
let values = array.into_iter().collect();
TensorData::quantized(values, shape, tensor.scheme, &scales)
}
))
}
fn q_swap_dims(
tensor: QuantizedTensor<Self>,
dim1: usize,
dim2: usize,
) -> QuantizedTensor<Self> {
NdArrayQTensor {
qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
NdArrayOps::swap_dims(array, dim1, dim2)
}),
scheme: tensor.scheme,
qparams: tensor.qparams,
}
}
fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
NdArrayQTensor {
qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
NdArrayOps::permute(array, axes)
}),
scheme: tensor.scheme,
qparams: tensor.qparams,
}
}
fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
NdArrayQTensor {
qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
NdArrayOps::flip(array, axes)
}),
scheme: tensor.scheme,
qparams: tensor.qparams,
}
}
fn q_gather(
dim: usize,
tensor: QuantizedTensor<Self>,
indices: IntTensor<Self>,
) -> QuantizedTensor<Self> {
let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray<
IntElem,
>|
-> NdArrayTensor {
execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
NdArrayOps::gather(dim, array, idx_array)
})
});
NdArrayQTensor {
qtensor,
scheme: tensor.scheme,
qparams: tensor.qparams,
}
}
fn q_select(
tensor: QuantizedTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
) -> QuantizedTensor<Self> {
let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray<
IntElem,
>|
-> NdArrayTensor {
execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
NdArrayMathOps::select(array, dim, idx_array)
})
});
NdArrayQTensor {
qtensor,
scheme: tensor.scheme,
qparams: tensor.qparams,
}
}
fn q_slice(
tensor: QuantizedTensor<Self>,
slices: &[burn_backend::Slice],
) -> QuantizedTensor<Self> {
NdArrayQTensor {
qtensor: slice!(tensor.qtensor, slices),
scheme: tensor.scheme,
qparams: tensor.qparams,
}
}
fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
NdArrayMathOps::argmax::<I>(array, dim)
})
}
fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
NdArrayMathOps::argmin::<I>(array, dim)
})
}
fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
NdArrayQTensor {
qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
NdArrayOps::expand(array, shape)
}),
scheme: tensor.scheme,
qparams: tensor.qparams,
}
}
}
fn dequantize<Q: QuantElement>(
data: Vec<Q>,
shape: Shape,
scheme: QuantScheme,
strategy: &QuantizationStrategy,
) -> TensorData {
let qparams = match strategy {
QuantizationStrategy::PerTensorSymmetric(quant) => vec![quant.scale],
QuantizationStrategy::PerBlockSymmetric(quant, _block_size) => {
quant.iter().map(|q| q.scale).collect()
}
};
let q_bytes = QuantizedBytes::new(data, scheme, &qparams);
let (values, _qparams) = q_bytes.into_vec_i8();
TensorData::new(strategy.dequantize(&values), shape)
}

View File

@@ -0,0 +1,218 @@
use alloc::vec::Vec;
use num_traits::{Float, PrimInt};
use burn_backend::quantization::{BlockSize, QuantValue};
// NOTE: this mainly serves as a simple reference implementation.
// The de/quantization ops should be refactored to use ndarray.
/// Quantization strategy.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum QuantizationStrategy {
/// Per-tensor symmetric quantization.
PerTensorSymmetric(SymmetricQuantization<f32>),
/// Per-block symmetric quantization.
PerBlockSymmetric(Vec<SymmetricQuantization<f32>>, BlockSize),
}
impl QuantizationStrategy {
/// Quantize the values to a lower precision data type.
pub fn quantize(&self, values: &[f32]) -> Vec<i8> {
match self {
QuantizationStrategy::PerTensorSymmetric(strategy) => strategy.quantize(values),
QuantizationStrategy::PerBlockSymmetric(strategy, block_size) => {
let block_elems = block_size.num_elements();
let num_blocks = strategy.len();
let numel = values.len();
assert_eq!(
numel / block_elems,
num_blocks,
"Invalid per-block quantization with num blocks {num_blocks} and {numel} values"
);
values
.chunks(block_elems)
.enumerate()
.flat_map(|(block_id, block)| strategy[block_id].quantize(block))
.collect()
}
}
}
/// Dequantize the values to a higher precision data type.
pub fn dequantize(&self, values: &[i8]) -> Vec<f32> {
match self {
QuantizationStrategy::PerTensorSymmetric(strategy) => strategy.dequantize(values),
QuantizationStrategy::PerBlockSymmetric(strategy, block_size) => {
let block_elems = block_size.num_elements();
let num_blocks = strategy.len();
let numel = values.len();
assert_eq!(
numel / block_elems,
num_blocks,
"Invalid per-block quantization with block size {block_elems}, num blocks {num_blocks} and {numel} values"
);
values
.chunks(block_elems)
.enumerate()
.flat_map(|(block_id, block)| strategy[block_id].dequantize(block))
.collect()
}
}
}
}
/// Quantization scheme to convert elements of a higher precision data type `E` to a lower precision
/// data type `Q` and vice-versa.
pub trait Quantization<E: Float + Send + Sync> {
/// Returns the quantization range `[a, b]`.
fn range(&self) -> (E, E);
/// Convert the values to a lower precision data type.
fn quantize<Q: PrimInt>(&self, values: &[E]) -> Vec<Q>;
/// Convert a single value to a lower precision data type.
fn quantize_one<Q: PrimInt>(&self, value: E) -> Q;
/// Convert the values back to a higher precision data type.
fn dequantize<Q: PrimInt>(&self, values: &[Q]) -> Vec<E>;
/// Convert a single value back to a higher precision data type.
fn dequantize_one<Q: PrimInt>(&self, value: Q) -> E;
}
fn valid_scale<E: Float>(mut scale: E) -> E {
// If scale is 0 (most likely due to a tensor full of zeros), we arbitrarily adjust the
// scale to 0.1 to avoid division by zero.
if scale.eq(&E::zero()) {
scale = E::from(0.1).unwrap();
}
scale
}
/// Symmetric quantization scheme.
#[derive(Debug, Clone, Copy)]
pub struct SymmetricQuantization<E: Float + Send + Sync> {
/// The scaling factor.
pub scale: E,
// The quantization value data type.
value: QuantValue,
}
impl<E: Float + Send + Sync> SymmetricQuantization<E> {
/// Initialize a symmetric quantization scheme with the given parameters.
pub fn init(scale: E, value: QuantValue) -> Self {
Self {
scale: valid_scale(scale),
value,
}
}
#[allow(dead_code)]
/// Create a new quantization scheme for an input range `[alpha, beta]`.
fn new(alpha: E, beta: E, value: QuantValue) -> Self {
let (a, b) = value.range();
let a = E::from(a).unwrap();
let b = E::from(b).unwrap();
// Compute scale to convert a floating point value in range `[-alpha, alpha]` to the quantized range
let alpha = alpha.abs().max(beta.abs());
let scale = valid_scale((alpha + alpha) / (b - a));
Self { scale, value }
}
}
impl<E: Float + Send + Sync> Quantization<E> for SymmetricQuantization<E> {
fn quantize<Q: PrimInt>(&self, values: &[E]) -> Vec<Q> {
values.iter().map(|x| self.quantize_one(*x)).collect()
}
fn dequantize<Q: PrimInt>(&self, values: &[Q]) -> Vec<E> {
values.iter().map(|x_q| self.dequantize_one(*x_q)).collect()
}
fn quantize_one<Q: PrimInt>(&self, value: E) -> Q {
let (a, b) = self.range();
// x_q = clamp(round(x / scale), a, b)
Q::from(value.div(self.scale).round().clamp(a, b)).unwrap()
}
fn dequantize_one<Q: PrimInt>(&self, value: Q) -> E {
// x = scale * x_q
self.scale * E::from(value).unwrap()
}
fn range(&self) -> (E, E) {
let (a, b) = self.value.range();
let a = E::from(a).unwrap();
let b = E::from(b).unwrap();
(a, b)
}
}
impl<E: Float + Send + Sync> PartialEq for SymmetricQuantization<E> {
fn eq(&self, other: &Self) -> bool {
self.scale == other.scale
}
}
impl<E: Float + Send + Sync> Eq for SymmetricQuantization<E> {}
#[cfg(test)]
mod tests {
use burn_backend::TensorData;
use super::*;
use alloc::vec;
#[test]
fn test_int8_symmetric_quantization() {
let x: [f32; 4] = [-1.8, -1.0, 0.0, 0.5];
let expected_q = vec![-127, -71, 0, 35];
let expected_d = vec![-1.8, -1.0062993, 0.0, 0.496063];
let symmetric = SymmetricQuantization::<f32>::new(-1.8, 0.5, QuantValue::Q8S);
let q: Vec<i8> = symmetric.quantize(&x);
assert_eq!(q, expected_q);
let d = symmetric.dequantize(&expected_q);
assert_eq!(d, expected_d);
}
#[test]
fn test_int8_symmetric_quantization_per_block() {
let x: [f32; 8] = [-1.8, -1.0, 0.0, 0.5, -1.8, -1.0, 0.0, 0.5];
let expected_q = vec![-127, -71, 0, 35, -127, -71, 0, 35];
let expected_d = vec![
-1.8, -1.0062993, 0.0, 0.496063, -1.8, -1.0062993, 0.0, 0.496063,
];
let symmetric = SymmetricQuantization::<f32>::new(-1.8, 0.5, QuantValue::Q8S);
let strategy = QuantizationStrategy::PerBlockSymmetric(
vec![symmetric, symmetric],
BlockSize::new([4]),
);
let q: Vec<i8> = strategy.quantize(&x);
assert_eq!(q, expected_q);
let d = symmetric.dequantize(&expected_q);
assert_eq!(d, expected_d);
}
#[test]
fn should_support_dequantize() {
let strategy = QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization {
scale: 0.1,
value: QuantValue::Q8S,
});
let output = strategy.dequantize(&[-127i8, -77, -26, 25, 76, 127]);
let output = TensorData::new(output, [2, 3]);
output.assert_approx_eq::<f32>(
&TensorData::from([[-12.7, -7.7, -2.6], [2.5, 7.6, 12.7]]),
Default::default(),
);
}
}

View File

@@ -0,0 +1,443 @@
use core::{marker::PhantomData, mem::transmute};
use crate::{SharedArray, iter_range_par, run_par, sharing::UnsafeSharedRef};
use burn_backend::DType;
use burn_backend::{Element, ElementConversion};
use bytemuck::Zeroable;
use macerator::{Simd, VAdd, VDiv};
use ndarray::{Array4, s};
use nhwc::avg_pool_nhwc;
use super::should_use_simd;
#[macerator::with_simd]
fn is_accelerated<S: Simd, T: VAdd + VDiv>(_x: PhantomData<T>) -> bool {
<T as VAdd>::is_accelerated::<S>() && <T as VDiv>::is_accelerated::<S>()
}
pub(crate) fn try_avg_pool2d_simd<E: Element>(
x: SharedArray<E>,
ksize: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
with_pad: bool,
) -> Result<SharedArray<E>, SharedArray<E>> {
// Strides must be unit, dilation isn't supported, rows must be contiguous
if x.strides()[1] != 1 || !should_use_simd(x.shape()[1]) {
return Err(x);
}
match E::dtype() {
DType::F64 if is_accelerated::<f64>(PhantomData) => Ok(cast(avg_pool_nhwc::<f64>(
cast(x),
ksize,
stride,
padding,
with_pad,
))),
DType::F32 if is_accelerated::<f32>(PhantomData) => Ok(cast(avg_pool_nhwc::<f32>(
cast(x),
ksize,
stride,
padding,
with_pad,
))),
_ => Err(x),
}
}
fn cast<T, E>(tensor: SharedArray<T>) -> SharedArray<E> {
unsafe { transmute::<SharedArray<T>, SharedArray<E>>(tensor) }
}
mod nhwc {
use itertools::Itertools;
use macerator::{Simd, Vector, vload_unaligned, vstore_unaligned};
use ndarray::{ArrayView3, ArrayViewMut3};
use seq_macro::seq;
use crate::ops::simd::lanes;
use super::*;
// Until you can use associated constants as array size, we need to hardcode this.
// The most common config (x86-v3) has 16 registers, so use half of them for accumulators.
const BLOCK_REGISTERS: usize = 8;
pub(crate) fn avg_pool_nhwc<E: Element + VAdd + VDiv>(
x: SharedArray<E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
with_pad: bool,
) -> SharedArray<E> {
let [kernel_height, kernel_width] = kernel_size;
let [pad_h, pad_w] = padding;
let [stride_height, stride_width] = stride;
let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap();
let lanes = lanes::<E>();
let ch_block = lanes * BLOCK_REGISTERS;
let out_height = ((x_height + 2 * pad_h - (kernel_height - 1) - 1) / stride_height) + 1;
let out_width = ((x_width + 2 * pad_w - (kernel_width - 1) - 1) / stride_width) + 1;
let mut output = unsafe {
Array4::<E>::uninit((batch_size, out_height, out_width, channels)).assume_init()
};
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
let x = x.view();
let x = x.permuted_axes(vec![0, 2, 3, 1]);
// Floor division ensures `blocks * lanes * blocking factor` is always `<= out_channels`.
// An exclusive loop will always have `lanes * blocking factor` elements in bounds.
let blocks = channels / ch_block;
let blocks_end = blocks * ch_block;
// Floor division means simd_end is always divisible by `lanes` and `<= out_channels`. An
// exclusive loop will always have `lanes` elements in bounds.
let simd_end = channels / lanes * lanes;
let num_simd_unblocked = (simd_end - blocks_end) / lanes;
let remainder = channels - simd_end;
run_par!(|| {
// SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe.
iter_range_par!(0, batch_size * blocks).for_each(|k| unsafe {
let block = k % blocks;
let b = k / blocks;
let output = unsafe_shared_out.get();
let x = x.slice(s![b, .., .., ..]);
let out = output.slice_mut(s![b, .., .., ..]);
loop_blocked(x, out, kernel_size, stride, padding, with_pad, block);
});
// SAFETY: See `loop_unblocked`
iter_range_par!(0, batch_size * num_simd_unblocked).for_each(|k| unsafe {
let ch = (k % num_simd_unblocked) * lanes + blocks_end;
let b = k / num_simd_unblocked;
let output = unsafe_shared_out.get();
let x = x.slice(s![b, .., .., ..]);
let out = output.slice_mut(s![b, .., .., ..]);
loop_unblocked(x, out, kernel_size, stride, padding, with_pad, ch);
});
// SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe.
iter_range_par!(0, batch_size * remainder).for_each(|k| unsafe {
let ch = (k % remainder) + simd_end;
let b = k / remainder;
let output = unsafe_shared_out.get();
let x = x.slice(s![b, .., .., ..]);
let out = output.slice_mut(s![b, .., .., ..]);
loop_scalar(x, out, kernel_size, stride, padding, with_pad, ch);
});
});
output = output.permuted_axes([0, 3, 1, 2]);
output.into_dyn().into_shared()
}
/// Execute the blocked (unrolled) portion of the pool.
#[allow(
clippy::too_many_arguments,
clippy::erasing_op,
clippy::identity_op,
unused_mut
)]
#[macerator::with_simd]
fn loop_blocked<'a, S: Simd, E: Element + VAdd + VDiv>(
x: ArrayView3<'a, E>,
mut out: ArrayViewMut3<'a, E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
with_pad: bool,
block: usize,
) where
'a: 'a,
{
let [kernel_height, kernel_width] = kernel_size;
let [pad_h, pad_w] = padding;
let [stride_height, stride_width] = stride;
let (x_height, x_width, _) = x.dim();
let (out_height, out_width, _) = out.dim();
let lanes = E::lanes::<S>();
let ch_block = lanes * BLOCK_REGISTERS;
// If pixels are more than `padding` from the edges, the in pixel cannot be out of bounds
for oh in pad_h..out_height.saturating_sub(pad_h) {
for ow in pad_w..out_width.saturating_sub(pad_w) {
seq!(N in 0..8 {
let mut sum~N: Vector<S, E> = Zeroable::zeroed();
});
let ch = block * ch_block;
let ch_end = ch + ch_block;
let mut out = out.slice_mut(s![oh, ow, ch..ch_end]);
for kh in 0..kernel_height {
let ih = oh * stride_height + kh - pad_h;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw - pad_w;
let x = x.slice(s![ih, iw, ch..ch_end]);
seq!(N in 0..8 {
// SAFETY:
// Load a full vector from x[N * lanes]. This is bounds checked by the
// slice above.
sum~N += unsafe { vload_unaligned(&x[N * lanes]) };
});
}
}
let count = kernel_height * kernel_width;
let count = (count as u64).elem::<E>();
let count_v = count.splat();
seq!(N in 0..8 {
let s~N = sum~N / count_v;
// SAFETY:
// Store a full vector to out[N * lanes]. This is bounds checked by the
// slice above.
unsafe { vstore_unaligned(&mut out[N * lanes], s~N) };
});
}
}
// Border pixels need bounds checks
if (pad_h, pad_w) != (0, 0) {
let v_borders = (0..pad_h)
.chain(out_height.saturating_sub(pad_h)..out_height)
.cartesian_product(0..out_width);
let h_borders = (0..out_height)
.cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width));
for (oh, ow) in v_borders.chain(h_borders) {
seq!(N in 0..8 {
let mut sum~N: Vector<S, E> = Zeroable::zeroed();
});
let mut count: usize = 0;
let ch = block * ch_block;
let ch_end = ch + ch_block;
let mut out = out.slice_mut(s![oh, ow, ch..ch_end]);
for kh in 0..kernel_height {
let ih = oh * stride_height + kh;
if ih < pad_h || ih >= x_height + pad_h {
continue;
}
let ih = ih - pad_h;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw;
if iw < pad_w || iw >= x_width + pad_w {
continue;
}
let iw = iw - pad_w;
count += 1;
let x = x.slice(s![ih, iw, ch..ch_end]);
seq!(N in 0..8 {
// SAFETY:
// Load a full vector from x[N * lanes]. This is bounds checked by the
// slice above.
sum~N += unsafe { vload_unaligned(&x[N * lanes]) };
});
}
}
if with_pad {
count = kernel_height * kernel_width;
}
let count = (count as u64).elem::<E>();
let count_v = count.splat();
seq!(N in 0..8 {
let s~N = sum~N / count_v;
// SAFETY:
// Store a full vector to out[N * lanes]. This is bounds checked by the
// slice above.
unsafe { vstore_unaligned(&mut out[N * lanes], s~N) };
});
}
}
}
/// Execute the unblocked (not unrolled) portion of the pool.
///
/// SAFETY: Safe as long as `ch + simd_lanes <= out_channels`.
#[allow(clippy::too_many_arguments, unused_mut)]
#[macerator::with_simd]
unsafe fn loop_unblocked<'a, S: Simd, E: Element + VAdd + VDiv>(
x: ArrayView3<'a, E>,
mut out: ArrayViewMut3<'a, E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
with_pad: bool,
ch: usize,
) where
'a: 'a,
{
let [kernel_height, kernel_width] = kernel_size;
let [pad_h, pad_w] = padding;
let [stride_height, stride_width] = stride;
let (x_height, x_width, _) = x.dim();
let (out_height, out_width, _) = out.dim();
// If pixels are not within padding range, bounds checks are always true
for oh in pad_h..out_height - pad_h {
for ow in pad_w..out_width - pad_w {
let mut sum: Vector<S, E> = Zeroable::zeroed();
for kh in 0..kernel_height {
let ih = oh * stride_height + kh - pad_h;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw - pad_w;
// Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes`
let s0 = unsafe { vload_unaligned(&x[[ih, iw, ch]]) };
sum += s0;
}
}
let count = kernel_height * kernel_width;
let count: E = (count as u64).elem();
let count_v = count.splat();
let s0 = sum / count_v;
// Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`.
unsafe { vstore_unaligned(&mut out[[oh, ow, ch]], s0) };
}
}
// Border pixels need bounds checks
if (pad_h, pad_w) != (0, 0) {
let v_borders = (0..pad_h)
.chain(out_height.saturating_sub(pad_h)..out_height)
.cartesian_product(0..out_width);
let h_borders = (0..out_height)
.cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width));
for (oh, ow) in v_borders.chain(h_borders) {
let mut sum: Vector<S, E> = Zeroable::zeroed();
let mut count: usize = 0;
for kh in 0..kernel_height {
let ih = oh * stride_height + kh;
if ih < pad_h || ih >= x_height + pad_h {
continue;
}
let ih = ih - pad_h;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw;
if iw < pad_w || iw >= x_width + pad_w {
continue;
}
let iw = iw - pad_w;
count += 1;
// Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes`
sum += unsafe { vload_unaligned(&x[[ih, iw, ch]]) };
}
}
if with_pad {
count = kernel_height * kernel_width;
}
let count = (count as u64).elem::<E>();
let count_v = count.splat();
let s0 = sum / count_v;
// Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`.
unsafe { vstore_unaligned(&mut out[[oh, ow, ch]], s0) };
}
}
}
/// Execute scalar portion of the pooling
#[allow(clippy::too_many_arguments)]
fn loop_scalar<E: Element + VAdd + VDiv>(
x: ArrayView3<'_, E>,
mut out: ArrayViewMut3<'_, E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
with_pad: bool,
ch: usize,
) {
let [kernel_height, kernel_width] = kernel_size;
let [pad_h, pad_w] = padding;
let [stride_height, stride_width] = stride;
let (x_height, x_width, _) = x.dim();
let (out_height, out_width, _) = out.dim();
// If pixels are not within padding range, bounds checks are always true
for oh in pad_h..out_height.saturating_sub(pad_h) {
for ow in pad_w..out_width.saturating_sub(pad_w) {
let mut sum: E = Zeroable::zeroed();
for kh in 0..kernel_height {
let ih = oh * stride_height + kh - pad_h;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw - pad_w;
sum = sum + x[[ih, iw, ch]];
}
}
let count = (kernel_height * kernel_width) as u64;
out[[oh, ow, ch]] = sum / count.elem();
}
}
// Border pixels need bounds checks
if (pad_h, pad_w) != (0, 0) {
let v_borders = (0..pad_h)
.chain(out_height.saturating_sub(pad_h)..out_height)
.cartesian_product(0..out_width);
let h_borders = (0..out_height)
.cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width));
for (oh, ow) in v_borders.chain(h_borders) {
let mut sum: E = Zeroable::zeroed();
let mut count: usize = 0;
for kh in 0..kernel_height {
let ih = oh * stride_height + kh;
if ih < pad_h || ih >= x_height + pad_h {
continue;
}
let ih = ih - pad_h;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw;
if iw < pad_w || iw >= x_width + pad_w {
continue;
}
let iw = iw - pad_w;
count += 1;
sum = sum + x[[ih, iw, ch]];
}
}
if with_pad {
count = kernel_height * kernel_width;
}
out[[oh, ow, ch]] = sum / (count as u64).elem();
}
}
}
}

View File

@@ -0,0 +1,115 @@
use core::{marker::PhantomData, mem::MaybeUninit};
use macerator::{Arch, Scalar, Simd};
use ndarray::{ArcArray, ArrayD, IxDyn, ShapeBuilder};
/// Whether SIMD instructions are worth using
#[cfg(all(
any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
target_arch = "wasm32",
target_arch = "loongarch64"
),
not(test)
))]
pub fn should_use_simd(len: usize) -> bool {
len >= 32
}
/// Whether SIMD instructions are worth using
#[cfg(all(
not(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
target_arch = "wasm32",
target_arch = "loongarch64"
)),
not(test)
))]
pub fn should_use_simd(_len: usize) -> bool {
false
}
#[cfg(test)]
pub fn should_use_simd(_len: usize) -> bool {
true
}
pub(crate) fn lanes<E: Scalar>() -> usize {
#[allow(non_camel_case_types)]
struct lanes<__T0>(__T0);
impl<E: Scalar> ::macerator::WithSimd for lanes<PhantomData<E>> {
type Output = usize;
#[inline(always)]
fn with_simd<__S: ::macerator::Simd>(self) -> <Self as ::macerator::WithSimd>::Output {
let Self(__ty) = self;
#[allow(unused_unsafe)]
unsafe {
lanes_simd::<__S, E>(__ty)
}
}
}
(Arch::new()).dispatch(lanes(PhantomData::<E>))
}
fn lanes_simd<S: Simd, E: Scalar>(_ty: PhantomData<E>) -> usize {
E::lanes::<S>()
}
pub(crate) fn uninit_array_like<In, Out>(reference: &ArcArray<In, IxDyn>) -> ArrayD<Out> {
let shape = reference.raw_dim();
let strides = reference.strides();
let strides = strides.iter().map(|it| *it as usize).collect::<Vec<_>>();
let shape_strides = shape.strides(IxDyn(&strides));
let size = reference.len();
let mut out_data: Vec<MaybeUninit<Out>> = Vec::with_capacity(size);
unsafe { out_data.set_len(size) };
unsafe { ArrayD::from_shape_vec_unchecked(shape_strides, out_data).assume_init() }
}
pub trait MinMax {
fn min(self, other: Self) -> Self;
fn max(self, other: Self) -> Self;
}
macro_rules! impl_minmax {
($ty: ty) => {
impl MinMax for $ty {
fn min(self, other: Self) -> Self {
Ord::min(self, other)
}
fn max(self, other: Self) -> Self {
Ord::max(self, other)
}
}
};
($($ty: ty),*) => {
$(impl_minmax!($ty);)*
}
}
impl_minmax!(u8, i8, u16, i16, u32, i32, u64, i64);
impl MinMax for f32 {
fn min(self, other: Self) -> Self {
self.min(other)
}
fn max(self, other: Self) -> Self {
self.max(other)
}
}
impl MinMax for f64 {
fn min(self, other: Self) -> Self {
self.min(other)
}
fn max(self, other: Self) -> Self {
self.max(other)
}
}

View File

@@ -0,0 +1,299 @@
use core::{marker::PhantomData, slice};
use burn_backend::Element;
use macerator::{
Scalar, Simd, VAdd, VBitAnd, VBitOr, VBitXor, VDiv, VMul, VOrd, VSub, Vector, vload_unaligned,
vstore_unaligned,
};
use ndarray::ArrayD;
use seq_macro::seq;
use crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like};
use super::{
MinMax,
binary_elemwise::{
VecAdd, VecBitAnd, VecBitOr, VecBitXor, VecDiv, VecMax, VecMin, VecMul, VecSub,
},
should_use_simd,
};
pub trait SimdBinop<T: Scalar, Out: Scalar> {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, Out>;
fn apply(lhs: T, rhs: T) -> Out;
fn is_accelerated<S: Simd>() -> bool;
}
impl<T: VAdd> SimdBinop<T, T> for VecAdd {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {
lhs + rhs
}
fn apply(lhs: T, rhs: T) -> T {
lhs + rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VAdd>::is_accelerated::<S>()
}
}
impl<T: VDiv> SimdBinop<T, T> for VecDiv {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {
lhs / rhs
}
fn apply(lhs: T, rhs: T) -> T {
lhs / rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VDiv>::is_accelerated::<S>()
}
}
impl<T: VMul> SimdBinop<T, T> for VecMul {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {
lhs * rhs
}
fn apply(lhs: T, rhs: T) -> T {
lhs * rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VMul>::is_accelerated::<S>()
}
}
impl<T: VSub> SimdBinop<T, T> for VecSub {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {
lhs - rhs
}
fn apply(lhs: T, rhs: T) -> T {
lhs - rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VSub>::is_accelerated::<S>()
}
}
impl<T: VOrd + MinMax> SimdBinop<T, T> for VecMin {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {
lhs.min(rhs)
}
fn apply(lhs: T, rhs: T) -> T {
MinMax::min(lhs, rhs)
}
fn is_accelerated<S: Simd>() -> bool {
<T as VOrd>::is_min_max_accelerated::<S>()
}
}
impl<T: VOrd + MinMax> SimdBinop<T, T> for VecMax {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {
lhs.max(rhs)
}
fn apply(lhs: T, rhs: T) -> T {
MinMax::max(lhs, rhs)
}
fn is_accelerated<S: Simd>() -> bool {
<T as VOrd>::is_min_max_accelerated::<S>()
}
}
impl<T: VBitAnd> SimdBinop<T, T> for VecBitAnd {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {
lhs & rhs
}
fn apply(lhs: T, rhs: T) -> T {
lhs.bitand(rhs)
}
fn is_accelerated<S: Simd>() -> bool {
<T as VBitAnd>::is_accelerated::<S>()
}
}
impl<T: VBitOr> SimdBinop<T, T> for VecBitOr {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {
lhs | rhs
}
fn apply(lhs: T, rhs: T) -> T {
lhs.bitor(rhs)
}
fn is_accelerated<S: Simd>() -> bool {
<T as VBitOr>::is_accelerated::<S>()
}
}
impl<T: VBitXor> SimdBinop<T, T> for VecBitXor {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {
lhs ^ rhs
}
fn apply(lhs: T, rhs: T) -> T {
lhs.bitxor(rhs)
}
fn is_accelerated<S: Simd>() -> bool {
<T as VBitXor>::is_accelerated::<S>()
}
}
#[macerator::with_simd]
fn is_accelerated<S: Simd, T: Scalar, Out: Scalar, Op: SimdBinop<T, Out>>(
_x: PhantomData<(T, Out, Op)>,
) -> bool {
Op::is_accelerated::<S>()
}
#[allow(clippy::result_large_err)]
pub fn try_binary_simd<
E: Element,
EOut: Element,
T: NdArrayElement + Scalar,
Out: NdArrayElement + Scalar,
Op: SimdBinop<T, Out>,
>(
lhs: SharedArray<E>,
rhs: SharedArray<E>,
) -> Result<SharedArray<EOut>, (SharedArray<E>, SharedArray<E>)> {
let lhs_len = lhs.len();
let rhs_len = rhs.len();
if !should_use_simd(lhs_len.max(rhs_len))
|| !lhs.is_standard_layout()
|| !rhs.is_standard_layout()
|| lhs.shape() != rhs.shape()
|| !is_accelerated::<T, Out, Op>(PhantomData)
{
return Err((lhs, rhs));
}
// Used to assert traits based on the dynamic `DType`.
let lhs = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(lhs) };
let rhs = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(rhs) };
let out = binary_simd_same::<T, Out, Op>(lhs, rhs);
// Used to assert traits based on the dynamic `DType`.
let out = unsafe { core::mem::transmute::<SharedArray<Out>, SharedArray<EOut>>(out) };
Ok(out)
}
fn binary_simd_same<
T: NdArrayElement + Scalar,
Out: NdArrayElement + Scalar,
Op: SimdBinop<T, Out>,
>(
lhs: SharedArray<T>,
rhs: SharedArray<T>,
) -> SharedArray<Out> {
let out = if lhs.is_unique() {
let mut buf = lhs.into_owned();
let lhs = buf.as_slice_mut().unwrap();
let rhs = rhs.as_slice().unwrap();
let out =
unsafe { core::mem::transmute::<&mut [T], &mut [Out]>(unsafe_alias_slice_mut(lhs)) };
binary(lhs, rhs, out, PhantomData::<Op>);
unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<Out>>(buf) }
} else if rhs.is_unique() {
let mut buf = rhs.into_owned();
let lhs = lhs.as_slice().unwrap();
let rhs = buf.as_slice_mut().unwrap();
let out =
unsafe { core::mem::transmute::<&mut [T], &mut [Out]>(unsafe_alias_slice_mut(rhs)) };
binary(lhs, rhs, out, PhantomData::<Op>);
unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<Out>>(buf) }
} else {
let mut out = uninit_array_like(&lhs);
let lhs = lhs.as_slice().unwrap();
let rhs = rhs.as_slice().unwrap();
let out_slice = out.as_slice_mut().unwrap();
binary(lhs, rhs, out_slice, PhantomData::<Op>);
out
};
out.into_shared()
}
#[allow(clippy::erasing_op, clippy::identity_op)]
#[macerator::with_simd]
fn binary<
'a,
S: Simd,
T: NdArrayElement + Scalar,
Out: NdArrayElement + Scalar,
Op: SimdBinop<T, Out>,
>(
lhs: &'a [T],
rhs: &'a [T],
out: &'a mut [Out],
_op: PhantomData<Op>,
) where
'a: 'a,
{
let lanes = T::lanes::<S>();
let mut chunks_lhs = lhs.chunks_exact(8 * lanes);
let mut chunks_rhs = rhs.chunks_exact(8 * lanes);
let mut chunks_out = out.chunks_exact_mut(8 * lanes);
while let Some(((lhs, rhs), out)) = chunks_lhs
.next()
.zip(chunks_rhs.next())
.zip(chunks_out.next())
{
seq!(N in 0..8 {
// Load one full vector from `lhs`.
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
let lhs~N = unsafe { vload_unaligned::<S, _>(&lhs[N * lanes]) };
// Load one full vector from `rhs`.
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
let rhs~N = unsafe { vload_unaligned(&rhs[N * lanes]) };
let s~N = Op::apply_vec(lhs~N, rhs~N);
// Store one full vector to `out`.
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
unsafe { vstore_unaligned(&mut out[N * lanes], s~N) };
});
}
let mut chunks_lhs = chunks_lhs.remainder().chunks_exact(lanes);
let mut chunks_rhs = chunks_rhs.remainder().chunks_exact(lanes);
let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes);
while let Some(((lhs, rhs), out)) = chunks_lhs
.next()
.zip(chunks_rhs.next())
.zip(chunks_out.next())
{
// Load one full vector from `lhs`.
// SAFETY: Guaranteed to be in bounds because `len == lanes`
let lhs0 = unsafe { vload_unaligned::<S, _>(lhs.as_ptr()) };
// Load one full vector from `rhs`.
// SAFETY: Guaranteed to be in bounds because `len == lanes`
let rhs0 = unsafe { vload_unaligned(rhs.as_ptr()) };
let s0 = Op::apply_vec(lhs0, rhs0);
// Store one full vector to `out`.
// SAFETY: Guaranteed to be in bounds because `len == lanes`
unsafe { vstore_unaligned(out.as_mut_ptr(), s0) };
}
for ((lhs, rhs), out) in chunks_lhs
.remainder()
.iter()
.zip(chunks_rhs.remainder())
.zip(chunks_out.into_remainder())
{
*out = Op::apply(*lhs, *rhs)
}
}
/// Unsafely alias a slice to use as an inline argument
fn unsafe_alias_slice_mut<'a, T>(slice: &mut [T]) -> &'a mut [T] {
let ptr = slice.as_mut_ptr();
let len = slice.len();
unsafe { slice::from_raw_parts_mut(ptr, len) }
}

View File

@@ -0,0 +1,419 @@
use core::marker::PhantomData;
use bytemuck::cast;
use macerator::{
Scalar, Simd, VAdd, VBitAnd, VBitOr, VBitXor, VDiv, VMul, VOrd, VSub, Vector, vload,
vload_unaligned, vstore, vstore_unaligned,
};
use ndarray::ArrayD;
use seq_macro::seq;
use crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like};
use super::{MinMax, should_use_simd};
pub trait ScalarSimdBinop<T: Scalar, Out: Scalar> {
type Rhs: Copy;
type RhsVec<S: Simd>: Copy;
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S>;
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, Out>;
fn apply(lhs: T, rhs: Self::Rhs) -> Out;
fn is_accelerated<S: Simd>() -> bool;
}
pub struct VecAdd;
pub struct VecDiv;
pub struct VecMul;
pub struct VecSub;
pub struct VecMin;
pub struct VecMax;
pub struct VecClamp;
pub struct VecBitAnd;
pub struct VecBitOr;
pub struct VecBitXor;
impl<T: VAdd> ScalarSimdBinop<T, T> for VecAdd {
type Rhs = T;
type RhsVec<S: Simd> = Vector<S, T>;
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {
rhs.splat()
}
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {
lhs + rhs
}
fn apply(lhs: T, rhs: T) -> T {
lhs + rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VAdd>::is_accelerated::<S>()
}
}
impl<T: VDiv> ScalarSimdBinop<T, T> for VecDiv {
type Rhs = T;
type RhsVec<S: Simd> = Vector<S, T>;
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {
rhs.splat()
}
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {
lhs / rhs
}
fn apply(lhs: T, rhs: T) -> T {
lhs / rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VDiv>::is_accelerated::<S>()
}
}
impl<T: VMul> ScalarSimdBinop<T, T> for VecMul {
type Rhs = T;
type RhsVec<S: Simd> = Vector<S, T>;
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {
rhs.splat()
}
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {
lhs * rhs
}
fn apply(lhs: T, rhs: T) -> T {
lhs * rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VMul>::is_accelerated::<S>()
}
}
impl<T: VSub> ScalarSimdBinop<T, T> for VecSub {
type Rhs = T;
type RhsVec<S: Simd> = Vector<S, T>;
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {
rhs.splat()
}
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {
lhs - rhs
}
fn apply(lhs: T, rhs: T) -> T {
lhs - rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VSub>::is_accelerated::<S>()
}
}
impl<T: VOrd + MinMax> ScalarSimdBinop<T, T> for VecMin {
type Rhs = T;
type RhsVec<S: Simd> = Vector<S, T>;
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {
rhs.splat()
}
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {
lhs.min(rhs)
}
fn apply(lhs: T, rhs: T) -> T {
lhs.min(rhs)
}
fn is_accelerated<S: Simd>() -> bool {
<T as VOrd>::is_min_max_accelerated::<S>()
}
}
impl<T: VOrd + MinMax> ScalarSimdBinop<T, T> for VecMax {
type Rhs = T;
type RhsVec<S: Simd> = Vector<S, T>;
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {
rhs.splat()
}
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {
lhs.max(rhs)
}
fn apply(lhs: T, rhs: T) -> T {
lhs.max(rhs)
}
fn is_accelerated<S: Simd>() -> bool {
<T as VOrd>::is_min_max_accelerated::<S>()
}
}
impl<T: VOrd + MinMax> ScalarSimdBinop<T, T> for VecClamp {
type Rhs = (T, T);
type RhsVec<S: Simd> = (Vector<S, T>, Vector<S, T>);
fn splat<S: Simd>((min, max): Self::Rhs) -> Self::RhsVec<S> {
(min.splat(), max.splat())
}
fn apply_vec<S: Simd>(lhs: Vector<S, T>, (min, max): Self::RhsVec<S>) -> Vector<S, T> {
lhs.min(max).max(min)
}
fn apply(lhs: T, (min, max): Self::Rhs) -> T {
lhs.min(max).max(min)
}
fn is_accelerated<S: Simd>() -> bool {
<T as VOrd>::is_min_max_accelerated::<S>()
}
}
impl<T: VBitAnd> ScalarSimdBinop<T, T> for VecBitAnd {
type Rhs = T;
type RhsVec<S: Simd> = Vector<S, T>;
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {
rhs.splat()
}
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {
lhs & rhs
}
fn apply(lhs: T, rhs: Self::Rhs) -> T {
lhs & rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VBitAnd>::is_accelerated::<S>()
}
}
impl<T: VBitOr> ScalarSimdBinop<T, T> for VecBitOr {
type Rhs = T;
type RhsVec<S: Simd> = Vector<S, T>;
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {
rhs.splat()
}
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {
lhs | rhs
}
fn apply(lhs: T, rhs: Self::Rhs) -> T {
lhs | rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VBitOr>::is_accelerated::<S>()
}
}
impl<T: VBitXor> ScalarSimdBinop<T, T> for VecBitXor {
type Rhs = T;
type RhsVec<S: Simd> = Vector<S, T>;
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {
rhs.splat()
}
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {
lhs ^ rhs
}
fn apply(lhs: T, rhs: Self::Rhs) -> T {
lhs ^ rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VBitXor>::is_accelerated::<S>()
}
}
#[macerator::with_simd]
fn is_accelerated<S: Simd, T: Scalar, Out: Scalar, Op: ScalarSimdBinop<T, Out>>(
_x: PhantomData<(T, Out, Op)>,
) -> bool {
Op::is_accelerated::<S>()
}
pub fn try_binary_scalar_simd<
E: NdArrayElement,
EOut: NdArrayElement,
T: NdArrayElement + Scalar,
Out: NdArrayElement + Scalar,
Op: ScalarSimdBinop<T, Out>,
>(
input: SharedArray<E>,
elem: Op::Rhs,
) -> Result<SharedArray<EOut>, SharedArray<E>> {
if !should_use_simd(input.len())
|| input.as_slice_memory_order().is_none()
|| !is_accelerated::<T, Out, Op>(PhantomData)
{
return Err(input);
}
// Used to assert traits based on the dynamic `DType`.
let input = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(input) };
let out = if size_of::<T>() == size_of::<Out>()
&& align_of::<T>() >= align_of::<Out>()
&& input.is_unique()
{
unsafe { binary_scalar_simd_inplace::<T, Out, Op>(input, elem) }
} else {
binary_scalar_simd_owned::<T, Out, Op>(input, elem)
};
// Used to assert traits based on the dynamic `DType`.
let out = unsafe { core::mem::transmute::<SharedArray<Out>, SharedArray<EOut>>(out) };
Ok(out)
}
/// Execute operation in place on an owned tensor
/// SAFETY:
/// Must ensure `size_of::<T> == size_of::<Out>` and `align_of::<T> >= align_of::<Out>`.
unsafe fn binary_scalar_simd_inplace<
T: NdArrayElement + Scalar,
Out: NdArrayElement + Scalar,
Op: ScalarSimdBinop<T, Out>,
>(
input: SharedArray<T>,
elem: Op::Rhs,
) -> SharedArray<Out> {
let mut buffer = input.into_owned();
let slice = buffer.as_slice_memory_order_mut().unwrap();
unsafe { binary_scalar_slice_inplace::<T, Out, Op>(slice, elem, PhantomData) };
// Buffer has the same elem size and is filled with the operation output, so this is safe
let out = unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<Out>>(buffer) };
out.into_shared()
}
/// Create a new copy of the tensor as the output
fn binary_scalar_simd_owned<
T: NdArrayElement + Scalar,
Out: NdArrayElement + Scalar,
Op: ScalarSimdBinop<T, Out>,
>(
input: SharedArray<T>,
elem: Op::Rhs,
) -> SharedArray<Out> {
let mut out = uninit_array_like(&input);
let input = input.as_slice_memory_order().unwrap();
let out_slice = out.as_slice_memory_order_mut().unwrap();
binary_scalar_slice::<T, Out, Op>(input, out_slice, elem, PhantomData);
out.into_shared()
}
#[inline(always)]
#[allow(clippy::erasing_op, clippy::identity_op)]
#[macerator::with_simd]
fn binary_scalar_slice<
'a,
S: Simd,
T: NdArrayElement + Scalar,
Out: NdArrayElement + Scalar,
Op: ScalarSimdBinop<T, Out>,
>(
input: &'a [T],
out: &'a mut [Out],
rhs: Op::Rhs,
_op: PhantomData<Op>,
) where
'a: 'a,
{
let lanes = T::lanes::<S>();
let mut chunks_input = input.chunks_exact(8 * lanes);
let mut chunks_out = out.chunks_exact_mut(8 * lanes);
let rhs_vec = Op::splat::<S>(rhs);
while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {
seq!(N in 0..8 {
// Load one full vector from `input`.
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
let s~N = unsafe { vload_unaligned(&input[N * lanes]) };
let s~N = Op::apply_vec(s~N, rhs_vec);
// Store one full vector to `out`.
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
unsafe { vstore_unaligned(&mut out[N * lanes], s~N) };
});
}
let mut chunks_input = chunks_input.remainder().chunks_exact(lanes);
let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes);
while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {
// Load one full vector from `input`.
// SAFETY: Guaranteed to be in bounds because `len == lanes`
let s0 = unsafe { vload_unaligned(input.as_ptr()) };
let s0 = Op::apply_vec(s0, rhs_vec);
// Store one full vector to `out`.
// SAFETY: Guaranteed to be in bounds because `len == lanes`
unsafe { vstore_unaligned(out.as_mut_ptr(), s0) };
}
for (input, out) in chunks_input
.remainder()
.iter()
.zip(chunks_out.into_remainder())
{
*out = Op::apply(*input, rhs)
}
}
/// Execute operation in line.
/// SAFETY:
/// Must ensure `size_of::<T> == size_of::<Out>` and `align_of::<T> >= align_of::<Out>`.
#[inline(always)]
#[macerator::with_simd]
unsafe fn binary_scalar_slice_inplace<
'a,
S: Simd,
T: NdArrayElement + Scalar,
Out: NdArrayElement + Scalar,
Op: ScalarSimdBinop<T, Out>,
>(
buf: &'a mut [T],
rhs: Op::Rhs,
_op: PhantomData<(Out, Op)>,
) where
'a: 'a,
{
let (head, main, tail) = unsafe { buf.align_to_mut::<Vector<S, T>>() };
for elem in head.iter_mut().chain(tail) {
*elem = cast(Op::apply(*elem, rhs));
}
let mut chunks = main.chunks_exact_mut(8);
let rhs = Op::splat::<S>(rhs);
for elem in chunks.by_ref() {
seq!(N in 0..8 {
// Load a full vector from the aligned portion of the buffer.
// SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is
// always a full vector in bounds.
let s~N = unsafe { vload(&elem[N] as *const _ as *const T) };
let s~N = Op::apply_vec(s~N, rhs);
// Store a full vector at the same position as the input. Cast is safe because `Out` is
// size and align compatible
unsafe { vstore_unaligned(&mut elem[N] as *mut _ as *mut Out, s~N) };
});
}
for elem in chunks.into_remainder() {
// Load a full vector from the aligned portion of the buffer.
// SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is
// always a full vector in bounds.
let s0 = unsafe { vload(elem as *const _ as *const T) };
let s0 = Op::apply_vec(s0, rhs);
// Store a full vector at the same position as the input. Cast is safe because `Out` is
// size and align compatible
unsafe { vstore(elem as *mut _ as *mut Out, s0) };
}
}

View File

@@ -0,0 +1,374 @@
use core::{marker::PhantomData, slice};
use burn_backend::Element;
use macerator::{Mask, Scalar, Simd, VEq, VOrd, Vector, vload_unaligned};
use ndarray::ArrayD;
use seq_macro::seq;
use crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like};
use super::should_use_simd;
pub trait SimdCmpOp<T: Scalar> {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T>;
fn apply(lhs: T, rhs: T) -> bool;
fn is_accelerated<S: Simd>() -> bool;
}
pub struct VecEquals;
impl<T: VEq> SimdCmpOp<T> for VecEquals {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {
lhs.eq(rhs)
}
fn apply(lhs: T, rhs: T) -> bool {
lhs == rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VEq>::is_accelerated::<S>()
}
}
pub struct VecGreater;
impl<T: VOrd + PartialOrd> SimdCmpOp<T> for VecGreater {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {
lhs.gt(rhs)
}
fn apply(lhs: T, rhs: T) -> bool {
lhs > rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VOrd>::is_cmp_accelerated::<S>()
}
}
pub struct VecGreaterEq;
impl<T: VOrd + PartialOrd> SimdCmpOp<T> for VecGreaterEq {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {
lhs.ge(rhs)
}
fn apply(lhs: T, rhs: T) -> bool {
lhs >= rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VOrd>::is_cmp_accelerated::<S>()
}
}
pub struct VecLowerEq;
impl<T: VOrd + PartialOrd> SimdCmpOp<T> for VecLowerEq {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {
lhs.le(rhs)
}
fn apply(lhs: T, rhs: T) -> bool {
lhs <= rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VOrd>::is_cmp_accelerated::<S>()
}
}
pub struct VecLower;
impl<T: VOrd + PartialOrd> SimdCmpOp<T> for VecLower {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {
lhs.lt(rhs)
}
fn apply(lhs: T, rhs: T) -> bool {
lhs < rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VOrd>::is_cmp_accelerated::<S>()
}
}
#[macerator::with_simd]
fn is_accelerated<S: Simd, T: Scalar, Op: SimdCmpOp<T>>(_x: PhantomData<(T, Op)>) -> bool {
Op::is_accelerated::<S>()
}
#[allow(clippy::result_large_err)]
pub fn try_cmp_simd<E: Element, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
lhs: SharedArray<E>,
rhs: SharedArray<E>,
) -> Result<SharedArray<bool>, (SharedArray<E>, SharedArray<E>)> {
let lhs_len = lhs.len();
let rhs_len = rhs.len();
if !should_use_simd(lhs_len.max(rhs_len))
|| !lhs.is_standard_layout()
|| !rhs.is_standard_layout()
|| lhs.shape() != rhs.shape()
|| !is_accelerated::<T, Op>(PhantomData)
{
return Err((lhs, rhs));
}
// Used to assert traits based on the dynamic `DType`.
let lhs = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(lhs) };
let rhs = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(rhs) };
let out = cmp_simd_same::<T, Op>(lhs, rhs);
Ok(out)
}
fn cmp_simd_same<T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
lhs: SharedArray<T>,
rhs: SharedArray<T>,
) -> SharedArray<bool> {
let out = if lhs.is_unique() && size_of::<T>() == size_of::<bool>() {
let mut buf = lhs.into_owned();
let lhs = buf.as_slice_mut().unwrap();
let rhs = rhs.as_slice().unwrap();
let out =
unsafe { core::mem::transmute::<&mut [T], &mut [bool]>(unsafe_alias_slice_mut(lhs)) };
cmp(lhs, rhs, out, PhantomData::<Op>);
unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<bool>>(buf) }
} else if rhs.is_unique() && size_of::<T>() == size_of::<bool>() {
let mut buf = rhs.into_owned();
let lhs = lhs.as_slice().unwrap();
let rhs = buf.as_slice_mut().unwrap();
let out =
unsafe { core::mem::transmute::<&mut [T], &mut [bool]>(unsafe_alias_slice_mut(rhs)) };
cmp(lhs, rhs, out, PhantomData::<Op>);
unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<bool>>(buf) }
} else {
let mut out = uninit_array_like(&lhs);
let lhs = lhs.as_slice().unwrap();
let rhs = rhs.as_slice().unwrap();
let out_slice = out.as_slice_mut().unwrap();
cmp(lhs, rhs, out_slice, PhantomData::<Op>);
out
};
out.into_shared()
}
#[allow(clippy::erasing_op, clippy::identity_op)]
#[macerator::with_simd]
fn cmp<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
lhs: &'a [T],
rhs: &'a [T],
out: &'a mut [bool],
_op: PhantomData<Op>,
) where
'a: 'a,
{
let lanes = T::lanes::<S>();
let mut chunks_lhs = lhs.chunks_exact(8 * lanes);
let mut chunks_rhs = rhs.chunks_exact(8 * lanes);
let mut chunks_out = out.chunks_exact_mut(8 * lanes);
while let Some(((lhs, rhs), out)) = chunks_lhs
.next()
.zip(chunks_rhs.next())
.zip(chunks_out.next())
{
seq!(N in 0..8 {
// Load one full vector from `lhs`.
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
let lhs~N = unsafe { vload_unaligned::<S, _>(&lhs[N * lanes]) };
// Load one full vector from `rhs`.
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
let rhs~N = unsafe { vload_unaligned(&rhs[N * lanes]) };
let s~N = Op::apply_vec(lhs~N, rhs~N);
// Store one full vector to `out`.
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
unsafe { T::mask_store_as_bool(&mut out[N * lanes], s~N) };
});
}
let mut chunks_lhs = chunks_lhs.remainder().chunks_exact(lanes);
let mut chunks_rhs = chunks_rhs.remainder().chunks_exact(lanes);
let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes);
while let Some(((lhs, rhs), out)) = chunks_lhs
.next()
.zip(chunks_rhs.next())
.zip(chunks_out.next())
{
// Load one full vector from `lhs`.
// SAFETY: Guaranteed to be in bounds because `len == lanes`
let lhs0 = unsafe { vload_unaligned::<S, _>(lhs.as_ptr()) };
// Load one full vector from `rhs`.
// SAFETY: Guaranteed to be in bounds because `len == lanes`
let rhs0 = unsafe { vload_unaligned(rhs.as_ptr()) };
let s0 = Op::apply_vec(lhs0, rhs0);
// Store one full vector to `out`.
// SAFETY: Guaranteed to be in bounds because `len == lanes`
unsafe { T::mask_store_as_bool(out.as_mut_ptr(), s0) };
}
for ((lhs, rhs), out) in chunks_lhs
.remainder()
.iter()
.zip(chunks_rhs.remainder())
.zip(chunks_out.into_remainder())
{
*out = Op::apply(*lhs, *rhs)
}
}
/// Unsafely alias a slice to use as an inline argument
fn unsafe_alias_slice_mut<'a, T>(slice: &mut [T]) -> &'a mut [T] {
let ptr = slice.as_mut_ptr();
let len = slice.len();
unsafe { slice::from_raw_parts_mut(ptr, len) }
}
pub use elemwise::try_cmp_scalar_simd;
mod elemwise {
use bytemuck::cast;
use macerator::vload;
use super::*;
pub fn try_cmp_scalar_simd<E: Element, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
input: SharedArray<E>,
elem: T,
) -> Result<SharedArray<bool>, SharedArray<E>> {
if !should_use_simd(input.len())
|| input.as_slice_memory_order().is_none()
|| !is_accelerated::<T, Op>(PhantomData)
{
return Err(input);
}
// Used to assert traits based on the dynamic `DType`.
let input = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(input) };
let out = if size_of::<T>() == size_of::<bool>()
&& align_of::<T>() >= align_of::<bool>()
&& input.is_unique()
{
unsafe { cmp_scalar_simd_inplace::<T, Op>(input, elem) }
} else {
cmp_scalar_simd_owned::<T, Op>(input, elem)
};
Ok(out)
}
/// Execute operation in place on an owned tensor
/// SAFETY:
/// Must ensure `size_of::<T> == size_of::<Out>` and `align_of::<T> >= align_of::<Out>`.
unsafe fn cmp_scalar_simd_inplace<T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
input: SharedArray<T>,
elem: T,
) -> SharedArray<bool> {
let mut buffer = input.into_owned();
let slice = buffer.as_slice_memory_order_mut().unwrap();
unsafe { cmp_scalar_slice_inplace::<T, Op>(slice, elem, PhantomData) };
// Buffer has the same elem size and is filled with the operation output, so this is safe
let out = unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<bool>>(buffer) };
out.into_shared()
}
/// Create a new copy of the tensor as the output
fn cmp_scalar_simd_owned<T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
input: SharedArray<T>,
elem: T,
) -> SharedArray<bool> {
let mut out = uninit_array_like(&input);
let input = input.as_slice_memory_order().unwrap();
let out_slice = out.as_slice_memory_order_mut().unwrap();
cmp_scalar_slice::<T, Op>(input, out_slice, elem, PhantomData);
out.into_shared()
}
#[inline(always)]
#[allow(clippy::erasing_op, clippy::identity_op)]
#[macerator::with_simd]
fn cmp_scalar_slice<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
input: &'a [T],
out: &'a mut [bool],
rhs: T,
_op: PhantomData<Op>,
) where
'a: 'a,
{
let lanes = T::lanes::<S>();
let mut chunks_input = input.chunks_exact(8 * lanes);
let mut chunks_out = out.chunks_exact_mut(8 * lanes);
let rhs_vec = rhs.splat::<S>();
while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {
seq!(N in 0..8 {
// Load one full vector from `input`.
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
let s~N = unsafe { vload_unaligned(&input[N * lanes]) };
let s~N = Op::apply_vec(s~N, rhs_vec);
// Store one full vector to `out`.
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
unsafe { T::mask_store_as_bool(&mut out[N * lanes], s~N) };
});
}
let mut chunks_input = chunks_input.remainder().chunks_exact(lanes);
let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes);
while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {
// Load one full vector from `input`.
// SAFETY: Guaranteed to be in bounds because `len == lanes`
let s0 = unsafe { vload_unaligned(input.as_ptr()) };
let s0 = Op::apply_vec(s0, rhs_vec);
// Store one full vector to `out`.
// SAFETY: Guaranteed to be in bounds because `len == lanes`
unsafe { T::mask_store_as_bool(out.as_mut_ptr(), s0) };
}
for (input, out) in chunks_input
.remainder()
.iter()
.zip(chunks_out.into_remainder())
{
*out = Op::apply(*input, rhs)
}
}
/// Execute operation in line.
/// SAFETY:
/// Must ensure `size_of::<T> == size_of::<Out>` and `align_of::<T> >= align_of::<Out>`.
#[inline(always)]
#[macerator::with_simd]
unsafe fn cmp_scalar_slice_inplace<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
buf: &'a mut [T],
rhs: T,
_op: PhantomData<Op>,
) where
'a: 'a,
{
let (head, main, tail) = unsafe { buf.align_to_mut::<Vector<S, T>>() };
for elem in head.iter_mut().chain(tail) {
*elem = cast(Op::apply(*elem, rhs));
}
let mut chunks = main.chunks_exact_mut(8);
let rhs = rhs.splat::<S>();
for elem in chunks.by_ref() {
seq!(N in 0..8 {
// Load a full vector from the aligned portion of the buffer.
// SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is
// always a full vector in bounds.
let s~N = unsafe { vload(&elem[N] as *const _ as *const T) };
let s~N = Op::apply_vec(s~N, rhs);
// Store a full vector at the same position as the input. Cast is safe because `Out` is
// size and align compatible
unsafe { T::mask_store_as_bool(&mut elem[N] as *mut _ as *mut bool, s~N) };
});
}
for elem in chunks.into_remainder() {
// Load a full vector from the aligned portion of the buffer.
// SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is
// always a full vector in bounds.
let s0 = unsafe { vload(elem as *const _ as *const T) };
let s0 = Op::apply_vec(s0, rhs);
// Store a full vector at the same position as the input. Cast is safe because `Out` is
// size and align compatible
unsafe { T::mask_store_as_bool(elem as *mut _ as *mut bool, s0) };
}
}
}

View File

@@ -0,0 +1,494 @@
use core::{marker::PhantomData, mem::transmute};
use burn_backend::{
DType, Element,
ops::{ConvOptions, conv::calculate_conv_output_size},
};
use bytemuck::Zeroable;
use macerator::{Simd, VMulAdd, Vector, vload_unaligned, vstore_unaligned};
use ndarray::{
ArcArray1, Array4, ArrayView3, ArrayView4, ArrayViewMut2, ArrayViewMut3, Dim, Ix1, Ix4, s,
};
use seq_macro::seq;
use crate::{FloatNdArrayElement, SharedArray, UnsafeSharedRef, iter_range_par, run_par};
type Args<E> = (SharedArray<E>, SharedArray<E>, Option<SharedArray<E>>);
#[allow(clippy::result_large_err)]
pub fn try_conv2d_simd<E: FloatNdArrayElement>(
x: SharedArray<E>,
weight: SharedArray<E>,
bias: Option<SharedArray<E>>,
options: ConvOptions<2>,
) -> Result<SharedArray<E>, Args<E>> {
match E::dtype() {
DType::F64 => conv2d::<f64, _>(x, weight, bias, options, PhantomData),
DType::F32 => conv2d::<f32, _>(x, weight, bias, options, PhantomData),
DType::I64 => conv2d::<i64, _>(x, weight, bias, options, PhantomData),
DType::I32 => conv2d::<i32, _>(x, weight, bias, options, PhantomData),
DType::I16 => conv2d::<i16, _>(x, weight, bias, options, PhantomData),
DType::U64 => conv2d::<u64, _>(x, weight, bias, options, PhantomData),
DType::U32 => conv2d::<u32, _>(x, weight, bias, options, PhantomData),
DType::U16 => conv2d::<u16, _>(x, weight, bias, options, PhantomData),
_ => Err((x, weight, bias)),
}
}
fn cast<T, E>(tensor: SharedArray<T>) -> SharedArray<E> {
unsafe { transmute::<SharedArray<T>, SharedArray<E>>(tensor) }
}
/// Out-channel last SIMD accelerated direct convolution. Loop order and register blocking based on
/// E. Georganas, S. Avancha, K. Banerjee, D. Kalamkar, G. Henry, H. Pabst, A. Heinecke (2018).
/// Anatomy Of High-Performance Deep Learning Convolutions On SIMD Architectures.
/// SC '18, Article 6, pp. 1-12. arXiv:1808.05567. <https://arxiv.org/abs/1808.05567>.
#[allow(clippy::result_large_err)]
fn conv2d<E: VMulAdd + Element, T: Element>(
x: SharedArray<T>,
weight: SharedArray<T>,
bias: Option<SharedArray<T>>,
options: ConvOptions<2>,
_ty: PhantomData<E>,
) -> Result<SharedArray<T>, Args<T>> {
let [out_channels, _, k_height, k_width] = weight.shape().try_into().unwrap();
let channels_per_group = out_channels / options.groups;
#[macerator::with_simd]
fn precheck<S: Simd, E: VMulAdd>(_ty: PhantomData<E>) -> (usize, bool) {
(E::lanes::<S>(), E::is_accelerated::<S>())
}
let (lanes, accelerated) = precheck::<E>(PhantomData);
if !accelerated || !channels_per_group.is_multiple_of(lanes) {
return Err((x, weight, bias));
}
let x = cast::<_, E>(x);
let weight = cast::<_, E>(weight);
let bias = bias.map(|bias| cast::<_, E>(bias));
let [batch_size, _in_channels, in_height, in_width] = x.shape().try_into().unwrap();
let [dilate_h, dilate_w] = options.dilation;
let [stride_h, stride_w] = options.stride;
let [pad_h, pad_w] = options.padding;
let padded = options.padding != [0, 0];
let strided = options.stride != [1, 1] || options.dilation != [1, 1];
let grouped = options.groups != 1;
let out_height = calculate_conv_output_size(k_height, stride_h, pad_h, dilate_h, in_height);
let out_width = calculate_conv_output_size(k_width, stride_w, pad_w, dilate_w, in_width);
let x = x.into_dimensionality::<Ix4>().unwrap();
let weights = weight.into_dimensionality::<Ix4>().unwrap();
let weights = weights.permuted_axes([1, 2, 3, 0]);
let weights = weights.as_standard_layout();
let bias = bias.map(|bias| bias.into_dimensionality::<Ix1>().unwrap());
// floor division means `(oc_blocks - 1) * lanes` can never be greater than `out_channels - lanes`.
let oc_blocks = out_channels / lanes;
let mut out = unsafe {
Array4::<E>::uninit(Dim([batch_size, out_height, out_width, out_channels])).assume_init()
};
let unsafe_shared_out = UnsafeSharedRef::new(&mut out);
run_par!(|| {
// SAFETY: Slices are guaranteed to be non-overlapping, so having an unsafe shared reference
// is safe. `oc_blocks * lanes` must be `<= out_channels` to satisfy safety of inner function.
iter_range_par!(0, batch_size * oc_blocks).for_each(|k| unsafe {
let b = k / oc_blocks;
let ob = k % oc_blocks;
let x = x.slice(s![b, .., .., ..]);
let out = unsafe_shared_out.get();
let mut out = out.slice_mut(s![b, .., .., ..]);
let w = weights.view();
match (padded, strided, grouped) {
(true, true, true) => {
conv2d_launch::<E, true, true, true>(x, w, &bias, &mut out, &options, ob)
}
(true, false, true) => {
conv2d_launch::<E, true, false, true>(x, w, &bias, &mut out, &options, ob)
}
(false, true, true) => {
conv2d_launch::<E, false, true, true>(x, w, &bias, &mut out, &options, ob)
}
(false, false, true) => {
conv2d_launch::<E, false, false, true>(x, w, &bias, &mut out, &options, ob)
}
(true, true, false) => {
conv2d_launch::<E, true, true, false>(x, w, &bias, &mut out, &options, ob)
}
(true, false, false) => {
conv2d_launch::<E, true, false, false>(x, w, &bias, &mut out, &options, ob)
}
(false, true, false) => {
conv2d_launch::<E, false, true, false>(x, w, &bias, &mut out, &options, ob)
}
(false, false, false) => {
conv2d_launch::<E, false, false, false>(x, w, &bias, &mut out, &options, ob)
}
}
});
});
let output = out.permuted_axes([0, 3, 1, 2]);
Ok(cast(output.into_dyn().into_shared()))
}
/// Size of register blocks, we need to hardcode this because Rust and the `seq` macro don't support
/// using associated constants as constant parameters. 8 works for all semi-modern CPUs but might
/// not be perfectly optimized for AVX-512 capable CPUs (which probably should use 16).
/// This should always be conservative, since oversizing it will cause register spills and that's
/// **much** worse than the performance lost with lower values.
const REGISTER_BLOCK: usize = 8;
inner_with_register_blocking_size!(8);
/// Run a loop of conv2d.
/// # SAFETY
/// See `conv2d_inner_nopad`, `conv2d_inner_nopad_nostride`, `conv2d_remainder`.
/// Required preconditions: `ob * simd_lanes` must be `<= out_channels - simd_lanes`, `weights` and
/// `out` must have unit stride for the out channels.
#[inline(always)]
#[macerator::with_simd]
unsafe fn conv2d_launch<
'a,
S: Simd,
E: VMulAdd,
const PAD: bool,
const STRIDE: bool,
const GROUPS: bool,
>(
x: ArrayView3<'a, E>,
weights: ArrayView4<'a, E>,
bias: &'a Option<ArcArray1<E>>,
out: &'a mut ArrayViewMut3<'a, E>,
options: &'a ConvOptions<2>,
ob: usize,
) where
'a: 'a,
{
let (in_channels, k_height, k_width, out_channels) = weights.dim();
let (out_height, out_width, _) = out.dim();
let channels_per_group = out_channels / options.groups;
let lanes = E::lanes::<S>();
let [mut pad_h, mut pad_w] = options.padding;
let [stride_h, stride_w] = options.stride;
let [dilate_h, dilate_w] = options.dilation;
// Trick compiler into inlining 0 to padding
if !PAD {
pad_h = 0;
pad_w = 0;
}
let oc_b = channels_per_group.min(lanes);
let ow_b = REGISTER_BLOCK;
let ow_start = pad_w;
let ow_width = out_width.saturating_sub(2 * pad_w);
let oh_start = pad_h;
let oh_end = out_height.saturating_sub(pad_h);
let ow_blocks = ow_width / ow_b;
let oc = ob * oc_b;
let group = oc / channels_per_group;
let mut ic_off = group * in_channels;
if !GROUPS {
ic_off = 0;
}
unsafe {
let bias = if let Some(bias) = &bias {
vload_unaligned::<S, _>(&bias[oc])
} else {
Zeroable::zeroed()
};
for oh in oh_start..oh_end {
let mut out = out.slice_mut(s![oh, .., ..]);
for ow_block in 0..ow_blocks {
let ow = ow_block * ow_b + ow_start;
#[allow(clippy::if_same_then_else)]
if STRIDE {
conv2d_inner_nopad(
&x, &weights, &mut out, bias, oh, ow, oc, ic_off, stride_h, stride_w,
dilate_h, dilate_w, k_height, k_width, pad_h, pad_w,
);
} else {
conv2d_inner_nopad_nostride(
&x, &weights, &mut out, bias, oh, ow, oc, ic_off, k_height, k_width, pad_h,
pad_w,
);
}
}
}
conv2d_remainder(
x,
weights,
out,
bias,
oc,
ic_off,
ow_blocks * ow_b,
stride_h,
stride_w,
dilate_h,
dilate_w,
pad_h,
pad_w,
k_height,
k_width,
);
}
}
/// Execute the non-unrolled and/or padded portion of the convolution. This has more checks and is
/// much slower, so we want to minimize the amount of pixels that need to be processed by this
///
/// SAFETY: `oc` must be an index that's at most `out_channels - simd_lanes`, so the full vector
/// is in bounds. Weights and `out` must be channels last (with `stride == 1`).
#[allow(clippy::too_many_arguments)]
#[inline(always)]
unsafe fn conv2d_remainder<S: Simd, E: VMulAdd>(
x: ArrayView3<E>,
weights: ArrayView4<E>,
out: &mut ArrayViewMut3<E>,
bias: Vector<S, E>,
oc: usize,
ic_off: usize,
owb_end: usize,
stride_h: usize,
stride_w: usize,
dilate_h: usize,
dilate_w: usize,
pad_h: usize,
pad_w: usize,
k_height: usize,
k_width: usize,
) {
let in_channels = weights.shape()[0];
let (_, in_height, in_width) = x.dim();
let (out_height, out_width, _) = out.dim();
let oh_start = pad_h;
let oh_end = out_height.saturating_sub(pad_h);
let ow_start = pad_w;
let height1 = in_height + pad_h;
let width1 = in_width + pad_w;
for oh in (0..oh_start).chain(oh_end..out_height) {
for ow in 0..out_width {
let mut acc = bias;
for ic in 0..in_channels {
for kh in 0..k_height {
let ih = oh * stride_h + kh * dilate_h;
if (ih < pad_h) | (ih >= height1) {
continue;
}
let ih = ih - pad_h;
for kw in 0..k_width {
let iw = ow * stride_w + kw * dilate_w;
if (iw < pad_w) | (iw >= width1) {
continue;
}
let iw = iw - pad_w;
// Load a full vector from the weights. This is guaranteed to be in bounds
// as long as `oc <= out_channels - simd_lanes` and out channels are last.
// We need to ensure the weights are reshaped appropriately.
let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) };
// The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the
// compiler can't prove this. We can't use `as_slice` with fixed bounds
// because we want to support arbitrary input layouts. So an unchecked load
// is used.
let i0 = unsafe { x.uget([ic, ih, iw]) }.splat::<S>();
acc = i0.mul_add(f0, acc);
}
}
}
// Store a full vector from the output. This is guaranteed to be in bounds
// as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with
// channels last, so this always holds.
unsafe { vstore_unaligned(&mut out[[oh, ow, oc]], acc) };
}
}
for ow in (0..ow_start).chain(owb_end..out_width) {
for oh in 0..out_height {
let mut acc = bias;
for ic in 0..in_channels {
for kh in 0..k_height {
let ih = oh * stride_h + kh * dilate_h;
if (ih < pad_h) | (ih >= height1) {
continue;
}
let ih = ih - pad_h;
for kw in 0..k_width {
let iw = ow * stride_w + kw * dilate_w;
if (iw < pad_w) | (iw >= width1) {
continue;
}
let iw = iw - pad_w;
// Load a full vector from the weights. This is guaranteed to be in bounds
// as long as `oc <= out_channels - simd_lanes` and out channels are last.
// We need to ensure the weights are reshaped appropriately.
let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) };
// The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the
// compiler can't prove this. We can't use `as_slice` with fixed bounds
// because we want to support arbitrary input layouts. So an unchecked load
// is used.
let i0 = unsafe { x.uget([ic_off + ic, ih, iw]) }.splat::<S>();
acc = i0.mul_add(f0, acc);
}
}
}
// Store a full vector from the output. This is guaranteed to be in bounds
// as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with
// channels last, so this always holds.
unsafe { vstore_unaligned(&mut out[[oh, ow, oc]], acc) };
}
}
}
macro_rules! inner_with_register_blocking_size {
($rb: literal) => {
/// Execute the unrolled and unpadded portion of the convolution. Any pixel that is more than
/// `pad_h` away from the horizontal border, and `pad_w` away from the vertical border is
/// guaranteed to always be in bounds (because of the way out size is calculated).
///
/// SAFETY: `oc` must be an index that's at most `out_channels - simd_lanes`, so the full vector
/// is in bounds. Weights and `out` must be channels last (with `stride == 1`).
#[allow(clippy::erasing_op, clippy::identity_op, clippy::too_many_arguments)]
#[inline(always)]
unsafe fn conv2d_inner_nopad<S: Simd, E: VMulAdd>(
x: &ArrayView3<E>,
weights: &ArrayView4<E>,
out: &mut ArrayViewMut2<E>,
bias: Vector<S, E>,
oh: usize,
ow: usize,
oc: usize,
ic_off: usize,
stride_h: usize,
stride_w: usize,
dilate_h: usize,
dilate_w: usize,
k_height: usize,
k_width: usize,
pad_h: usize,
pad_w: usize,
) {
let in_channels = weights.shape()[0];
seq!(N in 0..$rb {
let mut acc~N = bias;
});
for ic in 0..in_channels {
for kh in 0..k_height {
let ih = oh * stride_h + kh * dilate_h - pad_h;
for kw in 0..k_width {
// Load a full vector from the weights. This is guaranteed to be in bounds
// as long as `oc <= out_channels - simd_lanes` and out channels are last.
// We need to ensure the weights are reshaped appropriately.
let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) };
let iw = ow * stride_w + kw * dilate_w - pad_w;
seq!(N in 0..$rb {
// The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the
// compiler can't prove this. We can't use `as_slice` with fixed bounds
// because we want to support arbitrary input layouts. So an unchecked load
// is used.
let i~N = unsafe { x.uget([ic + ic_off, ih, iw + N * stride_w]) }.splat::<S>();
});
seq!(N in 0..$rb {
acc~N = i~N.mul_add(f0, acc~N);
});
}
}
}
seq!(N in 0..$rb {
// Store a full vector from the output. This is guaranteed to be in bounds
// as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with
// channels last, so this always holds.
unsafe { vstore_unaligned(&mut out[[ow + N, oc]], acc~N) };
});
}
/// Execute the unrolled and unpadded portion of the convolution. Any pixel that is more than
/// `pad_h` away from the horizontal border, and `pad_w` away from the vertical border is
/// guaranteed to always be in bounds (because of the way out size is calculated).
///
/// SAFETY: `oc` must be an index that's at most `out_channels - simd_lanes`, so the full vector
/// is in bounds. Weights and `out` must be channels last (with `stride == 1`).
#[allow(clippy::erasing_op, clippy::identity_op, clippy::too_many_arguments)]
#[inline(always)]
unsafe fn conv2d_inner_nopad_nostride<S: Simd, E: VMulAdd>(
x: &ArrayView3<E>,
weights: &ArrayView4<E>,
out: &mut ArrayViewMut2<E>,
bias: Vector<S, E>,
oh: usize,
ow: usize,
oc: usize,
ic_off: usize,
k_height: usize,
k_width: usize,
pad_h: usize,
pad_w: usize,
) {
let in_channels = weights.shape()[0];
seq!(N in 0..$rb {
let mut acc~N = bias;
});
for ic in 0..in_channels {
for kh in 0..k_height {
let ih = oh + kh - pad_h;
for kw in 0..k_width {
// Load a full vector from the weights. This is guaranteed to be in bounds
// as long as `oc <= out_channels - simd_lanes` and out channels are last.
// We need to ensure the weights are reshaped appropriately.
let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) };
let iw = ow + kw - pad_w;
seq!(N in 0..$rb {
// The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the
// compiler can't prove this. We can't use `as_slice` with fixed bounds
// because we want to support arbitrary input layouts. So an unchecked load
// is used.
let i~N = unsafe { x.uget([ic + ic_off, ih, iw + N]) }.splat::<S>();
});
seq!(N in 0..$rb {
acc~N = i~N.mul_add(f0, acc~N);
});
}
}
}
seq!(N in 0..$rb {
// Store a full vector from the output. This is guaranteed to be in bounds
// as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with
// channels last, so this always holds.
unsafe { vstore_unaligned(&mut out[[ow + N, oc]], acc~N) };
});
}
};
}
pub(crate) use inner_with_register_blocking_size;

View File

@@ -0,0 +1,394 @@
use core::{marker::PhantomData, mem::transmute};
use crate::{SharedArray, iter_range_par, run_par, sharing::UnsafeSharedRef};
use burn_backend::{DType, Element, quantization::QuantValue};
use macerator::{Simd, VOrd};
use ndarray::{Array4, s};
use nhwc::max_pool2d_nhwc;
use super::{MinMax, should_use_simd};
#[macerator::with_simd]
fn is_accelerated_impl<S: Simd, T: VOrd>(_x: PhantomData<T>) -> bool {
<T as VOrd>::is_min_max_accelerated::<S>()
}
fn is_accelerated<T: VOrd>() -> bool {
is_accelerated_impl::<T>(PhantomData)
}
macro_rules! launch_kernel {
($ty: ty, $func: ident, $x: expr, $($arg: expr),*) => {
match <$ty as Element>::dtype() {
DType::F64 if is_accelerated::<f64>() => Ok(cast($func::<f64>(cast($x), $($arg),*))),
DType::F32 if is_accelerated::<f32>() => Ok(cast($func::<f32>(cast($x), $($arg),*))),
DType::I64 if is_accelerated::<i64>() => Ok(cast($func::<i64>(cast($x), $($arg),*))),
DType::I32 if is_accelerated::<i32>() => Ok(cast($func::<i32>(cast($x), $($arg),*))),
DType::I16 if is_accelerated::<i16>() => Ok(cast($func::<i16>(cast($x), $($arg),*))),
DType::I8 if is_accelerated::<i8>() => Ok(cast($func::<i8>(cast($x), $($arg),*))),
DType::U64 if is_accelerated::<u64>() => Ok(cast($func::<u64>(cast($x), $($arg),*))),
DType::U32 if is_accelerated::<u32>() => Ok(cast($func::<u32>(cast($x), $($arg),*))),
DType::U16 if is_accelerated::<u16>() => Ok(cast($func::<u16>(cast($x), $($arg),*))),
DType::U8 if is_accelerated::<u8>() => Ok(cast($func::<u8>(cast($x), $($arg),*))),
DType::Bool if is_accelerated::<u8>() => Ok(cast($func::<u8>(cast($x), $($arg),*))),
DType::QFloat(scheme) => match scheme.value {
QuantValue::Q8F | QuantValue::Q8S if is_accelerated::<i8>() => Ok(cast($func::<i8>(cast($x), $($arg),*))),
_ => Err($x)
},
_ => Err($x),
}
};
}
pub(crate) fn try_max_pool2d_simd<E: Element>(
x: SharedArray<E>,
ksize: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> Result<SharedArray<E>, SharedArray<E>> {
let [_, c, _, _] = x.shape().try_into().unwrap();
if !should_use_simd(c) || x.strides()[1] != 1 {
return Err(x);
}
launch_kernel!(E, max_pool2d_nhwc, x, ksize, stride, padding, dilation)
}
fn cast<T, E>(tensor: SharedArray<T>) -> SharedArray<E> {
unsafe { transmute::<SharedArray<T>, SharedArray<E>>(tensor) }
}
mod nhwc {
use itertools::Itertools;
use macerator::{Simd, vload_unaligned, vstore_unaligned};
use ndarray::{ArrayView3, ArrayViewMut3, Ix4};
use seq_macro::seq;
use crate::ops::simd::lanes;
use super::*;
// Until you can use associated constants as array size, we need to hardcode this.
// The most common config (x86-v3) has 16 registers, so use half of them for accumulators.
const BLOCK_REGISTERS: usize = 8;
pub(crate) fn max_pool2d_nhwc<E: Element + VOrd + MinMax>(
x: SharedArray<E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> SharedArray<E> {
let [kernel_height, kernel_width] = kernel_size;
let [pad_h, pad_w] = padding;
let [stride_height, stride_width] = stride;
let [dilation_height, dilation_width] = dilation;
let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap();
let lanes = lanes::<E>();
let ch_block = lanes * BLOCK_REGISTERS;
let out_height = ((x_height + 2 * pad_h - dilation_height * (kernel_height - 1) - 1)
/ stride_height)
+ 1;
let out_width =
((x_width + 2 * pad_w - dilation_width * (kernel_width - 1) - 1) / stride_width) + 1;
let mut output = unsafe {
Array4::<E>::uninit((batch_size, out_height, out_width, channels)).assume_init()
};
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
let x = x.into_dimensionality::<Ix4>().unwrap();
let x = x.view();
let x = x.permuted_axes([0, 2, 3, 1]);
// Floor division ensures `blocks * lanes * blocking factor` is always `<= out_channels`.
// An exclusive loop will always have `lanes * blocking factor` elements in bounds.
let blocks = channels / ch_block;
let blocks_end = blocks * ch_block;
// Floor division means simd_end is always divisible by `lanes` and `<= out_channels`. An
// exclusive loop will always have `lanes` elements in bounds.
let simd_end = channels / lanes * lanes;
let simd_unblocked = (simd_end - blocks_end) / lanes;
let remainder = channels - simd_end;
run_par!(|| {
// SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe.
iter_range_par!(0, batch_size * blocks).for_each(|k| unsafe {
let block = k % blocks;
let b = k / blocks;
let output = unsafe_shared_out.get();
let x = x.slice(s![b, .., .., ..]);
let out = output.slice_mut(s![b, .., .., ..]);
loop_blocked(x, out, kernel_size, stride, padding, dilation, block);
});
// SAFETY: See `loop_unblocked`
iter_range_par!(0, batch_size * simd_unblocked).for_each(|k| unsafe {
let ch = (k % simd_unblocked) * lanes + blocks_end;
let b = k / simd_unblocked;
let output = unsafe_shared_out.get();
let x = x.slice(s![b, .., .., ..]);
let out = output.slice_mut(s![b, .., .., ..]);
loop_unblocked(x, out, kernel_size, stride, padding, dilation, ch);
});
// SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe.
iter_range_par!(0, batch_size * remainder).for_each(|k| unsafe {
let ch = (k % remainder) + simd_end;
let b = k / remainder;
let output = unsafe_shared_out.get();
let x = x.slice(s![b, .., .., ..]);
let out = output.slice_mut(s![b, .., .., ..]);
loop_scalar(x, out, kernel_size, stride, padding, dilation, ch);
});
});
output = output.permuted_axes([0, 3, 1, 2]);
output.into_dyn().into_shared()
}
/// Execute the blocked (unrolled) portion of the pool.
#[allow(
clippy::too_many_arguments,
clippy::erasing_op,
clippy::identity_op,
unused_mut
)]
#[inline(always)]
#[macerator::with_simd]
fn loop_blocked<'a, S: Simd, E: Element + VOrd + MinMax>(
x: ArrayView3<'a, E>,
mut out: ArrayViewMut3<'a, E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
block: usize,
) where
'a: 'a,
{
let [kernel_height, kernel_width] = kernel_size;
let [pad_h, pad_w] = padding;
let [stride_height, stride_width] = stride;
let [dilation_height, dilation_width] = dilation;
let (x_height, x_width, _) = x.dim();
let (out_height, out_width, _) = out.dim();
let lanes = E::lanes::<S>();
let ch_block = lanes * BLOCK_REGISTERS;
let min = E::MIN.splat::<S>();
// If outside padding area, kernels are guaranteed to be in bounds
for oh in pad_h..out_height.saturating_sub(pad_h) {
for ow in pad_w..out_width.saturating_sub(pad_w) {
seq!(N in 0..8 {
let mut acc~N = min;
});
let ch = block * ch_block;
let ch_end = ch + ch_block;
let mut out = out.slice_mut(s![oh, ow, ch..ch_end]);
for kh in 0..kernel_height {
let ih = oh * stride_height + kh * dilation_height - pad_h;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw * dilation_width - pad_w;
let x = x.slice(s![ih, iw, ch..ch_end]);
seq!(N in 0..8 {
// SAFETY:
// Load a full vector from x[N * lanes]. This is bounds checked by the
// slice above.
acc~N = acc~N.max(unsafe { vload_unaligned(&x[N * lanes]) });
});
}
}
seq!(N in 0..8 {
// SAFETY:
// Store a full vector to out[N * lanes]. This is bounds checked by the
// slice above.
unsafe { vstore_unaligned(&mut out[N * lanes], acc~N) };
});
}
}
// Border pixels need bounds checks
if (pad_h, pad_w) != (0, 0) {
let v_borders = (0..pad_h)
.chain(out_height.saturating_sub(pad_h)..out_height)
.cartesian_product(0..out_width);
let h_borders = (0..out_height)
.cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width));
for (oh, ow) in v_borders.chain(h_borders) {
seq!(N in 0..8 {
let mut acc~N = min;
});
let ch = block * ch_block;
let ch_end = ch + ch_block;
let mut out = out.slice_mut(s![oh, ow, ch..ch_end]);
for kh in 0..kernel_height {
let ih = oh * stride_height + kh * dilation_height;
if ih < pad_h || ih >= x_height + pad_h {
continue;
}
let ih = ih - pad_h;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw * dilation_width;
if iw < pad_w || iw >= x_width + pad_w {
continue;
}
let iw = iw - pad_w;
let x = x.slice(s![ih, iw, ch..ch_end]);
seq!(N in 0..8 {
// SAFETY:
// Load a full vector from x[N * lanes]. This is bounds checked by the
// slice above.
acc~N = acc~N.max(unsafe { vload_unaligned(&x[N * lanes]) });
});
}
}
seq!(N in 0..8 {
// SAFETY:
// Store a full vector to out[N * lanes]. This is bounds checked by the
// slice above.
unsafe { vstore_unaligned(&mut out[N * lanes], acc~N) };
});
}
}
}
/// Execute the unblocked (not unrolled) portion of the pool.
///
/// SAFETY: Safe as long as `ch + simd_lanes <= out_channels`.
#[allow(clippy::too_many_arguments, unused_mut)]
#[inline(always)]
#[macerator::with_simd]
unsafe fn loop_unblocked<'a, S: Simd, E: Element + VOrd + MinMax>(
x: ArrayView3<'a, E>,
mut out: ArrayViewMut3<'a, E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ch: usize,
) where
'a: 'a,
{
let [kernel_height, kernel_width] = kernel_size;
let [pad_h, pad_w] = padding;
let [stride_height, stride_width] = stride;
let [dilation_height, dilation_width] = dilation;
let (x_height, x_width, _) = x.dim();
let (out_height, out_width, _) = out.dim();
for oh in pad_h..out_height.saturating_sub(pad_h) {
for ow in pad_w..out_width.saturating_sub(pad_w) {
let mut acc = E::MIN.splat::<S>();
let out = &mut out[[oh, ow, ch]];
for kh in 0..kernel_height {
let ih = oh * stride_height + kh * dilation_height - pad_h;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw * dilation_width - pad_w;
// Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes`
acc = acc.max(unsafe { vload_unaligned(&x[[ih, iw, ch]]) });
}
}
// Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`.
unsafe { vstore_unaligned(out, acc) };
}
}
// Border pixels need bounds checks
if (pad_h, pad_w) != (0, 0) {
let v_borders = (0..pad_h)
.chain(out_height.saturating_sub(pad_h)..out_height)
.cartesian_product(0..out_width);
let h_borders = (0..out_height)
.cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width));
for (oh, ow) in v_borders.chain(h_borders) {
let mut acc = E::MIN.splat::<S>();
let out = &mut out[[oh, ow, ch]];
for kh in 0..kernel_height {
let ih = oh * stride_height + kh * dilation_height;
if ih < pad_h || ih >= x_height + pad_h {
continue;
}
let ih = ih - pad_h;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw * dilation_width;
if iw < pad_w || iw >= x_width + pad_w {
continue;
}
let iw = iw - pad_w;
// Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes`
acc = acc.max(unsafe { vload_unaligned(&x[[ih, iw, ch]]) });
}
}
// Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`.
unsafe { vstore_unaligned(out, acc) };
}
}
}
fn loop_scalar<E: Element + MinMax>(
x: ArrayView3<'_, E>,
mut out: ArrayViewMut3<'_, E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ch: usize,
) {
let [kernel_height, kernel_width] = kernel_size;
let [pad_h, pad_w] = padding;
let [stride_height, stride_width] = stride;
let [dilation_height, dilation_width] = dilation;
let (x_height, x_width, _) = x.dim();
let (out_height, out_width, _) = out.dim();
for oh in 0..out_height {
for ow in 0..out_width {
let mut acc = E::MIN;
for kh in 0..kernel_height {
let ih = oh * stride_height + kh * dilation_height;
if ih < pad_h || ih >= x_height + pad_h {
continue;
}
let ih = ih - pad_h;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw * dilation_width;
if iw < pad_w || iw >= x_width + pad_w {
continue;
}
let iw = iw - pad_w;
acc = acc.max(x[[ih, iw, ch]]);
}
}
out[[oh, ow, ch]] = acc;
}
}
}
}

View File

@@ -0,0 +1,10 @@
pub(crate) mod avgpool;
mod base;
pub(crate) mod binary;
pub(crate) mod binary_elemwise;
pub(crate) mod cmp;
pub(crate) mod conv;
pub(crate) mod maxpool;
pub(crate) mod unary;
pub use base::*;

View File

@@ -0,0 +1,234 @@
use core::marker::PhantomData;
use bytemuck::cast;
use macerator::{
Scalar, Simd, VAbs, VBitNot, VRecip, Vector, vload, vload_unaligned, vstore, vstore_unaligned,
};
use ndarray::ArrayD;
use num_traits::Signed;
use seq_macro::seq;
use crate::{NdArrayElement, SharedArray};
use super::should_use_simd;
pub trait SimdUnop<T: Scalar, Out: Scalar> {
fn apply_vec<S: Simd>(input: Vector<S, T>) -> Vector<S, Out>;
fn apply(input: T) -> Out;
fn is_accelerated<S: Simd>() -> bool;
}
pub struct RecipVec;
impl SimdUnop<f32, f32> for RecipVec {
fn apply_vec<S: Simd>(input: Vector<S, f32>) -> Vector<S, f32> {
input.recip()
}
fn apply(input: f32) -> f32 {
input.recip()
}
fn is_accelerated<S: Simd>() -> bool {
<f32 as VRecip>::is_accelerated::<S>()
}
}
pub struct VecAbs;
impl<T: VAbs + Signed> SimdUnop<T, T> for VecAbs {
fn apply_vec<S: Simd>(input: Vector<S, T>) -> Vector<S, T> {
input.abs()
}
fn apply(input: T) -> T {
input.abs()
}
fn is_accelerated<S: Simd>() -> bool {
<T as VAbs>::is_accelerated::<S>()
}
}
pub struct VecBitNot;
impl<T: VBitNot> SimdUnop<T, T> for VecBitNot {
fn apply_vec<S: Simd>(input: Vector<S, T>) -> Vector<S, T> {
!input
}
fn apply(input: T) -> T {
input.not()
}
fn is_accelerated<S: Simd>() -> bool {
<T as VBitNot>::is_accelerated::<S>()
}
}
#[macerator::with_simd]
fn is_accelerated<S: Simd, T: Scalar, Out: Scalar, Op: SimdUnop<T, Out>>(
_x: PhantomData<(T, Out, Op)>,
) -> bool {
Op::is_accelerated::<S>()
}
pub fn try_unary_simd<
E: NdArrayElement,
EOut: NdArrayElement,
T: NdArrayElement + Scalar,
Out: NdArrayElement + Scalar,
Op: SimdUnop<T, Out>,
>(
input: SharedArray<E>,
) -> Result<SharedArray<EOut>, SharedArray<E>> {
if !should_use_simd(input.len())
|| input.as_slice_memory_order().is_none()
|| !is_accelerated::<T, Out, Op>(PhantomData)
{
return Err(input);
}
// Used to assert traits based on the dynamic `DType`.
let input = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(input) };
let out = if size_of::<T>() == size_of::<Out>()
&& align_of::<T>() >= align_of::<Out>()
&& input.is_unique()
{
unsafe { unary_scalar_simd_inplace::<T, Out, Op>(input) }
} else {
unary_scalar_simd_owned::<T, Out, Op>(input)
};
// Used to assert traits based on the dynamic `DType`.
let out = unsafe { core::mem::transmute::<SharedArray<Out>, SharedArray<EOut>>(out) };
Ok(out)
}
/// Execute operation in line.
/// SAFETY:
/// Must ensure `size_of::<T> == size_of::<Out>` and `align_of::<T> >= align_of::<Out>`.
unsafe fn unary_scalar_simd_inplace<
T: NdArrayElement + Scalar,
Out: NdArrayElement + Scalar,
Op: SimdUnop<T, Out>,
>(
input: SharedArray<T>,
) -> SharedArray<Out> {
let mut buffer = input.into_owned();
let slice = buffer.as_slice_memory_order_mut().unwrap();
// This is only called when in and out have the same size, so it's safe
unsafe { unary_slice_inplace::<T, Out, Op>(slice, PhantomData) };
// Buffer has the same elem size and is filled with the operation output, so this is safe
let out = unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<Out>>(buffer) };
out.into_shared()
}
fn unary_scalar_simd_owned<
T: NdArrayElement + Scalar,
Out: NdArrayElement + Scalar,
Op: SimdUnop<T, Out>,
>(
input: SharedArray<T>,
) -> SharedArray<Out> {
let mut out = unsafe { ArrayD::uninit(input.shape()).assume_init() };
let input = input.as_slice_memory_order().unwrap();
let out_slice = out.as_slice_memory_order_mut().unwrap();
unary_slice::<T, Out, Op>(input, out_slice, PhantomData);
out.into_shared()
}
#[allow(clippy::erasing_op, clippy::identity_op)]
#[macerator::with_simd]
fn unary_slice<
'a,
S: Simd,
T: NdArrayElement + Scalar,
Out: NdArrayElement + Scalar,
Op: SimdUnop<T, Out>,
>(
input: &'a [T],
out: &'a mut [Out],
_op: PhantomData<Op>,
) where
'a: 'a,
{
let lanes = T::lanes::<S>();
let mut chunks_input = input.chunks_exact(8 * lanes);
let mut chunks_out = out.chunks_exact_mut(8 * lanes);
while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {
seq!(N in 0..8 {
// Load one full vector from `input`.
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
let s~N = unsafe { vload_unaligned(&input[N * lanes]) };
let s~N = Op::apply_vec::<S>(s~N);
// Store one full vector to `out`.
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
unsafe { vstore_unaligned(&mut out[N * lanes], s~N) };
});
}
let mut chunks_input = chunks_input.remainder().chunks_exact(lanes);
let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes);
while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {
// Load one full vector from `input`.
// SAFETY: Guaranteed to be in bounds because `len == lanes`
let s0 = unsafe { vload_unaligned(input.as_ptr()) };
let s0 = Op::apply_vec::<S>(s0);
// Store one full vector to `out`.
// SAFETY: Guaranteed to be in bounds because `len == lanes`
unsafe { vstore_unaligned(out.as_mut_ptr(), s0) };
}
for (input, out) in chunks_input
.remainder()
.iter()
.zip(chunks_out.into_remainder())
{
*out = Op::apply(*input)
}
}
/// Execute operation in line.
/// SAFETY:
/// Must ensure `size_of::<T> == size_of::<Out>` and `align_of::<T> >= align_of::<Out>`.
#[macerator::with_simd]
unsafe fn unary_slice_inplace<
'a,
S: Simd,
T: NdArrayElement + Scalar,
Out: NdArrayElement + Scalar,
Op: SimdUnop<T, Out>,
>(
buf: &'a mut [T],
_op: PhantomData<(Out, Op)>,
) where
'a: 'a,
{
let (head, main, tail) = unsafe { buf.align_to_mut::<Vector<S, T>>() };
for elem in head.iter_mut().chain(tail) {
*elem = cast(Op::apply(*elem));
}
let mut chunks = main.chunks_exact_mut(8);
for elem in chunks.by_ref() {
seq!(N in 0..8 {
// Load a full vector from the aligned portion of the buffer.
// SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is
// always a full vector in bounds.
let s~N = unsafe { vload(&elem[N] as *const _ as *const T) };
let s~N = Op::apply_vec::<S>(s~N);
// Store a full vector at the same position as the input. Cast is safe because `Out` is
// size and align compatible
unsafe { vstore(&mut elem[N] as *mut _ as *mut Out, s~N) };
});
}
for elem in chunks.into_remainder() {
// Load a full vector from the aligned portion of the buffer.
// SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is
// always a full vector in bounds.
let s0 = unsafe { vload(elem as *const _ as *const T) };
let s0 = Op::apply_vec::<S>(s0);
// Store a full vector at the same position as the input. Cast is safe because `Out` is
// size and align compatible
unsafe { vstore(elem as *mut _ as *mut Out, s0) };
}
}

View File

@@ -0,0 +1,688 @@
// Language
use alloc::vec::Vec;
use burn_backend::backend::ExecutionError;
use burn_backend::ops::GridSampleOptions;
use burn_backend::tensor::FloatTensor;
use burn_backend::{TensorMetadata, element::cast::ToElement};
// Current crate
use super::{
NdArrayMathOps, NdArrayOps,
matmul::{cross, matmul},
};
use crate::{
NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
};
use crate::{NdArrayDevice, SEED, slice};
use crate::{
SharedArray,
element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
};
use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
// Workspace crates
use crate::rand::get_seeded_rng;
use burn_backend::{Distribution, FloatDType, Scalar};
use burn_backend::{ElementConversion, Shape, TensorData, backend::Backend, ops::FloatTensorOps};
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float;
use libm::erf;
#[cfg(feature = "std")]
#[allow(dead_code)]
fn round_ties_even_wrapper(x: f64) -> f64 {
x.round_ties_even()
}
#[cfg(not(feature = "std"))]
#[allow(dead_code)]
fn round_ties_even_wrapper(x: f64) -> f64 {
if (x - x.floor()) == 0.5 {
(x * 0.5).round() * 2.0
} else {
x.round()
}
}
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
for NdArray<E, I, Q>
where
NdArrayTensor: From<SharedArray<E>>,
NdArrayTensor: From<SharedArray<I>>,
{
fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
NdArrayTensor::from_data(data)
}
fn float_random(
shape: Shape,
distribution: Distribution,
device: &NdArrayDevice,
) -> FloatTensor<Self> {
let mut seed = SEED.lock().unwrap();
let mut rng = seed.take().unwrap_or_else(get_seeded_rng);
let tensor = Self::float_from_data(
TensorData::random::<E, _, _>(shape, distribution, &mut rng),
device,
);
*seed = Some(rng);
tensor
}
async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
Ok(tensor.into_data())
}
fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
NdArrayDevice::Cpu
}
fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
tensor
}
fn float_empty(
shape: Shape,
device: &<NdArray<E> as Backend>::Device,
dtype: FloatDType,
) -> FloatTensor<Self> {
Self::float_zeros(shape, device, dtype)
}
fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
}
fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::add_scalar(array, rhs.elem())
})
}
fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
}
fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::sub_scalar(array, rhs.elem())
})
}
fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
}
fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::mul_scalar(array, rhs.elem())
})
}
fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
}
fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::div_scalar(array, rhs.elem())
})
}
fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
}
fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::remainder_scalar(array, rhs.elem())
})
}
fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!((lhs, rhs), matmul)
}
fn float_cross(
lhs: FloatTensor<Self>,
rhs: FloatTensor<Self>,
dim: usize,
) -> FloatTensor<Self> {
execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
}
fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::recip(array)
})
}
fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayOps::swap_dims(array, dim1, dim2)
})
}
fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayOps::reshape(array, shape)
})
}
fn float_gather(
dim: usize,
tensor: FloatTensor<Self>,
indices: NdArrayTensor,
) -> FloatTensor<Self> {
execute_with_int_dtype!(
indices,
IntElem,
|idx_array: SharedArray<IntElem>| -> NdArrayTensor {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayOps::gather(dim, array, idx_array)
})
}
)
}
fn float_scatter_add(
dim: usize,
tensor: FloatTensor<Self>,
indices: NdArrayTensor,
value: FloatTensor<Self>,
) -> FloatTensor<Self> {
execute_with_int_dtype!(
indices,
IntElem,
|idx_array: SharedArray<IntElem>| -> NdArrayTensor {
execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayOps::scatter(
dim, tensor, idx_array, value
))
}
)
}
fn float_select(
tensor: FloatTensor<Self>,
dim: usize,
indices: NdArrayTensor,
) -> FloatTensor<Self> {
execute_with_int_dtype!(
indices,
IntElem,
|idx_array: SharedArray<IntElem>| -> NdArrayTensor {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::select(array, dim, idx_array)
})
}
)
}
fn float_select_add(
tensor: FloatTensor<Self>,
dim: usize,
indices: NdArrayTensor,
value: FloatTensor<Self>,
) -> FloatTensor<Self> {
execute_with_int_dtype!(
indices,
IntElem,
|idx_array: SharedArray<IntElem>| -> NdArrayTensor {
execute_with_float_dtype!((tensor, value), |tensor, value| {
NdArrayMathOps::select_assign(tensor, dim, idx_array, value)
})
}
)
}
fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_backend::Slice]) -> FloatTensor<Self> {
slice!(tensor, slices)
}
fn float_slice_assign(
tensor: FloatTensor<Self>,
slices: &[burn_backend::Slice],
value: FloatTensor<Self>,
) -> FloatTensor<Self> {
execute_with_float_dtype!((tensor, value), |tensor, value| {
NdArrayOps::slice_assign(tensor, slices, value)
})
}
fn float_mask_where(
tensor: FloatTensor<Self>,
mask: NdArrayTensor,
value: FloatTensor<Self>,
) -> FloatTensor<Self> {
execute_with_float_dtype!((tensor, value), |tensor, value| {
NdArrayOps::mask_where(tensor, mask.bool(), value)
})
}
fn float_mask_fill(
tensor: FloatTensor<Self>,
mask: NdArrayTensor,
value: Scalar,
) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayOps::mask_fill(array, mask.bool(), value.elem())
})
}
fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
}
fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::equal_elem(array, rhs.elem())
})
}
fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
}
fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::greater_elem(array, rhs.elem())
})
}
fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
NdArrayMathOps::greater_equal(lhs, rhs)
})
}
fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::greater_equal_elem(array, rhs.elem())
})
}
fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
}
fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::lower_elem(array, rhs.elem())
})
}
fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
NdArrayMathOps::lower_equal(lhs, rhs)
})
}
fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::lower_equal_elem(array, rhs.elem())
})
}
fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
tensor
}
fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
// Use view() for zero-copy on borrowed storage
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::mean_view(array.view())
})
}
fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
// Use view() for zero-copy on borrowed storage
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::sum_view(array.view())
})
}
fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::mean_dim(array, dim)
})
}
fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::cumsum(array, dim)
})
}
fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::cumprod(array, dim)
})
}
fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::cummin(array, dim)
})
}
fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::cummax(array, dim)
})
}
fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::sum_dim(array, dim)
})
}
fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
// Use view() for zero-copy on borrowed storage
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::argmax_view::<I>(array.view(), dim)
})
}
fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
// Use view() for zero-copy on borrowed storage
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::argmin_view::<I>(array.view(), dim)
})
}
fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared()
})
}
fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array.mapv_into(|a: FloatElem| a.log_elem()).into_shared()
})
}
fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
// Use view() for zero-copy on borrowed storage
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::prod_view(array.view())
})
}
fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::prod_dim(array, dim)
})
}
fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
// Use view() for zero-copy on borrowed storage
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::max_view(array.view())
})
}
fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
// Use view() for zero-copy on borrowed storage
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::min_view(array.view())
})
}
fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array.mapv_into(|a: FloatElem| a.log1p_elem()).into_shared()
})
}
fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array
.mapv_into(|a: FloatElem| a.powf_elem(value.elem()))
.into_shared()
})
}
fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared()
})
}
fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::abs(array)
})
}
fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array
.mapv_into(|a: FloatElem| (a.to_f64()).cos().elem())
.into_shared()
})
}
fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array
.mapv_into(|a: FloatElem| (a.to_f64()).cosh().elem())
.into_shared()
})
}
fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array
.mapv_into(|a: FloatElem| (a.to_f64()).sin().elem())
.into_shared()
})
}
fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array
.mapv_into(|a: FloatElem| (a.to_f64()).sinh().elem())
.into_shared()
})
}
fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array
.mapv_into(|a: FloatElem| (a.to_f64()).tan().elem())
.into_shared()
})
}
fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array
.mapv_into(|a: FloatElem| (a.to_f64()).tanh().elem())
.into_shared()
})
}
fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array
.mapv_into(|a: FloatElem| (a.to_f64()).acos().elem())
.into_shared()
})
}
fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array
.mapv_into(|a: FloatElem| (a.to_f64()).acosh().elem())
.into_shared()
})
}
fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array
.mapv_into(|a: FloatElem| (a.to_f64()).asin().elem())
.into_shared()
})
}
fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array
.mapv_into(|a: FloatElem| (a.to_f64()).asinh().elem())
.into_shared()
})
}
fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array
.mapv_into(|a: FloatElem| (a.to_f64()).atan().elem())
.into_shared()
})
}
fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array
.mapv_into(|a: FloatElem| (a.to_f64()).atanh().elem())
.into_shared()
})
}
fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.atan2(*b))
})
}
fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array
.mapv_into(|a: FloatElem| round_ties_even_wrapper(a.to_f64()).elem())
.into_shared()
})
}
fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array
.mapv_into(|a: FloatElem| (a.to_f64()).floor().elem())
.into_shared()
})
}
fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array
.mapv_into(|a: FloatElem| (a.to_f64()).ceil().elem())
.into_shared()
})
}
fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array
.mapv_into(|a: FloatElem| (a.to_f64()).trunc().elem())
.into_shared()
})
}
fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array
.mapv_into(|a: FloatElem| erf(a.to_f64()).elem())
.into_shared()
})
}
fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
cat_with_dtype!(tensors, dim, [F64, F32])
}
fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::clamp_min(array, min.elem())
})
}
fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::clamp_max(array, max.elem())
})
}
fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::clamp(array, min.elem(), max.elem())
})
}
fn float_into_int(tensor: FloatTensor<Self>) -> NdArrayTensor {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
array.mapv(|a: FloatElem| a.elem::<I>()).into_shared()
})
}
fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.powf(*b))
})
}
fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayOps::permute(array, axes)
})
}
fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayOps::flip(array, axes)
})
}
fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayMathOps::sign_op(array)
})
}
fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayOps::expand(array, shape)
})
}
fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
cast_to_dtype(array, dtype.into())
})
}
fn float_grid_sample_2d(
tensor: FloatTensor<Self>,
grid: FloatTensor<Self>,
options: GridSampleOptions,
) -> FloatTensor<Self> {
execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
tensor, grid, options
))
}
fn float_unfold(
tensor: FloatTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
NdArrayOps::unfold(array, dim, size, step)
})
}
}

View File

@@ -0,0 +1,13 @@
use crate::{
FloatNdArrayElement, NdArray, NdArrayTensor, SharedArray,
element::{IntNdArrayElement, QuantElement},
};
use burn_backend::ops::TransactionOps;
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> TransactionOps<Self>
for NdArray<E, I, Q>
where
NdArrayTensor: From<SharedArray<E>>,
NdArrayTensor: From<SharedArray<I>>,
{
}

View File

@@ -0,0 +1,76 @@
/// Macro for running a function in parallel.
#[cfg(feature = "multi-threads")]
#[macro_export(local_inner_macros)]
macro_rules! run_par {
(
$func:expr
) => {{
use rayon::prelude::*;
#[allow(clippy::redundant_closure_call)]
rayon::scope(|_| $func())
}};
}
/// Macro for running a function in parallel.
#[cfg(not(feature = "multi-threads"))]
#[macro_export(local_inner_macros)]
macro_rules! run_par {
(
$func:expr
) => {{ $func() }};
}
/// Macro for iterating in parallel.
#[cfg(not(feature = "multi-threads"))]
#[macro_export(local_inner_macros)]
macro_rules! iter_par {
(
$iter:expr
) => {{ $iter }};
}
/// Macro for iterating in parallel.
#[cfg(feature = "multi-threads")]
#[macro_export(local_inner_macros)]
macro_rules! iter_par {
(
$iter:expr
) => {{ $iter.into_par_iter() }};
}
/// Macro for iterating in parallel.
#[cfg(feature = "multi-threads")]
#[macro_export(local_inner_macros)]
macro_rules! iter_slice_par {
(
$slice:expr
) => {{ $slice.into_par_iter() }};
}
/// Macro for iterating in parallel.
#[cfg(not(feature = "multi-threads"))]
#[macro_export(local_inner_macros)]
macro_rules! iter_slice_par {
(
$slice:expr
) => {{ $slice.iter() }};
}
/// Macro for iterating over a range in parallel.
#[cfg(feature = "multi-threads")]
#[macro_export(local_inner_macros)]
macro_rules! iter_range_par {
(
$start:expr, $end:expr
) => {{ ($start..$end).into_par_iter() }};
}
/// Macro for iterating over a range in parallel.
#[cfg(not(feature = "multi-threads"))]
#[macro_export(local_inner_macros)]
macro_rules! iter_range_par {
(
$start:expr, $end:expr
) => {{ ($start..$end) }};
}

View File

@@ -0,0 +1,36 @@
//! Random number generation utilities for burn-ndarray
#[cfg(not(feature = "std"))]
use rand::rngs::SmallRng;
#[cfg(feature = "std")]
use rand::rngs::StdRng;
/// Type alias for the RNG used by burn-ndarray
#[cfg(feature = "std")]
pub type NdArrayRng = StdRng;
#[cfg(not(feature = "std"))]
pub type NdArrayRng = SmallRng;
#[cfg(not(feature = "std"))]
use rand::SeedableRng;
/// Get a seeded random number generator
///
/// For std builds, uses OS entropy.
/// For no_std builds, uses a compile-time random seed.
#[cfg(feature = "std")]
pub fn get_seeded_rng() -> NdArrayRng {
// Use the standard implementation from burn-std
burn_std::rand::get_seeded_rng()
}
/// Get a seeded random number generator
///
/// For std builds, uses OS entropy.
/// For no_std builds, uses a compile-time random seed.
#[cfg(not(feature = "std"))]
pub fn get_seeded_rng() -> NdArrayRng {
// Use compile-time random seed for no_std
const SEED: u64 = const_random::const_random!(u64);
SmallRng::seed_from_u64(SEED)
}

View File

@@ -0,0 +1,19 @@
use core::cell::UnsafeCell;
/// Similar to `SyncUnsafeCell` see [Rust issues](https://github.com/rust-lang/rust/issues/95439).
pub(crate) struct UnsafeSharedRef<'a, T> {
cell: UnsafeCell<&'a mut T>,
}
unsafe impl<T> Sync for UnsafeSharedRef<'_, T> {}
impl<'a, T> UnsafeSharedRef<'a, T> {
pub fn new(data: &'a mut T) -> Self {
Self {
cell: UnsafeCell::new(data),
}
}
pub unsafe fn get(&self) -> &'a mut T {
unsafe { core::ptr::read(self.cell.get()) }
}
}

View File

@@ -0,0 +1,514 @@
//! Copy-on-write storage for zero-copy tensor loading.
//!
//! This module provides `NdArrayStorage<E>`, which enables true zero-copy loading
//! from burnpack files. When data is borrowed from external memory (like mmap'd files
//! or static data), it remains zero-copy until a mutating operation is performed,
//! at which point it's copied (copy-on-write semantics).
//!
//! This integrates with ndarray's existing COW patterns - operations that check
//! `is_unique()` will see borrowed data as non-unique, triggering the allocation path.
use alloc::vec::Vec;
use burn_backend::Element;
use burn_std::Bytes;
use core::mem;
use ndarray::{ArcArray, ArrayView, IxDyn};
/// Storage that supports both owned data and borrowed (zero-copy) data.
///
/// # Copy-on-Write Semantics
///
/// - **Borrowed**: Data from external source (burnpack, mmap, static).
/// Reports `is_unique() == false` to trigger copy on mutation.
/// - **Owned**: Standard `ArcArray` with built-in COW via Arc refcount.
///
/// # Example
///
/// ```ignore
/// // Zero-copy load
/// let storage = NdArrayStorage::from_borrowed(bytes, shape);
/// storage.is_unique(); // false - will copy on mutation
///
/// // Read operations use view() - zero-copy
/// let view = storage.view();
///
/// // Mutation converts to owned
/// let owned = storage.into_owned(); // Copies here
/// ```
#[derive(Debug)]
pub enum NdArrayStorage<E: Element> {
/// Borrowed from external source (e.g., burnpack zero-copy load).
/// Keeps `Bytes` alive to ensure the referenced memory is valid.
Borrowed {
/// Source bytes - keeps external memory alive via reference counting
bytes: Bytes,
/// Shape of the tensor
shape: Vec<usize>,
},
/// Standard owned storage with ArcArray COW semantics.
Owned(ArcArray<E, IxDyn>),
}
impl<E: Element> Clone for NdArrayStorage<E> {
fn clone(&self) -> Self {
match self {
// For borrowed data, clone the Bytes (cheap Arc clone) and shape
Self::Borrowed { bytes, shape } => Self::Borrowed {
bytes: bytes.clone(),
shape: shape.clone(),
},
// For owned data, clone the ArcArray (cheap Arc clone)
Self::Owned(arr) => Self::Owned(arr.clone()),
}
}
}
impl<E: Element> NdArrayStorage<E> {
/// Create borrowed storage from external bytes.
///
/// Returns the bytes and shape back on failure (misaligned or too small),
/// enabling zero-copy even for native allocations by avoiding defensive cloning.
///
/// # Requirements
///
/// The caller must ensure that:
/// - The `Bytes` contain valid data for the element type `E`
/// - The data is contiguous in row-major (C) order matching the provided shape
///
/// These requirements are upheld when loading from `TensorData` (burnpack, etc.)
/// which always stores data contiguously in row-major order.
pub fn from_borrowed(bytes: Bytes, shape: Vec<usize>) -> Result<Self, (Bytes, Vec<usize>)> {
// Validate alignment
let ptr = bytes.as_ptr();
if !(ptr as usize).is_multiple_of(mem::align_of::<E>()) {
return Err((bytes, shape));
}
// Validate size (using checked arithmetic to prevent overflow)
let num_elements = match shape
.iter()
.try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
{
Some(n) => n,
None => return Err((bytes, shape)),
};
let expected_size = match num_elements.checked_mul(mem::size_of::<E>()) {
Some(s) => s,
None => return Err((bytes, shape)),
};
if bytes.len() < expected_size {
return Err((bytes, shape));
}
Ok(Self::Borrowed { bytes, shape })
}
/// Create owned storage from an ArcArray.
#[inline]
pub fn from_owned(array: ArcArray<E, IxDyn>) -> Self {
Self::Owned(array)
}
/// Returns whether this storage is uniquely owned and can be mutated in-place.
///
/// - **Borrowed**: Always returns `false` to trigger copy-on-write.
/// - **Owned**: Delegates to `ArcArray::is_unique()`.
///
/// This integrates with existing SIMD code patterns like:
/// ```ignore
/// if tensor.is_unique() {
/// // mutate in place
/// } else {
/// // allocate new
/// }
/// ```
#[inline]
pub fn is_unique(&self) -> bool {
match self {
Self::Borrowed { .. } => false, // Force copy path
Self::Owned(arr) => arr.is_unique(),
}
}
/// Get a read-only view of the data.
///
/// This is zero-copy for both borrowed and owned variants.
#[inline]
pub fn view(&self) -> ArrayView<'_, E, IxDyn> {
match self {
Self::Borrowed { bytes, shape } => {
let ptr = bytes.as_ptr() as *const E;
let dim = IxDyn(shape);
// SAFETY:
// - `bytes` is kept alive for the lifetime of `self`
// - Alignment was validated in `from_borrowed`
// - Size was validated in `from_borrowed`
unsafe { ArrayView::from_shape_ptr(dim, ptr) }
}
Self::Owned(arr) => arr.view(),
}
}
/// Convert to owned ArcArray.
///
/// - **Borrowed**: Copies the data into a new ArcArray.
/// - **Owned + unique**: Returns the array without copying.
/// - **Owned + shared**: Clones the data.
pub fn into_owned(self) -> ArcArray<E, IxDyn> {
match self {
Self::Borrowed { bytes, shape } => {
let ptr = bytes.as_ptr() as *const E;
let dim = IxDyn(&shape);
// SAFETY: Same as view() - bytes is valid for this scope
let view = unsafe { ArrayView::from_shape_ptr(dim, ptr) };
view.to_owned().into_shared()
}
Self::Owned(arr) => arr,
}
}
/// Convert to shared ArcArray, suitable for returning from operations.
///
/// This is equivalent to `into_owned()` but named for clarity.
#[inline]
pub fn into_shared(self) -> ArcArray<E, IxDyn> {
self.into_owned()
}
/// Get the shape of the tensor.
pub fn shape(&self) -> &[usize] {
match self {
Self::Borrowed { shape, .. } => shape,
Self::Owned(arr) => arr.shape(),
}
}
/// Get the number of dimensions.
#[inline]
pub fn ndim(&self) -> usize {
self.shape().len()
}
/// Get the total number of elements.
#[inline]
pub fn len(&self) -> usize {
self.shape().iter().product()
}
/// Check if the tensor is empty.
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns `true` if this is borrowed (zero-copy) storage.
#[inline]
pub fn is_borrowed(&self) -> bool {
matches!(self, Self::Borrowed { .. })
}
/// Returns `true` if this is owned storage.
#[inline]
pub fn is_owned(&self) -> bool {
matches!(self, Self::Owned(_))
}
/// Ensure owned and return mutable reference to the ArcArray.
///
/// Converts borrowed to owned if necessary.
pub fn ensure_owned(&mut self) -> &mut ArcArray<E, IxDyn> {
if let Self::Borrowed { bytes, shape } = self {
let ptr = bytes.as_ptr() as *const E;
let dim = IxDyn(shape);
// SAFETY: Same as view()
let view = unsafe { ArrayView::from_shape_ptr(dim, ptr) };
*self = Self::Owned(view.to_owned().into_shared());
}
match self {
Self::Owned(arr) => arr,
Self::Borrowed { .. } => unreachable!(),
}
}
}
/// Convert from ArcArray to NdArrayStorage.
impl<E: Element> From<ArcArray<E, IxDyn>> for NdArrayStorage<E> {
fn from(array: ArcArray<E, IxDyn>) -> Self {
Self::Owned(array)
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use burn_std::Bytes;
#[test]
fn test_borrowed_is_not_unique() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let bytes = Bytes::from_elems(data);
let storage =
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
assert!(!storage.is_unique());
assert!(storage.is_borrowed());
}
#[test]
fn test_owned_unique_when_single_ref() {
let array = ndarray::ArrayD::from_elem(IxDyn(&[2, 2]), 1.0f32).into_shared();
let storage = NdArrayStorage::from_owned(array);
assert!(storage.is_unique());
assert!(storage.is_owned());
}
#[test]
fn test_owned_not_unique_when_cloned() {
let array = ndarray::ArrayD::from_elem(IxDyn(&[2, 2]), 1.0f32).into_shared();
let storage = NdArrayStorage::from_owned(array);
let _clone = storage.clone();
assert!(!storage.is_unique());
}
#[test]
fn test_view_zero_copy() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let bytes = Bytes::from_elems(data);
let storage =
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
let view = storage.view();
assert_eq!(view[[0, 0]], 1.0);
assert_eq!(view[[1, 1]], 4.0);
}
#[test]
fn test_into_owned_copies_borrowed() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let bytes = Bytes::from_elems(data);
let storage =
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
let owned = storage.into_owned();
assert_eq!(owned[[0, 0]], 1.0);
assert_eq!(owned[[1, 1]], 4.0);
}
#[test]
fn test_from_borrowed_validates_alignment() {
use burn_std::AllocationProperty;
// Test 1: Properly aligned data should succeed
let aligned_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let aligned_bytes = Bytes::from_elems(aligned_data);
// Verify test setup - should be 4-byte aligned for f32
assert_eq!(
(aligned_bytes.as_ptr() as usize) % core::mem::align_of::<f32>(),
0,
"Test setup: f32 data should be properly aligned"
);
let result = NdArrayStorage::<f32>::from_borrowed(aligned_bytes, vec![2, 2]);
assert!(
result.is_ok(),
"from_borrowed should succeed for properly aligned data"
);
// Test 2: Misaligned data should fail
// Create a buffer large enough to find a misaligned offset
// (static data placement varies by platform, so we find an offset dynamically)
let buffer: &[u8] = &[0u8; 32];
let shared = bytes::Bytes::from_static(buffer);
let base = shared.as_ptr() as usize;
let align = core::mem::align_of::<f32>();
// Find an offset in 1..align that produces misalignment (at least one must exist)
let misalign_offset = (1..align)
.find(|&off| !(base + off).is_multiple_of(align))
.expect("Should find a misaligned offset");
let sliced = shared.slice(misalign_offset..(misalign_offset + 16));
let misaligned_bytes = Bytes::from_shared(sliced, AllocationProperty::Other);
// Verify test setup - should NOT be 4-byte aligned
assert_ne!(
(misaligned_bytes.as_ptr() as usize) % align,
0,
"Test setup: sliced data should be misaligned for f32"
);
let result = NdArrayStorage::<f32>::from_borrowed(misaligned_bytes, vec![4]);
assert!(
result.is_err(),
"from_borrowed should return Err for misaligned data"
);
}
#[test]
fn test_insufficient_size_returns_err() {
// Create bytes that are too small for the requested shape
let data: Vec<f32> = vec![1.0, 2.0]; // 8 bytes
let bytes = Bytes::from_elems(data);
// Try to create storage for 4 elements (needs 16 bytes)
let result = NdArrayStorage::<f32>::from_borrowed(bytes, vec![4]);
assert!(
result.is_err(),
"from_borrowed should return Err when bytes are too small"
);
}
// ==========================================================================
// Zero-copy hardening tests
// These tests verify the zero-copy guarantee is maintained. If any of these
// fail, it indicates a regression in zero-copy functionality.
// ==========================================================================
#[test]
fn test_zero_copy_native_allocation() {
// CRITICAL: Verify that native allocations (Bytes::from_elems) are zero-copy
// on initial load. The view() must return a pointer to the SAME memory.
//
// Note: Native allocations copy on clone (this is expected), but the initial
// load is still zero-copy, avoiding an extra copy in the common case where
// the tensor is used without cloning.
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let bytes = Bytes::from_elems(data);
let original_ptr = bytes.as_ptr();
let storage =
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
// Initial load must be zero-copy
let view = storage.view();
let view_ptr = view.as_ptr() as *const u8;
assert_eq!(
original_ptr, view_ptr,
"ZERO-COPY REGRESSION: native allocation view() must return pointer to original bytes"
);
// Verify data integrity
assert_eq!(view[[0, 0]], 1.0);
assert_eq!(view[[0, 1]], 2.0);
assert_eq!(view[[1, 0]], 3.0);
assert_eq!(view[[1, 1]], 4.0);
}
#[test]
fn test_zero_copy_shared_bytes_pointer_identity() {
// CRITICAL: Test with SharedBytesAllocationController for true zero-copy.
// This simulates the actual burnpack/mmap loading path.
use burn_std::AllocationProperty;
// Create static-like data using bytes::Bytes
let data: &[u8] = &[
0, 0, 128, 63, // 1.0f32 in little-endian
0, 0, 0, 64, // 2.0f32
0, 0, 64, 64, // 3.0f32
0, 0, 128, 64, // 4.0f32
];
let shared = bytes::Bytes::from_static(data);
let original_ptr = shared.as_ptr();
// Create Bytes with SharedBytesAllocationController
let bytes = Bytes::from_shared(shared, AllocationProperty::Other);
let storage =
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
// Verify pointer identity
let view_ptr = storage.view().as_ptr() as *const u8;
assert_eq!(
original_ptr, view_ptr,
"ZERO-COPY REGRESSION: SharedBytes view must point to original static data"
);
// Clone should also share the same memory
let cloned = storage.clone();
let cloned_ptr = cloned.view().as_ptr() as *const u8;
assert_eq!(
original_ptr, cloned_ptr,
"ZERO-COPY REGRESSION: SharedBytes clone must share memory"
);
}
#[test]
fn test_clone_borrowed_stays_borrowed() {
// Verify that cloning borrowed storage produces another borrowed storage.
// Note: The underlying Bytes may or may not share memory depending on
// the allocation controller (native allocations copy, file-backed may share).
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let bytes = Bytes::from_elems(data);
let storage =
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
let cloned = storage.clone();
// Both should still be borrowed (the storage type is preserved)
assert!(
storage.is_borrowed(),
"ZERO-COPY REGRESSION: original should remain borrowed after clone"
);
assert!(
cloned.is_borrowed(),
"ZERO-COPY REGRESSION: clone should be borrowed type"
);
// Both should report not unique (important for COW behavior)
assert!(
!storage.is_unique(),
"ZERO-COPY REGRESSION: original should not be unique after clone"
);
assert!(
!cloned.is_unique(),
"ZERO-COPY REGRESSION: clone should not be unique"
);
// Data should be identical
assert_eq!(storage.view(), cloned.view(), "Clone should have same data");
}
#[test]
fn test_zero_copy_triggers_copy_on_mutation() {
// Verify that into_owned() on borrowed data creates a NEW allocation
// (this is the "copy" in copy-on-write)
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let bytes = Bytes::from_elems(data);
let original_ptr = bytes.as_ptr();
let storage =
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
assert!(storage.is_borrowed(), "should start as borrowed");
let owned = storage.into_owned();
let owned_ptr = owned.as_ptr() as *const u8;
assert_ne!(
original_ptr, owned_ptr,
"into_owned() on borrowed data MUST allocate new memory (copy-on-write)"
);
}
#[test]
fn test_borrowed_reports_not_unique() {
// CRITICAL: Borrowed storage must report is_unique() == false
// This is what triggers copy-on-write in mutation operations
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let bytes = Bytes::from_elems(data);
let storage =
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
assert!(
!storage.is_unique(),
"ZERO-COPY REGRESSION: borrowed storage MUST report is_unique() == false \
to trigger copy-on-write. If this is true, mutations will corrupt shared data!"
);
}
}

View File

@@ -0,0 +1,864 @@
use core::mem;
use burn_backend::{
DType, Element, QTensorPrimitive, Shape, TensorData, TensorMetadata,
quantization::{QParams, QuantLevel, QuantMode, QuantScheme, QuantValue},
};
use crate::NdArrayStorage;
use crate::ops::quantization::{QuantizationStrategy, SymmetricQuantization};
use alloc::vec::Vec;
use ndarray::{ArcArray, ArrayD, IxDyn};
/// Concrete storage type for ndarray (owned with COW semantics via Arc)
pub type SharedArray<E> = ArcArray<E, IxDyn>;
/// Tensor primitive used by the [ndarray backend](crate::NdArray).
///
/// Supports both owned and borrowed (zero-copy) data via `NdArrayStorage`.
/// When data is borrowed from external sources (like burnpack files),
/// it remains zero-copy until a mutating operation is performed.
#[derive(Debug, Clone)]
#[allow(missing_docs)]
pub enum NdArrayTensor {
F64(NdArrayStorage<f64>),
F32(NdArrayStorage<f32>),
I64(NdArrayStorage<i64>),
I32(NdArrayStorage<i32>),
I16(NdArrayStorage<i16>),
I8(NdArrayStorage<i8>),
U64(NdArrayStorage<u64>),
U32(NdArrayStorage<u32>),
U16(NdArrayStorage<u16>),
U8(NdArrayStorage<u8>),
Bool(NdArrayStorage<bool>),
}
impl NdArrayTensor {
/// Extract bool array, converting to owned if necessary.
pub(crate) fn bool(self) -> SharedArray<bool> {
match self {
NdArrayTensor::Bool(storage) => storage.into_shared(),
_ => unimplemented!("Expected bool tensor, got {:?}", self.dtype()),
}
}
/// Returns true if this tensor uses borrowed (zero-copy) storage.
#[inline]
pub fn is_borrowed(&self) -> bool {
macro_rules! check {
($($variant:ident),*) => {
match self {
$(NdArrayTensor::$variant(s) => s.is_borrowed(),)*
}
};
}
check!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
}
}
pub(crate) fn cast_to_dtype<E1: Element>(array: SharedArray<E1>, dtype: DType) -> NdArrayTensor
where
NdArrayTensor: From<SharedArray<E1>>,
{
fn cast<E1: Element, E2: Element>(array: SharedArray<E1>) -> SharedArray<E2> {
array.mapv(|a| a.elem()).into_shared()
}
if E1::dtype() == dtype {
return array.into();
}
match dtype {
DType::F64 => cast::<E1, f64>(array).into(),
DType::F32 => cast::<E1, f32>(array).into(),
DType::Flex32 => cast::<E1, f32>(array).into(),
DType::I64 => cast::<E1, i64>(array).into(),
DType::I32 => cast::<E1, i32>(array).into(),
DType::I16 => cast::<E1, i16>(array).into(),
DType::I8 => cast::<E1, i8>(array).into(),
DType::U64 => cast::<E1, u64>(array).into(),
DType::U32 => cast::<E1, u32>(array).into(),
DType::U16 => cast::<E1, u16>(array).into(),
DType::U8 => cast::<E1, u8>(array).into(),
DType::Bool => cast::<E1, bool>(array).into(),
dtype => panic!("Unsupported dtype: {dtype:?}"),
}
}
macro_rules! impl_from {
($($ty: ty => $dtype: ident),*) => {
// From SharedArray (owned) -> NdArrayTensor
$(impl From<SharedArray<$ty>> for NdArrayTensor {
fn from(value: SharedArray<$ty>) -> NdArrayTensor {
NdArrayTensor::$dtype(NdArrayStorage::from_owned(value))
}
})*
// From NdArrayStorage -> NdArrayTensor
$(impl From<NdArrayStorage<$ty>> for NdArrayTensor {
fn from(value: NdArrayStorage<$ty>) -> NdArrayTensor {
NdArrayTensor::$dtype(value)
}
})*
};
}
impl_from!(
f64 => F64, f32 => F32,
i64 => I64, i32 => I32, i16 => I16, i8 => I8,
u64 => U64, u32 => U32, u16 => U16, u8 => U8,
bool => Bool
);
/// Macro to execute an operation on a given element type.
///
/// Extracts the storage from NdArrayTensor, converts to SharedArray, and passes to operation.
///
/// # Panics
/// Since there is no automatic type cast at this time, binary operations for different
/// floating point precision data types will panic with a data type mismatch.
#[macro_export]
macro_rules! execute_with_dtype {
(($lhs:expr, $rhs:expr),$element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
let lhs_dtype = burn_backend::TensorMetadata::dtype(&$lhs);
let rhs_dtype = burn_backend::TensorMetadata::dtype(&$rhs);
match ($lhs, $rhs) {
$(
($crate::NdArrayTensor::$dtype(lhs), $crate::NdArrayTensor::$dtype(rhs)) => {
#[allow(unused)]
type $element = $ty;
// Convert storage to SharedArray for compatibility with existing operations
$op(lhs.into_shared(), rhs.into_shared()).into()
}
)*
_ => panic!(
"Data type mismatch (lhs: {:?}, rhs: {:?})",
lhs_dtype, rhs_dtype
),
}
}};
// Binary op: type automatically inferred by the compiler
(($lhs:expr, $rhs:expr), $op:expr) => {{
$crate::execute_with_dtype!(($lhs, $rhs), E, $op)
}};
// Binary op: generic type cannot be inferred for an operation
(($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
$crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
F64 => f64, F32 => f32,
I64 => i64, I32 => i32, I16 => i16, I8 => i8,
U64 => u64, U32 => u32, U16 => u16, U8 => u8,
Bool => bool
])
}};
($tensor:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
match $tensor {
$(
$crate::NdArrayTensor::$dtype(storage) => {
#[allow(unused)]
type $element = $ty;
// Convert to SharedArray for compatibility with most operations
$op(storage.into_shared()).into()
}
)*
#[allow(unreachable_patterns)]
other => unimplemented!("unsupported dtype: {:?}", other.dtype())
}
}};
// Unary op: type automatically inferred by the compiler
($tensor:expr, $op:expr) => {{
$crate::execute_with_dtype!($tensor, E, $op)
}};
// Unary op: generic type cannot be inferred for an operation
($tensor:expr, $element:ident, $op:expr) => {{
$crate::execute_with_dtype!($tensor, $element, $op, [
F64 => f64, F32 => f32,
I64 => i64, I32 => i32, I16 => i16, I8 => i8,
U64 => u64, U32 => u32, U16 => u16, U8 => u8,
Bool => bool
])
}};
}
/// Macro to execute an operation a given element type.
/// Only handles float types.
///
/// # Panics
/// Since there is no automatic type cast at this time, binary operations for different
/// floating point precision data types will panic with a data type mismatch.
#[macro_export]
macro_rules! execute_with_float_dtype {
// Binary op: type automatically inferred by the compiler
(($lhs:expr, $rhs:expr), $op:expr) => {{
$crate::execute_with_float_dtype!(($lhs, $rhs), E, $op)
}};
// Binary op: generic type cannot be inferred for an operation
(($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
$crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
F64 => f64, F32 => f32
])
}};
// Unary op: type automatically inferred by the compiler
($tensor:expr, $op:expr) => {{
$crate::execute_with_float_dtype!($tensor, E, $op)
}};
// Unary op: generic type cannot be inferred for an operation
($tensor:expr, $element:ident, $op:expr) => {{
$crate::execute_with_dtype!($tensor, $element, $op, [
F64 => f64, F32 => f32
])
}};
}
/// Macro to execute an operation a given element type.
/// Only handles int types.
///
/// # Panics
/// Since there is no automatic type cast at this time, binary operations for different
/// floating point precision data types will panic with a data type mismatch.
#[macro_export]
macro_rules! execute_with_int_dtype {
// Binary op: type automatically inferred by the compiler
(($lhs:expr, $rhs:expr), $op:expr) => {{
$crate::execute_with_int_dtype!(($lhs, $rhs), E, $op)
}};
// Binary op: generic type cannot be inferred for an operation
(($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
$crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
I64 => i64, I32 => i32, I16 => i16, I8 => i8,
U64 => u64, U32 => u32, U16 => u16, U8 => u8
])
}};
// Unary op: type automatically inferred by the compiler
($tensor:expr, $op:expr) => {{
$crate::execute_with_int_dtype!($tensor, E, $op)
}};
// Unary op: generic type cannot be inferred for an operation
($tensor:expr, $element:ident, $op:expr) => {{
$crate::execute_with_dtype!($tensor, $element, $op, [
I64 => i64, I32 => i32, I16 => i16, I8 => i8,
U64 => u64, U32 => u32, U16 => u16, U8 => u8
])
}};
}
/// Macro to execute an operation a given element type.
/// Only handles numeric types
///
/// # Panics
/// Since there is no automatic type cast at this time, binary operations for different
/// floating point precision data types will panic with a data type mismatch.
#[macro_export]
macro_rules! execute_with_numeric_dtype {
// Binary op: type automatically inferred by the compiler
(($lhs:expr, $rhs:expr), $op:expr) => {{
$crate::execute_with_numeric_dtype!(($lhs, $rhs), E, $op)
}};
// Binary op: generic type cannot be inferred for an operation
(($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
$crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
F64 => f64, F32 => f32,
I64 => i64, I32 => i32, I16 => i16, I8 => i8,
U64 => u64, U32 => u32, U16 => u16, U8 => u8
])
}};
// Unary op: type automatically inferred by the compiler
($tensor:expr, $op:expr) => {{
$crate::execute_with_numeric_dtype!($tensor, E, $op)
}};
// Unary op: generic type cannot be inferred for an operation
($tensor:expr, $element:ident, $op:expr) => {{
$crate::execute_with_dtype!($tensor, $element, $op, [
F64 => f64, F32 => f32,
I64 => i64, I32 => i32, I16 => i16, I8 => i8,
U64 => u64, U32 => u32, U16 => u16, U8 => u8
])
}};
}
/// Macro to execute a cat operation on a given set of element types.
///
/// Uses zero-copy views from storage for concatenation.
///
/// # Panics
/// Since there is no automatic type cast at this time, binary operations for different
/// floating point precision data types will panic with a data type mismatch.
#[macro_export]
macro_rules! cat_with_dtype {
($tensors: expr, $dim: expr, [$($dtype: ident),*]) => {
match &$tensors[0] {
$(NdArrayTensor::$dtype(_) => {
let tensors = $tensors
.iter()
.map(|t| {
if let NdArrayTensor::$dtype(storage) = t {
// Use storage.view() for zero-copy access
storage.view()
} else {
panic!("Concatenate data type mismatch (expected {:?}, got {:?})", $tensors[0].dtype(), t.dtype())
}
})
.collect::<Vec<_>>();
NdArrayOps::concatenate(&tensors, $dim).into()
})*
_ => panic!("Unsupported dtype: {:?}", $tensors[0].dtype())
}
};
}
impl TensorMetadata for NdArrayTensor {
fn dtype(&self) -> DType {
match self {
NdArrayTensor::F64(_) => DType::F64,
NdArrayTensor::F32(_) => DType::F32,
NdArrayTensor::I64(_) => DType::I64,
NdArrayTensor::I32(_) => DType::I32,
NdArrayTensor::I16(_) => DType::I16,
NdArrayTensor::I8(_) => DType::I8,
NdArrayTensor::U64(_) => DType::U64,
NdArrayTensor::U32(_) => DType::U32,
NdArrayTensor::U16(_) => DType::U16,
NdArrayTensor::U8(_) => DType::U8,
NdArrayTensor::Bool(_) => DType::Bool,
}
}
fn shape(&self) -> Shape {
// Use storage's shape method (works for both borrowed and owned)
macro_rules! get_shape {
($($variant:ident),*) => {
match self {
$(NdArrayTensor::$variant(storage) => Shape::from(storage.shape().to_vec()),)*
}
};
}
get_shape!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
}
fn rank(&self) -> usize {
self.shape().num_dims()
}
}
pub(crate) trait ShapeOps {
fn num_dims(self) -> usize;
fn num_elements(self) -> usize;
fn dims<const N: usize>(self) -> [usize; N];
fn into_shape(self) -> Shape;
}
impl ShapeOps for &[usize] {
fn num_dims(self) -> usize {
self.len()
}
fn num_elements(self) -> usize {
self.iter().product()
}
fn dims<const N: usize>(self) -> [usize; N] {
self.try_into().unwrap()
}
fn into_shape(self) -> Shape {
Shape::from(self)
}
}
mod utils {
use burn_std::tensor::is_contiguous;
use super::*;
impl NdArrayTensor {
pub(crate) fn into_data(self) -> TensorData {
let shape = self.shape();
let contiguous = self.is_contiguous();
fn inner<E: Element>(
shape: Shape,
is_contiguous: bool,
array: ArcArray<E, IxDyn>,
) -> TensorData {
let vec = if is_contiguous {
match array.try_into_owned_nocopy() {
Ok(owned) => {
let (mut vec, offset) = owned.into_raw_vec_and_offset();
if let Some(offset) = offset {
vec.drain(..offset);
}
if vec.len() > shape.num_elements() {
vec.drain(shape.num_elements()..vec.len());
}
vec
}
Err(array) => array.into_iter().collect(),
}
} else {
array.into_iter().collect()
};
TensorData::new(vec, shape)
}
// Convert storage to owned array before extracting data
execute_with_dtype!(self, |arr| inner(shape, contiguous, arr))
}
pub(crate) fn is_contiguous(&self) -> bool {
// For borrowed data, we assume it's contiguous (it came from TensorData which is contiguous)
// For owned data, we check the strides
macro_rules! check_contiguous {
($($variant:ident),*) => {
match self {
$(NdArrayTensor::$variant(storage) => {
match storage {
NdArrayStorage::Borrowed { .. } => {
// Borrowed storage requires contiguous row-major data
// (see NdArrayStorage::from_borrowed documentation)
true
}
NdArrayStorage::Owned(array) => {
let shape = array.shape();
let mut strides = Vec::with_capacity(array.strides().len());
for &stride in array.strides() {
if stride <= 0 {
return false;
}
strides.push(stride as usize);
}
is_contiguous(shape, &strides)
}
}
})*
}
};
}
check_contiguous!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
}
}
}
/// Converts a slice of usize to a typed dimension.
#[macro_export(local_inner_macros)]
macro_rules! to_typed_dims {
(
$n:expr,
$dims:expr,
justdim
) => {{
let mut dims = [0; $n];
for i in 0..$n {
dims[i] = $dims[i];
}
let dim: Dim<[usize; $n]> = Dim(dims);
dim
}};
}
/// Reshapes an array into a tensor.
#[macro_export(local_inner_macros)]
macro_rules! reshape {
(
ty $ty:ty,
n $n:expr,
shape $shape:expr,
array $array:expr
) => {{
let dim = $crate::to_typed_dims!($n, $shape, justdim);
let array = match $array.is_standard_layout() {
true => {
match $array.to_shape(dim) {
Ok(val) => val.into_shared(),
Err(err) => {
core::panic!("Shape should be compatible shape={dim:?}: {err:?}");
}
}
},
false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(),
};
array.into_dyn()
}};
(
ty $ty:ty,
shape $shape:expr,
array $array:expr,
d $D:expr
) => {{
match $D {
1 => reshape!(ty $ty, n 1, shape $shape, array $array),
2 => reshape!(ty $ty, n 2, shape $shape, array $array),
3 => reshape!(ty $ty, n 3, shape $shape, array $array),
4 => reshape!(ty $ty, n 4, shape $shape, array $array),
5 => reshape!(ty $ty, n 5, shape $shape, array $array),
6 => reshape!(ty $ty, n 6, shape $shape, array $array),
_ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D),
}
}};
}
/// Slice a tensor
#[macro_export]
macro_rules! slice {
($tensor:expr, $slices:expr) => {
slice!($tensor, $slices, F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
};
($tensor:expr, $slices:expr, $($variant:ident),*) => {
match $tensor {
$(NdArrayTensor::$variant(s) => { NdArrayOps::slice(s.view(), $slices).into() })*
}
};
}
impl NdArrayTensor {
/// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData).
///
/// This method attempts zero-copy loading when possible. If the data has properly
/// aligned bytes that can be borrowed, it creates a borrowed tensor. Otherwise,
/// it falls back to copying the data.
///
/// Zero-copy loading works when:
/// - The data's bytes are properly aligned for the element type
/// - The bytes can be borrowed (e.g., from mmap'd file or static data)
pub fn from_data(data: TensorData) -> NdArrayTensor {
// Try borrowed storage first, fall back to owned if not possible
match Self::try_from_data_borrowed(data) {
Ok(tensor) => tensor,
Err(data) => Self::from_data_owned(data),
}
}
/// Try to create a tensor with borrowed storage (zero-copy).
///
/// Takes ownership of TensorData and returns it back on failure.
/// No cloning occurs - bytes are moved into storage or returned on failure.
///
/// Returns `Err(data)` if borrowing is not possible (e.g., misaligned data).
fn try_from_data_borrowed(data: TensorData) -> Result<NdArrayTensor, TensorData> {
let TensorData {
bytes,
shape,
dtype,
} = data;
macro_rules! try_borrow {
($ty:ty, $variant:ident, $bytes:expr, $shape:expr) => {
match NdArrayStorage::<$ty>::from_borrowed($bytes, $shape) {
Ok(storage) => return Ok(NdArrayTensor::$variant(storage)),
Err((bytes, shape)) => (bytes, shape),
}
};
}
// Try to create borrowed storage; get bytes back on failure
let (bytes, shape) = match dtype {
DType::F64 => try_borrow!(f64, F64, bytes, shape),
DType::F32 => try_borrow!(f32, F32, bytes, shape),
DType::I64 => try_borrow!(i64, I64, bytes, shape),
DType::I32 => try_borrow!(i32, I32, bytes, shape),
DType::I16 => try_borrow!(i16, I16, bytes, shape),
DType::I8 => try_borrow!(i8, I8, bytes, shape),
DType::U64 => try_borrow!(u64, U64, bytes, shape),
DType::U32 => try_borrow!(u32, U32, bytes, shape),
DType::U16 => try_borrow!(u16, U16, bytes, shape),
DType::U8 => try_borrow!(u8, U8, bytes, shape),
DType::Bool => try_borrow!(bool, Bool, bytes, shape),
_ => (bytes, shape), // QFloat not supported for zero-copy
};
Err(TensorData {
bytes,
shape,
dtype,
})
}
/// Create a tensor with owned storage.
///
/// This may or may not copy data depending on whether the underlying bytes
/// can be reclaimed (via `try_into_vec`). If bytes are uniquely owned,
/// no copy occurs; otherwise data is copied to a new allocation.
fn from_data_owned(mut data: TensorData) -> NdArrayTensor {
let shape = mem::take(&mut data.shape);
macro_rules! execute {
($data: expr, [$($dtype: ident => $ty: ty),*]) => {
match $data.dtype {
$(DType::$dtype => {
match data.into_vec::<$ty>() {
// Safety: TensorData checks shape validity on creation
Ok(vec) => unsafe { ArrayD::from_shape_vec_unchecked(shape, vec) }.into_shared(),
Err(err) => panic!("Data should have the same element type as the tensor {err:?}"),
}.into()
},)*
other => unimplemented!("Unsupported dtype {other:?}"),
}
};
}
execute!(data, [
F64 => f64, F32 => f32,
I64 => i64, I32 => i32, I16 => i16, I8 => i8,
U64 => u64, U32 => u32, U16 => u16, U8 => u8,
Bool => bool
])
}
}
/// A quantized tensor for the ndarray backend.
#[derive(Clone, Debug)]
pub struct NdArrayQTensor {
/// The quantized tensor.
pub qtensor: NdArrayTensor,
/// The quantization scheme.
pub scheme: QuantScheme,
/// The quantization parameters.
pub qparams: Vec<QParams<f32>>,
}
impl NdArrayQTensor {
/// Returns the quantization strategy, including quantization parameters, for the given tensor.
pub fn strategy(&self) -> QuantizationStrategy {
match self.scheme {
QuantScheme {
level: QuantLevel::Tensor,
mode: QuantMode::Symmetric,
value:
QuantValue::Q8F
| QuantValue::Q8S
| QuantValue::E4M3
| QuantValue::E5M2
| QuantValue::Q4F
| QuantValue::Q4S
| QuantValue::E2M1
| QuantValue::Q2F
| QuantValue::Q2S,
..
} => QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
self.qparams[0].scales,
self.scheme.value,
)),
QuantScheme {
level: QuantLevel::Block(block_size),
mode: QuantMode::Symmetric,
value:
QuantValue::Q8F
| QuantValue::Q8S
| QuantValue::E4M3
| QuantValue::E5M2
| QuantValue::Q4F
| QuantValue::Q4S
| QuantValue::E2M1
| QuantValue::Q2F
| QuantValue::Q2S,
..
} => QuantizationStrategy::PerBlockSymmetric(
self.qparams
.iter()
.map(|q| SymmetricQuantization::init(q.scales, self.scheme.value))
.collect(),
block_size,
),
}
}
}
impl QTensorPrimitive for NdArrayQTensor {
fn scheme(&self) -> &QuantScheme {
&self.scheme
}
fn default_scheme() -> QuantScheme {
QuantScheme::default().with_store(burn_backend::quantization::QuantStore::Native)
}
}
impl TensorMetadata for NdArrayQTensor {
fn dtype(&self) -> DType {
DType::QFloat(self.scheme)
}
fn shape(&self) -> Shape {
self.qtensor.shape()
}
fn rank(&self) -> usize {
self.shape().num_dims()
}
}
#[cfg(test)]
mod tests {
use crate::NdArray;
use alloc::vec;
use super::*;
use burn_backend::{
Distribution,
ops::{FloatTensorOps, QTensorOps},
quantization::{QuantStore, QuantizationParametersPrimitive},
};
use burn_std::rand::get_seeded_rng;
#[test]
fn should_support_into_and_from_data_1d() {
let data_expected = TensorData::random::<f32, _, _>(
Shape::new([3]),
Distribution::Default,
&mut get_seeded_rng(),
);
let tensor = NdArrayTensor::from_data(data_expected.clone());
let data_actual = tensor.into_data();
assert_eq!(data_expected, data_actual);
}
#[test]
fn should_support_into_and_from_data_2d() {
let data_expected = TensorData::random::<f32, _, _>(
Shape::new([2, 3]),
Distribution::Default,
&mut get_seeded_rng(),
);
let tensor = NdArrayTensor::from_data(data_expected.clone());
let data_actual = tensor.into_data();
assert_eq!(data_expected, data_actual);
}
#[test]
fn should_support_into_and_from_data_3d() {
let data_expected = TensorData::random::<f32, _, _>(
Shape::new([2, 3, 4]),
Distribution::Default,
&mut get_seeded_rng(),
);
let tensor = NdArrayTensor::from_data(data_expected.clone());
let data_actual = tensor.into_data();
assert_eq!(data_expected, data_actual);
}
#[test]
fn should_support_into_and_from_data_4d() {
let data_expected = TensorData::random::<f32, _, _>(
Shape::new([2, 3, 4, 2]),
Distribution::Default,
&mut get_seeded_rng(),
);
let tensor = NdArrayTensor::from_data(data_expected.clone());
let data_actual = tensor.into_data();
assert_eq!(data_expected, data_actual);
}
#[test]
fn should_support_qtensor_strategy() {
type B = NdArray<f32, i64, i8>;
let scale: f32 = 0.009_019_608;
let device = Default::default();
let tensor = B::float_from_data(TensorData::from([-1.8f32, -1.0, 0.0, 0.5]), &device);
let scheme = QuantScheme::default()
.with_value(QuantValue::Q8S)
.with_store(QuantStore::Native);
let qparams = QuantizationParametersPrimitive {
scales: B::float_from_data(TensorData::from([scale]), &device),
};
let qtensor: NdArrayQTensor = B::quantize(tensor, &scheme, qparams);
assert_eq!(qtensor.scheme(), &scheme);
assert_eq!(
qtensor.strategy(),
QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
scale,
QuantValue::Q8S
))
);
}
// ==========================================================================
// Zero-copy integration tests
// These tests verify end-to-end zero-copy behavior through NdArrayTensor.
// ==========================================================================
#[test]
fn zero_copy_creates_borrowed_storage() {
// Verify that from_data creates borrowed storage when possible.
// Note: For native allocations, Bytes::clone() copies data internally,
// but the storage type (Borrowed) is preserved, which is important for
// the is_unique() behavior that triggers copy-on-write.
use burn_std::Bytes;
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let bytes = Bytes::from_elems(data);
let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32);
let tensor = NdArrayTensor::from_data(tensor_data);
match &tensor {
NdArrayTensor::F32(storage) => {
assert!(
storage.is_borrowed(),
"ZERO-COPY REGRESSION: from_data should create borrowed storage \
for properly aligned TensorData with Bytes"
);
assert!(
!storage.is_unique(),
"ZERO-COPY REGRESSION: borrowed storage must report is_unique() == false"
);
}
_ => panic!("Expected F32 tensor"),
}
}
#[test]
fn zero_copy_data_integrity() {
// Verify data is correctly accessible through borrowed storage
use burn_std::Bytes;
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let bytes = Bytes::from_elems(data);
let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32);
let tensor = NdArrayTensor::from_data(tensor_data);
match &tensor {
NdArrayTensor::F32(storage) => {
let view = storage.view();
assert_eq!(view[[0, 0]], 1.0);
assert_eq!(view[[0, 1]], 2.0);
assert_eq!(view[[1, 0]], 3.0);
assert_eq!(view[[1, 1]], 4.0);
}
_ => panic!("Expected F32 tensor"),
}
}
#[test]
fn zero_copy_fallback_when_bytes_owned() {
// When TensorData owns bytes exclusively, it may use the copy path
// This is expected behavior - verify it still works correctly
let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
let tensor = NdArrayTensor::from_data(data.clone());
let result = tensor.into_data();
assert_eq!(data, result, "Data should round-trip correctly");
}
}