feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
@@ -0,0 +1,221 @@
|
||||
use crate::rand::NdArrayRng;
|
||||
use crate::{NdArrayQTensor, NdArrayTensor};
|
||||
use crate::{
|
||||
SharedArray,
|
||||
element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
|
||||
};
|
||||
use alloc::string::String;
|
||||
use burn_backend::quantization::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue};
|
||||
use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
|
||||
use burn_backend::{Backend, DType, DeviceId, DeviceOps};
|
||||
use burn_ir::{BackendIr, HandleKind, TensorHandle};
|
||||
use burn_std::stub::Mutex;
|
||||
use core::marker::PhantomData;
|
||||
use rand::SeedableRng;
|
||||
|
||||
pub(crate) static SEED: Mutex<Option<NdArrayRng>> = Mutex::new(None);
|
||||
|
||||
/// The device type for the ndarray backend.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
|
||||
pub enum NdArrayDevice {
|
||||
/// The CPU device.
|
||||
#[default]
|
||||
Cpu,
|
||||
}
|
||||
|
||||
impl DeviceOps for NdArrayDevice {}
|
||||
|
||||
impl burn_backend::Device for NdArrayDevice {
|
||||
fn from_id(_device_id: DeviceId) -> Self {
|
||||
Self::Cpu
|
||||
}
|
||||
|
||||
fn to_id(&self) -> DeviceId {
|
||||
DeviceId {
|
||||
type_id: 0,
|
||||
index_id: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn device_count(_type_id: u16) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
/// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations.
|
||||
///
|
||||
/// This backend is compatible with CPUs and can be compiled for almost any platform, including
|
||||
/// `wasm`, `arm`, and `x86`.
|
||||
#[derive(Clone, Copy, Default, Debug)]
|
||||
pub struct NdArray<E = f32, I = i64, Q = i8>
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<E>>,
|
||||
NdArrayTensor: From<SharedArray<I>>,
|
||||
{
|
||||
_e: PhantomData<E>,
|
||||
_i: PhantomData<I>,
|
||||
_q: PhantomData<Q>,
|
||||
}
|
||||
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> Backend for NdArray<E, I, Q>
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<E>>,
|
||||
NdArrayTensor: From<SharedArray<I>>,
|
||||
{
|
||||
type Device = NdArrayDevice;
|
||||
|
||||
type FloatTensorPrimitive = NdArrayTensor;
|
||||
type FloatElem = E;
|
||||
|
||||
type IntTensorPrimitive = NdArrayTensor;
|
||||
type IntElem = I;
|
||||
|
||||
type BoolTensorPrimitive = NdArrayTensor;
|
||||
type BoolElem = bool;
|
||||
|
||||
type QuantizedTensorPrimitive = NdArrayQTensor;
|
||||
|
||||
fn ad_enabled(_device: &Self::Device) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn name(_device: &Self::Device) -> String {
|
||||
String::from("ndarray")
|
||||
}
|
||||
|
||||
fn seed(_device: &Self::Device, seed: u64) {
|
||||
let rng = NdArrayRng::seed_from_u64(seed);
|
||||
let mut seed = SEED.lock().unwrap();
|
||||
*seed = Some(rng);
|
||||
}
|
||||
|
||||
fn dtype_usage(_device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
|
||||
match dtype {
|
||||
DType::F64
|
||||
| DType::F32
|
||||
| DType::Flex32
|
||||
| DType::I64
|
||||
| DType::I32
|
||||
| DType::I16
|
||||
| DType::I8
|
||||
| DType::U64
|
||||
| DType::U32
|
||||
| DType::U16
|
||||
| DType::U8
|
||||
| DType::Bool => burn_backend::DTypeUsage::general(),
|
||||
DType::F16 | DType::BF16 => burn_backend::DTypeUsageSet::empty(),
|
||||
DType::QFloat(scheme) => {
|
||||
match scheme {
|
||||
QuantScheme {
|
||||
level: QuantLevel::Tensor | QuantLevel::Block(_),
|
||||
mode: QuantMode::Symmetric,
|
||||
#[cfg(not(feature = "export_tests"))]
|
||||
value: QuantValue::Q8F | QuantValue::Q8S,
|
||||
// For tests, "native" sub-byte quant serves as a reference for value equality.
|
||||
// Values are stored as i8 regardless.
|
||||
#[cfg(feature = "export_tests")]
|
||||
value:
|
||||
QuantValue::Q8F
|
||||
| QuantValue::Q8S
|
||||
| QuantValue::Q4F
|
||||
| QuantValue::Q4S
|
||||
| QuantValue::Q2F
|
||||
| QuantValue::Q2S,
|
||||
store: QuantStore::Native,
|
||||
..
|
||||
} => burn_backend::DTypeUsage::general(),
|
||||
_scheme => burn_backend::DTypeUsageSet::empty(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BackendIr for NdArray<E, I, Q>
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<E>>,
|
||||
NdArrayTensor: From<SharedArray<I>>,
|
||||
{
|
||||
type Handle = HandleKind<Self>;
|
||||
|
||||
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
|
||||
match handle.handle {
|
||||
HandleKind::Float(handle) => handle,
|
||||
_ => panic!("Expected float handle, got {}", handle.handle.name()),
|
||||
}
|
||||
}
|
||||
|
||||
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
|
||||
match handle.handle {
|
||||
HandleKind::Int(handle) => handle,
|
||||
_ => panic!("Expected int handle, got {}", handle.handle.name()),
|
||||
}
|
||||
}
|
||||
|
||||
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
|
||||
match handle.handle {
|
||||
HandleKind::Bool(handle) => handle,
|
||||
_ => panic!("Expected bool handle, got {}", handle.handle.name()),
|
||||
}
|
||||
}
|
||||
|
||||
fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
|
||||
match handle.handle {
|
||||
HandleKind::Quantized(handle) => handle,
|
||||
_ => panic!("Expected quantized handle, got {}", handle.handle.name()),
|
||||
}
|
||||
}
|
||||
|
||||
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
|
||||
HandleKind::Float(tensor)
|
||||
}
|
||||
|
||||
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
|
||||
HandleKind::Int(tensor)
|
||||
}
|
||||
|
||||
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
|
||||
HandleKind::Bool(tensor)
|
||||
}
|
||||
|
||||
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
|
||||
HandleKind::Quantized(tensor)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_backend::QTensorPrimitive;
|
||||
|
||||
#[test]
|
||||
fn should_support_dtypes() {
|
||||
type B = NdArray<f32>;
|
||||
let device = Default::default();
|
||||
|
||||
assert!(B::supports_dtype(&device, DType::F64));
|
||||
assert!(B::supports_dtype(&device, DType::F32));
|
||||
assert!(B::supports_dtype(&device, DType::Flex32));
|
||||
assert!(B::supports_dtype(&device, DType::I64));
|
||||
assert!(B::supports_dtype(&device, DType::I32));
|
||||
assert!(B::supports_dtype(&device, DType::I16));
|
||||
assert!(B::supports_dtype(&device, DType::I8));
|
||||
assert!(B::supports_dtype(&device, DType::U64));
|
||||
assert!(B::supports_dtype(&device, DType::U32));
|
||||
assert!(B::supports_dtype(&device, DType::U16));
|
||||
assert!(B::supports_dtype(&device, DType::U8));
|
||||
assert!(B::supports_dtype(&device, DType::Bool));
|
||||
assert!(B::supports_dtype(
|
||||
&device,
|
||||
DType::QFloat(NdArrayQTensor::default_scheme())
|
||||
));
|
||||
|
||||
assert!(!B::supports_dtype(&device, DType::F16));
|
||||
assert!(!B::supports_dtype(&device, DType::BF16));
|
||||
// QuantStore::U32 not supported
|
||||
assert!(!B::supports_dtype(
|
||||
&device,
|
||||
DType::QFloat(QuantScheme::default())
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,207 @@
|
||||
use burn_backend::Element;
|
||||
use num_traits::Signed;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[allow(unused_imports)]
|
||||
use num_traits::Float;
|
||||
|
||||
use num_traits::Pow;
|
||||
|
||||
use libm::{log1p, log1pf};
|
||||
|
||||
/// A float element for ndarray backend.
|
||||
pub trait FloatNdArrayElement: NdArrayElement + Signed + core::cmp::PartialOrd<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
}
|
||||
|
||||
/// An int element for ndarray backend.
|
||||
pub trait IntNdArrayElement: NdArrayElement + core::cmp::PartialOrd<Self> {}
|
||||
|
||||
/// A general element for ndarray backend.
|
||||
pub trait NdArrayElement:
|
||||
Element
|
||||
+ ndarray::LinalgScalar
|
||||
+ ndarray::ScalarOperand
|
||||
+ ExpElement
|
||||
+ AddAssignElement
|
||||
+ num_traits::FromPrimitive
|
||||
+ core::ops::AddAssign
|
||||
+ core::cmp::PartialEq
|
||||
+ core::ops::Rem<Output = Self>
|
||||
{
|
||||
}
|
||||
|
||||
/// A element for ndarray backend that supports exp ops.
|
||||
pub trait ExpElement {
|
||||
/// Exponent
|
||||
fn exp_elem(self) -> Self;
|
||||
/// Log
|
||||
fn log_elem(self) -> Self;
|
||||
/// Log1p
|
||||
fn log1p_elem(self) -> Self;
|
||||
/// Powf
|
||||
fn powf_elem(self, value: f32) -> Self;
|
||||
/// Powi
|
||||
fn powi_elem(self, value: i32) -> Self;
|
||||
/// Sqrt
|
||||
fn sqrt_elem(self) -> Self;
|
||||
/// Abs
|
||||
fn abs_elem(self) -> Self;
|
||||
}
|
||||
|
||||
/// The addition assignment operator implemented for ndarray elements.
|
||||
pub trait AddAssignElement<Rhs = Self> {
|
||||
/// Performs the addition assignment operation.
|
||||
///
|
||||
/// For `bool`, this corresponds to logical OR assignment.
|
||||
fn add_assign(&mut self, rhs: Rhs);
|
||||
}
|
||||
|
||||
impl<E: NdArrayElement> AddAssignElement for E {
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
*self += rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl AddAssignElement for bool {
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
*self = *self || rhs; // logical OR for bool
|
||||
}
|
||||
}
|
||||
|
||||
/// A quantized element for the ndarray backend.
|
||||
pub trait QuantElement: NdArrayElement {}
|
||||
|
||||
impl QuantElement for i8 {}
|
||||
|
||||
impl FloatNdArrayElement for f64 {}
|
||||
impl FloatNdArrayElement for f32 {}
|
||||
|
||||
impl IntNdArrayElement for i64 {}
|
||||
impl IntNdArrayElement for i32 {}
|
||||
impl IntNdArrayElement for i16 {}
|
||||
impl IntNdArrayElement for i8 {}
|
||||
|
||||
impl IntNdArrayElement for u64 {}
|
||||
impl IntNdArrayElement for u32 {}
|
||||
impl IntNdArrayElement for u16 {}
|
||||
impl IntNdArrayElement for u8 {}
|
||||
|
||||
macro_rules! make_float {
|
||||
(
|
||||
$ty:ty,
|
||||
$log1p:expr
|
||||
) => {
|
||||
impl NdArrayElement for $ty {}
|
||||
|
||||
#[allow(clippy::cast_abs_to_unsigned)]
|
||||
impl ExpElement for $ty {
|
||||
#[inline(always)]
|
||||
fn exp_elem(self) -> Self {
|
||||
self.exp()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn log_elem(self) -> Self {
|
||||
self.ln()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn log1p_elem(self) -> Self {
|
||||
$log1p(self)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn powf_elem(self, value: f32) -> Self {
|
||||
self.pow(value)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn powi_elem(self, value: i32) -> Self {
|
||||
#[cfg(feature = "std")]
|
||||
let val = self.powi(value);
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
let val = Self::powf_elem(self, value as f32);
|
||||
|
||||
val
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn sqrt_elem(self) -> Self {
|
||||
self.sqrt()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn abs_elem(self) -> Self {
|
||||
self.abs()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
macro_rules! make_int {
|
||||
(
|
||||
$ty:ty,
|
||||
$abs:expr
|
||||
) => {
|
||||
impl NdArrayElement for $ty {}
|
||||
|
||||
#[allow(clippy::cast_abs_to_unsigned)]
|
||||
impl ExpElement for $ty {
|
||||
#[inline(always)]
|
||||
fn exp_elem(self) -> Self {
|
||||
(self as f32).exp() as $ty
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn log_elem(self) -> Self {
|
||||
(self as f32).ln() as $ty
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn log1p_elem(self) -> Self {
|
||||
log1pf(self as f32) as $ty
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn powf_elem(self, value: f32) -> Self {
|
||||
(self as f32).pow(value) as $ty
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn powi_elem(self, value: i32) -> Self {
|
||||
#[cfg(feature = "std")]
|
||||
let val = f32::powi(self as f32, value) as $ty;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
let val = Self::powf_elem(self, value as f32);
|
||||
|
||||
val
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn sqrt_elem(self) -> Self {
|
||||
(self as f32).sqrt() as $ty
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn abs_elem(self) -> Self {
|
||||
$abs(self)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
make_float!(f64, log1p);
|
||||
make_float!(f32, log1pf);
|
||||
|
||||
make_int!(i64, i64::wrapping_abs);
|
||||
make_int!(i32, i32::wrapping_abs);
|
||||
make_int!(i16, i16::wrapping_abs);
|
||||
make_int!(i8, i8::wrapping_abs);
|
||||
make_int!(u64, |x| x);
|
||||
make_int!(u32, |x| x);
|
||||
make_int!(u16, |x| x);
|
||||
make_int!(u8, |x| x);
|
||||
@@ -0,0 +1,29 @@
|
||||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
//! Burn ndarray backend.
|
||||
|
||||
#[cfg(any(
|
||||
feature = "blas-netlib",
|
||||
feature = "blas-openblas",
|
||||
feature = "blas-openblas-system",
|
||||
))]
|
||||
extern crate blas_src;
|
||||
|
||||
mod backend;
|
||||
mod element;
|
||||
mod ops;
|
||||
mod parallel;
|
||||
mod rand;
|
||||
mod sharing;
|
||||
mod storage;
|
||||
mod tensor;
|
||||
|
||||
pub use backend::*;
|
||||
pub use element::*;
|
||||
pub(crate) use sharing::*;
|
||||
pub(crate) use storage::*;
|
||||
pub use tensor::*;
|
||||
|
||||
extern crate alloc;
|
||||
@@ -0,0 +1,18 @@
|
||||
use crate::{
|
||||
NdArray, NdArrayTensor, SharedArray,
|
||||
element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
|
||||
execute_with_numeric_dtype,
|
||||
ops::NdArrayMathOps,
|
||||
};
|
||||
use burn_backend::{ElementConversion, TensorMetadata, ops::ActivationOps, tensor::FloatTensor};
|
||||
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ActivationOps<Self>
|
||||
for NdArray<E, I, Q>
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<E>>,
|
||||
NdArrayTensor: From<SharedArray<I>>,
|
||||
{
|
||||
fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_numeric_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, 0.elem()))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
use crate::{
|
||||
SharedArray, element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef,
|
||||
};
|
||||
use burn_backend::ElementConversion;
|
||||
use ndarray::Array4;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[allow(unused_imports)]
|
||||
use num_traits::Float;
|
||||
|
||||
pub(crate) fn adaptive_avg_pool2d<E: FloatNdArrayElement>(
|
||||
x: SharedArray<E>,
|
||||
output_size: [usize; 2],
|
||||
) -> SharedArray<E> {
|
||||
let [batch_size, channels, input_height, input_width] = x.shape().try_into().unwrap();
|
||||
|
||||
let mut output = Array4::from_elem(
|
||||
(batch_size, channels, output_size[0], output_size[1]),
|
||||
0.elem(),
|
||||
);
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
run_par!(|| {
|
||||
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
let b = k / channels;
|
||||
let c = k % channels;
|
||||
|
||||
let output = unsafe_shared_out.get();
|
||||
for h in 0..output_size[0] {
|
||||
for w in 0..output_size[1] {
|
||||
let ih_start = start_index(h, output_size[0], input_height);
|
||||
let ih_end = end_index(h, output_size[0], input_height);
|
||||
let iw_start = start_index(w, output_size[1], input_width);
|
||||
let iw_end = end_index(w, output_size[1], input_width);
|
||||
|
||||
let mut sum_val: E = 0.elem();
|
||||
|
||||
for ih in ih_start..ih_end {
|
||||
for iw in iw_start..iw_end {
|
||||
sum_val += x[[b, c, ih, iw]];
|
||||
}
|
||||
}
|
||||
|
||||
let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem();
|
||||
output[[b, c, h, w]] = sum_val / count.elem();
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
output.into_dyn().into_shared()
|
||||
}
|
||||
|
||||
pub(crate) fn adaptive_avg_pool2d_backward<E: FloatNdArrayElement>(
|
||||
x: SharedArray<E>,
|
||||
grad: SharedArray<E>,
|
||||
) -> SharedArray<E> {
|
||||
let [_, _, input_height, input_width] = x.shape().try_into().unwrap();
|
||||
let [batch_size, channels, output_height, output_width] = grad.shape().try_into().unwrap();
|
||||
|
||||
let mut output_grad =
|
||||
Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem());
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad);
|
||||
|
||||
run_par!(|| {
|
||||
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
let b = k / channels;
|
||||
let c = k % channels;
|
||||
|
||||
let output_grad = unsafe_shared_out.get();
|
||||
for oh in 0..output_height {
|
||||
for ow in 0..output_width {
|
||||
let ih_start = start_index(oh, output_height, input_height);
|
||||
let ih_end = end_index(oh, output_height, input_height);
|
||||
|
||||
let iw_start = start_index(ow, output_width, input_width);
|
||||
let iw_end = end_index(ow, output_width, input_width);
|
||||
|
||||
let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem();
|
||||
|
||||
for ih in ih_start..ih_end {
|
||||
for iw in iw_start..iw_end {
|
||||
output_grad[[b, c, ih, iw]] += grad[[b, c, oh, ow]] / count.elem();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
output_grad.into_dyn().into_shared()
|
||||
}
|
||||
|
||||
fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
|
||||
((output_size_index as f32 * input_size as f32) / output_size as f32).floor() as usize
|
||||
}
|
||||
|
||||
fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
|
||||
let index =
|
||||
(((output_size_index + 1) as f32 * input_size as f32) / output_size as f32).ceil() as usize;
|
||||
|
||||
usize::min(index, input_size)
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
use crate::{
|
||||
SharedArray, element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef,
|
||||
};
|
||||
|
||||
use burn_backend::ElementConversion;
|
||||
use burn_backend::ops::conv::calculate_pool_output_size;
|
||||
use ndarray::Array4;
|
||||
|
||||
pub(crate) fn avg_pool2d<E: FloatNdArrayElement>(
|
||||
x: SharedArray<E>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
ceil_mode: bool,
|
||||
) -> SharedArray<E> {
|
||||
let [kernel_height, kernel_width] = kernel_size;
|
||||
let [padding_height, padding_width] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap();
|
||||
|
||||
let out_height = calculate_pool_output_size(
|
||||
kernel_height,
|
||||
stride_height,
|
||||
padding_height,
|
||||
1,
|
||||
x_height,
|
||||
ceil_mode,
|
||||
);
|
||||
let out_width = calculate_pool_output_size(
|
||||
kernel_width,
|
||||
stride_width,
|
||||
padding_width,
|
||||
1,
|
||||
x_width,
|
||||
ceil_mode,
|
||||
);
|
||||
|
||||
// Padded input bounds (for count_include_pad calculation)
|
||||
let padded_height = x_height + 2 * padding_height;
|
||||
let padded_width = x_width + 2 * padding_width;
|
||||
|
||||
let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), 0.elem());
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
run_par!(|| {
|
||||
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
let b = k / channels;
|
||||
let c = k % channels;
|
||||
|
||||
let output = unsafe_shared_out.get();
|
||||
|
||||
for oh in 0..out_height {
|
||||
for ow in 0..out_width {
|
||||
let mut sum_val: E = 0.elem();
|
||||
let mut valid_count = 0usize;
|
||||
let mut padded_count = 0usize;
|
||||
|
||||
for kh in 0..kernel_height {
|
||||
let ih = oh * stride_height + kh;
|
||||
|
||||
for kw in 0..kernel_width {
|
||||
let iw = ow * stride_width + kw;
|
||||
|
||||
// Check if within padded bounds (excludes ceil_mode extensions)
|
||||
if ih < padded_height && iw < padded_width {
|
||||
padded_count += 1;
|
||||
|
||||
// Check if within valid (non-padding) input bounds
|
||||
if ih >= padding_height
|
||||
&& ih < x_height + padding_height
|
||||
&& iw >= padding_width
|
||||
&& iw < x_width + padding_width
|
||||
{
|
||||
let ih_valid = ih - padding_height;
|
||||
let iw_valid = iw - padding_width;
|
||||
sum_val += x[[b, c, ih_valid, iw_valid]];
|
||||
valid_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// count_include_pad: count positions within padded bounds (not ceil_mode extensions)
|
||||
// !count_include_pad: count only valid (non-padding) positions
|
||||
let count: E = if count_include_pad {
|
||||
(padded_count as i32).elem()
|
||||
} else {
|
||||
(valid_count as i32).elem()
|
||||
};
|
||||
|
||||
output[[b, c, oh, ow]] = sum_val / count;
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
output.into_dyn().into_shared()
|
||||
}
|
||||
|
||||
pub(crate) fn avg_pool2d_backward<E: FloatNdArrayElement>(
|
||||
x: SharedArray<E>,
|
||||
grad: SharedArray<E>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
_ceil_mode: bool,
|
||||
) -> SharedArray<E> {
|
||||
let [kernel_height, kernel_width] = kernel_size;
|
||||
let [stride_height, stride_width] = stride;
|
||||
let [padding_height, padding_width] = padding;
|
||||
let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap();
|
||||
let [_batch_size, _channels, out_height, out_width] = grad.shape().try_into().unwrap();
|
||||
|
||||
// Padded input bounds (for count_include_pad calculation)
|
||||
let padded_height = x_height + 2 * padding_height;
|
||||
let padded_width = x_width + 2 * padding_width;
|
||||
|
||||
let mut output_grad = Array4::from_elem((batch_size, channels, x_height, x_width), 0.elem());
|
||||
let unsafe_shared_grad = UnsafeSharedRef::new(&mut output_grad);
|
||||
|
||||
run_par!(|| {
|
||||
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
let b = k / channels;
|
||||
let c = k % channels;
|
||||
|
||||
let output_grad = unsafe_shared_grad.get();
|
||||
|
||||
for oh in 0..out_height {
|
||||
for ow in 0..out_width {
|
||||
let ih_start_kernel = oh * stride_height;
|
||||
let iw_start_kernel = ow * stride_width;
|
||||
|
||||
let ih_end_kernel = ih_start_kernel + kernel_height;
|
||||
let iw_end_kernel = iw_start_kernel + kernel_width;
|
||||
|
||||
// Clip to valid input bounds (for gradient distribution)
|
||||
let ih_start = usize::max(ih_start_kernel, padding_height);
|
||||
let iw_start = usize::max(iw_start_kernel, padding_width);
|
||||
let ih_end = usize::min(ih_end_kernel, x_height + padding_height);
|
||||
let iw_end = usize::min(iw_end_kernel, x_width + padding_width);
|
||||
|
||||
// Calculate count based on count_include_pad
|
||||
let count = if count_include_pad {
|
||||
// Count positions within padded bounds (not ceil_mode extensions)
|
||||
let ih_start_padded = ih_start_kernel;
|
||||
let iw_start_padded = iw_start_kernel;
|
||||
let ih_end_padded = usize::min(ih_end_kernel, padded_height);
|
||||
let iw_end_padded = usize::min(iw_end_kernel, padded_width);
|
||||
(ih_end_padded - ih_start_padded) * (iw_end_padded - iw_start_padded)
|
||||
} else {
|
||||
// Count only valid (non-padding) positions
|
||||
(ih_end - ih_start) * (iw_end - iw_start)
|
||||
};
|
||||
|
||||
for ih in ih_start..ih_end {
|
||||
for iw in iw_start..iw_end {
|
||||
let ih = ih - padding_height;
|
||||
let iw = iw - padding_width;
|
||||
|
||||
output_grad[[b, c, ih, iw]] +=
|
||||
grad[[b, c, oh, ow]] / (count as i32).elem();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
output_grad.into_dyn().into_shared()
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,220 @@
|
||||
// Language
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
use burn_backend::Scalar;
|
||||
use burn_backend::{ElementConversion, TensorMetadata, tensor::FloatTensor};
|
||||
use burn_backend::{
|
||||
backend::ExecutionError,
|
||||
ops::BoolTensorOps,
|
||||
tensor::{BoolTensor, IntTensor},
|
||||
};
|
||||
use ndarray::IntoDimension;
|
||||
|
||||
// Current crate
|
||||
use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
|
||||
use crate::{NdArray, execute_with_int_dtype, tensor::NdArrayTensor};
|
||||
use crate::{NdArrayDevice, SharedArray, slice};
|
||||
|
||||
// Workspace crates
|
||||
use burn_backend::{Shape, TensorData, backend::Backend};
|
||||
|
||||
use super::{NdArrayBoolOps, NdArrayOps};
|
||||
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolTensorOps<Self>
|
||||
for NdArray<E, I, Q>
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<E>>,
|
||||
NdArrayTensor: From<SharedArray<I>>,
|
||||
{
|
||||
fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
|
||||
if !data.dtype.is_bool() {
|
||||
unimplemented!("Unsupported dtype for `bool_from_data`")
|
||||
}
|
||||
NdArrayTensor::from_data(data)
|
||||
}
|
||||
|
||||
async fn bool_into_data(tensor: NdArrayTensor) -> Result<TensorData, ExecutionError> {
|
||||
Ok(tensor.into_data())
|
||||
}
|
||||
|
||||
fn bool_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn bool_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
|
||||
NdArrayOps::reshape(tensor.bool(), shape).into()
|
||||
}
|
||||
|
||||
fn bool_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor {
|
||||
slice!(tensor, slices)
|
||||
}
|
||||
|
||||
fn bool_into_int(tensor: NdArrayTensor) -> NdArrayTensor {
|
||||
// Use mapv directly instead of collecting to Vec and going through TensorData
|
||||
let int_array: SharedArray<I> = tensor.bool().mapv(|b| b.elem()).into_shared();
|
||||
int_array.into()
|
||||
}
|
||||
|
||||
fn bool_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {
|
||||
NdArrayDevice::Cpu
|
||||
}
|
||||
|
||||
fn bool_empty(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {
|
||||
Self::bool_zeros(shape, _device)
|
||||
}
|
||||
|
||||
fn bool_zeros(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {
|
||||
let values = vec![false; shape.num_elements()];
|
||||
NdArrayTensor::from_data(TensorData::new(values, shape))
|
||||
}
|
||||
|
||||
fn bool_ones(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {
|
||||
let values = vec![true; shape.num_elements()];
|
||||
NdArrayTensor::from_data(TensorData::new(values, shape))
|
||||
}
|
||||
|
||||
fn bool_slice_assign(
|
||||
tensor: NdArrayTensor,
|
||||
slices: &[burn_backend::Slice],
|
||||
value: NdArrayTensor,
|
||||
) -> NdArrayTensor {
|
||||
NdArrayOps::slice_assign(tensor.bool(), slices, value.bool()).into()
|
||||
}
|
||||
|
||||
fn bool_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
|
||||
NdArrayOps::cat(tensors.into_iter().map(|it| it.bool()).collect(), dim).into()
|
||||
}
|
||||
|
||||
fn bool_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
NdArrayBoolOps::equal(lhs.bool(), rhs.bool()).into()
|
||||
}
|
||||
|
||||
fn bool_not(tensor: NdArrayTensor) -> NdArrayTensor {
|
||||
tensor.bool().mapv(|a| !a).into_shared().into()
|
||||
}
|
||||
|
||||
fn bool_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
NdArrayBoolOps::and(lhs.bool(), rhs.bool()).into()
|
||||
}
|
||||
|
||||
fn bool_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
NdArrayBoolOps::or(lhs.bool(), rhs.bool()).into()
|
||||
}
|
||||
|
||||
fn bool_into_float(tensor: NdArrayTensor) -> FloatTensor<Self> {
|
||||
let arr: SharedArray<E> = tensor.bool().mapv(|b| b.elem()).into_shared();
|
||||
arr.into()
|
||||
}
|
||||
|
||||
fn bool_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
|
||||
NdArrayOps::swap_dims(tensor.bool(), dim1, dim2).into()
|
||||
}
|
||||
|
||||
fn bool_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
|
||||
tensor.bool().permuted_axes(axes.into_dimension()).into()
|
||||
}
|
||||
|
||||
fn bool_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
|
||||
NdArrayOps::expand(tensor.bool(), shape).into()
|
||||
}
|
||||
|
||||
fn bool_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
|
||||
let tensor_bool = tensor.bool();
|
||||
let indices_vec: Vec<usize> = indices
|
||||
.into_iter()
|
||||
.map(|i| i.elem::<i64>() as usize)
|
||||
.collect();
|
||||
|
||||
let selected = tensor_bool.select(ndarray::Axis(dim), &indices_vec);
|
||||
selected.into_shared().into()
|
||||
})
|
||||
}
|
||||
|
||||
fn bool_select_or(
|
||||
tensor: NdArrayTensor,
|
||||
dim: usize,
|
||||
indices: NdArrayTensor,
|
||||
value: NdArrayTensor,
|
||||
) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
|
||||
let mut output_array = tensor.bool().into_owned();
|
||||
let value_bool = value.bool();
|
||||
|
||||
for (index_value, index) in indices.into_iter().enumerate() {
|
||||
let index_usize = index.elem::<i64>() as usize;
|
||||
let mut view = output_array.index_axis_mut(ndarray::Axis(dim), index_usize);
|
||||
let value_slice = value_bool.index_axis(ndarray::Axis(dim), index_value);
|
||||
// For boolean tensors, select_assign should use logical OR operation
|
||||
view.zip_mut_with(&value_slice, |a, b| *a = *a || *b);
|
||||
}
|
||||
output_array.into_shared().into()
|
||||
})
|
||||
}
|
||||
|
||||
fn bool_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
|
||||
NdArrayOps::flip(tensor.bool(), axes).into()
|
||||
}
|
||||
|
||||
fn bool_unfold(tensor: NdArrayTensor, dim: usize, size: usize, step: usize) -> NdArrayTensor {
|
||||
NdArrayOps::unfold(tensor.bool(), dim, size, step).into()
|
||||
}
|
||||
|
||||
fn bool_mask_where(
|
||||
tensor: BoolTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
NdArrayOps::mask_where(tensor.bool(), mask.bool(), value.bool()).into()
|
||||
}
|
||||
|
||||
fn bool_mask_fill(
|
||||
tensor: BoolTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: Scalar,
|
||||
) -> BoolTensor<Self> {
|
||||
NdArrayOps::mask_fill(tensor.bool(), mask.bool(), value.elem()).into()
|
||||
}
|
||||
|
||||
fn bool_gather(
|
||||
dim: usize,
|
||||
tensor: BoolTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
execute_with_int_dtype!(indices, |indices| NdArrayOps::gather(
|
||||
dim,
|
||||
tensor.bool(),
|
||||
indices
|
||||
))
|
||||
}
|
||||
|
||||
fn bool_scatter_or(
|
||||
dim: usize,
|
||||
tensor: BoolTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
execute_with_int_dtype!(indices, |indices| NdArrayOps::scatter(
|
||||
dim,
|
||||
tensor.bool(),
|
||||
indices,
|
||||
value.bool()
|
||||
))
|
||||
}
|
||||
|
||||
fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
NdArrayBoolOps::equal_elem(lhs.bool(), rhs.elem()).into()
|
||||
}
|
||||
|
||||
fn bool_any(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
// Use view() for zero-copy on borrowed storage with short-circuit evaluation
|
||||
let result = NdArrayBoolOps::any_view(tensor.bool().view());
|
||||
NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1])))
|
||||
}
|
||||
|
||||
fn bool_all(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
// Use view() for zero-copy on borrowed storage with short-circuit evaluation
|
||||
let result = NdArrayBoolOps::all_view(tensor.bool().view());
|
||||
NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1])))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,574 @@
|
||||
use burn_backend::{
|
||||
ElementConversion,
|
||||
ops::{
|
||||
ConvOptions, ConvTransposeOptions,
|
||||
conv::{calculate_conv_output_size, calculate_conv_transpose_output_size},
|
||||
},
|
||||
};
|
||||
use ndarray::{
|
||||
Array3, Array4, Array5, ArrayView2, ArrayView3, ArrayViewMut2, ArrayViewMut3, Axis, Dim, s,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
NdArrayElement, SharedArray, iter_par, iter_range_par,
|
||||
ops::padding::{apply_padding_4d, apply_padding_5d},
|
||||
run_par,
|
||||
sharing::UnsafeSharedRef,
|
||||
tensor::NdArrayTensor,
|
||||
};
|
||||
|
||||
#[inline(always)]
|
||||
fn conv2d_mad_inner<E: NdArrayElement>(
|
||||
mut output: ArrayViewMut2<E>,
|
||||
x: ArrayView2<E>,
|
||||
k: E,
|
||||
k_xy: (usize, usize),
|
||||
out_xy: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
dilation: (usize, usize),
|
||||
) {
|
||||
let (kh, kw) = k_xy;
|
||||
let (out_width, out_height) = out_xy;
|
||||
let (stride_width, stride_height) = stride;
|
||||
let (dilation_width, dilation_height) = dilation;
|
||||
|
||||
for oh in 0..out_height {
|
||||
// Construct a sub-slice view of the input row.
|
||||
// This is done upfront so that rustc does not have to emit bounds checks
|
||||
// in the hot loop below.
|
||||
let ir = x
|
||||
.row(oh * stride_height + kh * dilation_height)
|
||||
.to_slice()
|
||||
.unwrap();
|
||||
|
||||
// Ditto. Construct a sub-slice view of the output row, and explicitly specify
|
||||
// the bounds upfront as 0..out_width so that rustc can make the assumption
|
||||
// that all accesses are in-bounds in the below loop.
|
||||
let mut or = output.row_mut(oh);
|
||||
let or = &mut or.as_slice_mut().unwrap()[0..out_width];
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for ow in 0..out_width {
|
||||
let iw = ow * stride_width + kw * dilation_width;
|
||||
or[ow] += ir[iw] * k;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn conv3d_mad_inner<E: NdArrayElement>(
|
||||
mut output: ArrayViewMut3<E>,
|
||||
x: ArrayView3<E>,
|
||||
k: E,
|
||||
k_xyz: (usize, usize, usize),
|
||||
out_xyz: (usize, usize, usize),
|
||||
stride: (usize, usize, usize),
|
||||
dilation: (usize, usize, usize),
|
||||
) {
|
||||
let (kd, kh, kw) = k_xyz;
|
||||
let (out_width, out_height, out_depth) = out_xyz;
|
||||
let (stride_width, stride_height, stride_depth) = stride;
|
||||
let (dilation_width, dilation_height, dilation_depth) = dilation;
|
||||
|
||||
for od in 0..out_depth {
|
||||
let id = od * stride_depth + kd * dilation_depth;
|
||||
|
||||
for oh in 0..out_height {
|
||||
let ih = oh * stride_height + kh * dilation_height;
|
||||
|
||||
// Construct a sub-slice view of the input row.
|
||||
// This is done upfront so that rustc does not have to emit bounds checks
|
||||
// in the hot loop below.
|
||||
let ir = x.slice(s![id, ih, ..]).to_slice().unwrap();
|
||||
|
||||
// Ditto. Construct a sub-slice view of the output row, and explicitly specify
|
||||
// the bounds upfront as 0..out_width so that rustc can make the assumption
|
||||
// that all accesses are in-bounds in the below loop.
|
||||
let or = &mut output
|
||||
.slice_mut(s![od, oh, 0..out_width])
|
||||
.into_slice()
|
||||
.unwrap()[0..out_width];
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for ow in 0..out_width {
|
||||
let iw = ow * stride_width + kw * dilation_width;
|
||||
or[ow] += ir[iw] * k;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv2d<E: NdArrayElement>(
|
||||
x: SharedArray<E>,
|
||||
weight: SharedArray<E>,
|
||||
bias: Option<SharedArray<E>>,
|
||||
options: ConvOptions<2>,
|
||||
) -> SharedArray<E>
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<E>>,
|
||||
{
|
||||
let [dilation_height, dilation_width] = options.dilation;
|
||||
let [padding_height, padding_width] = options.padding;
|
||||
let [stride_height, stride_width] = options.stride;
|
||||
let [batch_size, _in_channels, in_height, in_width] = x.shape().try_into().unwrap();
|
||||
let [out_channels, in_channels, kernel_height, kernel_width] =
|
||||
weight.shape().try_into().unwrap();
|
||||
let channels_per_group = out_channels / options.groups;
|
||||
|
||||
let out_height = calculate_conv_output_size(
|
||||
kernel_height,
|
||||
stride_height,
|
||||
padding_height,
|
||||
dilation_height,
|
||||
in_height,
|
||||
);
|
||||
let out_width = calculate_conv_output_size(
|
||||
kernel_width,
|
||||
stride_width,
|
||||
padding_width,
|
||||
dilation_width,
|
||||
in_width,
|
||||
);
|
||||
|
||||
let x = apply_padding_4d::<E>(x, options.padding, 0i32.elem());
|
||||
|
||||
// Convert inputs from dynamic indexes to static to improve perf.
|
||||
let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();
|
||||
let weights = weight.into_dimensionality::<ndarray::Ix4>().unwrap();
|
||||
|
||||
let mut output = Array3::zeros(Dim([batch_size * out_channels, out_height, out_width]));
|
||||
|
||||
run_par!(|| {
|
||||
iter_par!(output.axis_iter_mut(Axis(0)))
|
||||
.enumerate()
|
||||
.for_each(
|
||||
#[inline(never)]
|
||||
|(k, mut output)| {
|
||||
let b = k / out_channels;
|
||||
let oc = k % out_channels;
|
||||
let g = oc / channels_per_group;
|
||||
|
||||
for ic in (in_channels * g)..(in_channels * (g + 1)) {
|
||||
let weight_ic = ic - (g * in_channels);
|
||||
|
||||
let x = x.slice(s![b, ic, .., ..]);
|
||||
let k = weights.slice(s![oc, weight_ic, .., ..]);
|
||||
|
||||
for kh in 0..kernel_height {
|
||||
for kw in 0..kernel_width {
|
||||
let k = k[[kh, kw]];
|
||||
|
||||
// NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization
|
||||
// in the case that the stride/dilation is 1.
|
||||
#[allow(clippy::if_same_then_else)]
|
||||
if (1, 1, 1, 1)
|
||||
== (
|
||||
stride_width,
|
||||
stride_height,
|
||||
dilation_width,
|
||||
dilation_height,
|
||||
)
|
||||
{
|
||||
conv2d_mad_inner(
|
||||
output.view_mut(),
|
||||
x.view(),
|
||||
k,
|
||||
(kh, kw),
|
||||
(out_width, out_height),
|
||||
(stride_width, stride_height),
|
||||
(dilation_width, dilation_height),
|
||||
);
|
||||
} else {
|
||||
conv2d_mad_inner(
|
||||
output.view_mut(),
|
||||
x.view(),
|
||||
k,
|
||||
(kh, kw),
|
||||
(out_width, out_height),
|
||||
(stride_width, stride_height),
|
||||
(dilation_width, dilation_height),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(bias) = &bias {
|
||||
let bias = bias[oc];
|
||||
|
||||
for oh in 0..out_height {
|
||||
// Get a mutable slice reference to the row we're looping over.
|
||||
// We explicitly define the bounds to 0..out_width so that rustc can make
|
||||
// the assumption that all accesses are in-bounds.
|
||||
let mut or = output.row_mut(oh);
|
||||
let or = &mut or.as_slice_mut().unwrap()[0..out_width];
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for ow in 0..out_width {
|
||||
or[ow] += bias;
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
output
|
||||
.to_shape([batch_size, out_channels, out_height, out_width])
|
||||
.unwrap()
|
||||
.into_dyn()
|
||||
.into_shared()
|
||||
}
|
||||
|
||||
pub(crate) fn conv_transpose2d<E: NdArrayElement>(
|
||||
x: SharedArray<E>,
|
||||
weight: SharedArray<E>,
|
||||
bias: Option<SharedArray<E>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> SharedArray<E> {
|
||||
let [dilation_height, dilation_width] = options.dilation;
|
||||
let [padding_height, padding_width] = options.padding;
|
||||
let [stride_height, stride_width] = options.stride;
|
||||
let [out_padding_height, out_padding_width] = options.padding_out;
|
||||
let [batch_size, _in_channels, in_height, in_width] = x.shape().try_into().unwrap();
|
||||
let [in_channels, out_channels, kernel_height, kernel_width] =
|
||||
weight.shape().try_into().unwrap();
|
||||
|
||||
let out_height = calculate_conv_transpose_output_size(
|
||||
kernel_height,
|
||||
stride_height,
|
||||
padding_height,
|
||||
out_padding_height,
|
||||
dilation_height,
|
||||
in_height,
|
||||
);
|
||||
let out_width = calculate_conv_transpose_output_size(
|
||||
kernel_width,
|
||||
stride_width,
|
||||
padding_width,
|
||||
out_padding_width,
|
||||
dilation_width,
|
||||
in_width,
|
||||
);
|
||||
|
||||
let x = x;
|
||||
let mut output = Array4::zeros(Dim([
|
||||
batch_size,
|
||||
out_channels * options.groups,
|
||||
out_height,
|
||||
out_width,
|
||||
]));
|
||||
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
run_par!(|| {
|
||||
iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe {
|
||||
let b = k / (out_channels * options.groups);
|
||||
let oc = k % out_channels;
|
||||
let g = (k / out_channels) % options.groups;
|
||||
|
||||
let output = unsafe_shared_out.get();
|
||||
|
||||
let oc_out = oc + (out_channels * g);
|
||||
let ic_start = g * (in_channels / options.groups);
|
||||
let ic_end = ic_start + in_channels / options.groups;
|
||||
|
||||
for ic in ic_start..ic_end {
|
||||
for ih in 0..in_height {
|
||||
for iw in 0..in_width {
|
||||
for kh in 0..kernel_height {
|
||||
for kw in 0..kernel_width {
|
||||
let oh = ih * stride_height + kh * dilation_height;
|
||||
let ow = iw * stride_width + kw * dilation_width;
|
||||
|
||||
if oh >= out_height + padding_height
|
||||
|| ow >= out_width + padding_width
|
||||
|| oh < padding_height
|
||||
|| ow < padding_width
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let oh = oh - padding_height;
|
||||
let ow = ow - padding_width;
|
||||
|
||||
output[[b, oc_out, oh, ow]] +=
|
||||
x[[b, ic, ih, iw]] * weight[[ic, oc, kh, kw]];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(bias) = &bias {
|
||||
for oh in 0..out_height {
|
||||
for ow in 0..out_width {
|
||||
output[[b, oc_out, oh, ow]] += bias[oc_out];
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
output.into_dyn().into_shared()
|
||||
}
|
||||
|
||||
pub(crate) fn conv3d<E: NdArrayElement>(
|
||||
x: SharedArray<E>,
|
||||
weight: SharedArray<E>,
|
||||
bias: Option<SharedArray<E>>,
|
||||
options: ConvOptions<3>,
|
||||
) -> SharedArray<E>
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<E>>,
|
||||
{
|
||||
let [dilation_depth, dilation_height, dilation_width] = options.dilation;
|
||||
let [padding_depth, padding_height, padding_width] = options.padding;
|
||||
let [stride_depth, stride_height, stride_width] = options.stride;
|
||||
let [batch_size, _in_channels, in_depth, in_height, in_width] = x.shape().try_into().unwrap();
|
||||
let [
|
||||
out_channels,
|
||||
in_channels,
|
||||
kernel_depth,
|
||||
kernel_height,
|
||||
kernel_width,
|
||||
] = weight.shape().try_into().unwrap();
|
||||
let out_c_per_group = out_channels / options.groups;
|
||||
|
||||
let out_depth = calculate_conv_output_size(
|
||||
kernel_depth,
|
||||
stride_depth,
|
||||
padding_depth,
|
||||
dilation_depth,
|
||||
in_depth,
|
||||
);
|
||||
let out_height = calculate_conv_output_size(
|
||||
kernel_height,
|
||||
stride_height,
|
||||
padding_height,
|
||||
dilation_height,
|
||||
in_height,
|
||||
);
|
||||
let out_width = calculate_conv_output_size(
|
||||
kernel_width,
|
||||
stride_width,
|
||||
padding_width,
|
||||
dilation_width,
|
||||
in_width,
|
||||
);
|
||||
|
||||
let x = apply_padding_5d::<E>(x, options.padding, 0i32.elem());
|
||||
|
||||
// Convert inputs from dynamic indexes to static to improve perf.
|
||||
let x = x.into_dimensionality::<ndarray::Ix5>().unwrap();
|
||||
let weights = weight.into_dimensionality::<ndarray::Ix5>().unwrap();
|
||||
|
||||
let mut output = Array4::zeros(Dim([
|
||||
batch_size * out_channels,
|
||||
out_depth,
|
||||
out_height,
|
||||
out_width,
|
||||
]));
|
||||
|
||||
run_par!(|| {
|
||||
iter_par!(output.axis_iter_mut(Axis(0)))
|
||||
.enumerate()
|
||||
.for_each(
|
||||
#[inline(never)]
|
||||
|(k, mut output)| {
|
||||
let b = k / out_channels;
|
||||
let oc = k % out_channels;
|
||||
let g = oc / out_c_per_group;
|
||||
|
||||
for ic in (in_channels * g)..(in_channels * (g + 1)) {
|
||||
let weight_ic = ic - (g * in_channels);
|
||||
|
||||
let x = x.slice(s![b, ic, .., .., ..]);
|
||||
let k = weights.slice(s![oc, weight_ic, .., .., ..]);
|
||||
|
||||
for kd in 0..kernel_depth {
|
||||
for kh in 0..kernel_height {
|
||||
for kw in 0..kernel_width {
|
||||
let k = k[[kd, kh, kw]];
|
||||
|
||||
// NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization
|
||||
// in the case that the stride/dilation is 1.
|
||||
#[allow(clippy::if_same_then_else)]
|
||||
if (1, 1, 1, 1, 1, 1)
|
||||
== (
|
||||
stride_width,
|
||||
stride_height,
|
||||
stride_depth,
|
||||
dilation_width,
|
||||
dilation_height,
|
||||
dilation_depth,
|
||||
)
|
||||
{
|
||||
conv3d_mad_inner(
|
||||
output.view_mut(),
|
||||
x.view(),
|
||||
k,
|
||||
(kd, kh, kw),
|
||||
(out_width, out_height, out_depth),
|
||||
(stride_width, stride_height, stride_depth),
|
||||
(dilation_width, dilation_height, dilation_depth),
|
||||
);
|
||||
} else {
|
||||
conv3d_mad_inner(
|
||||
output.view_mut(),
|
||||
x.view(),
|
||||
k,
|
||||
(kd, kh, kw),
|
||||
(out_width, out_height, out_depth),
|
||||
(stride_width, stride_height, stride_depth),
|
||||
(dilation_width, dilation_height, dilation_depth),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(bias) = &bias {
|
||||
let bias = bias[oc];
|
||||
|
||||
// Get a mutable iterator to the row we're looping over.
|
||||
let orows = output.rows_mut();
|
||||
for mut or in orows {
|
||||
// We explicitly define the bounds to 0..out_width so that rustc can make
|
||||
// the assumption that all accesses are in-bounds.
|
||||
let or = &mut or.as_slice_mut().unwrap()[0..out_width];
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for ow in 0..out_width {
|
||||
or[ow] += bias;
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
output
|
||||
.to_shape([batch_size, out_channels, out_depth, out_height, out_width])
|
||||
.unwrap()
|
||||
.into_dyn()
|
||||
.into_shared()
|
||||
}
|
||||
|
||||
pub(crate) fn conv_transpose3d<E: NdArrayElement>(
|
||||
x: SharedArray<E>,
|
||||
weight: SharedArray<E>,
|
||||
bias: Option<SharedArray<E>>,
|
||||
options: ConvTransposeOptions<3>,
|
||||
) -> SharedArray<E> {
|
||||
let [dilation_depth, dilation_height, dilation_width] = options.dilation;
|
||||
let [padding_depth, padding_height, padding_width] = options.padding;
|
||||
let [stride_depth, stride_height, stride_width] = options.stride;
|
||||
let [out_padding_depth, out_padding_height, out_padding_width] = options.padding_out;
|
||||
let [batch_size, _in_channels, in_depth, in_height, in_width] = x.shape().try_into().unwrap();
|
||||
let [
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_depth,
|
||||
kernel_height,
|
||||
kernel_width,
|
||||
] = weight.shape().try_into().unwrap();
|
||||
|
||||
let out_depth = calculate_conv_transpose_output_size(
|
||||
kernel_depth,
|
||||
stride_depth,
|
||||
padding_depth,
|
||||
out_padding_depth,
|
||||
dilation_depth,
|
||||
in_depth,
|
||||
);
|
||||
let out_height = calculate_conv_transpose_output_size(
|
||||
kernel_height,
|
||||
stride_height,
|
||||
padding_height,
|
||||
out_padding_height,
|
||||
dilation_height,
|
||||
in_height,
|
||||
);
|
||||
let out_width = calculate_conv_transpose_output_size(
|
||||
kernel_width,
|
||||
stride_width,
|
||||
padding_width,
|
||||
out_padding_width,
|
||||
dilation_width,
|
||||
in_width,
|
||||
);
|
||||
|
||||
let x = x;
|
||||
let mut output = Array5::zeros(Dim([
|
||||
batch_size,
|
||||
out_channels * options.groups,
|
||||
out_depth,
|
||||
out_height,
|
||||
out_width,
|
||||
]));
|
||||
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
run_par!(|| {
|
||||
iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe {
|
||||
let b = k / (out_channels * options.groups);
|
||||
let oc = k % out_channels;
|
||||
let g = (k / out_channels) % options.groups;
|
||||
|
||||
let output = unsafe_shared_out.get();
|
||||
|
||||
let oc_out = oc + (out_channels * g);
|
||||
let ic_start = g * (in_channels / options.groups);
|
||||
let ic_end = ic_start + in_channels / options.groups;
|
||||
|
||||
for ic in ic_start..ic_end {
|
||||
for id in 0..in_depth {
|
||||
for ih in 0..in_height {
|
||||
for iw in 0..in_width {
|
||||
for kd in 0..kernel_depth {
|
||||
for kh in 0..kernel_height {
|
||||
for kw in 0..kernel_width {
|
||||
let od = id * stride_depth + kd * dilation_depth;
|
||||
let oh = ih * stride_height + kh * dilation_height;
|
||||
let ow = iw * stride_width + kw * dilation_width;
|
||||
|
||||
if od >= out_depth + padding_depth
|
||||
|| oh >= out_height + padding_height
|
||||
|| ow >= out_width + padding_width
|
||||
|| od < padding_depth
|
||||
|| oh < padding_height
|
||||
|| ow < padding_width
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let od = od - padding_depth;
|
||||
let oh = oh - padding_height;
|
||||
let ow = ow - padding_width;
|
||||
|
||||
output[[b, oc_out, od, oh, ow]] +=
|
||||
x[[b, ic, id, ih, iw]] * weight[[ic, oc, kd, kh, kw]];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(bias) = &bias {
|
||||
for od in 0..out_depth {
|
||||
for oh in 0..out_height {
|
||||
for ow in 0..out_width {
|
||||
output[[b, oc_out, od, oh, ow]] += bias[oc_out];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
output.into_dyn().into_shared()
|
||||
}
|
||||
@@ -0,0 +1,662 @@
|
||||
use burn_backend::ops::{DeformConvOptions, conv::calculate_conv_output_size};
|
||||
use core::ops::AddAssign;
|
||||
use ndarray::{
|
||||
Array2, Array4, ArrayView2, ArrayView3, ArrayView4, ArrayView6, ArrayViewMut2, Axis, Dim, Ix4,
|
||||
Zip, s,
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[allow(unused_imports)]
|
||||
use num_traits::Float;
|
||||
|
||||
use crate::{FloatNdArrayElement, NdArrayTensor, ShapeOps, SharedArray, iter_par, run_par};
|
||||
|
||||
use super::matmul::matmul;
|
||||
|
||||
#[inline(always)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn deform_im2col_kernel<F: FloatNdArrayElement>(
|
||||
out_y: usize,
|
||||
out_x: usize,
|
||||
input: ArrayView2<F>,
|
||||
offset: ArrayView3<F>,
|
||||
mask: Option<ArrayView2<F>>,
|
||||
mut columns: ArrayViewMut2<F>,
|
||||
args: DeformConvOptions<2>,
|
||||
(kernel_h, kernel_w): (usize, usize),
|
||||
) {
|
||||
// position shape: [in_channels, batch_size, out_h, out_w]
|
||||
// columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]
|
||||
|
||||
let (height, width) = input.dim();
|
||||
|
||||
for kernel_y in 0..kernel_h {
|
||||
for kernel_x in 0..kernel_w {
|
||||
let mask_value = mask
|
||||
.map(|it| it[[kernel_y, kernel_x]])
|
||||
.unwrap_or_else(|| F::from_elem(1.0));
|
||||
|
||||
let offset = offset.slice(s![kernel_y, kernel_x, ..]);
|
||||
let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])
|
||||
- F::from_elem(args.padding[0])
|
||||
+ offset[0];
|
||||
let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])
|
||||
- F::from_elem(args.padding[1])
|
||||
+ offset[1];
|
||||
|
||||
let interpolated = bilinear_interpolate(input, height, width, y, x);
|
||||
|
||||
columns[[kernel_y, kernel_x]] = mask_value * interpolated;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn bilinear_interpolate<F: FloatNdArrayElement>(
|
||||
input: ArrayView2<F>,
|
||||
height: usize,
|
||||
width: usize,
|
||||
y: F,
|
||||
x: F,
|
||||
) -> F {
|
||||
// To simplify code
|
||||
let y = y.to_f32();
|
||||
let x = x.to_f32();
|
||||
|
||||
let mut result = F::from_elem(0.0);
|
||||
if y > -1.0 && height as f32 > y && x > -1.0 && width as f32 > x {
|
||||
let y_low = f32::floor(y);
|
||||
let x_low = f32::floor(x);
|
||||
let y_high = (y_low + 1.) as usize;
|
||||
let x_high = (x_low + 1.) as usize;
|
||||
|
||||
let zero = F::from_elem(0.0);
|
||||
let v1: F = if y_low >= 0. && x_low >= 0. {
|
||||
input[[y_low as usize, x_low as usize]]
|
||||
} else {
|
||||
zero
|
||||
};
|
||||
let v2: F = if y_low >= 0. && x_high < width {
|
||||
input[[y_low as usize, x_high]]
|
||||
} else {
|
||||
zero
|
||||
};
|
||||
let v3: F = if y_high < height && x_low >= 0. {
|
||||
input[[y_high, x_low as usize]]
|
||||
} else {
|
||||
zero
|
||||
};
|
||||
let v4: F = if y_high < height && x_high < width {
|
||||
input[[y_high, x_high]]
|
||||
} else {
|
||||
zero
|
||||
};
|
||||
|
||||
let l_y = y - y_low;
|
||||
let l_x = x - x_low;
|
||||
let h_y = 1.0 - l_y;
|
||||
let h_x = 1.0 - l_x;
|
||||
|
||||
let w1 = F::from_elem(h_y * h_x);
|
||||
let w2 = F::from_elem(h_y * l_x);
|
||||
let w3 = F::from_elem(l_y * h_x);
|
||||
let w4 = F::from_elem(l_y * l_x);
|
||||
|
||||
result = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub(crate) fn deform_conv2d<F: FloatNdArrayElement>(
|
||||
input: SharedArray<F>,
|
||||
offset: SharedArray<F>,
|
||||
weight: SharedArray<F>,
|
||||
mask: Option<SharedArray<F>>,
|
||||
bias: Option<SharedArray<F>>,
|
||||
args: DeformConvOptions<2>,
|
||||
) -> SharedArray<F>
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<F>>,
|
||||
{
|
||||
let [batch_size, _, in_height, in_width] = input.shape().dims();
|
||||
let [out_channels, _, kernel_h, kernel_w] = weight.shape().dims();
|
||||
let groups = args.weight_groups;
|
||||
|
||||
let weight = weight.as_standard_layout();
|
||||
|
||||
let out_h = calculate_conv_output_size(
|
||||
kernel_h,
|
||||
args.stride[0],
|
||||
args.padding[0],
|
||||
args.dilation[0],
|
||||
in_height,
|
||||
);
|
||||
let out_w = calculate_conv_output_size(
|
||||
kernel_w,
|
||||
args.stride[1],
|
||||
args.padding[1],
|
||||
args.dilation[1],
|
||||
in_width,
|
||||
);
|
||||
let out_dims = (out_h, out_w);
|
||||
|
||||
let input = input.into_dimensionality::<Ix4>().unwrap();
|
||||
let offset = offset.into_dimensionality::<Ix4>().unwrap();
|
||||
let mask = mask.as_ref().map(|it| {
|
||||
it.to_shape((
|
||||
batch_size,
|
||||
args.offset_groups,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
out_h,
|
||||
out_w,
|
||||
))
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let columns = deform_im2col(
|
||||
input.view(),
|
||||
offset.view(),
|
||||
mask.as_ref().map(|it| it.view()),
|
||||
args,
|
||||
out_dims,
|
||||
(kernel_h, kernel_w),
|
||||
);
|
||||
|
||||
let (col_size_0, col_size_1) = columns.dim();
|
||||
let col_size_0 = col_size_0 / groups;
|
||||
let out_c_per_group = out_channels / groups;
|
||||
|
||||
let weight = weight
|
||||
.to_shape((groups, out_c_per_group, col_size_0))
|
||||
.unwrap();
|
||||
let columns = columns.to_shape((groups, col_size_0, col_size_1)).unwrap();
|
||||
let out = matmul(
|
||||
weight.to_owned().into_dyn().into_shared(),
|
||||
columns.to_owned().into_dyn().into_shared(),
|
||||
);
|
||||
|
||||
let mut out = out
|
||||
.into_shape_with_order((out_channels, batch_size, out_h, out_w))
|
||||
.unwrap();
|
||||
out.swap_axes(0, 1);
|
||||
|
||||
if let Some(bias) = bias {
|
||||
let bias = bias.to_shape((1, out_channels, 1, 1)).unwrap();
|
||||
out.add_assign(&bias);
|
||||
}
|
||||
|
||||
out.into_dyn().into_shared()
|
||||
}
|
||||
|
||||
pub(crate) fn deform_im2col<F: FloatNdArrayElement>(
|
||||
input: ArrayView4<F>,
|
||||
offset: ArrayView4<F>,
|
||||
mask: Option<ArrayView6<F>>,
|
||||
args: DeformConvOptions<2>,
|
||||
out_dims: (usize, usize),
|
||||
kernel_dims: (usize, usize),
|
||||
) -> Array2<F> {
|
||||
let (batch_size, in_channels, _, _) = input.dim();
|
||||
let (kernel_h, kernel_w) = kernel_dims;
|
||||
let (out_h, out_w) = out_dims;
|
||||
let channels_per_offset_group = in_channels / args.offset_groups;
|
||||
|
||||
let mut columns = Array4::zeros(Dim([
|
||||
in_channels,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
batch_size * out_h * out_w,
|
||||
]));
|
||||
|
||||
let groups = args.offset_groups;
|
||||
|
||||
run_par!(|| {
|
||||
iter_par!(columns.axis_iter_mut(Axis(3)))
|
||||
.enumerate()
|
||||
.for_each(|(index, mut columns)| {
|
||||
let out_x = index % out_w;
|
||||
let out_y = (index / out_w) % out_h;
|
||||
let batch = (index / (out_w * out_h)) % batch_size;
|
||||
let offset = offset.slice(s![batch, .., out_y, out_x]);
|
||||
let offset = offset.to_shape((groups, kernel_h, kernel_w, 2)).unwrap();
|
||||
let mask = mask
|
||||
.as_ref()
|
||||
.map(|it| it.slice(s![batch, .., .., .., out_y, out_x]));
|
||||
columns
|
||||
.axis_iter_mut(Axis(0))
|
||||
.enumerate()
|
||||
.for_each(|(in_channel, mut columns)| {
|
||||
let group_index = in_channel / channels_per_offset_group;
|
||||
deform_im2col_kernel(
|
||||
out_y,
|
||||
out_x,
|
||||
input.slice(s![batch, in_channel, .., ..]),
|
||||
offset.slice(s![group_index, .., .., ..]),
|
||||
mask.as_ref().map(|it| it.slice(s![group_index, .., ..])),
|
||||
columns.view_mut(),
|
||||
args.clone(),
|
||||
kernel_dims,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
columns
|
||||
// Columns is created here, so we know it's contiguous
|
||||
.into_shape_with_order((
|
||||
in_channels * kernel_h * kernel_w,
|
||||
batch_size * out_h * out_w,
|
||||
))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub mod backward {
|
||||
#[cfg(target_has_atomic = "32")]
|
||||
use core::sync::atomic::Ordering;
|
||||
|
||||
use atomic_float::AtomicF32;
|
||||
use ndarray::{Array1, Array5, ArrayView4, ArrayView6, Ix4};
|
||||
|
||||
use super::*;
|
||||
|
||||
pub(crate) type DeformConv2dBackward<F> = (
|
||||
SharedArray<F>,
|
||||
SharedArray<F>,
|
||||
SharedArray<F>,
|
||||
Option<SharedArray<F>>,
|
||||
Option<SharedArray<F>>,
|
||||
);
|
||||
|
||||
/// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions.
|
||||
pub(crate) fn deform_conv2d_backward<F: FloatNdArrayElement>(
|
||||
input: SharedArray<F>,
|
||||
offset: SharedArray<F>,
|
||||
weight: SharedArray<F>,
|
||||
mask: Option<SharedArray<F>>,
|
||||
bias: Option<SharedArray<F>>,
|
||||
out_grad: SharedArray<F>,
|
||||
args: DeformConvOptions<2>,
|
||||
) -> DeformConv2dBackward<F> {
|
||||
let [batch_size, out_channels, out_h, out_w] = out_grad.shape().dims();
|
||||
let [_, _, kernel_h, kernel_w] = weight.shape().dims();
|
||||
let groups = args.weight_groups;
|
||||
let out_c_per_group = out_channels / groups;
|
||||
let col_shape_1 = batch_size * out_h * out_w;
|
||||
let mut out_grad = out_grad.into_dimensionality::<Ix4>().unwrap();
|
||||
|
||||
let gradient_bias = bias.map(|_| {
|
||||
let out_grad = out_grad
|
||||
.clone()
|
||||
.sum_axis(Axis(0))
|
||||
.sum_axis(Axis(1))
|
||||
.sum_axis(Axis(1));
|
||||
|
||||
out_grad.into_dyn().into_shared()
|
||||
});
|
||||
|
||||
out_grad.swap_axes(0, 1);
|
||||
let out_grad = out_grad
|
||||
.to_shape((groups, out_c_per_group, col_shape_1))
|
||||
.unwrap();
|
||||
|
||||
let input = input.into_dimensionality::<Ix4>().unwrap();
|
||||
let offset = offset.into_dimensionality::<Ix4>().unwrap();
|
||||
let mask = mask.map(|it| {
|
||||
it.into_shape_with_order((
|
||||
batch_size,
|
||||
args.offset_groups,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
out_h,
|
||||
out_w,
|
||||
))
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs(
|
||||
input.view(),
|
||||
weight,
|
||||
offset.view(),
|
||||
mask.as_ref().map(|it| it.view()),
|
||||
out_grad.view(),
|
||||
&args,
|
||||
(kernel_h, kernel_w),
|
||||
);
|
||||
|
||||
let weight_grad = compute_weight_grad(
|
||||
input.view(),
|
||||
offset.view(),
|
||||
mask.as_ref().map(|it| it.view()),
|
||||
out_grad.view(),
|
||||
args,
|
||||
(kernel_h, kernel_w),
|
||||
(out_h, out_w),
|
||||
);
|
||||
|
||||
(
|
||||
input_gradient,
|
||||
offset_gradient,
|
||||
weight_grad,
|
||||
mask_gradient,
|
||||
gradient_bias,
|
||||
)
|
||||
}
|
||||
|
||||
fn compute_weight_grad<F: FloatNdArrayElement>(
|
||||
input: ArrayView4<F>,
|
||||
offset: ArrayView4<F>,
|
||||
mask: Option<ArrayView6<F>>,
|
||||
out_grad: ArrayView3<F>,
|
||||
options: DeformConvOptions<2>,
|
||||
kernel_dims: (usize, usize),
|
||||
out_dims: (usize, usize),
|
||||
) -> SharedArray<F> {
|
||||
let in_channels = input.dim().1;
|
||||
let (groups, out_c_per_group, _) = out_grad.dim();
|
||||
let (kernel_h, kernel_w) = kernel_dims;
|
||||
|
||||
let in_c_per_group = in_channels / groups;
|
||||
|
||||
let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims);
|
||||
let (col_size_0, col_size_1) = columns.dim();
|
||||
let col_size_0 = col_size_0 / groups;
|
||||
|
||||
let mut columns = columns.to_shape((groups, col_size_0, col_size_1)).unwrap();
|
||||
columns.swap_axes(1, 2);
|
||||
|
||||
let grad_weight = matmul(
|
||||
out_grad.to_owned().into_dyn().into_shared(),
|
||||
columns.to_owned().into_dyn().into_shared(),
|
||||
);
|
||||
|
||||
let grad_weight = grad_weight
|
||||
.into_shape_with_order((out_c_per_group * groups, in_c_per_group, kernel_h, kernel_w))
|
||||
.unwrap();
|
||||
grad_weight.into_dyn().into_shared()
|
||||
}
|
||||
|
||||
type InputGradients<F> = (SharedArray<F>, SharedArray<F>, Option<SharedArray<F>>);
|
||||
|
||||
fn backward_gradient_inputs<F: FloatNdArrayElement>(
|
||||
image: ArrayView4<F>,
|
||||
weight: SharedArray<F>,
|
||||
offset: ArrayView4<F>,
|
||||
mask: Option<ArrayView6<F>>,
|
||||
out_grad: ArrayView3<F>,
|
||||
args: &DeformConvOptions<2>,
|
||||
kernel_dims: (usize, usize),
|
||||
) -> InputGradients<F> {
|
||||
let input_shape = image.dim();
|
||||
let in_channels = input_shape.1;
|
||||
let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.shape().dims();
|
||||
let (batch_size, _, out_h, out_w) = offset.dim();
|
||||
|
||||
let groups = args.weight_groups;
|
||||
let out_c_per_group = out_channels / groups;
|
||||
|
||||
let col_shape_0 = in_c_per_group * kernel_h * kernel_w;
|
||||
|
||||
let mut weight = weight
|
||||
.to_shape((groups, out_c_per_group, col_shape_0))
|
||||
.unwrap();
|
||||
weight.swap_axes(1, 2);
|
||||
let columns = matmul(
|
||||
weight.to_owned().into_dyn().into_shared(),
|
||||
out_grad.to_owned().into_dyn().into_shared(),
|
||||
);
|
||||
|
||||
let columns = columns
|
||||
.to_shape((in_channels, kernel_h, kernel_w, batch_size, out_h, out_w))
|
||||
.unwrap();
|
||||
|
||||
let (offset_gradient, mask_gradient) = compute_offset_and_mask_gradient(
|
||||
columns.view(),
|
||||
image.view(),
|
||||
offset,
|
||||
mask,
|
||||
args,
|
||||
kernel_dims,
|
||||
);
|
||||
|
||||
let input_gradient =
|
||||
compute_input_grad(columns.view(), offset, mask, args, kernel_dims, input_shape);
|
||||
|
||||
(input_gradient, offset_gradient, mask_gradient)
|
||||
}
|
||||
|
||||
fn compute_offset_and_mask_gradient<F: FloatNdArrayElement>(
|
||||
columns: ArrayView6<F>,
|
||||
image: ArrayView4<F>,
|
||||
offset: ArrayView4<F>,
|
||||
mask: Option<ArrayView6<F>>,
|
||||
args: &DeformConvOptions<2>,
|
||||
kernel_dims: (usize, usize),
|
||||
) -> (SharedArray<F>, Option<SharedArray<F>>) {
|
||||
let (kernel_h, kernel_w) = kernel_dims;
|
||||
let (_, in_channels, height, width) = image.dim();
|
||||
let (batch_size, offset_channels, out_h, out_w) = offset.dim();
|
||||
let offs_groups = args.offset_groups;
|
||||
let channels_per_offset_group = in_channels / args.offset_groups;
|
||||
|
||||
let mut grad_offset = Array5::zeros((
|
||||
offs_groups,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
2,
|
||||
batch_size * out_h * out_w,
|
||||
));
|
||||
let mut grad_mask =
|
||||
Array4::zeros((offs_groups, kernel_h, kernel_w, batch_size * out_h * out_w));
|
||||
|
||||
grad_mask
|
||||
.axis_iter_mut(Axis(3))
|
||||
.zip(grad_offset.axis_iter_mut(Axis(4)))
|
||||
.enumerate()
|
||||
.for_each(|(index, (mut grad_mask, mut grad_offset))| {
|
||||
let out_x = index % out_w;
|
||||
let out_y = (index / out_w) % out_h;
|
||||
let batch = index / (out_w * out_h);
|
||||
let offset = offset.slice(s![batch, .., out_y, out_x]);
|
||||
let offset = offset
|
||||
.to_shape((offs_groups, kernel_h, kernel_w, 2))
|
||||
.unwrap();
|
||||
let mask: Option<ArrayView3<F>> = mask
|
||||
.as_ref()
|
||||
.map(|mask| mask.slice(s![batch, .., .., .., out_y, out_x]));
|
||||
let columns = columns.slice(s![.., .., .., batch, out_y, out_x]);
|
||||
let image = image.slice(s![batch, .., .., ..]);
|
||||
|
||||
for ((group, kernel_y, kernel_x), grad_mask) in grad_mask.indexed_iter_mut() {
|
||||
let grad_mask: &mut F = grad_mask;
|
||||
let mut grad_offset = grad_offset.slice_mut(s![group, kernel_y, kernel_x, ..]);
|
||||
let offset = offset.slice(s![group, kernel_y, kernel_x, ..]);
|
||||
let mask = mask.map(|it| it[[group, kernel_y, kernel_x]]);
|
||||
let columns = columns.slice(s![.., kernel_y, kernel_x]);
|
||||
let group_offset = group * channels_per_offset_group;
|
||||
let image = image.slice(s![group_offset.., .., ..]);
|
||||
let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])
|
||||
- F::from_elem(args.padding[0])
|
||||
+ offset[0];
|
||||
let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])
|
||||
- F::from_elem(args.padding[1])
|
||||
+ offset[1];
|
||||
for (i, grad_offset) in grad_offset.iter_mut().enumerate() {
|
||||
let is_y_direction = i % 2 == 0;
|
||||
let use_mask = mask.is_some();
|
||||
|
||||
for channel in 0..channels_per_offset_group {
|
||||
let mask = mask.unwrap_or_else(|| F::one());
|
||||
let image = image.index_axis(Axis(0), channel);
|
||||
let weight =
|
||||
get_coordinate_weight(image, height, width, y, x, is_y_direction);
|
||||
*grad_offset += mask * weight * columns[channel];
|
||||
if use_mask && is_y_direction {
|
||||
*grad_mask += columns[channel]
|
||||
* bilinear_interpolate(image, height, width, y, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mask_gradient = mask.map(|_| {
|
||||
let mut grad_mask = grad_mask
|
||||
.into_shape_with_order((offset_channels / 2, batch_size, out_h, out_w))
|
||||
.unwrap();
|
||||
grad_mask.swap_axes(0, 1);
|
||||
grad_mask.into_dyn().into_shared()
|
||||
});
|
||||
let mut grad_offset = grad_offset
|
||||
.into_shape_with_order((offset_channels, batch_size, out_h, out_w))
|
||||
.unwrap();
|
||||
grad_offset.swap_axes(0, 1);
|
||||
let offset_gradient = grad_offset.into_dyn().into_shared();
|
||||
(offset_gradient, mask_gradient)
|
||||
}
|
||||
|
||||
fn get_coordinate_weight<F: FloatNdArrayElement>(
|
||||
input: ArrayView2<F>,
|
||||
height: usize,
|
||||
width: usize,
|
||||
y: F,
|
||||
x: F,
|
||||
is_y_direction: bool,
|
||||
) -> F {
|
||||
let y = y.to_f32();
|
||||
let x = x.to_f32();
|
||||
|
||||
let y_low = f32::floor(y);
|
||||
let x_low = f32::floor(x);
|
||||
let y_high = y_low + 1.;
|
||||
let x_high = x_low + 1.;
|
||||
|
||||
let valid_y_low = y_low >= 0. && y_low < height as f32;
|
||||
let valid_y_high = y_high >= 0. && y_high < height as f32;
|
||||
let valid_x_low = x_low >= 0. && x_low < width as f32;
|
||||
let valid_x_high = x_high >= 0. && x_high < width as f32;
|
||||
|
||||
let bottom_left = if valid_y_low && valid_x_low {
|
||||
input[[y_low as usize, x_low as usize]]
|
||||
} else {
|
||||
F::zero()
|
||||
};
|
||||
let bottom_right = if valid_y_low && valid_x_high {
|
||||
input[[y_low as usize, x_high as usize]]
|
||||
} else {
|
||||
F::zero()
|
||||
};
|
||||
let top_left = if valid_y_high && valid_x_low {
|
||||
input[[y_high as usize, x_low as usize]]
|
||||
} else {
|
||||
F::zero()
|
||||
};
|
||||
let top_right = if valid_y_high && valid_x_high {
|
||||
input[[y_high as usize, x_high as usize]]
|
||||
} else {
|
||||
F::zero()
|
||||
};
|
||||
|
||||
if is_y_direction {
|
||||
let delta_x = F::from_elem(x - x_low);
|
||||
delta_x * (top_right - bottom_right) + (F::one() - delta_x) * (top_left - bottom_left)
|
||||
} else {
|
||||
let delta_y = F::from_elem(y - y_low);
|
||||
delta_y * (top_right - top_left) + (F::one() - delta_y) * (bottom_right - bottom_left)
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_input_grad<F: FloatNdArrayElement>(
|
||||
columns: ArrayView6<F>,
|
||||
offset: ArrayView4<F>,
|
||||
mask: Option<ArrayView6<F>>,
|
||||
args: &DeformConvOptions<2>,
|
||||
kernel_dims: (usize, usize),
|
||||
input_shape: (usize, usize, usize, usize),
|
||||
) -> SharedArray<F> {
|
||||
let (batch_size, in_channels, height, width) = input_shape;
|
||||
let (kernel_h, kernel_w) = kernel_dims;
|
||||
let offs_groups = args.offset_groups;
|
||||
let channels_per_offset_group = in_channels / offs_groups;
|
||||
|
||||
let grad_in =
|
||||
Array4::from_shape_simple_fn((batch_size, in_channels, height, width), || {
|
||||
AtomicF32::new(0.0)
|
||||
});
|
||||
|
||||
let compute_for_each = |(in_channel, kernel_y, kernel_x, batch, out_y, out_x), col: &F| {
|
||||
let group = in_channel / channels_per_offset_group;
|
||||
let offset = offset.slice(s![batch, .., out_y, out_x]);
|
||||
let offset = offset
|
||||
.to_shape((offs_groups, kernel_h, kernel_w, 2))
|
||||
.unwrap();
|
||||
let offset = offset.slice(s![group, kernel_y, kernel_x, ..]);
|
||||
let offset = [offset[0], offset[1]];
|
||||
let mask = mask
|
||||
.as_ref()
|
||||
.map(|it| it[[batch, group, kernel_y, kernel_x, out_y, out_x]].to_f32());
|
||||
let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])
|
||||
- F::from_elem(args.padding[0])
|
||||
+ offset[0];
|
||||
let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])
|
||||
- F::from_elem(args.padding[1])
|
||||
+ offset[1];
|
||||
let grad_in = grad_in.slice(s![batch, in_channel, .., ..]);
|
||||
deform_col2img_kernel(y.to_f32(), x.to_f32(), mask, col.to_f32(), grad_in);
|
||||
};
|
||||
|
||||
// `for_each` expects a 2-tuple argument with `.into_par_iter()`, but 2 separate arguments otherwise
|
||||
#[cfg(feature = "multi-threads")]
|
||||
run_par!(|| {
|
||||
iter_par!(Zip::indexed(columns))
|
||||
.for_each(|(args0, args1)| compute_for_each(args0, args1))
|
||||
});
|
||||
|
||||
#[cfg(not(feature = "multi-threads"))]
|
||||
run_par!(|| { iter_par!(Zip::indexed(columns)).for_each(&compute_for_each) });
|
||||
|
||||
let grad_in: Array1<F> = grad_in
|
||||
.into_iter()
|
||||
.map(|it| F::from_elem(it.into_inner()))
|
||||
.collect();
|
||||
let grad_in = grad_in
|
||||
.into_shape_with_order((batch_size, in_channels, height, width))
|
||||
.unwrap();
|
||||
grad_in.into_dyn().into_shared()
|
||||
}
|
||||
|
||||
fn deform_col2img_kernel(
|
||||
y: f32,
|
||||
x: f32,
|
||||
mask: Option<f32>,
|
||||
col: f32,
|
||||
grad_input: ArrayView2<AtomicF32>,
|
||||
) {
|
||||
let (height, width) = grad_input.dim();
|
||||
let mask_value = mask.unwrap_or(1.0);
|
||||
|
||||
for dy in -1..=1 {
|
||||
for dx in -1..=1 {
|
||||
let yp = f32::floor(y) + dy as f32;
|
||||
let xp = f32::floor(x) + dx as f32;
|
||||
|
||||
if yp >= 0.0
|
||||
&& yp < height as f32
|
||||
&& xp >= 0.0
|
||||
&& xp < width as f32
|
||||
&& f32::abs(y - yp) < 1.0
|
||||
&& f32::abs(x - xp) < 1.0
|
||||
{
|
||||
let weight = (1.0 - f32::abs(y - yp)) * (1.0 - f32::abs(x - xp));
|
||||
|
||||
#[cfg_attr(not(target_has_atomic = "32"), allow(unused))]
|
||||
let value = mask_value * weight * col;
|
||||
|
||||
#[cfg(target_has_atomic = "32")]
|
||||
grad_input[[yp as usize, xp as usize]].fetch_add(value, Ordering::AcqRel);
|
||||
#[cfg(not(target_has_atomic = "32"))]
|
||||
panic!("Can't use deformable convolution backwards pass without atomics");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
use burn_backend::ElementConversion;
|
||||
use burn_backend::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[allow(unused_imports)]
|
||||
use num_traits::Float;
|
||||
|
||||
use ndarray::Array4;
|
||||
|
||||
use crate::SharedArray;
|
||||
use crate::{FloatNdArrayElement, UnsafeSharedRef, iter_range_par, run_par};
|
||||
|
||||
/// Sample a tensor using grid-based sampling.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in)
|
||||
/// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
|
||||
/// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
|
||||
/// * `options` - Grid sampling options (mode, padding_mode, align_corners)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with shape (N, C, H_out, W_out)
|
||||
pub(crate) fn grid_sample_2d<E: FloatNdArrayElement>(
|
||||
tensor: SharedArray<E>,
|
||||
grid: SharedArray<E>,
|
||||
options: GridSampleOptions,
|
||||
) -> SharedArray<E> {
|
||||
match options.mode {
|
||||
InterpolateMode::Bilinear => (),
|
||||
_ => todo!(
|
||||
"grid_sample_2d with {:?} mode is not implemented",
|
||||
options.mode
|
||||
),
|
||||
}
|
||||
|
||||
let tensor = tensor.into_dimensionality::<ndarray::Ix4>().unwrap();
|
||||
let grid = grid.into_dimensionality::<ndarray::Ix4>().unwrap();
|
||||
|
||||
let (batch_size, channels, height_in, width_in) = tensor.dim();
|
||||
let (b, height_out, width_out, d) = grid.dim();
|
||||
assert!(batch_size == b);
|
||||
assert!(2 == d);
|
||||
|
||||
let mut output = Array4::zeros((batch_size, channels, height_out, width_out));
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
let sample_count = batch_size * channels * height_out * width_out;
|
||||
let strides = (
|
||||
channels * height_out * width_out,
|
||||
height_out * width_out,
|
||||
width_out,
|
||||
);
|
||||
|
||||
let align = options.align_corners;
|
||||
let pad_mode = options.padding_mode;
|
||||
|
||||
run_par!(|| {
|
||||
iter_range_par!(0, sample_count).for_each(|id| {
|
||||
let (b, c, y, x) = (
|
||||
id / strides.0,
|
||||
id % strides.0 / strides.1,
|
||||
id % strides.1 / strides.2,
|
||||
id % strides.2,
|
||||
);
|
||||
|
||||
let sample_x = grid[(b, y, x, 0)].elem::<f64>();
|
||||
let sample_y = grid[(b, y, x, 1)].elem::<f64>();
|
||||
|
||||
// Convert normalized grid coordinates [-1, 1] to pixel coordinates
|
||||
let (px, py) = if align {
|
||||
// align_corners=true: x_pixel = (x_norm + 1) * (width - 1) / 2
|
||||
// Maps -1 to 0 and 1 to width - 1
|
||||
let px = (sample_x + 1.0) * ((width_in - 1) as f64) / 2.0;
|
||||
let py = (sample_y + 1.0) * ((height_in - 1) as f64) / 2.0;
|
||||
(px, py)
|
||||
} else {
|
||||
// align_corners=false: x_pixel = (x_norm + 1) * width / 2 - 0.5
|
||||
// Maps -1 to -0.5 and 1 to width - 0.5
|
||||
let px = (sample_x + 1.0) * (width_in as f64) / 2.0 - 0.5;
|
||||
let py = (sample_y + 1.0) * (height_in as f64) / 2.0 - 0.5;
|
||||
(px, py)
|
||||
};
|
||||
|
||||
// Bilinear interpolation with the specified padding mode
|
||||
let val =
|
||||
bilinear_interpolate(&tensor, b, c, px, py, width_in, height_in, pad_mode, align);
|
||||
|
||||
unsafe {
|
||||
let output = unsafe_shared_out.get();
|
||||
output[(b, c, y, x)] = val.elem();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
output.into_dyn().into_shared()
|
||||
}
|
||||
|
||||
/// Bilinear interpolation at a point with configurable padding mode.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn bilinear_interpolate<E, S>(
|
||||
source: &ndarray::ArrayBase<S, ndarray::Dim<[usize; 4]>>,
|
||||
b: usize,
|
||||
c: usize,
|
||||
x: f64,
|
||||
y: f64,
|
||||
width: usize,
|
||||
height: usize,
|
||||
padding_mode: GridSamplePaddingMode,
|
||||
align_corners: bool,
|
||||
) -> f64
|
||||
where
|
||||
E: FloatNdArrayElement,
|
||||
S: ndarray::Data<Elem = E>,
|
||||
{
|
||||
// Handle inf/nan coordinates
|
||||
if !x.is_finite() || !y.is_finite() {
|
||||
return match padding_mode {
|
||||
GridSamplePaddingMode::Zeros => 0.0,
|
||||
GridSamplePaddingMode::Border => {
|
||||
// Clamp to center of image for inf/nan
|
||||
let cx = ((width - 1) as f64 / 2.0).clamp(0.0, (width - 1) as f64);
|
||||
let cy = ((height - 1) as f64 / 2.0).clamp(0.0, (height - 1) as f64);
|
||||
source[(b, c, cy as usize, cx as usize)].elem::<f64>()
|
||||
}
|
||||
GridSamplePaddingMode::Reflection => 0.0, // Simplified: treat as zeros for inf/nan
|
||||
};
|
||||
}
|
||||
|
||||
// Apply padding mode to get actual sampling coordinates
|
||||
let (x, y) = match padding_mode {
|
||||
GridSamplePaddingMode::Border => {
|
||||
// Clamp coordinates to valid range [0, size-1]
|
||||
let x = x.clamp(0.0, (width - 1) as f64);
|
||||
let y = y.clamp(0.0, (height - 1) as f64);
|
||||
(x, y)
|
||||
}
|
||||
GridSamplePaddingMode::Reflection => {
|
||||
// Reflect coordinates at boundaries
|
||||
let x = reflect_coordinate(x, width, align_corners);
|
||||
let y = reflect_coordinate(y, height, align_corners);
|
||||
(x, y)
|
||||
}
|
||||
GridSamplePaddingMode::Zeros => (x, y), // Keep as-is, handle out-of-bounds in read
|
||||
};
|
||||
|
||||
// Get the four corner indices
|
||||
let x0 = x.floor() as i64;
|
||||
let y0 = y.floor() as i64;
|
||||
let x1 = x0.saturating_add(1);
|
||||
let y1 = y0.saturating_add(1);
|
||||
|
||||
// Compute interpolation weights (fractional part)
|
||||
let x_frac = x - x.floor();
|
||||
let y_frac = y - y.floor();
|
||||
|
||||
// Helper to read a value based on padding mode
|
||||
let read_value = |xi: i64, yi: i64| -> f64 {
|
||||
match padding_mode {
|
||||
GridSamplePaddingMode::Zeros => {
|
||||
// Return 0 for out-of-bounds
|
||||
if xi >= 0 && xi < width as i64 && yi >= 0 && yi < height as i64 {
|
||||
source[(b, c, yi as usize, xi as usize)].elem::<f64>()
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
GridSamplePaddingMode::Border | GridSamplePaddingMode::Reflection => {
|
||||
// Coordinates should already be in valid range after clamping/reflection
|
||||
let xi = xi.clamp(0, (width - 1) as i64) as usize;
|
||||
let yi = yi.clamp(0, (height - 1) as i64) as usize;
|
||||
source[(b, c, yi, xi)].elem::<f64>()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Read the four corners
|
||||
let v00 = read_value(x0, y0);
|
||||
let v01 = read_value(x0, y1);
|
||||
let v10 = read_value(x1, y0);
|
||||
let v11 = read_value(x1, y1);
|
||||
|
||||
// Bilinear interpolation weights
|
||||
let w00 = (1.0 - x_frac) * (1.0 - y_frac);
|
||||
let w01 = (1.0 - x_frac) * y_frac;
|
||||
let w10 = x_frac * (1.0 - y_frac);
|
||||
let w11 = x_frac * y_frac;
|
||||
|
||||
v00 * w00 + v01 * w01 + v10 * w10 + v11 * w11
|
||||
}
|
||||
|
||||
/// Reflect a coordinate at the boundaries using a triangle wave pattern.
|
||||
///
|
||||
/// For align_corners=true: reflects within [0, size-1]
|
||||
/// For align_corners=false: reflects within [-0.5, size-0.5]
|
||||
fn reflect_coordinate(coord: f64, size: usize, align_corners: bool) -> f64 {
|
||||
let size_f = size as f64;
|
||||
let (min_val, max_val) = if align_corners {
|
||||
(0.0, size_f - 1.0)
|
||||
} else {
|
||||
(-0.5, size_f - 0.5)
|
||||
};
|
||||
|
||||
let span = max_val - min_val;
|
||||
if span <= 0.0 {
|
||||
return min_val;
|
||||
}
|
||||
|
||||
// Triangle wave formula: span - |((x mod 2*span) - span)|
|
||||
let period = 2.0 * span;
|
||||
let x = (coord - min_val).abs();
|
||||
let x_mod = x - (x / period).floor() * period;
|
||||
span - (x_mod - span).abs() + min_val
|
||||
}
|
||||
@@ -0,0 +1,497 @@
|
||||
// Language
|
||||
use crate::rand::get_seeded_rng;
|
||||
use alloc::vec::Vec;
|
||||
use burn_backend::backend::ExecutionError;
|
||||
use burn_backend::ops::IntTensorOps;
|
||||
use burn_backend::tensor::{FloatTensor, IntTensor};
|
||||
use burn_backend::{Distribution, IntDType, Scalar, TensorMetadata};
|
||||
|
||||
use burn_backend::ElementConversion;
|
||||
|
||||
// Current crate
|
||||
use crate::{NdArray, cast_to_dtype, execute_with_dtype, tensor::NdArrayTensor};
|
||||
use crate::{NdArrayDevice, SEED, slice};
|
||||
use crate::{SharedArray, element::QuantElement};
|
||||
use crate::{cat_with_dtype, execute_with_float_dtype};
|
||||
use crate::{element::FloatNdArrayElement, ops::matmul::matmul};
|
||||
use crate::{element::IntNdArrayElement, execute_with_int_dtype};
|
||||
|
||||
// Workspace crates
|
||||
use super::{NdArrayBitOps, NdArrayMathOps, NdArrayOps};
|
||||
use burn_backend::{DType, Shape, TensorData, backend::Backend};
|
||||
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps<Self>
|
||||
for NdArray<E, I, Q>
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<E>>,
|
||||
NdArrayTensor: From<SharedArray<I>>,
|
||||
{
|
||||
fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
|
||||
if data.dtype.is_int() || data.dtype.is_uint() {
|
||||
NdArrayTensor::from_data(data)
|
||||
} else {
|
||||
unimplemented!("Unsupported dtype for `int_from_data`: {:?}", data.dtype)
|
||||
}
|
||||
}
|
||||
|
||||
async fn int_into_data(tensor: NdArrayTensor) -> Result<TensorData, ExecutionError> {
|
||||
Ok(tensor.into_data())
|
||||
}
|
||||
|
||||
fn int_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn int_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, |array| NdArrayOps::reshape(array, shape))
|
||||
}
|
||||
|
||||
fn int_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor {
|
||||
slice!(tensor, slices)
|
||||
}
|
||||
|
||||
fn int_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {
|
||||
NdArrayDevice::Cpu
|
||||
}
|
||||
|
||||
fn int_empty(
|
||||
shape: Shape,
|
||||
device: &<NdArray<E> as Backend>::Device,
|
||||
dtype: IntDType,
|
||||
) -> NdArrayTensor {
|
||||
Self::int_zeros(shape, device, dtype)
|
||||
}
|
||||
|
||||
fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
execute_with_int_dtype!((lhs, rhs), matmul)
|
||||
}
|
||||
|
||||
fn int_mask_where(
|
||||
tensor: NdArrayTensor,
|
||||
mask: NdArrayTensor,
|
||||
source: NdArrayTensor,
|
||||
) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((tensor, source), |tensor, source| {
|
||||
NdArrayOps::mask_where(tensor, mask.bool(), source)
|
||||
})
|
||||
}
|
||||
|
||||
fn int_mask_fill(tensor: NdArrayTensor, mask: NdArrayTensor, value: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, |array| NdArrayOps::mask_fill(
|
||||
array,
|
||||
mask.bool(),
|
||||
value.elem()
|
||||
))
|
||||
}
|
||||
|
||||
fn int_slice_assign(
|
||||
tensor: NdArrayTensor,
|
||||
slices: &[burn_backend::Slice],
|
||||
value: NdArrayTensor,
|
||||
) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((tensor, value), |tensor, value| NdArrayOps::slice_assign(
|
||||
tensor, slices, value
|
||||
))
|
||||
}
|
||||
|
||||
fn int_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
|
||||
cat_with_dtype!(tensors, dim, [I64, I32, I16, I8, U64, U32, U16, U8])
|
||||
}
|
||||
|
||||
fn int_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::equal)
|
||||
}
|
||||
|
||||
fn int_equal_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::equal_elem(array, rhs.elem()))
|
||||
}
|
||||
|
||||
fn int_greater(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater)
|
||||
}
|
||||
|
||||
fn int_greater_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_elem(array, rhs.elem()))
|
||||
}
|
||||
|
||||
fn int_greater_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater_equal)
|
||||
}
|
||||
|
||||
fn int_greater_equal_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_equal_elem(
|
||||
array,
|
||||
rhs.elem()
|
||||
))
|
||||
}
|
||||
|
||||
fn int_lower(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower)
|
||||
}
|
||||
|
||||
fn int_lower_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_elem(array, rhs.elem()))
|
||||
}
|
||||
|
||||
fn int_lower_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower_equal)
|
||||
}
|
||||
|
||||
fn int_lower_equal_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_equal_elem(
|
||||
array,
|
||||
rhs.elem()
|
||||
))
|
||||
}
|
||||
|
||||
fn int_add(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::add)
|
||||
}
|
||||
|
||||
fn int_add_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::add_scalar(array, rhs.elem()))
|
||||
}
|
||||
|
||||
fn int_sub(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::sub)
|
||||
}
|
||||
|
||||
fn int_sub_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::sub_scalar(array, rhs.elem()))
|
||||
}
|
||||
|
||||
fn int_mul(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::mul)
|
||||
}
|
||||
|
||||
fn int_mul_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::mul_scalar(array, rhs.elem()))
|
||||
}
|
||||
|
||||
fn int_div(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::div)
|
||||
}
|
||||
|
||||
fn int_div_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::div_scalar(array, rhs.elem()))
|
||||
}
|
||||
|
||||
fn int_remainder(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::remainder)
|
||||
}
|
||||
|
||||
fn int_remainder_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(lhs, |array| NdArrayMathOps::remainder_scalar(
|
||||
array,
|
||||
rhs.elem()
|
||||
))
|
||||
}
|
||||
|
||||
fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor {
|
||||
// Use view() for zero-copy on borrowed storage
|
||||
execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::sum_view(
|
||||
array.view()
|
||||
))
|
||||
}
|
||||
|
||||
fn int_sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::sum_dim(array, dim))
|
||||
}
|
||||
|
||||
fn int_prod(tensor: NdArrayTensor) -> NdArrayTensor {
|
||||
// Use view() for zero-copy on borrowed storage
|
||||
execute_with_int_dtype!(
|
||||
tensor,
|
||||
E,
|
||||
|array: SharedArray<E>| NdArrayMathOps::prod_view(array.view())
|
||||
)
|
||||
}
|
||||
|
||||
fn int_prod_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::prod_dim(array, dim))
|
||||
}
|
||||
|
||||
fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor {
|
||||
// Use view() for zero-copy on borrowed storage
|
||||
execute_with_int_dtype!(
|
||||
tensor,
|
||||
E,
|
||||
|array: SharedArray<E>| NdArrayMathOps::mean_view(array.view())
|
||||
)
|
||||
}
|
||||
|
||||
fn int_mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::mean_dim(array, dim))
|
||||
}
|
||||
|
||||
fn int_max(tensor: NdArrayTensor) -> NdArrayTensor {
|
||||
// Use view() for zero-copy on borrowed storage
|
||||
execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::max_view(
|
||||
array.view()
|
||||
))
|
||||
}
|
||||
|
||||
fn int_min(tensor: NdArrayTensor) -> NdArrayTensor {
|
||||
// Use view() for zero-copy on borrowed storage
|
||||
execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::min_view(
|
||||
array.view()
|
||||
))
|
||||
}
|
||||
|
||||
fn int_cumsum(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumsum(array, dim))
|
||||
}
|
||||
|
||||
fn int_cumprod(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumprod(array, dim))
|
||||
}
|
||||
|
||||
fn int_cummin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummin(array, dim))
|
||||
}
|
||||
|
||||
fn int_cummax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummax(array, dim))
|
||||
}
|
||||
|
||||
fn int_gather(dim: usize, tensor: NdArrayTensor, indices: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor {
|
||||
execute_with_int_dtype!(indices, |idx_array| NdArrayOps::gather(
|
||||
dim, array, idx_array
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn int_scatter_add(
|
||||
dim: usize,
|
||||
tensor: NdArrayTensor,
|
||||
indices: NdArrayTensor,
|
||||
value: NdArrayTensor,
|
||||
) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
|
||||
execute_with_int_dtype!(indices, |idx_array| NdArrayOps::<I>::scatter(
|
||||
dim, tensor, idx_array, value
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn int_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor {
|
||||
execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::select(
|
||||
array, dim, idx_array
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn int_select_add(
|
||||
tensor: NdArrayTensor,
|
||||
dim: usize,
|
||||
indices: NdArrayTensor,
|
||||
value: NdArrayTensor,
|
||||
) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
|
||||
execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::<I>::select_assign(
|
||||
tensor, dim, idx_array, value
|
||||
))
|
||||
})
|
||||
}
|
||||
fn int_argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
|
||||
// Use view() for zero-copy on borrowed storage
|
||||
execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {
|
||||
NdArrayMathOps::argmax_view::<I>(array.view(), dim)
|
||||
})
|
||||
}
|
||||
|
||||
fn int_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
|
||||
// Use view() for zero-copy on borrowed storage
|
||||
execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {
|
||||
NdArrayMathOps::argmin_view::<I>(array.view(), dim)
|
||||
})
|
||||
}
|
||||
|
||||
fn int_clamp_min(tensor: NdArrayTensor, min: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, min.elem()))
|
||||
}
|
||||
|
||||
fn int_clamp_max(tensor: NdArrayTensor, max: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_max(array, max.elem()))
|
||||
}
|
||||
|
||||
fn int_clamp(tensor: NdArrayTensor, min: Scalar, max: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp(
|
||||
array,
|
||||
min.elem(),
|
||||
max.elem()
|
||||
))
|
||||
}
|
||||
|
||||
fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor {
|
||||
match tensor.dtype() {
|
||||
DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
|
||||
execute_with_dtype!(tensor, I, NdArrayMathOps::abs, [
|
||||
I64 => i64, I32 => i32, I16 => i16, I8 => i8
|
||||
])
|
||||
}
|
||||
// Already unsigned
|
||||
DType::U64 | DType::U32 | DType::U16 | DType::U8 => tensor,
|
||||
other => panic!("Unsupported dtype: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn int_into_float(tensor: NdArrayTensor) -> FloatTensor<Self> {
|
||||
execute_with_int_dtype!(tensor, IntElem, |array: SharedArray<IntElem>| array
|
||||
.mapv(|a: IntElem| a.elem::<E>())
|
||||
.into_shared())
|
||||
}
|
||||
|
||||
fn int_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, |array| NdArrayOps::swap_dims(array, dim1, dim2))
|
||||
}
|
||||
|
||||
fn int_random(
|
||||
shape: Shape,
|
||||
distribution: Distribution,
|
||||
device: &NdArrayDevice,
|
||||
) -> NdArrayTensor {
|
||||
let mut seed = SEED.lock().unwrap();
|
||||
let mut rng = seed.take().unwrap_or_else(get_seeded_rng);
|
||||
|
||||
let effective_distribution = if distribution == Distribution::Default {
|
||||
Distribution::Uniform(0.0, 255.0) // Assuming UniformInt is the integer variant
|
||||
} else {
|
||||
distribution
|
||||
};
|
||||
|
||||
let tensor = Self::int_from_data(
|
||||
TensorData::random::<I, _, _>(shape, effective_distribution, &mut rng),
|
||||
device,
|
||||
);
|
||||
*seed = Some(rng);
|
||||
tensor
|
||||
}
|
||||
|
||||
fn int_powi(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| NdArrayMathOps::elementwise_op(
|
||||
lhs,
|
||||
rhs,
|
||||
|a: &I, b: &I| { (a.elem::<i64>().pow(b.elem::<u32>())).elem() }
|
||||
))
|
||||
}
|
||||
|
||||
fn int_powf(lhs: NdArrayTensor, rhs: FloatTensor<Self>) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(lhs, I, |array| -> NdArrayTensor {
|
||||
execute_with_float_dtype!(rhs, E, |rhs_array| {
|
||||
NdArrayMathOps::elementwise_op(array, rhs_array, |a: &I, b: &E| {
|
||||
(a.elem::<i64>().pow(*b as u32)).elem()
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn int_powf_scalar_impl(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(lhs, I, |array| {
|
||||
NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
|
||||
(a.elem::<i64>().pow(rhs.elem())).elem()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn int_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, |array| NdArrayOps::permute(array, axes))
|
||||
}
|
||||
|
||||
fn int_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, |array| NdArrayOps::flip(array, axes))
|
||||
}
|
||||
|
||||
fn int_sign(tensor: NdArrayTensor) -> NdArrayTensor {
|
||||
match tensor.dtype() {
|
||||
DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
|
||||
execute_with_dtype!(tensor, I, NdArrayMathOps::sign_op, [
|
||||
I64 => i64, I32 => i32, I16 => i16, I8 => i8
|
||||
])
|
||||
}
|
||||
DType::U64 | DType::U32 | DType::U16 | DType::U8 => {
|
||||
Self::int_greater_elem(tensor, 0.into())
|
||||
}
|
||||
other => panic!("Unsupported dtype: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, |array| NdArrayOps::expand(array, shape))
|
||||
}
|
||||
|
||||
fn bitwise_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitand)
|
||||
}
|
||||
|
||||
fn bitwise_and_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitand_scalar(array, rhs.elem()))
|
||||
}
|
||||
|
||||
fn bitwise_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitor)
|
||||
}
|
||||
|
||||
fn bitwise_or_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitor_scalar(array, rhs.elem()))
|
||||
}
|
||||
|
||||
fn bitwise_xor(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitxor)
|
||||
}
|
||||
|
||||
fn bitwise_xor_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitxor_scalar(array, rhs.elem()))
|
||||
}
|
||||
|
||||
fn bitwise_not(tensor: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(tensor, NdArrayBitOps::bitnot)
|
||||
}
|
||||
|
||||
fn bitwise_left_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
|
||||
NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
|
||||
(a.elem::<i64>() << (b.elem::<u32>())).elem()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn bitwise_left_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(lhs, I, |array| {
|
||||
NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
|
||||
(a.elem::<i64>() << rhs.elem::<u32>()).elem()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn bitwise_right_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
|
||||
execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
|
||||
NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
|
||||
(a.elem::<i64>() >> (b.elem::<u32>())).elem()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn bitwise_right_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_int_dtype!(lhs, I, |array| {
|
||||
NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
|
||||
(a.elem::<i64>() >> rhs.elem::<u32>()).elem()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
|
||||
execute_with_int_dtype!(tensor, |array| cast_to_dtype(array, dtype.into()))
|
||||
}
|
||||
|
||||
fn int_unfold(
|
||||
tensor: IntTensor<Self>,
|
||||
dim: usize,
|
||||
size: usize,
|
||||
step: usize,
|
||||
) -> IntTensor<Self> {
|
||||
execute_with_int_dtype!(tensor, |array| NdArrayOps::unfold(array, dim, size, step))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
use burn_backend::ElementConversion;
|
||||
use ndarray::{Array4, ArrayBase, DataOwned};
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[allow(unused_imports)]
|
||||
use num_traits::Float;
|
||||
|
||||
use crate::{FloatNdArrayElement, ShapeOps, SharedArray, UnsafeSharedRef, iter_range_par, run_par};
|
||||
|
||||
pub(crate) fn nearest_interpolate<E: FloatNdArrayElement>(
|
||||
x: SharedArray<E>,
|
||||
output_size: [usize; 2],
|
||||
) -> SharedArray<E> {
|
||||
let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();
|
||||
|
||||
let (batch_size, channels, in_height, in_width) = x.dim();
|
||||
let [out_height, out_width] = output_size;
|
||||
|
||||
let y_ratio = (in_height as f64) / (out_height as f64);
|
||||
let x_ratio = (in_width as f64) / (out_width as f64);
|
||||
|
||||
let out_element_num = batch_size * channels * out_height * out_width;
|
||||
let strides = (
|
||||
channels * out_height * out_width,
|
||||
out_height * out_width,
|
||||
out_width,
|
||||
);
|
||||
|
||||
let mut output = Array4::zeros((batch_size, channels, out_height, out_width));
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
run_par!(|| {
|
||||
iter_range_par!(0, out_element_num).for_each(|id| {
|
||||
let (b, c, h, w) = (
|
||||
id / strides.0,
|
||||
id % strides.0 / strides.1,
|
||||
id % strides.1 / strides.2,
|
||||
id % strides.2,
|
||||
);
|
||||
|
||||
let y_in = (y_ratio * h as f64).floor() as usize;
|
||||
let x_in = (x_ratio * w as f64).floor() as usize;
|
||||
|
||||
unsafe {
|
||||
let output = unsafe_shared_out.get();
|
||||
output[(b, c, h, w)] = x[(b, c, y_in, x_in)];
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
output.into_dyn().into_shared()
|
||||
}
|
||||
|
||||
pub(crate) fn nearest_interpolate_backward<E: FloatNdArrayElement>(
|
||||
x: SharedArray<E>,
|
||||
grad: SharedArray<E>,
|
||||
output_size: [usize; 2],
|
||||
) -> SharedArray<E> {
|
||||
let [batch_size, channels, input_height, input_width] = x.shape().dims();
|
||||
let [output_height, output_width] = output_size;
|
||||
|
||||
let mut output_grad =
|
||||
Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem());
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad);
|
||||
|
||||
run_par!(|| {
|
||||
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
let b = k / channels;
|
||||
let c = k % channels;
|
||||
|
||||
let output_grad = unsafe_shared_out.get();
|
||||
|
||||
for oh in 0..output_height {
|
||||
for ow in 0..output_width {
|
||||
let ih = start_index(oh, output_height, input_height);
|
||||
let iw = start_index(ow, output_width, input_width);
|
||||
|
||||
output_grad[[b, c, ih, iw]] += grad[[b, c, oh, ow]]
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
output_grad.into_dyn().into_shared()
|
||||
}
|
||||
|
||||
fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
|
||||
((output_size_index as f32 * input_size as f32) / output_size as f32).floor() as usize
|
||||
}
|
||||
|
||||
// clamp ceil(frac) to stay within bounds in case of floating-point imprecision
|
||||
pub(crate) fn ceil_clamp(frac: f64, max: usize) -> f64 {
|
||||
frac.ceil().min(max as f64)
|
||||
}
|
||||
|
||||
pub(crate) fn bilinear_interpolate<E: FloatNdArrayElement>(
|
||||
x: SharedArray<E>,
|
||||
output_size: [usize; 2],
|
||||
align_corners: bool,
|
||||
) -> SharedArray<E> {
|
||||
let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();
|
||||
|
||||
let (batch_size, channels, in_height, in_width) = x.dim();
|
||||
let [out_height, out_width] = output_size;
|
||||
|
||||
let out_element_num = batch_size * channels * out_height * out_width;
|
||||
let strides = (
|
||||
channels * out_height * out_width,
|
||||
out_height * out_width,
|
||||
out_width,
|
||||
);
|
||||
|
||||
let mut output = Array4::zeros((batch_size, channels, out_height, out_width));
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
run_par!(|| {
|
||||
iter_range_par!(0, out_element_num).for_each(|id| {
|
||||
let (b, c, h, w) = (
|
||||
id / strides.0,
|
||||
id % strides.0 / strides.1,
|
||||
id % strides.1 / strides.2,
|
||||
id % strides.2,
|
||||
);
|
||||
|
||||
let (y_frac, x_frac) = if align_corners {
|
||||
let y_ratio = ((in_height - 1) as f64) / (core::cmp::max(out_height - 1, 1) as f64);
|
||||
let x_ratio = ((in_width - 1) as f64) / (core::cmp::max(out_width - 1, 1) as f64);
|
||||
(y_ratio * h as f64, x_ratio * w as f64)
|
||||
} else {
|
||||
let y_frac = (h as f64 + 0.5) * (in_height as f64 / out_height as f64) - 0.5;
|
||||
let x_frac = (w as f64 + 0.5) * (in_width as f64 / out_width as f64) - 0.5;
|
||||
(
|
||||
y_frac.clamp(0.0, (in_height - 1) as f64),
|
||||
x_frac.clamp(0.0, (in_width - 1) as f64),
|
||||
)
|
||||
};
|
||||
let val =
|
||||
bilinear_interpolate_single(&x, b, c, x_frac, y_frac, in_width - 1, in_height - 1);
|
||||
|
||||
unsafe {
|
||||
let output = unsafe_shared_out.get();
|
||||
output[(b, c, h, w)] = val.elem();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
output.into_dyn().into_shared()
|
||||
}
|
||||
|
||||
pub(crate) fn bicubic_interpolate<E: FloatNdArrayElement>(
|
||||
x: SharedArray<E>,
|
||||
output_size: [usize; 2],
|
||||
align_corners: bool,
|
||||
) -> SharedArray<E> {
|
||||
fn cubic_interp1d(x0: f64, x1: f64, x2: f64, x3: f64, t: f64) -> f64 {
|
||||
fn cubic_convolution1(x: f64, a: f64) -> f64 {
|
||||
((a + 2.0) * x - (a + 3.0)) * x * x + 1.0
|
||||
}
|
||||
|
||||
fn cubic_convolution2(x: f64, a: f64) -> f64 {
|
||||
((a * x - 5.0 * a) * x + 8.0 * a) * x - 4.0 * a
|
||||
}
|
||||
|
||||
let coeffs = [
|
||||
cubic_convolution2(t + 1.0, -0.75),
|
||||
cubic_convolution1(t, -0.75),
|
||||
cubic_convolution1(1.0 - t, -0.75),
|
||||
cubic_convolution2(2.0 - t, -0.75),
|
||||
];
|
||||
|
||||
x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]
|
||||
}
|
||||
|
||||
let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();
|
||||
|
||||
let (batch_size, channels, in_height, in_width) = x.dim();
|
||||
let [out_height, out_width] = output_size;
|
||||
|
||||
let out_element_num = batch_size * channels * out_height * out_width;
|
||||
let strides = (
|
||||
channels * out_height * out_width,
|
||||
out_height * out_width,
|
||||
out_width,
|
||||
);
|
||||
|
||||
let mut output = Array4::zeros((batch_size, channels, out_height, out_width));
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
run_par!(|| {
|
||||
iter_range_par!(0, out_element_num).for_each(|id| {
|
||||
let (b, c, h, w) = (
|
||||
id / strides.0,
|
||||
id % strides.0 / strides.1,
|
||||
id % strides.1 / strides.2,
|
||||
id % strides.2,
|
||||
);
|
||||
|
||||
let (y_frac, x_frac) = if align_corners {
|
||||
let y_ratio = ((in_height - 1) as f64) / (core::cmp::max(out_height - 1, 1) as f64);
|
||||
let x_ratio = ((in_width - 1) as f64) / (core::cmp::max(out_width - 1, 1) as f64);
|
||||
(y_ratio * h as f64, x_ratio * w as f64)
|
||||
} else {
|
||||
let y_frac = (h as f64 + 0.5) * (in_height as f64 / out_height as f64) - 0.5;
|
||||
let x_frac = (w as f64 + 0.5) * (in_width as f64 / out_width as f64) - 0.5;
|
||||
(y_frac, x_frac)
|
||||
};
|
||||
let y0 = y_frac.floor();
|
||||
let yw = y_frac - y0;
|
||||
let y_in = y0 as isize;
|
||||
|
||||
let x0 = x_frac.floor();
|
||||
let xw = x_frac - x0;
|
||||
let x_in = x0 as isize;
|
||||
|
||||
let max_h = (in_height - 1) as isize;
|
||||
let max_w = (in_width - 1) as isize;
|
||||
|
||||
let ys_in = [
|
||||
(y_in - 1).clamp(0, max_h) as usize,
|
||||
y_in.clamp(0, max_h) as usize,
|
||||
(y_in + 1).clamp(0, max_h) as usize,
|
||||
(y_in + 2).clamp(0, max_h) as usize,
|
||||
];
|
||||
|
||||
let xs_in = [
|
||||
(x_in - 1).clamp(0, max_w) as usize,
|
||||
x_in.clamp(0, max_w) as usize,
|
||||
(x_in + 1).clamp(0, max_w) as usize,
|
||||
(x_in + 2).clamp(0, max_w) as usize,
|
||||
];
|
||||
|
||||
let coefficients = ys_in.map(|y| {
|
||||
cubic_interp1d(
|
||||
x[(b, c, y, xs_in[0])].elem(),
|
||||
x[(b, c, y, xs_in[1])].elem(),
|
||||
x[(b, c, y, xs_in[2])].elem(),
|
||||
x[(b, c, y, xs_in[3])].elem(),
|
||||
xw,
|
||||
)
|
||||
});
|
||||
|
||||
let result = cubic_interp1d(
|
||||
coefficients[0],
|
||||
coefficients[1],
|
||||
coefficients[2],
|
||||
coefficients[3],
|
||||
yw,
|
||||
)
|
||||
.elem();
|
||||
|
||||
unsafe {
|
||||
let output = unsafe_shared_out.get();
|
||||
output[(b, c, h, w)] = result;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
output.into_dyn().into_shared()
|
||||
}
|
||||
|
||||
/// Sample an element of the source array with bilinear interpolation
|
||||
///
|
||||
/// * `source` - The tensor to read from. Has shape (batch_size, channels, height, width)
|
||||
/// * `b` - The batch to read from
|
||||
/// * `c` - The channel to read from
|
||||
/// * `x` - The x position to read in the array
|
||||
/// * `y` - The y position to read in the array
|
||||
/// * `x_max` - The max x position (inclusive)
|
||||
/// * `y_max` - The max y position (inclusive)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The interpolated value read from the array
|
||||
pub(crate) fn bilinear_interpolate_single<E, S>(
|
||||
source: &ArrayBase<S, ndarray::Dim<[usize; 4]>>,
|
||||
b: usize,
|
||||
c: usize,
|
||||
x: f64,
|
||||
y: f64,
|
||||
x_max: usize,
|
||||
y_max: usize,
|
||||
) -> f64
|
||||
where
|
||||
E: FloatNdArrayElement,
|
||||
S: DataOwned<Elem = E>,
|
||||
{
|
||||
let y0 = y.floor();
|
||||
let y1 = ceil_clamp(y, y_max);
|
||||
let yw = y - y0;
|
||||
|
||||
let x0 = x.floor();
|
||||
let x1 = ceil_clamp(x, x_max);
|
||||
let xw = x - x0;
|
||||
|
||||
let (x0, x1, y0, y1) = (x0 as usize, x1 as usize, y0 as usize, y1 as usize);
|
||||
|
||||
let p_a = source[(b, c, y0, x0)].elem::<f64>() * (1.0 - xw) * (1.0 - yw);
|
||||
let p_b = source[(b, c, y0, x1)].elem::<f64>() * xw * (1.0 - yw);
|
||||
let p_c = source[(b, c, y1, x0)].elem::<f64>() * (1.0 - xw) * yw;
|
||||
let p_d = source[(b, c, y1, x1)].elem::<f64>() * xw * yw;
|
||||
|
||||
p_a + p_b + p_c + p_d
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
macro_rules! keepdim {
|
||||
(
|
||||
$dim:expr,
|
||||
$self:expr,
|
||||
mean
|
||||
) => {{
|
||||
// Get shape first (via reference), then pass ownership to avoid clone
|
||||
let mut shape = $self.shape().into_shape();
|
||||
shape[$dim] = 1;
|
||||
let tensor: SharedArray<E> = mean_dim($self, $dim);
|
||||
NdArrayOps::reshape(tensor, shape)
|
||||
}};
|
||||
(
|
||||
$dim:expr,
|
||||
$self:expr,
|
||||
sum
|
||||
) => {{
|
||||
// Get shape first (via reference), then pass ownership to avoid clone
|
||||
let mut shape = $self.shape().into_shape();
|
||||
shape[$dim] = 1;
|
||||
let tensor: SharedArray<E> = sum_dim($self, $dim);
|
||||
NdArrayOps::reshape(tensor, shape)
|
||||
}};
|
||||
(
|
||||
$dim:expr,
|
||||
$self:expr,
|
||||
prod
|
||||
) => {{
|
||||
// Get shape first (via reference), then pass ownership to avoid clone
|
||||
let mut shape = $self.shape().into_shape();
|
||||
shape[$dim] = 1;
|
||||
let tensor: SharedArray<E> = prod_dim($self, $dim);
|
||||
NdArrayOps::reshape(tensor, shape)
|
||||
}};
|
||||
}
|
||||
|
||||
use burn_backend::ElementConversion;
|
||||
pub(crate) use keepdim;
|
||||
use ndarray::{Axis, Zip};
|
||||
|
||||
use crate::{SharedArray, element::NdArrayElement};
|
||||
|
||||
pub(crate) fn mean_dim<E: NdArrayElement>(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
|
||||
tensor.mean_axis(Axis(dim)).unwrap().into_shared()
|
||||
}
|
||||
|
||||
pub(crate) fn sum_dim<E: NdArrayElement>(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
|
||||
tensor.sum_axis(Axis(dim)).into_shared()
|
||||
}
|
||||
|
||||
pub(crate) fn prod_dim<E: NdArrayElement>(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
|
||||
tensor
|
||||
.fold_axis(Axis(dim), 1.elem::<E>(), |acc, &x| acc.mul(x.elem()))
|
||||
.into_shared()
|
||||
}
|
||||
|
||||
/// Generic cumulative operation function with closure-based operation.
|
||||
pub(crate) fn cumulative_with_op<E, F>(tensor: SharedArray<E>, dim: usize, op: F) -> SharedArray<E>
|
||||
where
|
||||
E: NdArrayElement,
|
||||
F: Fn(&mut E, &E),
|
||||
{
|
||||
let axis = Axis(dim);
|
||||
let shape = tensor.shape().to_vec();
|
||||
// Use into_owned() instead of to_owned() - only copies if shared, avoids copy if unique
|
||||
let mut result = tensor.into_owned();
|
||||
let dim_size = shape[dim];
|
||||
|
||||
for i in 1..dim_size {
|
||||
let prev = result.index_axis(axis, i - 1).to_owned();
|
||||
let mut current = result.index_axis_mut(axis, i);
|
||||
Zip::from(&mut current).and(&prev).for_each(&op);
|
||||
}
|
||||
|
||||
result.into_shared()
|
||||
}
|
||||
|
||||
// Define all cumulative operation functions using the generic function
|
||||
pub(crate) fn cumsum_dim<E: NdArrayElement>(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
|
||||
cumulative_with_op(tensor, dim, |c, &p| *c = c.add(p.elem()))
|
||||
}
|
||||
|
||||
pub(crate) fn cumprod_dim<E: NdArrayElement>(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
|
||||
cumulative_with_op(tensor, dim, |c, &p| *c = c.mul(p.elem()))
|
||||
}
|
||||
|
||||
pub(crate) fn cummin_dim<E: NdArrayElement + core::cmp::PartialOrd<E>>(
|
||||
tensor: SharedArray<E>,
|
||||
dim: usize,
|
||||
) -> SharedArray<E> {
|
||||
cumulative_with_op(tensor, dim, |c, &p| {
|
||||
if p < *c {
|
||||
*c = p;
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn cummax_dim<E: NdArrayElement + core::cmp::PartialOrd<E>>(
|
||||
tensor: SharedArray<E>,
|
||||
dim: usize,
|
||||
) -> SharedArray<E> {
|
||||
cumulative_with_op(tensor, dim, |c, &p| {
|
||||
if p > *c {
|
||||
*c = p;
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,362 @@
|
||||
use crate::UnsafeSharedRef;
|
||||
use crate::{NdArrayElement, ShapeOps, SharedArray, iter_range_par, ops::NdArrayOps, run_par};
|
||||
|
||||
use alloc::{vec, vec::Vec};
|
||||
use burn_backend::ElementConversion;
|
||||
use burn_backend::Shape;
|
||||
use ndarray::{IxDyn, s};
|
||||
|
||||
pub(crate) fn matmul<E: NdArrayElement>(
|
||||
lhs: SharedArray<E>,
|
||||
rhs: SharedArray<E>,
|
||||
) -> SharedArray<E> {
|
||||
let shape_lhs = lhs.shape();
|
||||
let shape_rhs = rhs.shape();
|
||||
let ndims = shape_lhs.num_dims();
|
||||
let m = shape_lhs[ndims - 2]; // # of left rows
|
||||
let k = shape_rhs[ndims - 2]; // # of left cols and right rows
|
||||
let n = shape_rhs[ndims - 1]; // # of right cols
|
||||
|
||||
let (out_shape, strides_lhs, strides_rhs, strides_out) = output_shape(shape_lhs, shape_rhs);
|
||||
let l_mat_size = m * k; // size of matrix component of left array
|
||||
let r_mat_size = k * n; // size of matrix component of right array
|
||||
let out_mat_size = m * n; // size of matrix component of output array
|
||||
|
||||
let num_l_batches = shape_lhs.num_elements() / l_mat_size;
|
||||
let num_r_batches = shape_rhs.num_elements() / r_mat_size;
|
||||
let num_out_batches = out_shape.num_elements() / out_mat_size;
|
||||
|
||||
let lhs_array = NdArrayOps::reshape(lhs, Shape::new([num_l_batches, m, k]));
|
||||
let rhs_array = NdArrayOps::reshape(rhs, Shape::new([num_r_batches, k, n]));
|
||||
|
||||
let alpha: E = 1.0.elem();
|
||||
let beta: E = 0.0.elem();
|
||||
|
||||
let out = run_par!(|| {
|
||||
let mut out_array = ndarray::Array3::<E>::zeros((num_out_batches, m, n));
|
||||
let unsafe_shared_out_array = UnsafeSharedRef::new(&mut out_array);
|
||||
|
||||
iter_range_par!(0, num_out_batches).for_each(|out_batch| {
|
||||
// Here, we:
|
||||
// 1. Un-flatten the output batch into a component-based batch index.
|
||||
// 2. Use the strides for left and right batch indices to convert it to a flattened
|
||||
// batch for left and right.
|
||||
let out_index = strides_out.unflatten(out_batch);
|
||||
let l_batch = strides_lhs.flatten(&out_index);
|
||||
let r_batch = strides_rhs.flatten(&out_index);
|
||||
|
||||
let lhs_slice = lhs_array.slice(s!(l_batch, .., ..));
|
||||
let rhs_slice = rhs_array.slice(s!(r_batch, .., ..));
|
||||
|
||||
unsafe {
|
||||
let mut out_slice = unsafe_shared_out_array
|
||||
.get()
|
||||
.slice_mut(s!(out_batch, .., ..));
|
||||
|
||||
ndarray::linalg::general_mat_mul(
|
||||
alpha,
|
||||
&lhs_slice,
|
||||
&rhs_slice,
|
||||
beta,
|
||||
&mut out_slice,
|
||||
)
|
||||
}
|
||||
});
|
||||
|
||||
out_array.into_shared().into_dyn()
|
||||
});
|
||||
|
||||
NdArrayOps::reshape(out, out_shape)
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
struct Strides {
|
||||
strides: Vec<usize>,
|
||||
}
|
||||
impl Strides {
|
||||
fn new(strides: Vec<usize>) -> Self {
|
||||
Strides { strides }
|
||||
}
|
||||
|
||||
fn unflatten(&self, linear_index: usize) -> Vec<usize> {
|
||||
let mut coord = Vec::with_capacity(self.strides.len());
|
||||
let mut rem = linear_index;
|
||||
for stride in self.strides.iter() {
|
||||
coord.push(rem / stride);
|
||||
rem %= stride;
|
||||
}
|
||||
coord
|
||||
}
|
||||
|
||||
fn flatten(&self, index: &Vec<usize>) -> usize {
|
||||
assert_eq!(self.strides.len(), index.len());
|
||||
self.strides
|
||||
.iter()
|
||||
.zip(index)
|
||||
.map(|(stride, index)| stride * index)
|
||||
.sum()
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the (broadcasted) output shape of matrix multiplication, along with strides for
|
||||
/// the non-matrix dimensions of all arrays.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `lsh`: Shape of the first (left-hand) matrix multiplication argument.
|
||||
/// * `rsh`: Shape of the second (right-hand) matrix multiplication argument.
|
||||
///
|
||||
/// # Panics
|
||||
/// * If `D` is not at least 2.
|
||||
/// * If the matrix multiplication dimensions (last 2) are incompatible.
|
||||
/// * If any other dimension is not the same for both tensors, or equal to 1. (Any dimension where
|
||||
/// one dim is equal to 1 is broadcast.)
|
||||
fn output_shape(lsh: &[usize], rsh: &[usize]) -> (Shape, Strides, Strides, Strides) {
|
||||
let ndims = lsh.num_dims();
|
||||
if ndims < 2 {
|
||||
panic!("Matrix multiplication requires an array with at least 2 dimensions.");
|
||||
}
|
||||
|
||||
// Fetch matrix dimensions and check compatibility.
|
||||
let l_rows = lsh[ndims - 2];
|
||||
let l_cols = lsh[ndims - 1];
|
||||
let r_rows = rsh[ndims - 2];
|
||||
let r_cols = rsh[ndims - 1];
|
||||
if l_cols != r_rows {
|
||||
panic!("Dimensions are incompatible for matrix multiplication.");
|
||||
}
|
||||
// Set matrix dimensions of the output shape.
|
||||
let mut osh = vec![0; ndims];
|
||||
osh[ndims - 2] = l_rows;
|
||||
osh[ndims - 1] = r_cols;
|
||||
|
||||
// Set other array dimensions, broadcasting as necessary.
|
||||
// Compute the strides inline.
|
||||
let mut cur_l_stride: usize = 1;
|
||||
let mut cur_r_stride: usize = 1;
|
||||
let mut cur_o_stride: usize = 1;
|
||||
let mut l_strides = Vec::with_capacity(ndims - 2);
|
||||
let mut r_strides = Vec::with_capacity(ndims - 2);
|
||||
let mut o_strides = Vec::with_capacity(ndims - 2);
|
||||
for i in (0..ndims - 2).rev() {
|
||||
let l_dim = lsh[i];
|
||||
let r_dim = rsh[i];
|
||||
|
||||
// Compatible dimensions are:
|
||||
// 1. Both dimensions are equal.
|
||||
// 2. One of the dimensions is equal to 1.
|
||||
let o_dim: usize;
|
||||
if l_dim == r_dim {
|
||||
o_dim = l_dim; // both dimensions are equal
|
||||
l_strides.push(cur_l_stride);
|
||||
r_strides.push(cur_r_stride);
|
||||
} else if l_dim == 1 {
|
||||
o_dim = r_dim; // broadcast the left
|
||||
l_strides.push(0);
|
||||
r_strides.push(cur_r_stride);
|
||||
} else if r_dim == 1 {
|
||||
o_dim = l_dim; // broadcast the right
|
||||
l_strides.push(cur_l_stride);
|
||||
r_strides.push(0);
|
||||
} else {
|
||||
panic!("Dimensions differ and cannot be broadcasted.");
|
||||
}
|
||||
osh[i] = o_dim;
|
||||
o_strides.push(cur_o_stride);
|
||||
cur_o_stride *= o_dim;
|
||||
|
||||
cur_l_stride *= l_dim;
|
||||
cur_r_stride *= r_dim;
|
||||
}
|
||||
l_strides.reverse();
|
||||
r_strides.reverse();
|
||||
o_strides.reverse();
|
||||
|
||||
(
|
||||
Shape::from(osh),
|
||||
Strides::new(l_strides),
|
||||
Strides::new(r_strides),
|
||||
Strides::new(o_strides),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn cross<E: NdArrayElement>(
|
||||
lhs: SharedArray<E>,
|
||||
rhs: SharedArray<E>,
|
||||
dim: usize,
|
||||
) -> SharedArray<E> {
|
||||
let shape_lhs = lhs.shape();
|
||||
let shape_rhs = rhs.shape();
|
||||
let ndims = shape_lhs.num_dims();
|
||||
|
||||
// Broadcast the shapes except along dim
|
||||
let mut broadcast_shape = vec![0; ndims];
|
||||
for i in 0..ndims {
|
||||
if i == dim {
|
||||
broadcast_shape[i] = shape_lhs[i]; // already checked to be 3
|
||||
} else {
|
||||
let l = shape_lhs[i];
|
||||
let r = shape_rhs[i];
|
||||
if l == r {
|
||||
broadcast_shape[i] = l;
|
||||
} else if l == 1 {
|
||||
broadcast_shape[i] = r;
|
||||
} else if r == 1 {
|
||||
broadcast_shape[i] = l;
|
||||
} else {
|
||||
panic!("Tensors are not broadcastable along dimension {}", i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast lhs and rhs
|
||||
let lhs_broadcast = if shape_lhs == broadcast_shape.as_slice() {
|
||||
lhs
|
||||
} else {
|
||||
NdArrayOps::expand(lhs, Shape::from(broadcast_shape.clone()))
|
||||
};
|
||||
let rhs_broadcast = if shape_rhs == broadcast_shape.as_slice() {
|
||||
rhs
|
||||
} else {
|
||||
NdArrayOps::expand(rhs, Shape::from(broadcast_shape.clone()))
|
||||
};
|
||||
|
||||
// Now, move dim to the last dimension
|
||||
let mut perm = (0..ndims).collect::<Vec<_>>();
|
||||
perm.remove(dim);
|
||||
perm.push(dim);
|
||||
|
||||
let lhs_permuted = NdArrayOps::permute(lhs_broadcast, &perm);
|
||||
let rhs_permuted = NdArrayOps::permute(rhs_broadcast, &perm);
|
||||
|
||||
// Reshape to (*, 3)
|
||||
let total_elements = lhs_permuted.shape().num_elements();
|
||||
let batch_size = total_elements / 3;
|
||||
let lhs_reshaped = NdArrayOps::reshape(lhs_permuted, Shape::new([batch_size, 3]));
|
||||
let rhs_reshaped = NdArrayOps::reshape(rhs_permuted, Shape::new([batch_size, 3]));
|
||||
|
||||
// Compute cross product
|
||||
let mut result = ndarray::ArrayD::<E>::zeros(IxDyn(&[batch_size, 3]));
|
||||
for i in 0..batch_size {
|
||||
let a1 = lhs_reshaped[IxDyn(&[i, 0])];
|
||||
let a2 = lhs_reshaped[IxDyn(&[i, 1])];
|
||||
let a3 = lhs_reshaped[IxDyn(&[i, 2])];
|
||||
let b1 = rhs_reshaped[IxDyn(&[i, 0])];
|
||||
let b2 = rhs_reshaped[IxDyn(&[i, 1])];
|
||||
let b3 = rhs_reshaped[IxDyn(&[i, 2])];
|
||||
result[IxDyn(&[i, 0])] = a2.mul(b3).sub(a3.mul(b2));
|
||||
result[IxDyn(&[i, 1])] = a3.mul(b1).sub(a1.mul(b3));
|
||||
result[IxDyn(&[i, 2])] = a1.mul(b2).sub(a2.mul(b1));
|
||||
}
|
||||
|
||||
let result_shared = result.into_shared();
|
||||
|
||||
// Reshape back to the broadcast shape with dim at the end
|
||||
let mut result_shape = broadcast_shape;
|
||||
result_shape.remove(dim);
|
||||
result_shape.push(3);
|
||||
let result_reshaped = NdArrayOps::reshape(result_shared, Shape::from(result_shape));
|
||||
|
||||
// Permute back
|
||||
let mut inv_perm = vec![0; ndims];
|
||||
for (i, &p) in perm.iter().enumerate() {
|
||||
inv_perm[p] = i;
|
||||
}
|
||||
NdArrayOps::permute(result_reshaped, &inv_perm)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
impl Strides {
|
||||
fn empty() -> Self {
|
||||
Strides {
|
||||
strides: Vec::with_capacity(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_shape() {
|
||||
// plain matrix multiply
|
||||
assert_eq!(
|
||||
output_shape(&[5, 3], &[3, 7]),
|
||||
(
|
||||
Shape::from([5, 7]),
|
||||
Strides::empty(),
|
||||
Strides::empty(),
|
||||
Strides::empty()
|
||||
)
|
||||
);
|
||||
// matrix multiply with one extra stack dimension
|
||||
assert_eq!(
|
||||
output_shape(&[4, 5, 3], &[4, 3, 7]),
|
||||
(
|
||||
Shape::from([4, 5, 7]),
|
||||
Strides::new(vec![1]),
|
||||
Strides::new(vec![1]),
|
||||
Strides::new(vec![1])
|
||||
)
|
||||
);
|
||||
// rank 3, broadcast left
|
||||
assert_eq!(
|
||||
output_shape(&[1, 5, 3], &[4, 3, 7]),
|
||||
(
|
||||
Shape::from([4, 5, 7]),
|
||||
Strides::new(vec![0]),
|
||||
Strides::new(vec![1]),
|
||||
Strides::new(vec![1])
|
||||
)
|
||||
);
|
||||
// rank 3, broadcast right
|
||||
assert_eq!(
|
||||
output_shape(&[4, 5, 3], &[1, 3, 7]),
|
||||
(
|
||||
Shape::from([4, 5, 7]),
|
||||
Strides::new(vec![1]),
|
||||
Strides::new(vec![0]),
|
||||
Strides::new(vec![1])
|
||||
)
|
||||
);
|
||||
// rank 4, multi broadcast
|
||||
assert_eq!(
|
||||
output_shape(&[1, 4, 5, 3], &[8, 1, 3, 7]),
|
||||
(
|
||||
Shape::from([8, 4, 5, 7]),
|
||||
Strides::new(vec![0, 1]),
|
||||
Strides::new(vec![1, 0]),
|
||||
Strides::new(vec![4, 1])
|
||||
)
|
||||
);
|
||||
// rank 5, multi-broadcast
|
||||
assert_eq!(
|
||||
output_shape(&[1, 3, 4, 5, 3], &[8, 3, 1, 3, 7]),
|
||||
(
|
||||
Shape::from([8, 3, 4, 5, 7]),
|
||||
Strides::new(vec![0, 4, 1]),
|
||||
Strides::new(vec![3, 1, 0]),
|
||||
Strides::new(vec![12, 4, 1])
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(
|
||||
expected = "Matrix multiplication requires an array with at least 2 dimensions."
|
||||
)]
|
||||
fn test_output_shape_too_small() {
|
||||
output_shape(&[4], &[4]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Dimensions are incompatible for matrix multiplication.")]
|
||||
fn test_output_shape_bad_matrix_dims() {
|
||||
output_shape(&[5, 3], &[4, 7]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Dimensions differ and cannot be broadcasted.")]
|
||||
fn test_output_shape_non_broadcast() {
|
||||
output_shape(&[4, 5, 3], &[2, 3, 7]);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
use crate::{
|
||||
ShapeOps, SharedArray,
|
||||
element::{FloatNdArrayElement, IntNdArrayElement},
|
||||
iter_range_par,
|
||||
ops::padding::apply_padding_4d,
|
||||
run_par,
|
||||
sharing::UnsafeSharedRef,
|
||||
};
|
||||
|
||||
use burn_backend::ElementConversion;
|
||||
use burn_backend::ops::conv::calculate_pool_output_size;
|
||||
use ndarray::Array4;
|
||||
|
||||
pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
|
||||
x: SharedArray<E>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
) -> SharedArray<E> {
|
||||
let [kernel_height, kernel_width] = kernel_size;
|
||||
let [padding_height, padding_width] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
let [dilation_height, dilation_width] = dilation;
|
||||
let [batch_size, channels, x_height, x_width] = x.shape().dims();
|
||||
let inf = (-f32::INFINITY).elem::<E>();
|
||||
|
||||
let out_height = calculate_pool_output_size(
|
||||
kernel_height,
|
||||
stride_height,
|
||||
padding_height,
|
||||
dilation_height,
|
||||
x_height,
|
||||
ceil_mode,
|
||||
);
|
||||
let out_width = calculate_pool_output_size(
|
||||
kernel_width,
|
||||
stride_width,
|
||||
padding_width,
|
||||
dilation_width,
|
||||
x_width,
|
||||
ceil_mode,
|
||||
);
|
||||
|
||||
// Calculate extra padding needed for ceil_mode
|
||||
// The maximum input position accessed is: (out_size - 1) * stride + (kernel_size - 1) * dilation
|
||||
// This must be < input_size + 2 * total_padding
|
||||
let max_ih =
|
||||
(out_height.saturating_sub(1)) * stride_height + (kernel_height - 1) * dilation_height;
|
||||
let max_iw = (out_width.saturating_sub(1)) * stride_width + (kernel_width - 1) * dilation_width;
|
||||
let padded_height = x_height + 2 * padding_height;
|
||||
let padded_width = x_width + 2 * padding_width;
|
||||
let extra_pad_h = max_ih.saturating_sub(padded_height.saturating_sub(1));
|
||||
let extra_pad_w = max_iw.saturating_sub(padded_width.saturating_sub(1));
|
||||
let total_padding = [padding_height + extra_pad_h, padding_width + extra_pad_w];
|
||||
|
||||
let x = apply_padding_4d::<E>(x, total_padding, inf);
|
||||
|
||||
// Offset to account for extra padding (extra_pad is added on both sides by apply_padding_4d)
|
||||
let offset_h = extra_pad_h;
|
||||
let offset_w = extra_pad_w;
|
||||
|
||||
let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf);
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
run_par!(|| {
|
||||
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
let b = k / channels;
|
||||
let c = k % channels;
|
||||
|
||||
let output = unsafe_shared_out.get();
|
||||
|
||||
for oh in 0..out_height {
|
||||
for ow in 0..out_width {
|
||||
let mut max_val = inf;
|
||||
|
||||
for kh in 0..kernel_height {
|
||||
let ih = offset_h + oh * stride_height + kh * dilation_height;
|
||||
|
||||
for kw in 0..kernel_width {
|
||||
let iw = offset_w + ow * stride_width + kw * dilation_width;
|
||||
|
||||
let val = x[[b, c, ih, iw]];
|
||||
|
||||
if val > max_val {
|
||||
max_val = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output[[b, c, oh, ow]] = max_val;
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
output.into_dyn().into_shared()
|
||||
}
|
||||
|
||||
pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement, I: IntNdArrayElement>(
|
||||
x: SharedArray<E>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
) -> (SharedArray<E>, SharedArray<I>) {
|
||||
let [kernel_height, kernel_width] = kernel_size;
|
||||
let [padding_height, padding_width] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
let [dilation_height, dilation_width] = dilation;
|
||||
let [batch_size, channels, x_height, x_width] = x.shape().dims();
|
||||
let inf = (-f32::INFINITY).elem::<E>();
|
||||
|
||||
let out_height = calculate_pool_output_size(
|
||||
kernel_height,
|
||||
stride_height,
|
||||
padding_height,
|
||||
dilation_height,
|
||||
x_height,
|
||||
ceil_mode,
|
||||
);
|
||||
let out_width = calculate_pool_output_size(
|
||||
kernel_width,
|
||||
stride_width,
|
||||
padding_width,
|
||||
dilation_width,
|
||||
x_width,
|
||||
ceil_mode,
|
||||
);
|
||||
|
||||
// Calculate extra padding needed for ceil_mode
|
||||
let max_ih =
|
||||
(out_height.saturating_sub(1)) * stride_height + (kernel_height - 1) * dilation_height;
|
||||
let max_iw = (out_width.saturating_sub(1)) * stride_width + (kernel_width - 1) * dilation_width;
|
||||
let padded_height = x_height + 2 * padding_height;
|
||||
let padded_width = x_width + 2 * padding_width;
|
||||
let extra_pad_h = max_ih.saturating_sub(padded_height.saturating_sub(1));
|
||||
let extra_pad_w = max_iw.saturating_sub(padded_width.saturating_sub(1));
|
||||
let total_padding = [padding_height + extra_pad_h, padding_width + extra_pad_w];
|
||||
|
||||
let x = apply_padding_4d::<E>(x, total_padding, inf);
|
||||
|
||||
// Offset to account for extra padding
|
||||
let offset_h = extra_pad_h;
|
||||
let offset_w = extra_pad_w;
|
||||
|
||||
let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf);
|
||||
let mut indices = Array4::<I>::zeros((batch_size, channels, out_height, out_width));
|
||||
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices);
|
||||
|
||||
run_par!(|| {
|
||||
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
let b = k / channels;
|
||||
let c = k % channels;
|
||||
|
||||
let output = unsafe_shared_out.get();
|
||||
let indices = unsafe_shared_indices.get();
|
||||
|
||||
for oh in 0..out_height {
|
||||
for ow in 0..out_width {
|
||||
let mut max_val = inf;
|
||||
let mut index = 0;
|
||||
|
||||
for kh in 0..kernel_height {
|
||||
let ih = offset_h + oh * stride_height + kh * dilation_height;
|
||||
|
||||
for kw in 0..kernel_width {
|
||||
let iw = offset_w + ow * stride_width + kw * dilation_width;
|
||||
let val = x[[b, c, ih, iw]];
|
||||
|
||||
if val > max_val {
|
||||
max_val = val;
|
||||
|
||||
// Calculate index in original (unpadded) input
|
||||
let ih_orig = ih as i64 - (total_padding[0]) as i64;
|
||||
let iw_orig = iw as i64 - (total_padding[1]) as i64;
|
||||
|
||||
// Clamp to valid range for index calculation
|
||||
let ih_clamped = ih_orig.max(0).min(x_height as i64 - 1);
|
||||
let iw_clamped = iw_orig.max(0).min(x_width as i64 - 1);
|
||||
|
||||
index = ih_clamped * x_width as i64 + iw_clamped;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output[[b, c, oh, ow]] = max_val;
|
||||
indices[[b, c, oh, ow]] = index.elem();
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
let output = output.into_dyn().into_shared();
|
||||
let indices = indices.into_dyn().into_shared();
|
||||
|
||||
(output, indices)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement, I: IntNdArrayElement>(
|
||||
x: SharedArray<E>,
|
||||
_kernel_size: [usize; 2],
|
||||
_stride: [usize; 2],
|
||||
_padding: [usize; 2],
|
||||
_dilation: [usize; 2],
|
||||
_ceil_mode: bool,
|
||||
output_grad: SharedArray<E>,
|
||||
indices: SharedArray<I>,
|
||||
) -> SharedArray<E> {
|
||||
let [_batch_size, _channels, height, width] = output_grad.shape().dims();
|
||||
let [batch_size, channels, height_x, width_x] = x.shape().dims();
|
||||
|
||||
let output_grad = output_grad;
|
||||
let indices = indices;
|
||||
|
||||
let mut output = Array4::zeros((batch_size, channels, height_x, width_x));
|
||||
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
run_par!(|| {
|
||||
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
let b = k / channels;
|
||||
let c = k % channels;
|
||||
|
||||
let output = unsafe_shared_out.get();
|
||||
|
||||
for h in 0..height {
|
||||
for w in 0..width {
|
||||
let index = indices[[b, c, h, w]].elem::<i64>();
|
||||
let grad = output_grad[[b, c, h, w]];
|
||||
|
||||
let index_h = index as usize / width_x;
|
||||
let index_w = index as usize % width_x;
|
||||
|
||||
output[[b, c, index_h, index_w]] += grad;
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
output.into_dyn().into_shared()
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
mod activation;
|
||||
mod base;
|
||||
mod bool_tensor;
|
||||
mod int_tensor;
|
||||
mod module;
|
||||
mod qtensor;
|
||||
#[cfg(feature = "simd")]
|
||||
mod simd;
|
||||
mod tensor;
|
||||
mod transaction;
|
||||
|
||||
pub(crate) mod adaptive_avgpool;
|
||||
pub(crate) mod avgpool;
|
||||
pub(crate) mod conv;
|
||||
pub(crate) mod deform_conv;
|
||||
pub(crate) mod grid_sample;
|
||||
pub(crate) mod interpolate;
|
||||
pub(crate) mod macros;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod maxpool;
|
||||
pub(crate) mod padding;
|
||||
pub(crate) mod quantization;
|
||||
|
||||
pub(crate) use base::*;
|
||||
@@ -0,0 +1,367 @@
|
||||
use super::{
|
||||
adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
|
||||
avgpool::{avg_pool2d, avg_pool2d_backward},
|
||||
conv::{conv_transpose2d, conv_transpose3d, conv2d, conv3d},
|
||||
deform_conv::{backward::deform_conv2d_backward, deform_conv2d},
|
||||
interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
|
||||
maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
|
||||
};
|
||||
#[cfg(feature = "simd")]
|
||||
use crate::ops::simd::{
|
||||
avgpool::try_avg_pool2d_simd, conv::try_conv2d_simd, maxpool::try_max_pool2d_simd,
|
||||
};
|
||||
use crate::{
|
||||
NdArray, SharedArray, element::FloatNdArrayElement, execute_with_int_dtype,
|
||||
tensor::NdArrayTensor,
|
||||
};
|
||||
use crate::{
|
||||
element::{IntNdArrayElement, QuantElement},
|
||||
ops::interpolate::nearest_interpolate_backward,
|
||||
};
|
||||
use burn_backend::{
|
||||
ElementConversion, TensorMetadata,
|
||||
ops::{attention::attention_fallback, *},
|
||||
tensor::FloatTensor,
|
||||
};
|
||||
|
||||
macro_rules! module_op {
|
||||
// Module op with inputs (inp), optional (opt) and arguments (args).
|
||||
// Converts NdArrayStorage to SharedArray for compatibility with existing operations.
|
||||
(inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{
|
||||
#[allow(unused_parens, unreachable_patterns)]
|
||||
match ($($x),+) {
|
||||
($(NdArrayTensor::F32($x)),+) => {
|
||||
type $element = f32;
|
||||
$op(
|
||||
$($x.into_shared()),+
|
||||
$(, $opt.map(|o| match o { NdArrayTensor::F32(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))*
|
||||
)
|
||||
}
|
||||
($(NdArrayTensor::F64($x)),+) => {
|
||||
type $element = f64;
|
||||
$op(
|
||||
$($x.into_shared()),+
|
||||
$(, $opt.map(|o| match o { NdArrayTensor::F64(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))*
|
||||
)
|
||||
}
|
||||
_ => panic!("Data type mismatch"),
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ModuleOps<Self>
|
||||
for NdArray<E, I, Q>
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<E>>,
|
||||
NdArrayTensor: From<SharedArray<I>>,
|
||||
{
|
||||
fn conv2d(
|
||||
x: NdArrayTensor,
|
||||
weight: NdArrayTensor,
|
||||
bias: Option<NdArrayTensor>,
|
||||
options: ConvOptions<2>,
|
||||
) -> NdArrayTensor {
|
||||
module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
|
||||
#[cfg(feature = "simd")]
|
||||
let (x, weight, bias) = match try_conv2d_simd(x, weight, bias, options.clone()) {
|
||||
Ok(out) => return out.into(),
|
||||
Err(args) => args,
|
||||
};
|
||||
conv2d::<E>(x, weight, bias, options).into()
|
||||
})
|
||||
}
|
||||
|
||||
fn deform_conv2d(
|
||||
x: FloatTensor<Self>,
|
||||
offset: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
mask: Option<FloatTensor<Self>>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
module_op!(
|
||||
inp(x, offset, weight),
|
||||
opt(mask, bias),
|
||||
E,
|
||||
|x, offset, weight, mask, bias| deform_conv2d::<E>(
|
||||
x, offset, weight, mask, bias, options
|
||||
)
|
||||
.into()
|
||||
)
|
||||
}
|
||||
|
||||
fn deform_conv2d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
offset: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
mask: Option<FloatTensor<Self>>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> DeformConv2dBackward<Self> {
|
||||
module_op!(
|
||||
inp(x, offset, weight, output_grad),
|
||||
opt(mask, bias),
|
||||
E,
|
||||
|x, offset, weight, output_grad, mask, bias| {
|
||||
let (x, offset, weight, mask, bias) = deform_conv2d_backward::<E>(
|
||||
x,
|
||||
offset,
|
||||
weight,
|
||||
mask,
|
||||
bias,
|
||||
output_grad,
|
||||
options,
|
||||
);
|
||||
DeformConv2dBackward::new(
|
||||
x.into(),
|
||||
offset.into(),
|
||||
weight.into(),
|
||||
mask.map(|m| m.into()),
|
||||
bias.map(|b| b.into()),
|
||||
)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
|
||||
conv_transpose2d::<E>(x, weight, bias, options).into()
|
||||
})
|
||||
}
|
||||
|
||||
fn avg_pool2d(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
ceil_mode: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
module_op!(inp(x), opt(), E, |x| {
|
||||
#[cfg(feature = "simd")]
|
||||
let x = match if ceil_mode {
|
||||
// SIMD path doesn't support ceil_mode yet, skip it
|
||||
Err(x)
|
||||
} else {
|
||||
try_avg_pool2d_simd(x, kernel_size, stride, padding, count_include_pad)
|
||||
} {
|
||||
Ok(out) => return out.into(),
|
||||
Err(x) => x,
|
||||
};
|
||||
avg_pool2d::<E>(
|
||||
x,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad,
|
||||
ceil_mode,
|
||||
)
|
||||
.into()
|
||||
})
|
||||
}
|
||||
|
||||
fn avg_pool2d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
ceil_mode: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
module_op!(inp(x, grad), opt(), E, |x, grad| avg_pool2d_backward::<E>(
|
||||
x,
|
||||
grad,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad,
|
||||
ceil_mode
|
||||
)
|
||||
.into())
|
||||
}
|
||||
|
||||
fn max_pool2d(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
module_op!(inp(x), opt(), E, |x| {
|
||||
#[cfg(feature = "simd")]
|
||||
let x = match if ceil_mode {
|
||||
// SIMD path doesn't support ceil_mode yet, skip it
|
||||
Err(x)
|
||||
} else {
|
||||
try_max_pool2d_simd(x, kernel_size, stride, padding, dilation)
|
||||
} {
|
||||
Ok(out) => return out.into(),
|
||||
Err(x) => x,
|
||||
};
|
||||
max_pool2d::<E>(x, kernel_size, stride, padding, dilation, ceil_mode).into()
|
||||
})
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indices(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {
|
||||
module_op!(inp(x), opt(), E, |x| {
|
||||
let (output, indices) = max_pool2d_with_indices::<E, I>(
|
||||
x,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
);
|
||||
MaxPool2dWithIndices::new(output.into(), indices.into())
|
||||
})
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indices_backward(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
output_grad: FloatTensor<Self>,
|
||||
indices: NdArrayTensor,
|
||||
) -> MaxPool2dBackward<NdArray<E, I, Q>> {
|
||||
execute_with_int_dtype!(indices, IntElem, |idx_s: SharedArray<IntElem>| {
|
||||
// Convert indices from runtime dtype to the expected I type
|
||||
// (pool indices are bounded by tensor dimensions, so conversion is safe)
|
||||
let indices: SharedArray<I> = idx_s.mapv(|x| x.elem()).into_shared();
|
||||
module_op!(inp(x, output_grad), opt(), E, |x, output_grad| {
|
||||
let output = max_pool2d_backward::<E, I>(
|
||||
x,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
output_grad,
|
||||
indices,
|
||||
);
|
||||
MaxPool2dBackward::new(output.into())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
|
||||
module_op!(inp(x), opt(), E, |x| adaptive_avg_pool2d::<E>(
|
||||
x,
|
||||
output_size
|
||||
)
|
||||
.into())
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
module_op!(inp(x, grad), opt(), E, |x, grad| {
|
||||
adaptive_avg_pool2d_backward::<E>(x, grad).into()
|
||||
})
|
||||
}
|
||||
|
||||
fn interpolate(
|
||||
x: FloatTensor<Self>,
|
||||
output_size: [usize; 2],
|
||||
options: InterpolateOptions,
|
||||
) -> FloatTensor<Self> {
|
||||
match options.mode {
|
||||
InterpolateMode::Nearest => {
|
||||
module_op!(inp(x), opt(), E, |x| nearest_interpolate::<E>(
|
||||
x,
|
||||
output_size
|
||||
)
|
||||
.into())
|
||||
}
|
||||
InterpolateMode::Bilinear => {
|
||||
let align_corners = options.align_corners;
|
||||
module_op!(inp(x), opt(), E, |x| bilinear_interpolate::<E>(
|
||||
x,
|
||||
output_size,
|
||||
align_corners
|
||||
)
|
||||
.into())
|
||||
}
|
||||
InterpolateMode::Bicubic => {
|
||||
let align_corners = options.align_corners;
|
||||
module_op!(inp(x), opt(), E, |x| bicubic_interpolate::<E>(
|
||||
x,
|
||||
output_size,
|
||||
align_corners
|
||||
)
|
||||
.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn interpolate_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
output_size: [usize; 2],
|
||||
options: InterpolateOptions,
|
||||
) -> FloatTensor<Self> {
|
||||
match options.mode {
|
||||
InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| {
|
||||
nearest_interpolate_backward::<E>(x, grad, output_size).into()
|
||||
}),
|
||||
InterpolateMode::Bilinear => {
|
||||
panic!("bilinear interpolation backward is not supported for ndarray backend")
|
||||
}
|
||||
InterpolateMode::Bicubic => {
|
||||
panic!("bicubic interpolation backward is not supported for ndarray backend")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn conv3d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvOptions<3>,
|
||||
) -> FloatTensor<Self> {
|
||||
module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::<E>(
|
||||
x, weight, bias, options
|
||||
)
|
||||
.into())
|
||||
}
|
||||
|
||||
fn conv_transpose3d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvTransposeOptions<3>,
|
||||
) -> FloatTensor<Self> {
|
||||
module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
|
||||
conv_transpose3d::<E>(x, weight, bias, options).into()
|
||||
})
|
||||
}
|
||||
|
||||
fn attention(
|
||||
query: FloatTensor<Self>,
|
||||
key: FloatTensor<Self>,
|
||||
value: FloatTensor<Self>,
|
||||
mask: Option<burn_backend::tensor::BoolTensor<Self>>,
|
||||
attn_bias: Option<FloatTensor<Self>>,
|
||||
options: AttentionModuleOptions,
|
||||
) -> FloatTensor<Self> {
|
||||
attention_fallback::<Self>(query, key, value, mask, attn_bias, options)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
use crate::{NdArrayElement, SharedArray};
|
||||
use ndarray::{Array4, Array5};
|
||||
|
||||
use super::NdArrayOps;
|
||||
|
||||
pub(crate) fn apply_padding_4d<E: NdArrayElement>(
|
||||
x: SharedArray<E>,
|
||||
padding: [usize; 2],
|
||||
elem: E,
|
||||
) -> SharedArray<E> {
|
||||
let [batch_size, input_channels, height, width] = x.shape().try_into().unwrap();
|
||||
let [padding_height, padding_width] = padding;
|
||||
let padded_height = height + 2 * padding_height;
|
||||
let padded_width = width + 2 * padding_width;
|
||||
|
||||
let x_new = Array4::from_elem(
|
||||
(batch_size, input_channels, padded_height, padded_width),
|
||||
elem,
|
||||
);
|
||||
let mut x_new = x_new.into_shared().into_dyn();
|
||||
|
||||
x_new = NdArrayOps::slice_assign(
|
||||
x_new,
|
||||
&[
|
||||
burn_backend::Slice::from(0..batch_size),
|
||||
burn_backend::Slice::from(0..input_channels),
|
||||
burn_backend::Slice::from(padding_height..height + padding_height),
|
||||
burn_backend::Slice::from(padding_width..width + padding_width),
|
||||
],
|
||||
x,
|
||||
);
|
||||
|
||||
x_new
|
||||
}
|
||||
|
||||
pub(crate) fn apply_padding_5d<E: NdArrayElement>(
|
||||
x: SharedArray<E>,
|
||||
padding: [usize; 3],
|
||||
elem: E,
|
||||
) -> SharedArray<E> {
|
||||
let [batch_size, input_channels, depth, height, width] = x.shape().try_into().unwrap();
|
||||
let [padding_depth, padding_height, padding_width] = padding;
|
||||
let padded_depth = depth + 2 * padding_depth;
|
||||
let padded_height = height + 2 * padding_height;
|
||||
let padded_width = width + 2 * padding_width;
|
||||
|
||||
let x_new = Array5::from_elem(
|
||||
(
|
||||
batch_size,
|
||||
input_channels,
|
||||
padded_depth,
|
||||
padded_height,
|
||||
padded_width,
|
||||
),
|
||||
elem,
|
||||
);
|
||||
let mut x_new = x_new.into_shared().into_dyn();
|
||||
|
||||
x_new = NdArrayOps::slice_assign(
|
||||
x_new,
|
||||
&[
|
||||
burn_backend::Slice::from(0..batch_size),
|
||||
burn_backend::Slice::from(0..input_channels),
|
||||
burn_backend::Slice::from(padding_depth..depth + padding_depth),
|
||||
burn_backend::Slice::from(padding_height..height + padding_height),
|
||||
burn_backend::Slice::from(padding_width..width + padding_width),
|
||||
],
|
||||
x,
|
||||
);
|
||||
|
||||
x_new
|
||||
}
|
||||
@@ -0,0 +1,346 @@
|
||||
use alloc::{vec, vec::Vec};
|
||||
|
||||
use burn_backend::{
|
||||
DType, ExecutionError, Shape, TensorData, TensorMetadata,
|
||||
ops::QTensorOps,
|
||||
quantization::{
|
||||
QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue,
|
||||
QuantizationParametersPrimitive, QuantizedBytes,
|
||||
},
|
||||
tensor::{FloatTensor, IntTensor, QuantizedTensor},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, SharedArray,
|
||||
element::{IntNdArrayElement, QuantElement},
|
||||
execute_with_dtype, execute_with_int_dtype, execute_with_numeric_dtype, slice,
|
||||
};
|
||||
|
||||
use super::quantization::{QuantizationStrategy, SymmetricQuantization};
|
||||
use super::{NdArrayMathOps, NdArrayOps};
|
||||
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
|
||||
for NdArray<E, I, Q>
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<E>>,
|
||||
NdArrayTensor: From<SharedArray<I>>,
|
||||
{
|
||||
fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
|
||||
match data.dtype {
|
||||
DType::QFloat(scheme) => {
|
||||
let shape = data.shape.clone();
|
||||
let num_elements = data.num_elements();
|
||||
let q_bytes = QuantizedBytes {
|
||||
bytes: data.into_bytes(),
|
||||
scheme,
|
||||
num_elements,
|
||||
};
|
||||
|
||||
match scheme {
|
||||
QuantScheme {
|
||||
level: QuantLevel::Tensor | QuantLevel::Block(_),
|
||||
mode: QuantMode::Symmetric,
|
||||
value: QuantValue::Q8F | QuantValue::Q8S,
|
||||
..
|
||||
} => {
|
||||
// We can load QuantStore::U32 w/ QuantizedBytes impl
|
||||
let (values, qparams) = q_bytes.into_vec_i8();
|
||||
let data = TensorData::new(values, shape);
|
||||
// Overwrite storage
|
||||
let scheme = scheme.with_store(QuantStore::Native);
|
||||
|
||||
let qparams = qparams
|
||||
.scales
|
||||
.into_iter()
|
||||
.map(|scales| QParams { scales })
|
||||
.collect();
|
||||
|
||||
NdArrayQTensor {
|
||||
qtensor: NdArrayTensor::from_data(data),
|
||||
scheme,
|
||||
qparams,
|
||||
}
|
||||
}
|
||||
QuantScheme {
|
||||
value:
|
||||
QuantValue::Q4F
|
||||
| QuantValue::Q4S
|
||||
| QuantValue::Q2F
|
||||
| QuantValue::Q2S
|
||||
| QuantValue::E2M1
|
||||
| QuantValue::E4M3
|
||||
| QuantValue::E5M2,
|
||||
..
|
||||
} => unimplemented!("from_data not supported for scheme {scheme:?}"),
|
||||
}
|
||||
}
|
||||
_ => panic!(
|
||||
"Invalid dtype (expected DType::QFloat, got {:?})",
|
||||
data.dtype
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn quantize(
|
||||
tensor: FloatTensor<Self>,
|
||||
scheme: &QuantScheme,
|
||||
qparams: QuantizationParametersPrimitive<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
let shape = tensor.shape();
|
||||
let data_f = tensor.into_data();
|
||||
let scales = qparams.scales.into_data().convert::<f32>();
|
||||
|
||||
// Implement with ndarray instead of QuantizationStrategy?
|
||||
let (data, qparams) = match scheme {
|
||||
QuantScheme {
|
||||
level: QuantLevel::Tensor,
|
||||
mode: QuantMode::Symmetric,
|
||||
#[cfg(not(feature = "export_tests"))]
|
||||
value: QuantValue::Q8F | QuantValue::Q8S,
|
||||
// For tests, "native" sub-byte quant serves as a reference for value equality.
|
||||
// Values are stored as i8 regardless.
|
||||
#[cfg(feature = "export_tests")]
|
||||
value:
|
||||
QuantValue::Q8F
|
||||
| QuantValue::Q8S
|
||||
| QuantValue::Q4F
|
||||
| QuantValue::Q4S
|
||||
| QuantValue::Q2F
|
||||
| QuantValue::Q2S,
|
||||
store: QuantStore::Native,
|
||||
..
|
||||
} => {
|
||||
let scales = scales.iter().next().unwrap();
|
||||
let strategy = QuantizationStrategy::PerTensorSymmetric(
|
||||
SymmetricQuantization::init(scales, scheme.value),
|
||||
);
|
||||
let values = strategy.quantize(data_f.as_slice().unwrap());
|
||||
(
|
||||
TensorData::quantized(values, shape.clone(), *scheme, &[scales]),
|
||||
vec![QParams { scales }],
|
||||
)
|
||||
}
|
||||
QuantScheme {
|
||||
level: QuantLevel::Block(block_size),
|
||||
mode: QuantMode::Symmetric,
|
||||
#[cfg(not(feature = "export_tests"))]
|
||||
value: QuantValue::Q8F | QuantValue::Q8S,
|
||||
#[cfg(feature = "export_tests")]
|
||||
value:
|
||||
QuantValue::Q8F
|
||||
| QuantValue::Q8S
|
||||
| QuantValue::Q4F
|
||||
| QuantValue::Q4S
|
||||
| QuantValue::Q2F
|
||||
| QuantValue::Q2S,
|
||||
store: QuantStore::Native,
|
||||
..
|
||||
} => {
|
||||
let scales = scales.as_slice().unwrap();
|
||||
let (strategy, qparams) = scales
|
||||
.iter()
|
||||
.map(|&s| {
|
||||
(
|
||||
SymmetricQuantization::init(s, scheme.value),
|
||||
QParams { scales: s },
|
||||
)
|
||||
})
|
||||
.unzip();
|
||||
let strategy = QuantizationStrategy::PerBlockSymmetric(strategy, *block_size);
|
||||
let values = strategy.quantize(data_f.as_slice().unwrap());
|
||||
(
|
||||
TensorData::quantized(values, shape.clone(), *scheme, scales),
|
||||
qparams,
|
||||
)
|
||||
}
|
||||
scheme => unimplemented!("Quantization not supported for scheme {scheme:?}"),
|
||||
};
|
||||
|
||||
let num_elements = data.num_elements();
|
||||
let q_bytes = QuantizedBytes {
|
||||
bytes: data.into_bytes(),
|
||||
scheme: *scheme,
|
||||
num_elements,
|
||||
};
|
||||
let (values, _) = q_bytes.into_vec_i8();
|
||||
let data = TensorData::new(values, shape).convert::<Q>();
|
||||
|
||||
NdArrayQTensor {
|
||||
qtensor: NdArrayTensor::from_data(data),
|
||||
scheme: *scheme,
|
||||
qparams,
|
||||
}
|
||||
}
|
||||
|
||||
fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
|
||||
let strategy = tensor.strategy();
|
||||
let scheme = tensor.scheme;
|
||||
let shape = tensor.shape();
|
||||
let data = match tensor.qtensor {
|
||||
NdArrayTensor::I8(storage) => {
|
||||
let data = storage.into_shared().into_iter().collect();
|
||||
dequantize(data, shape, scheme, &strategy)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
NdArrayTensor::from_data(data)
|
||||
}
|
||||
|
||||
fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
|
||||
NdArrayDevice::Cpu
|
||||
}
|
||||
|
||||
fn q_to_device(
|
||||
tensor: QuantizedTensor<Self>,
|
||||
_device: &NdArrayDevice,
|
||||
) -> QuantizedTensor<Self> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
|
||||
NdArrayQTensor {
|
||||
qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
|
||||
NdArrayOps::reshape(array, shape)
|
||||
}),
|
||||
scheme: tensor.scheme,
|
||||
qparams: tensor.qparams,
|
||||
}
|
||||
}
|
||||
|
||||
async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
|
||||
let shape = tensor.qtensor.shape();
|
||||
let scales = tensor.qparams.iter().map(|q| q.scales).collect::<Vec<_>>();
|
||||
Ok(execute_with_numeric_dtype!(
|
||||
tensor.qtensor,
|
||||
E,
|
||||
|array: SharedArray<E>| {
|
||||
let values = array.into_iter().collect();
|
||||
TensorData::quantized(values, shape, tensor.scheme, &scales)
|
||||
}
|
||||
))
|
||||
}
|
||||
|
||||
fn q_swap_dims(
|
||||
tensor: QuantizedTensor<Self>,
|
||||
dim1: usize,
|
||||
dim2: usize,
|
||||
) -> QuantizedTensor<Self> {
|
||||
NdArrayQTensor {
|
||||
qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
|
||||
NdArrayOps::swap_dims(array, dim1, dim2)
|
||||
}),
|
||||
scheme: tensor.scheme,
|
||||
qparams: tensor.qparams,
|
||||
}
|
||||
}
|
||||
|
||||
fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
|
||||
NdArrayQTensor {
|
||||
qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
|
||||
NdArrayOps::permute(array, axes)
|
||||
}),
|
||||
scheme: tensor.scheme,
|
||||
qparams: tensor.qparams,
|
||||
}
|
||||
}
|
||||
|
||||
fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
|
||||
NdArrayQTensor {
|
||||
qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
|
||||
NdArrayOps::flip(array, axes)
|
||||
}),
|
||||
scheme: tensor.scheme,
|
||||
qparams: tensor.qparams,
|
||||
}
|
||||
}
|
||||
|
||||
fn q_gather(
|
||||
dim: usize,
|
||||
tensor: QuantizedTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray<
|
||||
IntElem,
|
||||
>|
|
||||
-> NdArrayTensor {
|
||||
execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
|
||||
NdArrayOps::gather(dim, array, idx_array)
|
||||
})
|
||||
});
|
||||
NdArrayQTensor {
|
||||
qtensor,
|
||||
scheme: tensor.scheme,
|
||||
qparams: tensor.qparams,
|
||||
}
|
||||
}
|
||||
|
||||
fn q_select(
|
||||
tensor: QuantizedTensor<Self>,
|
||||
dim: usize,
|
||||
indices: IntTensor<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray<
|
||||
IntElem,
|
||||
>|
|
||||
-> NdArrayTensor {
|
||||
execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
|
||||
NdArrayMathOps::select(array, dim, idx_array)
|
||||
})
|
||||
});
|
||||
NdArrayQTensor {
|
||||
qtensor,
|
||||
scheme: tensor.scheme,
|
||||
qparams: tensor.qparams,
|
||||
}
|
||||
}
|
||||
|
||||
fn q_slice(
|
||||
tensor: QuantizedTensor<Self>,
|
||||
slices: &[burn_backend::Slice],
|
||||
) -> QuantizedTensor<Self> {
|
||||
NdArrayQTensor {
|
||||
qtensor: slice!(tensor.qtensor, slices),
|
||||
scheme: tensor.scheme,
|
||||
qparams: tensor.qparams,
|
||||
}
|
||||
}
|
||||
|
||||
fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
|
||||
NdArrayMathOps::argmax::<I>(array, dim)
|
||||
})
|
||||
}
|
||||
|
||||
fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
|
||||
NdArrayMathOps::argmin::<I>(array, dim)
|
||||
})
|
||||
}
|
||||
|
||||
fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
|
||||
NdArrayQTensor {
|
||||
qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
|
||||
NdArrayOps::expand(array, shape)
|
||||
}),
|
||||
scheme: tensor.scheme,
|
||||
qparams: tensor.qparams,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn dequantize<Q: QuantElement>(
|
||||
data: Vec<Q>,
|
||||
shape: Shape,
|
||||
scheme: QuantScheme,
|
||||
strategy: &QuantizationStrategy,
|
||||
) -> TensorData {
|
||||
let qparams = match strategy {
|
||||
QuantizationStrategy::PerTensorSymmetric(quant) => vec![quant.scale],
|
||||
QuantizationStrategy::PerBlockSymmetric(quant, _block_size) => {
|
||||
quant.iter().map(|q| q.scale).collect()
|
||||
}
|
||||
};
|
||||
let q_bytes = QuantizedBytes::new(data, scheme, &qparams);
|
||||
let (values, _qparams) = q_bytes.into_vec_i8();
|
||||
TensorData::new(strategy.dequantize(&values), shape)
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
use alloc::vec::Vec;
|
||||
use num_traits::{Float, PrimInt};
|
||||
|
||||
use burn_backend::quantization::{BlockSize, QuantValue};
|
||||
|
||||
// NOTE: this mainly serves as a simple reference implementation.
|
||||
// The de/quantization ops should be refactored to use ndarray.
|
||||
|
||||
/// Quantization strategy.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum QuantizationStrategy {
|
||||
/// Per-tensor symmetric quantization.
|
||||
PerTensorSymmetric(SymmetricQuantization<f32>),
|
||||
/// Per-block symmetric quantization.
|
||||
PerBlockSymmetric(Vec<SymmetricQuantization<f32>>, BlockSize),
|
||||
}
|
||||
|
||||
impl QuantizationStrategy {
|
||||
/// Quantize the values to a lower precision data type.
|
||||
pub fn quantize(&self, values: &[f32]) -> Vec<i8> {
|
||||
match self {
|
||||
QuantizationStrategy::PerTensorSymmetric(strategy) => strategy.quantize(values),
|
||||
QuantizationStrategy::PerBlockSymmetric(strategy, block_size) => {
|
||||
let block_elems = block_size.num_elements();
|
||||
let num_blocks = strategy.len();
|
||||
let numel = values.len();
|
||||
assert_eq!(
|
||||
numel / block_elems,
|
||||
num_blocks,
|
||||
"Invalid per-block quantization with num blocks {num_blocks} and {numel} values"
|
||||
);
|
||||
values
|
||||
.chunks(block_elems)
|
||||
.enumerate()
|
||||
.flat_map(|(block_id, block)| strategy[block_id].quantize(block))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Dequantize the values to a higher precision data type.
|
||||
pub fn dequantize(&self, values: &[i8]) -> Vec<f32> {
|
||||
match self {
|
||||
QuantizationStrategy::PerTensorSymmetric(strategy) => strategy.dequantize(values),
|
||||
QuantizationStrategy::PerBlockSymmetric(strategy, block_size) => {
|
||||
let block_elems = block_size.num_elements();
|
||||
let num_blocks = strategy.len();
|
||||
let numel = values.len();
|
||||
assert_eq!(
|
||||
numel / block_elems,
|
||||
num_blocks,
|
||||
"Invalid per-block quantization with block size {block_elems}, num blocks {num_blocks} and {numel} values"
|
||||
);
|
||||
values
|
||||
.chunks(block_elems)
|
||||
.enumerate()
|
||||
.flat_map(|(block_id, block)| strategy[block_id].dequantize(block))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Quantization scheme to convert elements of a higher precision data type `E` to a lower precision
|
||||
/// data type `Q` and vice-versa.
|
||||
pub trait Quantization<E: Float + Send + Sync> {
|
||||
/// Returns the quantization range `[a, b]`.
|
||||
fn range(&self) -> (E, E);
|
||||
/// Convert the values to a lower precision data type.
|
||||
fn quantize<Q: PrimInt>(&self, values: &[E]) -> Vec<Q>;
|
||||
/// Convert a single value to a lower precision data type.
|
||||
fn quantize_one<Q: PrimInt>(&self, value: E) -> Q;
|
||||
/// Convert the values back to a higher precision data type.
|
||||
fn dequantize<Q: PrimInt>(&self, values: &[Q]) -> Vec<E>;
|
||||
/// Convert a single value back to a higher precision data type.
|
||||
fn dequantize_one<Q: PrimInt>(&self, value: Q) -> E;
|
||||
}
|
||||
|
||||
fn valid_scale<E: Float>(mut scale: E) -> E {
|
||||
// If scale is 0 (most likely due to a tensor full of zeros), we arbitrarily adjust the
|
||||
// scale to 0.1 to avoid division by zero.
|
||||
if scale.eq(&E::zero()) {
|
||||
scale = E::from(0.1).unwrap();
|
||||
}
|
||||
scale
|
||||
}
|
||||
|
||||
/// Symmetric quantization scheme.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct SymmetricQuantization<E: Float + Send + Sync> {
|
||||
/// The scaling factor.
|
||||
pub scale: E,
|
||||
// The quantization value data type.
|
||||
value: QuantValue,
|
||||
}
|
||||
|
||||
impl<E: Float + Send + Sync> SymmetricQuantization<E> {
|
||||
/// Initialize a symmetric quantization scheme with the given parameters.
|
||||
pub fn init(scale: E, value: QuantValue) -> Self {
|
||||
Self {
|
||||
scale: valid_scale(scale),
|
||||
value,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
/// Create a new quantization scheme for an input range `[alpha, beta]`.
|
||||
fn new(alpha: E, beta: E, value: QuantValue) -> Self {
|
||||
let (a, b) = value.range();
|
||||
let a = E::from(a).unwrap();
|
||||
let b = E::from(b).unwrap();
|
||||
|
||||
// Compute scale to convert a floating point value in range `[-alpha, alpha]` to the quantized range
|
||||
let alpha = alpha.abs().max(beta.abs());
|
||||
let scale = valid_scale((alpha + alpha) / (b - a));
|
||||
Self { scale, value }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Float + Send + Sync> Quantization<E> for SymmetricQuantization<E> {
|
||||
fn quantize<Q: PrimInt>(&self, values: &[E]) -> Vec<Q> {
|
||||
values.iter().map(|x| self.quantize_one(*x)).collect()
|
||||
}
|
||||
|
||||
fn dequantize<Q: PrimInt>(&self, values: &[Q]) -> Vec<E> {
|
||||
values.iter().map(|x_q| self.dequantize_one(*x_q)).collect()
|
||||
}
|
||||
|
||||
fn quantize_one<Q: PrimInt>(&self, value: E) -> Q {
|
||||
let (a, b) = self.range();
|
||||
|
||||
// x_q = clamp(round(x / scale), a, b)
|
||||
Q::from(value.div(self.scale).round().clamp(a, b)).unwrap()
|
||||
}
|
||||
|
||||
fn dequantize_one<Q: PrimInt>(&self, value: Q) -> E {
|
||||
// x = scale * x_q
|
||||
self.scale * E::from(value).unwrap()
|
||||
}
|
||||
|
||||
fn range(&self) -> (E, E) {
|
||||
let (a, b) = self.value.range();
|
||||
let a = E::from(a).unwrap();
|
||||
let b = E::from(b).unwrap();
|
||||
(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Float + Send + Sync> PartialEq for SymmetricQuantization<E> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.scale == other.scale
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Float + Send + Sync> Eq for SymmetricQuantization<E> {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn_backend::TensorData;
|
||||
|
||||
use super::*;
|
||||
use alloc::vec;
|
||||
|
||||
#[test]
|
||||
fn test_int8_symmetric_quantization() {
|
||||
let x: [f32; 4] = [-1.8, -1.0, 0.0, 0.5];
|
||||
let expected_q = vec![-127, -71, 0, 35];
|
||||
let expected_d = vec![-1.8, -1.0062993, 0.0, 0.496063];
|
||||
|
||||
let symmetric = SymmetricQuantization::<f32>::new(-1.8, 0.5, QuantValue::Q8S);
|
||||
|
||||
let q: Vec<i8> = symmetric.quantize(&x);
|
||||
assert_eq!(q, expected_q);
|
||||
|
||||
let d = symmetric.dequantize(&expected_q);
|
||||
|
||||
assert_eq!(d, expected_d);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_int8_symmetric_quantization_per_block() {
|
||||
let x: [f32; 8] = [-1.8, -1.0, 0.0, 0.5, -1.8, -1.0, 0.0, 0.5];
|
||||
let expected_q = vec![-127, -71, 0, 35, -127, -71, 0, 35];
|
||||
let expected_d = vec![
|
||||
-1.8, -1.0062993, 0.0, 0.496063, -1.8, -1.0062993, 0.0, 0.496063,
|
||||
];
|
||||
|
||||
let symmetric = SymmetricQuantization::<f32>::new(-1.8, 0.5, QuantValue::Q8S);
|
||||
let strategy = QuantizationStrategy::PerBlockSymmetric(
|
||||
vec![symmetric, symmetric],
|
||||
BlockSize::new([4]),
|
||||
);
|
||||
|
||||
let q: Vec<i8> = strategy.quantize(&x);
|
||||
assert_eq!(q, expected_q);
|
||||
|
||||
let d = symmetric.dequantize(&expected_q);
|
||||
|
||||
assert_eq!(d, expected_d);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dequantize() {
|
||||
let strategy = QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization {
|
||||
scale: 0.1,
|
||||
value: QuantValue::Q8S,
|
||||
});
|
||||
|
||||
let output = strategy.dequantize(&[-127i8, -77, -26, 25, 76, 127]);
|
||||
|
||||
let output = TensorData::new(output, [2, 3]);
|
||||
|
||||
output.assert_approx_eq::<f32>(
|
||||
&TensorData::from([[-12.7, -7.7, -2.6], [2.5, 7.6, 12.7]]),
|
||||
Default::default(),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,443 @@
|
||||
use core::{marker::PhantomData, mem::transmute};
|
||||
|
||||
use crate::{SharedArray, iter_range_par, run_par, sharing::UnsafeSharedRef};
|
||||
|
||||
use burn_backend::DType;
|
||||
use burn_backend::{Element, ElementConversion};
|
||||
use bytemuck::Zeroable;
|
||||
use macerator::{Simd, VAdd, VDiv};
|
||||
use ndarray::{Array4, s};
|
||||
use nhwc::avg_pool_nhwc;
|
||||
|
||||
use super::should_use_simd;
|
||||
|
||||
#[macerator::with_simd]
|
||||
fn is_accelerated<S: Simd, T: VAdd + VDiv>(_x: PhantomData<T>) -> bool {
|
||||
<T as VAdd>::is_accelerated::<S>() && <T as VDiv>::is_accelerated::<S>()
|
||||
}
|
||||
|
||||
pub(crate) fn try_avg_pool2d_simd<E: Element>(
|
||||
x: SharedArray<E>,
|
||||
ksize: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
with_pad: bool,
|
||||
) -> Result<SharedArray<E>, SharedArray<E>> {
|
||||
// Strides must be unit, dilation isn't supported, rows must be contiguous
|
||||
if x.strides()[1] != 1 || !should_use_simd(x.shape()[1]) {
|
||||
return Err(x);
|
||||
}
|
||||
|
||||
match E::dtype() {
|
||||
DType::F64 if is_accelerated::<f64>(PhantomData) => Ok(cast(avg_pool_nhwc::<f64>(
|
||||
cast(x),
|
||||
ksize,
|
||||
stride,
|
||||
padding,
|
||||
with_pad,
|
||||
))),
|
||||
DType::F32 if is_accelerated::<f32>(PhantomData) => Ok(cast(avg_pool_nhwc::<f32>(
|
||||
cast(x),
|
||||
ksize,
|
||||
stride,
|
||||
padding,
|
||||
with_pad,
|
||||
))),
|
||||
_ => Err(x),
|
||||
}
|
||||
}
|
||||
|
||||
fn cast<T, E>(tensor: SharedArray<T>) -> SharedArray<E> {
|
||||
unsafe { transmute::<SharedArray<T>, SharedArray<E>>(tensor) }
|
||||
}
|
||||
|
||||
mod nhwc {
|
||||
use itertools::Itertools;
|
||||
use macerator::{Simd, Vector, vload_unaligned, vstore_unaligned};
|
||||
use ndarray::{ArrayView3, ArrayViewMut3};
|
||||
use seq_macro::seq;
|
||||
|
||||
use crate::ops::simd::lanes;
|
||||
|
||||
use super::*;
|
||||
|
||||
// Until you can use associated constants as array size, we need to hardcode this.
|
||||
// The most common config (x86-v3) has 16 registers, so use half of them for accumulators.
|
||||
const BLOCK_REGISTERS: usize = 8;
|
||||
|
||||
pub(crate) fn avg_pool_nhwc<E: Element + VAdd + VDiv>(
|
||||
x: SharedArray<E>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
with_pad: bool,
|
||||
) -> SharedArray<E> {
|
||||
let [kernel_height, kernel_width] = kernel_size;
|
||||
let [pad_h, pad_w] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap();
|
||||
let lanes = lanes::<E>();
|
||||
|
||||
let ch_block = lanes * BLOCK_REGISTERS;
|
||||
|
||||
let out_height = ((x_height + 2 * pad_h - (kernel_height - 1) - 1) / stride_height) + 1;
|
||||
let out_width = ((x_width + 2 * pad_w - (kernel_width - 1) - 1) / stride_width) + 1;
|
||||
|
||||
let mut output = unsafe {
|
||||
Array4::<E>::uninit((batch_size, out_height, out_width, channels)).assume_init()
|
||||
};
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
let x = x.view();
|
||||
let x = x.permuted_axes(vec![0, 2, 3, 1]);
|
||||
|
||||
// Floor division ensures `blocks * lanes * blocking factor` is always `<= out_channels`.
|
||||
// An exclusive loop will always have `lanes * blocking factor` elements in bounds.
|
||||
let blocks = channels / ch_block;
|
||||
let blocks_end = blocks * ch_block;
|
||||
// Floor division means simd_end is always divisible by `lanes` and `<= out_channels`. An
|
||||
// exclusive loop will always have `lanes` elements in bounds.
|
||||
let simd_end = channels / lanes * lanes;
|
||||
let num_simd_unblocked = (simd_end - blocks_end) / lanes;
|
||||
let remainder = channels - simd_end;
|
||||
|
||||
run_par!(|| {
|
||||
// SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe.
|
||||
iter_range_par!(0, batch_size * blocks).for_each(|k| unsafe {
|
||||
let block = k % blocks;
|
||||
let b = k / blocks;
|
||||
|
||||
let output = unsafe_shared_out.get();
|
||||
|
||||
let x = x.slice(s![b, .., .., ..]);
|
||||
let out = output.slice_mut(s![b, .., .., ..]);
|
||||
|
||||
loop_blocked(x, out, kernel_size, stride, padding, with_pad, block);
|
||||
});
|
||||
// SAFETY: See `loop_unblocked`
|
||||
iter_range_par!(0, batch_size * num_simd_unblocked).for_each(|k| unsafe {
|
||||
let ch = (k % num_simd_unblocked) * lanes + blocks_end;
|
||||
let b = k / num_simd_unblocked;
|
||||
|
||||
let output = unsafe_shared_out.get();
|
||||
|
||||
let x = x.slice(s![b, .., .., ..]);
|
||||
let out = output.slice_mut(s![b, .., .., ..]);
|
||||
|
||||
loop_unblocked(x, out, kernel_size, stride, padding, with_pad, ch);
|
||||
});
|
||||
// SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe.
|
||||
iter_range_par!(0, batch_size * remainder).for_each(|k| unsafe {
|
||||
let ch = (k % remainder) + simd_end;
|
||||
let b = k / remainder;
|
||||
|
||||
let output = unsafe_shared_out.get();
|
||||
|
||||
let x = x.slice(s![b, .., .., ..]);
|
||||
let out = output.slice_mut(s![b, .., .., ..]);
|
||||
|
||||
loop_scalar(x, out, kernel_size, stride, padding, with_pad, ch);
|
||||
});
|
||||
});
|
||||
|
||||
output = output.permuted_axes([0, 3, 1, 2]);
|
||||
|
||||
output.into_dyn().into_shared()
|
||||
}
|
||||
|
||||
/// Execute the blocked (unrolled) portion of the pool.
|
||||
#[allow(
|
||||
clippy::too_many_arguments,
|
||||
clippy::erasing_op,
|
||||
clippy::identity_op,
|
||||
unused_mut
|
||||
)]
|
||||
#[macerator::with_simd]
|
||||
fn loop_blocked<'a, S: Simd, E: Element + VAdd + VDiv>(
|
||||
x: ArrayView3<'a, E>,
|
||||
mut out: ArrayViewMut3<'a, E>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
with_pad: bool,
|
||||
block: usize,
|
||||
) where
|
||||
'a: 'a,
|
||||
{
|
||||
let [kernel_height, kernel_width] = kernel_size;
|
||||
let [pad_h, pad_w] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
|
||||
let (x_height, x_width, _) = x.dim();
|
||||
let (out_height, out_width, _) = out.dim();
|
||||
let lanes = E::lanes::<S>();
|
||||
|
||||
let ch_block = lanes * BLOCK_REGISTERS;
|
||||
|
||||
// If pixels are more than `padding` from the edges, the in pixel cannot be out of bounds
|
||||
for oh in pad_h..out_height.saturating_sub(pad_h) {
|
||||
for ow in pad_w..out_width.saturating_sub(pad_w) {
|
||||
seq!(N in 0..8 {
|
||||
let mut sum~N: Vector<S, E> = Zeroable::zeroed();
|
||||
});
|
||||
let ch = block * ch_block;
|
||||
let ch_end = ch + ch_block;
|
||||
let mut out = out.slice_mut(s![oh, ow, ch..ch_end]);
|
||||
|
||||
for kh in 0..kernel_height {
|
||||
let ih = oh * stride_height + kh - pad_h;
|
||||
|
||||
for kw in 0..kernel_width {
|
||||
let iw = ow * stride_width + kw - pad_w;
|
||||
let x = x.slice(s![ih, iw, ch..ch_end]);
|
||||
|
||||
seq!(N in 0..8 {
|
||||
// SAFETY:
|
||||
// Load a full vector from x[N * lanes]. This is bounds checked by the
|
||||
// slice above.
|
||||
sum~N += unsafe { vload_unaligned(&x[N * lanes]) };
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let count = kernel_height * kernel_width;
|
||||
let count = (count as u64).elem::<E>();
|
||||
let count_v = count.splat();
|
||||
seq!(N in 0..8 {
|
||||
let s~N = sum~N / count_v;
|
||||
// SAFETY:
|
||||
// Store a full vector to out[N * lanes]. This is bounds checked by the
|
||||
// slice above.
|
||||
unsafe { vstore_unaligned(&mut out[N * lanes], s~N) };
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Border pixels need bounds checks
|
||||
if (pad_h, pad_w) != (0, 0) {
|
||||
let v_borders = (0..pad_h)
|
||||
.chain(out_height.saturating_sub(pad_h)..out_height)
|
||||
.cartesian_product(0..out_width);
|
||||
let h_borders = (0..out_height)
|
||||
.cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width));
|
||||
|
||||
for (oh, ow) in v_borders.chain(h_borders) {
|
||||
seq!(N in 0..8 {
|
||||
let mut sum~N: Vector<S, E> = Zeroable::zeroed();
|
||||
});
|
||||
let mut count: usize = 0;
|
||||
let ch = block * ch_block;
|
||||
let ch_end = ch + ch_block;
|
||||
let mut out = out.slice_mut(s![oh, ow, ch..ch_end]);
|
||||
|
||||
for kh in 0..kernel_height {
|
||||
let ih = oh * stride_height + kh;
|
||||
if ih < pad_h || ih >= x_height + pad_h {
|
||||
continue;
|
||||
}
|
||||
let ih = ih - pad_h;
|
||||
|
||||
for kw in 0..kernel_width {
|
||||
let iw = ow * stride_width + kw;
|
||||
if iw < pad_w || iw >= x_width + pad_w {
|
||||
continue;
|
||||
}
|
||||
let iw = iw - pad_w;
|
||||
count += 1;
|
||||
|
||||
let x = x.slice(s![ih, iw, ch..ch_end]);
|
||||
|
||||
seq!(N in 0..8 {
|
||||
// SAFETY:
|
||||
// Load a full vector from x[N * lanes]. This is bounds checked by the
|
||||
// slice above.
|
||||
sum~N += unsafe { vload_unaligned(&x[N * lanes]) };
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if with_pad {
|
||||
count = kernel_height * kernel_width;
|
||||
}
|
||||
|
||||
let count = (count as u64).elem::<E>();
|
||||
let count_v = count.splat();
|
||||
seq!(N in 0..8 {
|
||||
let s~N = sum~N / count_v;
|
||||
// SAFETY:
|
||||
// Store a full vector to out[N * lanes]. This is bounds checked by the
|
||||
// slice above.
|
||||
unsafe { vstore_unaligned(&mut out[N * lanes], s~N) };
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute the unblocked (not unrolled) portion of the pool.
|
||||
///
|
||||
/// SAFETY: Safe as long as `ch + simd_lanes <= out_channels`.
|
||||
#[allow(clippy::too_many_arguments, unused_mut)]
|
||||
#[macerator::with_simd]
|
||||
unsafe fn loop_unblocked<'a, S: Simd, E: Element + VAdd + VDiv>(
|
||||
x: ArrayView3<'a, E>,
|
||||
mut out: ArrayViewMut3<'a, E>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
with_pad: bool,
|
||||
ch: usize,
|
||||
) where
|
||||
'a: 'a,
|
||||
{
|
||||
let [kernel_height, kernel_width] = kernel_size;
|
||||
let [pad_h, pad_w] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
|
||||
let (x_height, x_width, _) = x.dim();
|
||||
let (out_height, out_width, _) = out.dim();
|
||||
|
||||
// If pixels are not within padding range, bounds checks are always true
|
||||
for oh in pad_h..out_height - pad_h {
|
||||
for ow in pad_w..out_width - pad_w {
|
||||
let mut sum: Vector<S, E> = Zeroable::zeroed();
|
||||
|
||||
for kh in 0..kernel_height {
|
||||
let ih = oh * stride_height + kh - pad_h;
|
||||
|
||||
for kw in 0..kernel_width {
|
||||
let iw = ow * stride_width + kw - pad_w;
|
||||
// Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes`
|
||||
let s0 = unsafe { vload_unaligned(&x[[ih, iw, ch]]) };
|
||||
sum += s0;
|
||||
}
|
||||
}
|
||||
|
||||
let count = kernel_height * kernel_width;
|
||||
let count: E = (count as u64).elem();
|
||||
let count_v = count.splat();
|
||||
let s0 = sum / count_v;
|
||||
// Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`.
|
||||
unsafe { vstore_unaligned(&mut out[[oh, ow, ch]], s0) };
|
||||
}
|
||||
}
|
||||
|
||||
// Border pixels need bounds checks
|
||||
if (pad_h, pad_w) != (0, 0) {
|
||||
let v_borders = (0..pad_h)
|
||||
.chain(out_height.saturating_sub(pad_h)..out_height)
|
||||
.cartesian_product(0..out_width);
|
||||
let h_borders = (0..out_height)
|
||||
.cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width));
|
||||
|
||||
for (oh, ow) in v_borders.chain(h_borders) {
|
||||
let mut sum: Vector<S, E> = Zeroable::zeroed();
|
||||
let mut count: usize = 0;
|
||||
|
||||
for kh in 0..kernel_height {
|
||||
let ih = oh * stride_height + kh;
|
||||
if ih < pad_h || ih >= x_height + pad_h {
|
||||
continue;
|
||||
}
|
||||
let ih = ih - pad_h;
|
||||
|
||||
for kw in 0..kernel_width {
|
||||
let iw = ow * stride_width + kw;
|
||||
if iw < pad_w || iw >= x_width + pad_w {
|
||||
continue;
|
||||
}
|
||||
let iw = iw - pad_w;
|
||||
count += 1;
|
||||
|
||||
// Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes`
|
||||
sum += unsafe { vload_unaligned(&x[[ih, iw, ch]]) };
|
||||
}
|
||||
}
|
||||
|
||||
if with_pad {
|
||||
count = kernel_height * kernel_width;
|
||||
}
|
||||
|
||||
let count = (count as u64).elem::<E>();
|
||||
let count_v = count.splat();
|
||||
let s0 = sum / count_v;
|
||||
// Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`.
|
||||
unsafe { vstore_unaligned(&mut out[[oh, ow, ch]], s0) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute scalar portion of the pooling
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn loop_scalar<E: Element + VAdd + VDiv>(
|
||||
x: ArrayView3<'_, E>,
|
||||
mut out: ArrayViewMut3<'_, E>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
with_pad: bool,
|
||||
ch: usize,
|
||||
) {
|
||||
let [kernel_height, kernel_width] = kernel_size;
|
||||
let [pad_h, pad_w] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
|
||||
let (x_height, x_width, _) = x.dim();
|
||||
let (out_height, out_width, _) = out.dim();
|
||||
|
||||
// If pixels are not within padding range, bounds checks are always true
|
||||
for oh in pad_h..out_height.saturating_sub(pad_h) {
|
||||
for ow in pad_w..out_width.saturating_sub(pad_w) {
|
||||
let mut sum: E = Zeroable::zeroed();
|
||||
|
||||
for kh in 0..kernel_height {
|
||||
let ih = oh * stride_height + kh - pad_h;
|
||||
|
||||
for kw in 0..kernel_width {
|
||||
let iw = ow * stride_width + kw - pad_w;
|
||||
sum = sum + x[[ih, iw, ch]];
|
||||
}
|
||||
}
|
||||
|
||||
let count = (kernel_height * kernel_width) as u64;
|
||||
out[[oh, ow, ch]] = sum / count.elem();
|
||||
}
|
||||
}
|
||||
|
||||
// Border pixels need bounds checks
|
||||
if (pad_h, pad_w) != (0, 0) {
|
||||
let v_borders = (0..pad_h)
|
||||
.chain(out_height.saturating_sub(pad_h)..out_height)
|
||||
.cartesian_product(0..out_width);
|
||||
let h_borders = (0..out_height)
|
||||
.cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width));
|
||||
|
||||
for (oh, ow) in v_borders.chain(h_borders) {
|
||||
let mut sum: E = Zeroable::zeroed();
|
||||
let mut count: usize = 0;
|
||||
|
||||
for kh in 0..kernel_height {
|
||||
let ih = oh * stride_height + kh;
|
||||
if ih < pad_h || ih >= x_height + pad_h {
|
||||
continue;
|
||||
}
|
||||
let ih = ih - pad_h;
|
||||
|
||||
for kw in 0..kernel_width {
|
||||
let iw = ow * stride_width + kw;
|
||||
if iw < pad_w || iw >= x_width + pad_w {
|
||||
continue;
|
||||
}
|
||||
let iw = iw - pad_w;
|
||||
count += 1;
|
||||
sum = sum + x[[ih, iw, ch]];
|
||||
}
|
||||
}
|
||||
|
||||
if with_pad {
|
||||
count = kernel_height * kernel_width;
|
||||
}
|
||||
|
||||
out[[oh, ow, ch]] = sum / (count as u64).elem();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
use core::{marker::PhantomData, mem::MaybeUninit};
|
||||
|
||||
use macerator::{Arch, Scalar, Simd};
|
||||
use ndarray::{ArcArray, ArrayD, IxDyn, ShapeBuilder};
|
||||
|
||||
/// Whether SIMD instructions are worth using
|
||||
#[cfg(all(
|
||||
any(
|
||||
target_arch = "x86",
|
||||
target_arch = "x86_64",
|
||||
target_arch = "aarch64",
|
||||
target_arch = "wasm32",
|
||||
target_arch = "loongarch64"
|
||||
),
|
||||
not(test)
|
||||
))]
|
||||
pub fn should_use_simd(len: usize) -> bool {
|
||||
len >= 32
|
||||
}
|
||||
|
||||
/// Whether SIMD instructions are worth using
|
||||
#[cfg(all(
|
||||
not(any(
|
||||
target_arch = "x86",
|
||||
target_arch = "x86_64",
|
||||
target_arch = "aarch64",
|
||||
target_arch = "wasm32",
|
||||
target_arch = "loongarch64"
|
||||
)),
|
||||
not(test)
|
||||
))]
|
||||
pub fn should_use_simd(_len: usize) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn should_use_simd(_len: usize) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
pub(crate) fn lanes<E: Scalar>() -> usize {
|
||||
#[allow(non_camel_case_types)]
|
||||
struct lanes<__T0>(__T0);
|
||||
|
||||
impl<E: Scalar> ::macerator::WithSimd for lanes<PhantomData<E>> {
|
||||
type Output = usize;
|
||||
#[inline(always)]
|
||||
fn with_simd<__S: ::macerator::Simd>(self) -> <Self as ::macerator::WithSimd>::Output {
|
||||
let Self(__ty) = self;
|
||||
#[allow(unused_unsafe)]
|
||||
unsafe {
|
||||
lanes_simd::<__S, E>(__ty)
|
||||
}
|
||||
}
|
||||
}
|
||||
(Arch::new()).dispatch(lanes(PhantomData::<E>))
|
||||
}
|
||||
|
||||
fn lanes_simd<S: Simd, E: Scalar>(_ty: PhantomData<E>) -> usize {
|
||||
E::lanes::<S>()
|
||||
}
|
||||
|
||||
pub(crate) fn uninit_array_like<In, Out>(reference: &ArcArray<In, IxDyn>) -> ArrayD<Out> {
|
||||
let shape = reference.raw_dim();
|
||||
let strides = reference.strides();
|
||||
let strides = strides.iter().map(|it| *it as usize).collect::<Vec<_>>();
|
||||
let shape_strides = shape.strides(IxDyn(&strides));
|
||||
let size = reference.len();
|
||||
let mut out_data: Vec<MaybeUninit<Out>> = Vec::with_capacity(size);
|
||||
unsafe { out_data.set_len(size) };
|
||||
unsafe { ArrayD::from_shape_vec_unchecked(shape_strides, out_data).assume_init() }
|
||||
}
|
||||
|
||||
pub trait MinMax {
|
||||
fn min(self, other: Self) -> Self;
|
||||
fn max(self, other: Self) -> Self;
|
||||
}
|
||||
|
||||
macro_rules! impl_minmax {
|
||||
($ty: ty) => {
|
||||
impl MinMax for $ty {
|
||||
fn min(self, other: Self) -> Self {
|
||||
Ord::min(self, other)
|
||||
}
|
||||
fn max(self, other: Self) -> Self {
|
||||
Ord::max(self, other)
|
||||
}
|
||||
}
|
||||
};
|
||||
($($ty: ty),*) => {
|
||||
$(impl_minmax!($ty);)*
|
||||
}
|
||||
}
|
||||
|
||||
impl_minmax!(u8, i8, u16, i16, u32, i32, u64, i64);
|
||||
|
||||
impl MinMax for f32 {
|
||||
fn min(self, other: Self) -> Self {
|
||||
self.min(other)
|
||||
}
|
||||
|
||||
fn max(self, other: Self) -> Self {
|
||||
self.max(other)
|
||||
}
|
||||
}
|
||||
|
||||
impl MinMax for f64 {
|
||||
fn min(self, other: Self) -> Self {
|
||||
self.min(other)
|
||||
}
|
||||
|
||||
fn max(self, other: Self) -> Self {
|
||||
self.max(other)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,299 @@
|
||||
use core::{marker::PhantomData, slice};
|
||||
|
||||
use burn_backend::Element;
|
||||
use macerator::{
|
||||
Scalar, Simd, VAdd, VBitAnd, VBitOr, VBitXor, VDiv, VMul, VOrd, VSub, Vector, vload_unaligned,
|
||||
vstore_unaligned,
|
||||
};
|
||||
use ndarray::ArrayD;
|
||||
use seq_macro::seq;
|
||||
|
||||
use crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like};
|
||||
|
||||
use super::{
|
||||
MinMax,
|
||||
binary_elemwise::{
|
||||
VecAdd, VecBitAnd, VecBitOr, VecBitXor, VecDiv, VecMax, VecMin, VecMul, VecSub,
|
||||
},
|
||||
should_use_simd,
|
||||
};
|
||||
|
||||
pub trait SimdBinop<T: Scalar, Out: Scalar> {
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, Out>;
|
||||
fn apply(lhs: T, rhs: T) -> Out;
|
||||
fn is_accelerated<S: Simd>() -> bool;
|
||||
}
|
||||
|
||||
impl<T: VAdd> SimdBinop<T, T> for VecAdd {
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {
|
||||
lhs + rhs
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> T {
|
||||
lhs + rhs
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VAdd>::is_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VDiv> SimdBinop<T, T> for VecDiv {
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {
|
||||
lhs / rhs
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> T {
|
||||
lhs / rhs
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VDiv>::is_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VMul> SimdBinop<T, T> for VecMul {
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {
|
||||
lhs * rhs
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> T {
|
||||
lhs * rhs
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VMul>::is_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VSub> SimdBinop<T, T> for VecSub {
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {
|
||||
lhs - rhs
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> T {
|
||||
lhs - rhs
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VSub>::is_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VOrd + MinMax> SimdBinop<T, T> for VecMin {
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {
|
||||
lhs.min(rhs)
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> T {
|
||||
MinMax::min(lhs, rhs)
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VOrd>::is_min_max_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VOrd + MinMax> SimdBinop<T, T> for VecMax {
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {
|
||||
lhs.max(rhs)
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> T {
|
||||
MinMax::max(lhs, rhs)
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VOrd>::is_min_max_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VBitAnd> SimdBinop<T, T> for VecBitAnd {
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {
|
||||
lhs & rhs
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> T {
|
||||
lhs.bitand(rhs)
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VBitAnd>::is_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VBitOr> SimdBinop<T, T> for VecBitOr {
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {
|
||||
lhs | rhs
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> T {
|
||||
lhs.bitor(rhs)
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VBitOr>::is_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VBitXor> SimdBinop<T, T> for VecBitXor {
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {
|
||||
lhs ^ rhs
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> T {
|
||||
lhs.bitxor(rhs)
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VBitXor>::is_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
#[macerator::with_simd]
|
||||
fn is_accelerated<S: Simd, T: Scalar, Out: Scalar, Op: SimdBinop<T, Out>>(
|
||||
_x: PhantomData<(T, Out, Op)>,
|
||||
) -> bool {
|
||||
Op::is_accelerated::<S>()
|
||||
}
|
||||
|
||||
#[allow(clippy::result_large_err)]
|
||||
pub fn try_binary_simd<
|
||||
E: Element,
|
||||
EOut: Element,
|
||||
T: NdArrayElement + Scalar,
|
||||
Out: NdArrayElement + Scalar,
|
||||
Op: SimdBinop<T, Out>,
|
||||
>(
|
||||
lhs: SharedArray<E>,
|
||||
rhs: SharedArray<E>,
|
||||
) -> Result<SharedArray<EOut>, (SharedArray<E>, SharedArray<E>)> {
|
||||
let lhs_len = lhs.len();
|
||||
let rhs_len = rhs.len();
|
||||
if !should_use_simd(lhs_len.max(rhs_len))
|
||||
|| !lhs.is_standard_layout()
|
||||
|| !rhs.is_standard_layout()
|
||||
|| lhs.shape() != rhs.shape()
|
||||
|| !is_accelerated::<T, Out, Op>(PhantomData)
|
||||
{
|
||||
return Err((lhs, rhs));
|
||||
}
|
||||
// Used to assert traits based on the dynamic `DType`.
|
||||
let lhs = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(lhs) };
|
||||
let rhs = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(rhs) };
|
||||
let out = binary_simd_same::<T, Out, Op>(lhs, rhs);
|
||||
|
||||
// Used to assert traits based on the dynamic `DType`.
|
||||
let out = unsafe { core::mem::transmute::<SharedArray<Out>, SharedArray<EOut>>(out) };
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn binary_simd_same<
|
||||
T: NdArrayElement + Scalar,
|
||||
Out: NdArrayElement + Scalar,
|
||||
Op: SimdBinop<T, Out>,
|
||||
>(
|
||||
lhs: SharedArray<T>,
|
||||
rhs: SharedArray<T>,
|
||||
) -> SharedArray<Out> {
|
||||
let out = if lhs.is_unique() {
|
||||
let mut buf = lhs.into_owned();
|
||||
let lhs = buf.as_slice_mut().unwrap();
|
||||
let rhs = rhs.as_slice().unwrap();
|
||||
let out =
|
||||
unsafe { core::mem::transmute::<&mut [T], &mut [Out]>(unsafe_alias_slice_mut(lhs)) };
|
||||
binary(lhs, rhs, out, PhantomData::<Op>);
|
||||
unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<Out>>(buf) }
|
||||
} else if rhs.is_unique() {
|
||||
let mut buf = rhs.into_owned();
|
||||
let lhs = lhs.as_slice().unwrap();
|
||||
let rhs = buf.as_slice_mut().unwrap();
|
||||
let out =
|
||||
unsafe { core::mem::transmute::<&mut [T], &mut [Out]>(unsafe_alias_slice_mut(rhs)) };
|
||||
binary(lhs, rhs, out, PhantomData::<Op>);
|
||||
unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<Out>>(buf) }
|
||||
} else {
|
||||
let mut out = uninit_array_like(&lhs);
|
||||
let lhs = lhs.as_slice().unwrap();
|
||||
let rhs = rhs.as_slice().unwrap();
|
||||
let out_slice = out.as_slice_mut().unwrap();
|
||||
binary(lhs, rhs, out_slice, PhantomData::<Op>);
|
||||
out
|
||||
};
|
||||
out.into_shared()
|
||||
}
|
||||
|
||||
#[allow(clippy::erasing_op, clippy::identity_op)]
|
||||
#[macerator::with_simd]
|
||||
fn binary<
|
||||
'a,
|
||||
S: Simd,
|
||||
T: NdArrayElement + Scalar,
|
||||
Out: NdArrayElement + Scalar,
|
||||
Op: SimdBinop<T, Out>,
|
||||
>(
|
||||
lhs: &'a [T],
|
||||
rhs: &'a [T],
|
||||
out: &'a mut [Out],
|
||||
_op: PhantomData<Op>,
|
||||
) where
|
||||
'a: 'a,
|
||||
{
|
||||
let lanes = T::lanes::<S>();
|
||||
let mut chunks_lhs = lhs.chunks_exact(8 * lanes);
|
||||
let mut chunks_rhs = rhs.chunks_exact(8 * lanes);
|
||||
let mut chunks_out = out.chunks_exact_mut(8 * lanes);
|
||||
while let Some(((lhs, rhs), out)) = chunks_lhs
|
||||
.next()
|
||||
.zip(chunks_rhs.next())
|
||||
.zip(chunks_out.next())
|
||||
{
|
||||
seq!(N in 0..8 {
|
||||
// Load one full vector from `lhs`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
|
||||
let lhs~N = unsafe { vload_unaligned::<S, _>(&lhs[N * lanes]) };
|
||||
// Load one full vector from `rhs`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
|
||||
let rhs~N = unsafe { vload_unaligned(&rhs[N * lanes]) };
|
||||
let s~N = Op::apply_vec(lhs~N, rhs~N);
|
||||
// Store one full vector to `out`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
|
||||
unsafe { vstore_unaligned(&mut out[N * lanes], s~N) };
|
||||
});
|
||||
}
|
||||
let mut chunks_lhs = chunks_lhs.remainder().chunks_exact(lanes);
|
||||
let mut chunks_rhs = chunks_rhs.remainder().chunks_exact(lanes);
|
||||
let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes);
|
||||
while let Some(((lhs, rhs), out)) = chunks_lhs
|
||||
.next()
|
||||
.zip(chunks_rhs.next())
|
||||
.zip(chunks_out.next())
|
||||
{
|
||||
// Load one full vector from `lhs`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == lanes`
|
||||
let lhs0 = unsafe { vload_unaligned::<S, _>(lhs.as_ptr()) };
|
||||
// Load one full vector from `rhs`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == lanes`
|
||||
let rhs0 = unsafe { vload_unaligned(rhs.as_ptr()) };
|
||||
let s0 = Op::apply_vec(lhs0, rhs0);
|
||||
// Store one full vector to `out`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == lanes`
|
||||
unsafe { vstore_unaligned(out.as_mut_ptr(), s0) };
|
||||
}
|
||||
|
||||
for ((lhs, rhs), out) in chunks_lhs
|
||||
.remainder()
|
||||
.iter()
|
||||
.zip(chunks_rhs.remainder())
|
||||
.zip(chunks_out.into_remainder())
|
||||
{
|
||||
*out = Op::apply(*lhs, *rhs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Unsafely alias a slice to use as an inline argument
|
||||
fn unsafe_alias_slice_mut<'a, T>(slice: &mut [T]) -> &'a mut [T] {
|
||||
let ptr = slice.as_mut_ptr();
|
||||
let len = slice.len();
|
||||
unsafe { slice::from_raw_parts_mut(ptr, len) }
|
||||
}
|
||||
@@ -0,0 +1,419 @@
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use bytemuck::cast;
|
||||
use macerator::{
|
||||
Scalar, Simd, VAdd, VBitAnd, VBitOr, VBitXor, VDiv, VMul, VOrd, VSub, Vector, vload,
|
||||
vload_unaligned, vstore, vstore_unaligned,
|
||||
};
|
||||
use ndarray::ArrayD;
|
||||
use seq_macro::seq;
|
||||
|
||||
use crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like};
|
||||
|
||||
use super::{MinMax, should_use_simd};
|
||||
|
||||
pub trait ScalarSimdBinop<T: Scalar, Out: Scalar> {
|
||||
type Rhs: Copy;
|
||||
type RhsVec<S: Simd>: Copy;
|
||||
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S>;
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, Out>;
|
||||
fn apply(lhs: T, rhs: Self::Rhs) -> Out;
|
||||
fn is_accelerated<S: Simd>() -> bool;
|
||||
}
|
||||
|
||||
pub struct VecAdd;
|
||||
pub struct VecDiv;
|
||||
pub struct VecMul;
|
||||
pub struct VecSub;
|
||||
pub struct VecMin;
|
||||
pub struct VecMax;
|
||||
pub struct VecClamp;
|
||||
pub struct VecBitAnd;
|
||||
pub struct VecBitOr;
|
||||
pub struct VecBitXor;
|
||||
|
||||
impl<T: VAdd> ScalarSimdBinop<T, T> for VecAdd {
|
||||
type Rhs = T;
|
||||
type RhsVec<S: Simd> = Vector<S, T>;
|
||||
|
||||
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {
|
||||
rhs.splat()
|
||||
}
|
||||
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {
|
||||
lhs + rhs
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> T {
|
||||
lhs + rhs
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VAdd>::is_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VDiv> ScalarSimdBinop<T, T> for VecDiv {
|
||||
type Rhs = T;
|
||||
type RhsVec<S: Simd> = Vector<S, T>;
|
||||
|
||||
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {
|
||||
rhs.splat()
|
||||
}
|
||||
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {
|
||||
lhs / rhs
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> T {
|
||||
lhs / rhs
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VDiv>::is_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VMul> ScalarSimdBinop<T, T> for VecMul {
|
||||
type Rhs = T;
|
||||
type RhsVec<S: Simd> = Vector<S, T>;
|
||||
|
||||
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {
|
||||
rhs.splat()
|
||||
}
|
||||
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {
|
||||
lhs * rhs
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> T {
|
||||
lhs * rhs
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VMul>::is_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VSub> ScalarSimdBinop<T, T> for VecSub {
|
||||
type Rhs = T;
|
||||
type RhsVec<S: Simd> = Vector<S, T>;
|
||||
|
||||
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {
|
||||
rhs.splat()
|
||||
}
|
||||
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {
|
||||
lhs - rhs
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> T {
|
||||
lhs - rhs
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VSub>::is_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VOrd + MinMax> ScalarSimdBinop<T, T> for VecMin {
|
||||
type Rhs = T;
|
||||
type RhsVec<S: Simd> = Vector<S, T>;
|
||||
|
||||
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {
|
||||
rhs.splat()
|
||||
}
|
||||
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {
|
||||
lhs.min(rhs)
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> T {
|
||||
lhs.min(rhs)
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VOrd>::is_min_max_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VOrd + MinMax> ScalarSimdBinop<T, T> for VecMax {
|
||||
type Rhs = T;
|
||||
type RhsVec<S: Simd> = Vector<S, T>;
|
||||
|
||||
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {
|
||||
rhs.splat()
|
||||
}
|
||||
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {
|
||||
lhs.max(rhs)
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> T {
|
||||
lhs.max(rhs)
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VOrd>::is_min_max_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VOrd + MinMax> ScalarSimdBinop<T, T> for VecClamp {
|
||||
type Rhs = (T, T);
|
||||
type RhsVec<S: Simd> = (Vector<S, T>, Vector<S, T>);
|
||||
|
||||
fn splat<S: Simd>((min, max): Self::Rhs) -> Self::RhsVec<S> {
|
||||
(min.splat(), max.splat())
|
||||
}
|
||||
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, (min, max): Self::RhsVec<S>) -> Vector<S, T> {
|
||||
lhs.min(max).max(min)
|
||||
}
|
||||
|
||||
fn apply(lhs: T, (min, max): Self::Rhs) -> T {
|
||||
lhs.min(max).max(min)
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VOrd>::is_min_max_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VBitAnd> ScalarSimdBinop<T, T> for VecBitAnd {
|
||||
type Rhs = T;
|
||||
type RhsVec<S: Simd> = Vector<S, T>;
|
||||
|
||||
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {
|
||||
rhs.splat()
|
||||
}
|
||||
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {
|
||||
lhs & rhs
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: Self::Rhs) -> T {
|
||||
lhs & rhs
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VBitAnd>::is_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VBitOr> ScalarSimdBinop<T, T> for VecBitOr {
|
||||
type Rhs = T;
|
||||
type RhsVec<S: Simd> = Vector<S, T>;
|
||||
|
||||
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {
|
||||
rhs.splat()
|
||||
}
|
||||
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {
|
||||
lhs | rhs
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: Self::Rhs) -> T {
|
||||
lhs | rhs
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VBitOr>::is_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VBitXor> ScalarSimdBinop<T, T> for VecBitXor {
|
||||
type Rhs = T;
|
||||
type RhsVec<S: Simd> = Vector<S, T>;
|
||||
|
||||
fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {
|
||||
rhs.splat()
|
||||
}
|
||||
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {
|
||||
lhs ^ rhs
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: Self::Rhs) -> T {
|
||||
lhs ^ rhs
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VBitXor>::is_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
#[macerator::with_simd]
|
||||
fn is_accelerated<S: Simd, T: Scalar, Out: Scalar, Op: ScalarSimdBinop<T, Out>>(
|
||||
_x: PhantomData<(T, Out, Op)>,
|
||||
) -> bool {
|
||||
Op::is_accelerated::<S>()
|
||||
}
|
||||
|
||||
pub fn try_binary_scalar_simd<
|
||||
E: NdArrayElement,
|
||||
EOut: NdArrayElement,
|
||||
T: NdArrayElement + Scalar,
|
||||
Out: NdArrayElement + Scalar,
|
||||
Op: ScalarSimdBinop<T, Out>,
|
||||
>(
|
||||
input: SharedArray<E>,
|
||||
elem: Op::Rhs,
|
||||
) -> Result<SharedArray<EOut>, SharedArray<E>> {
|
||||
if !should_use_simd(input.len())
|
||||
|| input.as_slice_memory_order().is_none()
|
||||
|| !is_accelerated::<T, Out, Op>(PhantomData)
|
||||
{
|
||||
return Err(input);
|
||||
}
|
||||
// Used to assert traits based on the dynamic `DType`.
|
||||
let input = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(input) };
|
||||
let out = if size_of::<T>() == size_of::<Out>()
|
||||
&& align_of::<T>() >= align_of::<Out>()
|
||||
&& input.is_unique()
|
||||
{
|
||||
unsafe { binary_scalar_simd_inplace::<T, Out, Op>(input, elem) }
|
||||
} else {
|
||||
binary_scalar_simd_owned::<T, Out, Op>(input, elem)
|
||||
};
|
||||
// Used to assert traits based on the dynamic `DType`.
|
||||
let out = unsafe { core::mem::transmute::<SharedArray<Out>, SharedArray<EOut>>(out) };
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Execute operation in place on an owned tensor
|
||||
/// SAFETY:
|
||||
/// Must ensure `size_of::<T> == size_of::<Out>` and `align_of::<T> >= align_of::<Out>`.
|
||||
unsafe fn binary_scalar_simd_inplace<
|
||||
T: NdArrayElement + Scalar,
|
||||
Out: NdArrayElement + Scalar,
|
||||
Op: ScalarSimdBinop<T, Out>,
|
||||
>(
|
||||
input: SharedArray<T>,
|
||||
elem: Op::Rhs,
|
||||
) -> SharedArray<Out> {
|
||||
let mut buffer = input.into_owned();
|
||||
let slice = buffer.as_slice_memory_order_mut().unwrap();
|
||||
unsafe { binary_scalar_slice_inplace::<T, Out, Op>(slice, elem, PhantomData) };
|
||||
// Buffer has the same elem size and is filled with the operation output, so this is safe
|
||||
let out = unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<Out>>(buffer) };
|
||||
out.into_shared()
|
||||
}
|
||||
|
||||
/// Create a new copy of the tensor as the output
|
||||
fn binary_scalar_simd_owned<
|
||||
T: NdArrayElement + Scalar,
|
||||
Out: NdArrayElement + Scalar,
|
||||
Op: ScalarSimdBinop<T, Out>,
|
||||
>(
|
||||
input: SharedArray<T>,
|
||||
elem: Op::Rhs,
|
||||
) -> SharedArray<Out> {
|
||||
let mut out = uninit_array_like(&input);
|
||||
let input = input.as_slice_memory_order().unwrap();
|
||||
let out_slice = out.as_slice_memory_order_mut().unwrap();
|
||||
binary_scalar_slice::<T, Out, Op>(input, out_slice, elem, PhantomData);
|
||||
out.into_shared()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[allow(clippy::erasing_op, clippy::identity_op)]
|
||||
#[macerator::with_simd]
|
||||
fn binary_scalar_slice<
|
||||
'a,
|
||||
S: Simd,
|
||||
T: NdArrayElement + Scalar,
|
||||
Out: NdArrayElement + Scalar,
|
||||
Op: ScalarSimdBinop<T, Out>,
|
||||
>(
|
||||
input: &'a [T],
|
||||
out: &'a mut [Out],
|
||||
rhs: Op::Rhs,
|
||||
_op: PhantomData<Op>,
|
||||
) where
|
||||
'a: 'a,
|
||||
{
|
||||
let lanes = T::lanes::<S>();
|
||||
let mut chunks_input = input.chunks_exact(8 * lanes);
|
||||
let mut chunks_out = out.chunks_exact_mut(8 * lanes);
|
||||
let rhs_vec = Op::splat::<S>(rhs);
|
||||
while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {
|
||||
seq!(N in 0..8 {
|
||||
// Load one full vector from `input`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
|
||||
let s~N = unsafe { vload_unaligned(&input[N * lanes]) };
|
||||
let s~N = Op::apply_vec(s~N, rhs_vec);
|
||||
// Store one full vector to `out`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
|
||||
unsafe { vstore_unaligned(&mut out[N * lanes], s~N) };
|
||||
});
|
||||
}
|
||||
let mut chunks_input = chunks_input.remainder().chunks_exact(lanes);
|
||||
let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes);
|
||||
while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {
|
||||
// Load one full vector from `input`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == lanes`
|
||||
let s0 = unsafe { vload_unaligned(input.as_ptr()) };
|
||||
let s0 = Op::apply_vec(s0, rhs_vec);
|
||||
// Store one full vector to `out`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == lanes`
|
||||
unsafe { vstore_unaligned(out.as_mut_ptr(), s0) };
|
||||
}
|
||||
|
||||
for (input, out) in chunks_input
|
||||
.remainder()
|
||||
.iter()
|
||||
.zip(chunks_out.into_remainder())
|
||||
{
|
||||
*out = Op::apply(*input, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute operation in line.
|
||||
/// SAFETY:
|
||||
/// Must ensure `size_of::<T> == size_of::<Out>` and `align_of::<T> >= align_of::<Out>`.
|
||||
#[inline(always)]
|
||||
#[macerator::with_simd]
|
||||
unsafe fn binary_scalar_slice_inplace<
|
||||
'a,
|
||||
S: Simd,
|
||||
T: NdArrayElement + Scalar,
|
||||
Out: NdArrayElement + Scalar,
|
||||
Op: ScalarSimdBinop<T, Out>,
|
||||
>(
|
||||
buf: &'a mut [T],
|
||||
rhs: Op::Rhs,
|
||||
_op: PhantomData<(Out, Op)>,
|
||||
) where
|
||||
'a: 'a,
|
||||
{
|
||||
let (head, main, tail) = unsafe { buf.align_to_mut::<Vector<S, T>>() };
|
||||
for elem in head.iter_mut().chain(tail) {
|
||||
*elem = cast(Op::apply(*elem, rhs));
|
||||
}
|
||||
let mut chunks = main.chunks_exact_mut(8);
|
||||
let rhs = Op::splat::<S>(rhs);
|
||||
for elem in chunks.by_ref() {
|
||||
seq!(N in 0..8 {
|
||||
// Load a full vector from the aligned portion of the buffer.
|
||||
// SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is
|
||||
// always a full vector in bounds.
|
||||
let s~N = unsafe { vload(&elem[N] as *const _ as *const T) };
|
||||
let s~N = Op::apply_vec(s~N, rhs);
|
||||
// Store a full vector at the same position as the input. Cast is safe because `Out` is
|
||||
// size and align compatible
|
||||
unsafe { vstore_unaligned(&mut elem[N] as *mut _ as *mut Out, s~N) };
|
||||
});
|
||||
}
|
||||
|
||||
for elem in chunks.into_remainder() {
|
||||
// Load a full vector from the aligned portion of the buffer.
|
||||
// SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is
|
||||
// always a full vector in bounds.
|
||||
let s0 = unsafe { vload(elem as *const _ as *const T) };
|
||||
|
||||
let s0 = Op::apply_vec(s0, rhs);
|
||||
// Store a full vector at the same position as the input. Cast is safe because `Out` is
|
||||
// size and align compatible
|
||||
unsafe { vstore(elem as *mut _ as *mut Out, s0) };
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,374 @@
|
||||
use core::{marker::PhantomData, slice};
|
||||
|
||||
use burn_backend::Element;
|
||||
use macerator::{Mask, Scalar, Simd, VEq, VOrd, Vector, vload_unaligned};
|
||||
use ndarray::ArrayD;
|
||||
use seq_macro::seq;
|
||||
|
||||
use crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like};
|
||||
|
||||
use super::should_use_simd;
|
||||
|
||||
pub trait SimdCmpOp<T: Scalar> {
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T>;
|
||||
fn apply(lhs: T, rhs: T) -> bool;
|
||||
fn is_accelerated<S: Simd>() -> bool;
|
||||
}
|
||||
|
||||
pub struct VecEquals;
|
||||
|
||||
impl<T: VEq> SimdCmpOp<T> for VecEquals {
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {
|
||||
lhs.eq(rhs)
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> bool {
|
||||
lhs == rhs
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VEq>::is_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VecGreater;
|
||||
|
||||
impl<T: VOrd + PartialOrd> SimdCmpOp<T> for VecGreater {
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {
|
||||
lhs.gt(rhs)
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> bool {
|
||||
lhs > rhs
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VOrd>::is_cmp_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VecGreaterEq;
|
||||
|
||||
impl<T: VOrd + PartialOrd> SimdCmpOp<T> for VecGreaterEq {
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {
|
||||
lhs.ge(rhs)
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> bool {
|
||||
lhs >= rhs
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VOrd>::is_cmp_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VecLowerEq;
|
||||
|
||||
impl<T: VOrd + PartialOrd> SimdCmpOp<T> for VecLowerEq {
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {
|
||||
lhs.le(rhs)
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> bool {
|
||||
lhs <= rhs
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VOrd>::is_cmp_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VecLower;
|
||||
|
||||
impl<T: VOrd + PartialOrd> SimdCmpOp<T> for VecLower {
|
||||
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {
|
||||
lhs.lt(rhs)
|
||||
}
|
||||
|
||||
fn apply(lhs: T, rhs: T) -> bool {
|
||||
lhs < rhs
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VOrd>::is_cmp_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
#[macerator::with_simd]
|
||||
fn is_accelerated<S: Simd, T: Scalar, Op: SimdCmpOp<T>>(_x: PhantomData<(T, Op)>) -> bool {
|
||||
Op::is_accelerated::<S>()
|
||||
}
|
||||
|
||||
#[allow(clippy::result_large_err)]
|
||||
pub fn try_cmp_simd<E: Element, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
|
||||
lhs: SharedArray<E>,
|
||||
rhs: SharedArray<E>,
|
||||
) -> Result<SharedArray<bool>, (SharedArray<E>, SharedArray<E>)> {
|
||||
let lhs_len = lhs.len();
|
||||
let rhs_len = rhs.len();
|
||||
if !should_use_simd(lhs_len.max(rhs_len))
|
||||
|| !lhs.is_standard_layout()
|
||||
|| !rhs.is_standard_layout()
|
||||
|| lhs.shape() != rhs.shape()
|
||||
|| !is_accelerated::<T, Op>(PhantomData)
|
||||
{
|
||||
return Err((lhs, rhs));
|
||||
}
|
||||
// Used to assert traits based on the dynamic `DType`.
|
||||
let lhs = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(lhs) };
|
||||
let rhs = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(rhs) };
|
||||
let out = cmp_simd_same::<T, Op>(lhs, rhs);
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn cmp_simd_same<T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
|
||||
lhs: SharedArray<T>,
|
||||
rhs: SharedArray<T>,
|
||||
) -> SharedArray<bool> {
|
||||
let out = if lhs.is_unique() && size_of::<T>() == size_of::<bool>() {
|
||||
let mut buf = lhs.into_owned();
|
||||
let lhs = buf.as_slice_mut().unwrap();
|
||||
let rhs = rhs.as_slice().unwrap();
|
||||
let out =
|
||||
unsafe { core::mem::transmute::<&mut [T], &mut [bool]>(unsafe_alias_slice_mut(lhs)) };
|
||||
cmp(lhs, rhs, out, PhantomData::<Op>);
|
||||
unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<bool>>(buf) }
|
||||
} else if rhs.is_unique() && size_of::<T>() == size_of::<bool>() {
|
||||
let mut buf = rhs.into_owned();
|
||||
let lhs = lhs.as_slice().unwrap();
|
||||
let rhs = buf.as_slice_mut().unwrap();
|
||||
let out =
|
||||
unsafe { core::mem::transmute::<&mut [T], &mut [bool]>(unsafe_alias_slice_mut(rhs)) };
|
||||
cmp(lhs, rhs, out, PhantomData::<Op>);
|
||||
unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<bool>>(buf) }
|
||||
} else {
|
||||
let mut out = uninit_array_like(&lhs);
|
||||
let lhs = lhs.as_slice().unwrap();
|
||||
let rhs = rhs.as_slice().unwrap();
|
||||
let out_slice = out.as_slice_mut().unwrap();
|
||||
cmp(lhs, rhs, out_slice, PhantomData::<Op>);
|
||||
out
|
||||
};
|
||||
out.into_shared()
|
||||
}
|
||||
|
||||
#[allow(clippy::erasing_op, clippy::identity_op)]
|
||||
#[macerator::with_simd]
|
||||
fn cmp<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
|
||||
lhs: &'a [T],
|
||||
rhs: &'a [T],
|
||||
out: &'a mut [bool],
|
||||
_op: PhantomData<Op>,
|
||||
) where
|
||||
'a: 'a,
|
||||
{
|
||||
let lanes = T::lanes::<S>();
|
||||
let mut chunks_lhs = lhs.chunks_exact(8 * lanes);
|
||||
let mut chunks_rhs = rhs.chunks_exact(8 * lanes);
|
||||
let mut chunks_out = out.chunks_exact_mut(8 * lanes);
|
||||
while let Some(((lhs, rhs), out)) = chunks_lhs
|
||||
.next()
|
||||
.zip(chunks_rhs.next())
|
||||
.zip(chunks_out.next())
|
||||
{
|
||||
seq!(N in 0..8 {
|
||||
// Load one full vector from `lhs`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
|
||||
let lhs~N = unsafe { vload_unaligned::<S, _>(&lhs[N * lanes]) };
|
||||
// Load one full vector from `rhs`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
|
||||
let rhs~N = unsafe { vload_unaligned(&rhs[N * lanes]) };
|
||||
let s~N = Op::apply_vec(lhs~N, rhs~N);
|
||||
// Store one full vector to `out`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
|
||||
unsafe { T::mask_store_as_bool(&mut out[N * lanes], s~N) };
|
||||
});
|
||||
}
|
||||
let mut chunks_lhs = chunks_lhs.remainder().chunks_exact(lanes);
|
||||
let mut chunks_rhs = chunks_rhs.remainder().chunks_exact(lanes);
|
||||
let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes);
|
||||
while let Some(((lhs, rhs), out)) = chunks_lhs
|
||||
.next()
|
||||
.zip(chunks_rhs.next())
|
||||
.zip(chunks_out.next())
|
||||
{
|
||||
// Load one full vector from `lhs`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == lanes`
|
||||
let lhs0 = unsafe { vload_unaligned::<S, _>(lhs.as_ptr()) };
|
||||
// Load one full vector from `rhs`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == lanes`
|
||||
let rhs0 = unsafe { vload_unaligned(rhs.as_ptr()) };
|
||||
let s0 = Op::apply_vec(lhs0, rhs0);
|
||||
// Store one full vector to `out`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == lanes`
|
||||
unsafe { T::mask_store_as_bool(out.as_mut_ptr(), s0) };
|
||||
}
|
||||
|
||||
for ((lhs, rhs), out) in chunks_lhs
|
||||
.remainder()
|
||||
.iter()
|
||||
.zip(chunks_rhs.remainder())
|
||||
.zip(chunks_out.into_remainder())
|
||||
{
|
||||
*out = Op::apply(*lhs, *rhs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Unsafely alias a slice to use as an inline argument
|
||||
fn unsafe_alias_slice_mut<'a, T>(slice: &mut [T]) -> &'a mut [T] {
|
||||
let ptr = slice.as_mut_ptr();
|
||||
let len = slice.len();
|
||||
unsafe { slice::from_raw_parts_mut(ptr, len) }
|
||||
}
|
||||
|
||||
pub use elemwise::try_cmp_scalar_simd;
|
||||
|
||||
mod elemwise {
|
||||
use bytemuck::cast;
|
||||
use macerator::vload;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub fn try_cmp_scalar_simd<E: Element, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
|
||||
input: SharedArray<E>,
|
||||
elem: T,
|
||||
) -> Result<SharedArray<bool>, SharedArray<E>> {
|
||||
if !should_use_simd(input.len())
|
||||
|| input.as_slice_memory_order().is_none()
|
||||
|| !is_accelerated::<T, Op>(PhantomData)
|
||||
{
|
||||
return Err(input);
|
||||
}
|
||||
// Used to assert traits based on the dynamic `DType`.
|
||||
let input = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(input) };
|
||||
let out = if size_of::<T>() == size_of::<bool>()
|
||||
&& align_of::<T>() >= align_of::<bool>()
|
||||
&& input.is_unique()
|
||||
{
|
||||
unsafe { cmp_scalar_simd_inplace::<T, Op>(input, elem) }
|
||||
} else {
|
||||
cmp_scalar_simd_owned::<T, Op>(input, elem)
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Execute operation in place on an owned tensor
|
||||
/// SAFETY:
|
||||
/// Must ensure `size_of::<T> == size_of::<Out>` and `align_of::<T> >= align_of::<Out>`.
|
||||
unsafe fn cmp_scalar_simd_inplace<T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
|
||||
input: SharedArray<T>,
|
||||
elem: T,
|
||||
) -> SharedArray<bool> {
|
||||
let mut buffer = input.into_owned();
|
||||
let slice = buffer.as_slice_memory_order_mut().unwrap();
|
||||
unsafe { cmp_scalar_slice_inplace::<T, Op>(slice, elem, PhantomData) };
|
||||
// Buffer has the same elem size and is filled with the operation output, so this is safe
|
||||
let out = unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<bool>>(buffer) };
|
||||
out.into_shared()
|
||||
}
|
||||
|
||||
/// Create a new copy of the tensor as the output
|
||||
fn cmp_scalar_simd_owned<T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
|
||||
input: SharedArray<T>,
|
||||
elem: T,
|
||||
) -> SharedArray<bool> {
|
||||
let mut out = uninit_array_like(&input);
|
||||
let input = input.as_slice_memory_order().unwrap();
|
||||
let out_slice = out.as_slice_memory_order_mut().unwrap();
|
||||
cmp_scalar_slice::<T, Op>(input, out_slice, elem, PhantomData);
|
||||
out.into_shared()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[allow(clippy::erasing_op, clippy::identity_op)]
|
||||
#[macerator::with_simd]
|
||||
fn cmp_scalar_slice<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
|
||||
input: &'a [T],
|
||||
out: &'a mut [bool],
|
||||
rhs: T,
|
||||
_op: PhantomData<Op>,
|
||||
) where
|
||||
'a: 'a,
|
||||
{
|
||||
let lanes = T::lanes::<S>();
|
||||
let mut chunks_input = input.chunks_exact(8 * lanes);
|
||||
let mut chunks_out = out.chunks_exact_mut(8 * lanes);
|
||||
let rhs_vec = rhs.splat::<S>();
|
||||
while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {
|
||||
seq!(N in 0..8 {
|
||||
// Load one full vector from `input`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
|
||||
let s~N = unsafe { vload_unaligned(&input[N * lanes]) };
|
||||
let s~N = Op::apply_vec(s~N, rhs_vec);
|
||||
// Store one full vector to `out`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
|
||||
unsafe { T::mask_store_as_bool(&mut out[N * lanes], s~N) };
|
||||
});
|
||||
}
|
||||
let mut chunks_input = chunks_input.remainder().chunks_exact(lanes);
|
||||
let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes);
|
||||
while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {
|
||||
// Load one full vector from `input`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == lanes`
|
||||
let s0 = unsafe { vload_unaligned(input.as_ptr()) };
|
||||
let s0 = Op::apply_vec(s0, rhs_vec);
|
||||
// Store one full vector to `out`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == lanes`
|
||||
unsafe { T::mask_store_as_bool(out.as_mut_ptr(), s0) };
|
||||
}
|
||||
|
||||
for (input, out) in chunks_input
|
||||
.remainder()
|
||||
.iter()
|
||||
.zip(chunks_out.into_remainder())
|
||||
{
|
||||
*out = Op::apply(*input, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute operation in line.
|
||||
/// SAFETY:
|
||||
/// Must ensure `size_of::<T> == size_of::<Out>` and `align_of::<T> >= align_of::<Out>`.
|
||||
#[inline(always)]
|
||||
#[macerator::with_simd]
|
||||
unsafe fn cmp_scalar_slice_inplace<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
|
||||
buf: &'a mut [T],
|
||||
rhs: T,
|
||||
_op: PhantomData<Op>,
|
||||
) where
|
||||
'a: 'a,
|
||||
{
|
||||
let (head, main, tail) = unsafe { buf.align_to_mut::<Vector<S, T>>() };
|
||||
for elem in head.iter_mut().chain(tail) {
|
||||
*elem = cast(Op::apply(*elem, rhs));
|
||||
}
|
||||
let mut chunks = main.chunks_exact_mut(8);
|
||||
let rhs = rhs.splat::<S>();
|
||||
for elem in chunks.by_ref() {
|
||||
seq!(N in 0..8 {
|
||||
// Load a full vector from the aligned portion of the buffer.
|
||||
// SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is
|
||||
// always a full vector in bounds.
|
||||
let s~N = unsafe { vload(&elem[N] as *const _ as *const T) };
|
||||
let s~N = Op::apply_vec(s~N, rhs);
|
||||
// Store a full vector at the same position as the input. Cast is safe because `Out` is
|
||||
// size and align compatible
|
||||
unsafe { T::mask_store_as_bool(&mut elem[N] as *mut _ as *mut bool, s~N) };
|
||||
});
|
||||
}
|
||||
|
||||
for elem in chunks.into_remainder() {
|
||||
// Load a full vector from the aligned portion of the buffer.
|
||||
// SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is
|
||||
// always a full vector in bounds.
|
||||
let s0 = unsafe { vload(elem as *const _ as *const T) };
|
||||
|
||||
let s0 = Op::apply_vec(s0, rhs);
|
||||
// Store a full vector at the same position as the input. Cast is safe because `Out` is
|
||||
// size and align compatible
|
||||
unsafe { T::mask_store_as_bool(elem as *mut _ as *mut bool, s0) };
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,494 @@
|
||||
use core::{marker::PhantomData, mem::transmute};
|
||||
|
||||
use burn_backend::{
|
||||
DType, Element,
|
||||
ops::{ConvOptions, conv::calculate_conv_output_size},
|
||||
};
|
||||
use bytemuck::Zeroable;
|
||||
use macerator::{Simd, VMulAdd, Vector, vload_unaligned, vstore_unaligned};
|
||||
use ndarray::{
|
||||
ArcArray1, Array4, ArrayView3, ArrayView4, ArrayViewMut2, ArrayViewMut3, Dim, Ix1, Ix4, s,
|
||||
};
|
||||
use seq_macro::seq;
|
||||
|
||||
use crate::{FloatNdArrayElement, SharedArray, UnsafeSharedRef, iter_range_par, run_par};
|
||||
|
||||
type Args<E> = (SharedArray<E>, SharedArray<E>, Option<SharedArray<E>>);
|
||||
|
||||
#[allow(clippy::result_large_err)]
|
||||
pub fn try_conv2d_simd<E: FloatNdArrayElement>(
|
||||
x: SharedArray<E>,
|
||||
weight: SharedArray<E>,
|
||||
bias: Option<SharedArray<E>>,
|
||||
options: ConvOptions<2>,
|
||||
) -> Result<SharedArray<E>, Args<E>> {
|
||||
match E::dtype() {
|
||||
DType::F64 => conv2d::<f64, _>(x, weight, bias, options, PhantomData),
|
||||
DType::F32 => conv2d::<f32, _>(x, weight, bias, options, PhantomData),
|
||||
DType::I64 => conv2d::<i64, _>(x, weight, bias, options, PhantomData),
|
||||
DType::I32 => conv2d::<i32, _>(x, weight, bias, options, PhantomData),
|
||||
DType::I16 => conv2d::<i16, _>(x, weight, bias, options, PhantomData),
|
||||
DType::U64 => conv2d::<u64, _>(x, weight, bias, options, PhantomData),
|
||||
DType::U32 => conv2d::<u32, _>(x, weight, bias, options, PhantomData),
|
||||
DType::U16 => conv2d::<u16, _>(x, weight, bias, options, PhantomData),
|
||||
_ => Err((x, weight, bias)),
|
||||
}
|
||||
}
|
||||
|
||||
fn cast<T, E>(tensor: SharedArray<T>) -> SharedArray<E> {
|
||||
unsafe { transmute::<SharedArray<T>, SharedArray<E>>(tensor) }
|
||||
}
|
||||
|
||||
/// Out-channel last SIMD accelerated direct convolution. Loop order and register blocking based on
|
||||
/// E. Georganas, S. Avancha, K. Banerjee, D. Kalamkar, G. Henry, H. Pabst, A. Heinecke (2018).
|
||||
/// Anatomy Of High-Performance Deep Learning Convolutions On SIMD Architectures.
|
||||
/// SC '18, Article 6, pp. 1-12. arXiv:1808.05567. <https://arxiv.org/abs/1808.05567>.
|
||||
#[allow(clippy::result_large_err)]
|
||||
fn conv2d<E: VMulAdd + Element, T: Element>(
|
||||
x: SharedArray<T>,
|
||||
weight: SharedArray<T>,
|
||||
bias: Option<SharedArray<T>>,
|
||||
options: ConvOptions<2>,
|
||||
_ty: PhantomData<E>,
|
||||
) -> Result<SharedArray<T>, Args<T>> {
|
||||
let [out_channels, _, k_height, k_width] = weight.shape().try_into().unwrap();
|
||||
let channels_per_group = out_channels / options.groups;
|
||||
|
||||
#[macerator::with_simd]
|
||||
fn precheck<S: Simd, E: VMulAdd>(_ty: PhantomData<E>) -> (usize, bool) {
|
||||
(E::lanes::<S>(), E::is_accelerated::<S>())
|
||||
}
|
||||
|
||||
let (lanes, accelerated) = precheck::<E>(PhantomData);
|
||||
|
||||
if !accelerated || !channels_per_group.is_multiple_of(lanes) {
|
||||
return Err((x, weight, bias));
|
||||
}
|
||||
|
||||
let x = cast::<_, E>(x);
|
||||
let weight = cast::<_, E>(weight);
|
||||
let bias = bias.map(|bias| cast::<_, E>(bias));
|
||||
|
||||
let [batch_size, _in_channels, in_height, in_width] = x.shape().try_into().unwrap();
|
||||
let [dilate_h, dilate_w] = options.dilation;
|
||||
let [stride_h, stride_w] = options.stride;
|
||||
let [pad_h, pad_w] = options.padding;
|
||||
let padded = options.padding != [0, 0];
|
||||
let strided = options.stride != [1, 1] || options.dilation != [1, 1];
|
||||
let grouped = options.groups != 1;
|
||||
|
||||
let out_height = calculate_conv_output_size(k_height, stride_h, pad_h, dilate_h, in_height);
|
||||
let out_width = calculate_conv_output_size(k_width, stride_w, pad_w, dilate_w, in_width);
|
||||
|
||||
let x = x.into_dimensionality::<Ix4>().unwrap();
|
||||
let weights = weight.into_dimensionality::<Ix4>().unwrap();
|
||||
let weights = weights.permuted_axes([1, 2, 3, 0]);
|
||||
let weights = weights.as_standard_layout();
|
||||
let bias = bias.map(|bias| bias.into_dimensionality::<Ix1>().unwrap());
|
||||
// floor division means `(oc_blocks - 1) * lanes` can never be greater than `out_channels - lanes`.
|
||||
let oc_blocks = out_channels / lanes;
|
||||
|
||||
let mut out = unsafe {
|
||||
Array4::<E>::uninit(Dim([batch_size, out_height, out_width, out_channels])).assume_init()
|
||||
};
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut out);
|
||||
|
||||
run_par!(|| {
|
||||
// SAFETY: Slices are guaranteed to be non-overlapping, so having an unsafe shared reference
|
||||
// is safe. `oc_blocks * lanes` must be `<= out_channels` to satisfy safety of inner function.
|
||||
iter_range_par!(0, batch_size * oc_blocks).for_each(|k| unsafe {
|
||||
let b = k / oc_blocks;
|
||||
let ob = k % oc_blocks;
|
||||
let x = x.slice(s![b, .., .., ..]);
|
||||
let out = unsafe_shared_out.get();
|
||||
let mut out = out.slice_mut(s![b, .., .., ..]);
|
||||
let w = weights.view();
|
||||
|
||||
match (padded, strided, grouped) {
|
||||
(true, true, true) => {
|
||||
conv2d_launch::<E, true, true, true>(x, w, &bias, &mut out, &options, ob)
|
||||
}
|
||||
(true, false, true) => {
|
||||
conv2d_launch::<E, true, false, true>(x, w, &bias, &mut out, &options, ob)
|
||||
}
|
||||
(false, true, true) => {
|
||||
conv2d_launch::<E, false, true, true>(x, w, &bias, &mut out, &options, ob)
|
||||
}
|
||||
(false, false, true) => {
|
||||
conv2d_launch::<E, false, false, true>(x, w, &bias, &mut out, &options, ob)
|
||||
}
|
||||
(true, true, false) => {
|
||||
conv2d_launch::<E, true, true, false>(x, w, &bias, &mut out, &options, ob)
|
||||
}
|
||||
(true, false, false) => {
|
||||
conv2d_launch::<E, true, false, false>(x, w, &bias, &mut out, &options, ob)
|
||||
}
|
||||
(false, true, false) => {
|
||||
conv2d_launch::<E, false, true, false>(x, w, &bias, &mut out, &options, ob)
|
||||
}
|
||||
(false, false, false) => {
|
||||
conv2d_launch::<E, false, false, false>(x, w, &bias, &mut out, &options, ob)
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
let output = out.permuted_axes([0, 3, 1, 2]);
|
||||
Ok(cast(output.into_dyn().into_shared()))
|
||||
}
|
||||
|
||||
/// Size of register blocks, we need to hardcode this because Rust and the `seq` macro don't support
|
||||
/// using associated constants as constant parameters. 8 works for all semi-modern CPUs but might
|
||||
/// not be perfectly optimized for AVX-512 capable CPUs (which probably should use 16).
|
||||
/// This should always be conservative, since oversizing it will cause register spills and that's
|
||||
/// **much** worse than the performance lost with lower values.
|
||||
const REGISTER_BLOCK: usize = 8;
|
||||
inner_with_register_blocking_size!(8);
|
||||
|
||||
/// Run a loop of conv2d.
|
||||
/// # SAFETY
|
||||
/// See `conv2d_inner_nopad`, `conv2d_inner_nopad_nostride`, `conv2d_remainder`.
|
||||
/// Required preconditions: `ob * simd_lanes` must be `<= out_channels - simd_lanes`, `weights` and
|
||||
/// `out` must have unit stride for the out channels.
|
||||
#[inline(always)]
|
||||
#[macerator::with_simd]
|
||||
unsafe fn conv2d_launch<
|
||||
'a,
|
||||
S: Simd,
|
||||
E: VMulAdd,
|
||||
const PAD: bool,
|
||||
const STRIDE: bool,
|
||||
const GROUPS: bool,
|
||||
>(
|
||||
x: ArrayView3<'a, E>,
|
||||
weights: ArrayView4<'a, E>,
|
||||
bias: &'a Option<ArcArray1<E>>,
|
||||
out: &'a mut ArrayViewMut3<'a, E>,
|
||||
options: &'a ConvOptions<2>,
|
||||
ob: usize,
|
||||
) where
|
||||
'a: 'a,
|
||||
{
|
||||
let (in_channels, k_height, k_width, out_channels) = weights.dim();
|
||||
let (out_height, out_width, _) = out.dim();
|
||||
let channels_per_group = out_channels / options.groups;
|
||||
let lanes = E::lanes::<S>();
|
||||
|
||||
let [mut pad_h, mut pad_w] = options.padding;
|
||||
let [stride_h, stride_w] = options.stride;
|
||||
let [dilate_h, dilate_w] = options.dilation;
|
||||
|
||||
// Trick compiler into inlining 0 to padding
|
||||
if !PAD {
|
||||
pad_h = 0;
|
||||
pad_w = 0;
|
||||
}
|
||||
|
||||
let oc_b = channels_per_group.min(lanes);
|
||||
let ow_b = REGISTER_BLOCK;
|
||||
|
||||
let ow_start = pad_w;
|
||||
let ow_width = out_width.saturating_sub(2 * pad_w);
|
||||
let oh_start = pad_h;
|
||||
let oh_end = out_height.saturating_sub(pad_h);
|
||||
|
||||
let ow_blocks = ow_width / ow_b;
|
||||
let oc = ob * oc_b;
|
||||
let group = oc / channels_per_group;
|
||||
let mut ic_off = group * in_channels;
|
||||
if !GROUPS {
|
||||
ic_off = 0;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let bias = if let Some(bias) = &bias {
|
||||
vload_unaligned::<S, _>(&bias[oc])
|
||||
} else {
|
||||
Zeroable::zeroed()
|
||||
};
|
||||
|
||||
for oh in oh_start..oh_end {
|
||||
let mut out = out.slice_mut(s![oh, .., ..]);
|
||||
for ow_block in 0..ow_blocks {
|
||||
let ow = ow_block * ow_b + ow_start;
|
||||
|
||||
#[allow(clippy::if_same_then_else)]
|
||||
if STRIDE {
|
||||
conv2d_inner_nopad(
|
||||
&x, &weights, &mut out, bias, oh, ow, oc, ic_off, stride_h, stride_w,
|
||||
dilate_h, dilate_w, k_height, k_width, pad_h, pad_w,
|
||||
);
|
||||
} else {
|
||||
conv2d_inner_nopad_nostride(
|
||||
&x, &weights, &mut out, bias, oh, ow, oc, ic_off, k_height, k_width, pad_h,
|
||||
pad_w,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
conv2d_remainder(
|
||||
x,
|
||||
weights,
|
||||
out,
|
||||
bias,
|
||||
oc,
|
||||
ic_off,
|
||||
ow_blocks * ow_b,
|
||||
stride_h,
|
||||
stride_w,
|
||||
dilate_h,
|
||||
dilate_w,
|
||||
pad_h,
|
||||
pad_w,
|
||||
k_height,
|
||||
k_width,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute the non-unrolled and/or padded portion of the convolution. This has more checks and is
|
||||
/// much slower, so we want to minimize the amount of pixels that need to be processed by this
|
||||
///
|
||||
/// SAFETY: `oc` must be an index that's at most `out_channels - simd_lanes`, so the full vector
|
||||
/// is in bounds. Weights and `out` must be channels last (with `stride == 1`).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[inline(always)]
|
||||
unsafe fn conv2d_remainder<S: Simd, E: VMulAdd>(
|
||||
x: ArrayView3<E>,
|
||||
weights: ArrayView4<E>,
|
||||
out: &mut ArrayViewMut3<E>,
|
||||
bias: Vector<S, E>,
|
||||
oc: usize,
|
||||
ic_off: usize,
|
||||
owb_end: usize,
|
||||
stride_h: usize,
|
||||
stride_w: usize,
|
||||
dilate_h: usize,
|
||||
dilate_w: usize,
|
||||
pad_h: usize,
|
||||
pad_w: usize,
|
||||
k_height: usize,
|
||||
k_width: usize,
|
||||
) {
|
||||
let in_channels = weights.shape()[0];
|
||||
let (_, in_height, in_width) = x.dim();
|
||||
let (out_height, out_width, _) = out.dim();
|
||||
let oh_start = pad_h;
|
||||
let oh_end = out_height.saturating_sub(pad_h);
|
||||
let ow_start = pad_w;
|
||||
|
||||
let height1 = in_height + pad_h;
|
||||
let width1 = in_width + pad_w;
|
||||
|
||||
for oh in (0..oh_start).chain(oh_end..out_height) {
|
||||
for ow in 0..out_width {
|
||||
let mut acc = bias;
|
||||
|
||||
for ic in 0..in_channels {
|
||||
for kh in 0..k_height {
|
||||
let ih = oh * stride_h + kh * dilate_h;
|
||||
if (ih < pad_h) | (ih >= height1) {
|
||||
continue;
|
||||
}
|
||||
let ih = ih - pad_h;
|
||||
|
||||
for kw in 0..k_width {
|
||||
let iw = ow * stride_w + kw * dilate_w;
|
||||
if (iw < pad_w) | (iw >= width1) {
|
||||
continue;
|
||||
}
|
||||
let iw = iw - pad_w;
|
||||
|
||||
// Load a full vector from the weights. This is guaranteed to be in bounds
|
||||
// as long as `oc <= out_channels - simd_lanes` and out channels are last.
|
||||
// We need to ensure the weights are reshaped appropriately.
|
||||
let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) };
|
||||
|
||||
// The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the
|
||||
// compiler can't prove this. We can't use `as_slice` with fixed bounds
|
||||
// because we want to support arbitrary input layouts. So an unchecked load
|
||||
// is used.
|
||||
let i0 = unsafe { x.uget([ic, ih, iw]) }.splat::<S>();
|
||||
acc = i0.mul_add(f0, acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store a full vector from the output. This is guaranteed to be in bounds
|
||||
// as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with
|
||||
// channels last, so this always holds.
|
||||
unsafe { vstore_unaligned(&mut out[[oh, ow, oc]], acc) };
|
||||
}
|
||||
}
|
||||
for ow in (0..ow_start).chain(owb_end..out_width) {
|
||||
for oh in 0..out_height {
|
||||
let mut acc = bias;
|
||||
|
||||
for ic in 0..in_channels {
|
||||
for kh in 0..k_height {
|
||||
let ih = oh * stride_h + kh * dilate_h;
|
||||
if (ih < pad_h) | (ih >= height1) {
|
||||
continue;
|
||||
}
|
||||
let ih = ih - pad_h;
|
||||
|
||||
for kw in 0..k_width {
|
||||
let iw = ow * stride_w + kw * dilate_w;
|
||||
if (iw < pad_w) | (iw >= width1) {
|
||||
continue;
|
||||
}
|
||||
let iw = iw - pad_w;
|
||||
|
||||
// Load a full vector from the weights. This is guaranteed to be in bounds
|
||||
// as long as `oc <= out_channels - simd_lanes` and out channels are last.
|
||||
// We need to ensure the weights are reshaped appropriately.
|
||||
let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) };
|
||||
|
||||
// The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the
|
||||
// compiler can't prove this. We can't use `as_slice` with fixed bounds
|
||||
// because we want to support arbitrary input layouts. So an unchecked load
|
||||
// is used.
|
||||
let i0 = unsafe { x.uget([ic_off + ic, ih, iw]) }.splat::<S>();
|
||||
acc = i0.mul_add(f0, acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store a full vector from the output. This is guaranteed to be in bounds
|
||||
// as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with
|
||||
// channels last, so this always holds.
|
||||
unsafe { vstore_unaligned(&mut out[[oh, ow, oc]], acc) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! inner_with_register_blocking_size {
|
||||
($rb: literal) => {
|
||||
/// Execute the unrolled and unpadded portion of the convolution. Any pixel that is more than
|
||||
/// `pad_h` away from the horizontal border, and `pad_w` away from the vertical border is
|
||||
/// guaranteed to always be in bounds (because of the way out size is calculated).
|
||||
///
|
||||
/// SAFETY: `oc` must be an index that's at most `out_channels - simd_lanes`, so the full vector
|
||||
/// is in bounds. Weights and `out` must be channels last (with `stride == 1`).
|
||||
#[allow(clippy::erasing_op, clippy::identity_op, clippy::too_many_arguments)]
|
||||
#[inline(always)]
|
||||
unsafe fn conv2d_inner_nopad<S: Simd, E: VMulAdd>(
|
||||
x: &ArrayView3<E>,
|
||||
weights: &ArrayView4<E>,
|
||||
out: &mut ArrayViewMut2<E>,
|
||||
bias: Vector<S, E>,
|
||||
oh: usize,
|
||||
ow: usize,
|
||||
oc: usize,
|
||||
ic_off: usize,
|
||||
stride_h: usize,
|
||||
stride_w: usize,
|
||||
dilate_h: usize,
|
||||
dilate_w: usize,
|
||||
k_height: usize,
|
||||
k_width: usize,
|
||||
pad_h: usize,
|
||||
pad_w: usize,
|
||||
) {
|
||||
let in_channels = weights.shape()[0];
|
||||
|
||||
seq!(N in 0..$rb {
|
||||
let mut acc~N = bias;
|
||||
});
|
||||
|
||||
for ic in 0..in_channels {
|
||||
for kh in 0..k_height {
|
||||
let ih = oh * stride_h + kh * dilate_h - pad_h;
|
||||
|
||||
for kw in 0..k_width {
|
||||
// Load a full vector from the weights. This is guaranteed to be in bounds
|
||||
// as long as `oc <= out_channels - simd_lanes` and out channels are last.
|
||||
// We need to ensure the weights are reshaped appropriately.
|
||||
let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) };
|
||||
let iw = ow * stride_w + kw * dilate_w - pad_w;
|
||||
|
||||
seq!(N in 0..$rb {
|
||||
// The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the
|
||||
// compiler can't prove this. We can't use `as_slice` with fixed bounds
|
||||
// because we want to support arbitrary input layouts. So an unchecked load
|
||||
// is used.
|
||||
let i~N = unsafe { x.uget([ic + ic_off, ih, iw + N * stride_w]) }.splat::<S>();
|
||||
});
|
||||
seq!(N in 0..$rb {
|
||||
acc~N = i~N.mul_add(f0, acc~N);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
seq!(N in 0..$rb {
|
||||
// Store a full vector from the output. This is guaranteed to be in bounds
|
||||
// as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with
|
||||
// channels last, so this always holds.
|
||||
unsafe { vstore_unaligned(&mut out[[ow + N, oc]], acc~N) };
|
||||
});
|
||||
}
|
||||
|
||||
/// Execute the unrolled and unpadded portion of the convolution. Any pixel that is more than
|
||||
/// `pad_h` away from the horizontal border, and `pad_w` away from the vertical border is
|
||||
/// guaranteed to always be in bounds (because of the way out size is calculated).
|
||||
///
|
||||
/// SAFETY: `oc` must be an index that's at most `out_channels - simd_lanes`, so the full vector
|
||||
/// is in bounds. Weights and `out` must be channels last (with `stride == 1`).
|
||||
#[allow(clippy::erasing_op, clippy::identity_op, clippy::too_many_arguments)]
|
||||
#[inline(always)]
|
||||
unsafe fn conv2d_inner_nopad_nostride<S: Simd, E: VMulAdd>(
|
||||
x: &ArrayView3<E>,
|
||||
weights: &ArrayView4<E>,
|
||||
out: &mut ArrayViewMut2<E>,
|
||||
bias: Vector<S, E>,
|
||||
oh: usize,
|
||||
ow: usize,
|
||||
oc: usize,
|
||||
ic_off: usize,
|
||||
k_height: usize,
|
||||
k_width: usize,
|
||||
pad_h: usize,
|
||||
pad_w: usize,
|
||||
) {
|
||||
let in_channels = weights.shape()[0];
|
||||
|
||||
seq!(N in 0..$rb {
|
||||
let mut acc~N = bias;
|
||||
});
|
||||
|
||||
for ic in 0..in_channels {
|
||||
for kh in 0..k_height {
|
||||
let ih = oh + kh - pad_h;
|
||||
|
||||
for kw in 0..k_width {
|
||||
// Load a full vector from the weights. This is guaranteed to be in bounds
|
||||
// as long as `oc <= out_channels - simd_lanes` and out channels are last.
|
||||
// We need to ensure the weights are reshaped appropriately.
|
||||
let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) };
|
||||
let iw = ow + kw - pad_w;
|
||||
|
||||
seq!(N in 0..$rb {
|
||||
// The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the
|
||||
// compiler can't prove this. We can't use `as_slice` with fixed bounds
|
||||
// because we want to support arbitrary input layouts. So an unchecked load
|
||||
// is used.
|
||||
let i~N = unsafe { x.uget([ic + ic_off, ih, iw + N]) }.splat::<S>();
|
||||
});
|
||||
seq!(N in 0..$rb {
|
||||
acc~N = i~N.mul_add(f0, acc~N);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
seq!(N in 0..$rb {
|
||||
// Store a full vector from the output. This is guaranteed to be in bounds
|
||||
// as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with
|
||||
// channels last, so this always holds.
|
||||
unsafe { vstore_unaligned(&mut out[[ow + N, oc]], acc~N) };
|
||||
});
|
||||
}
|
||||
};
|
||||
}
|
||||
pub(crate) use inner_with_register_blocking_size;
|
||||
@@ -0,0 +1,394 @@
|
||||
use core::{marker::PhantomData, mem::transmute};
|
||||
|
||||
use crate::{SharedArray, iter_range_par, run_par, sharing::UnsafeSharedRef};
|
||||
|
||||
use burn_backend::{DType, Element, quantization::QuantValue};
|
||||
use macerator::{Simd, VOrd};
|
||||
use ndarray::{Array4, s};
|
||||
use nhwc::max_pool2d_nhwc;
|
||||
|
||||
use super::{MinMax, should_use_simd};
|
||||
|
||||
#[macerator::with_simd]
|
||||
fn is_accelerated_impl<S: Simd, T: VOrd>(_x: PhantomData<T>) -> bool {
|
||||
<T as VOrd>::is_min_max_accelerated::<S>()
|
||||
}
|
||||
|
||||
fn is_accelerated<T: VOrd>() -> bool {
|
||||
is_accelerated_impl::<T>(PhantomData)
|
||||
}
|
||||
|
||||
macro_rules! launch_kernel {
|
||||
($ty: ty, $func: ident, $x: expr, $($arg: expr),*) => {
|
||||
match <$ty as Element>::dtype() {
|
||||
DType::F64 if is_accelerated::<f64>() => Ok(cast($func::<f64>(cast($x), $($arg),*))),
|
||||
DType::F32 if is_accelerated::<f32>() => Ok(cast($func::<f32>(cast($x), $($arg),*))),
|
||||
DType::I64 if is_accelerated::<i64>() => Ok(cast($func::<i64>(cast($x), $($arg),*))),
|
||||
DType::I32 if is_accelerated::<i32>() => Ok(cast($func::<i32>(cast($x), $($arg),*))),
|
||||
DType::I16 if is_accelerated::<i16>() => Ok(cast($func::<i16>(cast($x), $($arg),*))),
|
||||
DType::I8 if is_accelerated::<i8>() => Ok(cast($func::<i8>(cast($x), $($arg),*))),
|
||||
DType::U64 if is_accelerated::<u64>() => Ok(cast($func::<u64>(cast($x), $($arg),*))),
|
||||
DType::U32 if is_accelerated::<u32>() => Ok(cast($func::<u32>(cast($x), $($arg),*))),
|
||||
DType::U16 if is_accelerated::<u16>() => Ok(cast($func::<u16>(cast($x), $($arg),*))),
|
||||
DType::U8 if is_accelerated::<u8>() => Ok(cast($func::<u8>(cast($x), $($arg),*))),
|
||||
DType::Bool if is_accelerated::<u8>() => Ok(cast($func::<u8>(cast($x), $($arg),*))),
|
||||
DType::QFloat(scheme) => match scheme.value {
|
||||
QuantValue::Q8F | QuantValue::Q8S if is_accelerated::<i8>() => Ok(cast($func::<i8>(cast($x), $($arg),*))),
|
||||
_ => Err($x)
|
||||
},
|
||||
_ => Err($x),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) fn try_max_pool2d_simd<E: Element>(
|
||||
x: SharedArray<E>,
|
||||
ksize: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> Result<SharedArray<E>, SharedArray<E>> {
|
||||
let [_, c, _, _] = x.shape().try_into().unwrap();
|
||||
if !should_use_simd(c) || x.strides()[1] != 1 {
|
||||
return Err(x);
|
||||
}
|
||||
|
||||
launch_kernel!(E, max_pool2d_nhwc, x, ksize, stride, padding, dilation)
|
||||
}
|
||||
|
||||
fn cast<T, E>(tensor: SharedArray<T>) -> SharedArray<E> {
|
||||
unsafe { transmute::<SharedArray<T>, SharedArray<E>>(tensor) }
|
||||
}
|
||||
|
||||
mod nhwc {
|
||||
use itertools::Itertools;
|
||||
use macerator::{Simd, vload_unaligned, vstore_unaligned};
|
||||
use ndarray::{ArrayView3, ArrayViewMut3, Ix4};
|
||||
use seq_macro::seq;
|
||||
|
||||
use crate::ops::simd::lanes;
|
||||
|
||||
use super::*;
|
||||
|
||||
// Until you can use associated constants as array size, we need to hardcode this.
|
||||
// The most common config (x86-v3) has 16 registers, so use half of them for accumulators.
|
||||
const BLOCK_REGISTERS: usize = 8;
|
||||
|
||||
pub(crate) fn max_pool2d_nhwc<E: Element + VOrd + MinMax>(
|
||||
x: SharedArray<E>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> SharedArray<E> {
|
||||
let [kernel_height, kernel_width] = kernel_size;
|
||||
let [pad_h, pad_w] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
let [dilation_height, dilation_width] = dilation;
|
||||
let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap();
|
||||
let lanes = lanes::<E>();
|
||||
|
||||
let ch_block = lanes * BLOCK_REGISTERS;
|
||||
|
||||
let out_height = ((x_height + 2 * pad_h - dilation_height * (kernel_height - 1) - 1)
|
||||
/ stride_height)
|
||||
+ 1;
|
||||
let out_width =
|
||||
((x_width + 2 * pad_w - dilation_width * (kernel_width - 1) - 1) / stride_width) + 1;
|
||||
|
||||
let mut output = unsafe {
|
||||
Array4::<E>::uninit((batch_size, out_height, out_width, channels)).assume_init()
|
||||
};
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
let x = x.into_dimensionality::<Ix4>().unwrap();
|
||||
let x = x.view();
|
||||
let x = x.permuted_axes([0, 2, 3, 1]);
|
||||
|
||||
// Floor division ensures `blocks * lanes * blocking factor` is always `<= out_channels`.
|
||||
// An exclusive loop will always have `lanes * blocking factor` elements in bounds.
|
||||
let blocks = channels / ch_block;
|
||||
let blocks_end = blocks * ch_block;
|
||||
// Floor division means simd_end is always divisible by `lanes` and `<= out_channels`. An
|
||||
// exclusive loop will always have `lanes` elements in bounds.
|
||||
let simd_end = channels / lanes * lanes;
|
||||
let simd_unblocked = (simd_end - blocks_end) / lanes;
|
||||
let remainder = channels - simd_end;
|
||||
|
||||
run_par!(|| {
|
||||
// SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe.
|
||||
iter_range_par!(0, batch_size * blocks).for_each(|k| unsafe {
|
||||
let block = k % blocks;
|
||||
let b = k / blocks;
|
||||
|
||||
let output = unsafe_shared_out.get();
|
||||
let x = x.slice(s![b, .., .., ..]);
|
||||
let out = output.slice_mut(s![b, .., .., ..]);
|
||||
loop_blocked(x, out, kernel_size, stride, padding, dilation, block);
|
||||
});
|
||||
// SAFETY: See `loop_unblocked`
|
||||
iter_range_par!(0, batch_size * simd_unblocked).for_each(|k| unsafe {
|
||||
let ch = (k % simd_unblocked) * lanes + blocks_end;
|
||||
let b = k / simd_unblocked;
|
||||
|
||||
let output = unsafe_shared_out.get();
|
||||
let x = x.slice(s![b, .., .., ..]);
|
||||
let out = output.slice_mut(s![b, .., .., ..]);
|
||||
loop_unblocked(x, out, kernel_size, stride, padding, dilation, ch);
|
||||
});
|
||||
// SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe.
|
||||
iter_range_par!(0, batch_size * remainder).for_each(|k| unsafe {
|
||||
let ch = (k % remainder) + simd_end;
|
||||
let b = k / remainder;
|
||||
|
||||
let output = unsafe_shared_out.get();
|
||||
let x = x.slice(s![b, .., .., ..]);
|
||||
let out = output.slice_mut(s![b, .., .., ..]);
|
||||
loop_scalar(x, out, kernel_size, stride, padding, dilation, ch);
|
||||
});
|
||||
});
|
||||
|
||||
output = output.permuted_axes([0, 3, 1, 2]);
|
||||
|
||||
output.into_dyn().into_shared()
|
||||
}
|
||||
|
||||
/// Execute the blocked (unrolled) portion of the pool.
|
||||
#[allow(
|
||||
clippy::too_many_arguments,
|
||||
clippy::erasing_op,
|
||||
clippy::identity_op,
|
||||
unused_mut
|
||||
)]
|
||||
#[inline(always)]
|
||||
#[macerator::with_simd]
|
||||
fn loop_blocked<'a, S: Simd, E: Element + VOrd + MinMax>(
|
||||
x: ArrayView3<'a, E>,
|
||||
mut out: ArrayViewMut3<'a, E>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
block: usize,
|
||||
) where
|
||||
'a: 'a,
|
||||
{
|
||||
let [kernel_height, kernel_width] = kernel_size;
|
||||
let [pad_h, pad_w] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
let [dilation_height, dilation_width] = dilation;
|
||||
|
||||
let (x_height, x_width, _) = x.dim();
|
||||
let (out_height, out_width, _) = out.dim();
|
||||
let lanes = E::lanes::<S>();
|
||||
let ch_block = lanes * BLOCK_REGISTERS;
|
||||
|
||||
let min = E::MIN.splat::<S>();
|
||||
// If outside padding area, kernels are guaranteed to be in bounds
|
||||
for oh in pad_h..out_height.saturating_sub(pad_h) {
|
||||
for ow in pad_w..out_width.saturating_sub(pad_w) {
|
||||
seq!(N in 0..8 {
|
||||
let mut acc~N = min;
|
||||
});
|
||||
let ch = block * ch_block;
|
||||
let ch_end = ch + ch_block;
|
||||
let mut out = out.slice_mut(s![oh, ow, ch..ch_end]);
|
||||
|
||||
for kh in 0..kernel_height {
|
||||
let ih = oh * stride_height + kh * dilation_height - pad_h;
|
||||
|
||||
for kw in 0..kernel_width {
|
||||
let iw = ow * stride_width + kw * dilation_width - pad_w;
|
||||
let x = x.slice(s![ih, iw, ch..ch_end]);
|
||||
|
||||
seq!(N in 0..8 {
|
||||
// SAFETY:
|
||||
// Load a full vector from x[N * lanes]. This is bounds checked by the
|
||||
// slice above.
|
||||
acc~N = acc~N.max(unsafe { vload_unaligned(&x[N * lanes]) });
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
seq!(N in 0..8 {
|
||||
// SAFETY:
|
||||
// Store a full vector to out[N * lanes]. This is bounds checked by the
|
||||
// slice above.
|
||||
unsafe { vstore_unaligned(&mut out[N * lanes], acc~N) };
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Border pixels need bounds checks
|
||||
if (pad_h, pad_w) != (0, 0) {
|
||||
let v_borders = (0..pad_h)
|
||||
.chain(out_height.saturating_sub(pad_h)..out_height)
|
||||
.cartesian_product(0..out_width);
|
||||
let h_borders = (0..out_height)
|
||||
.cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width));
|
||||
|
||||
for (oh, ow) in v_borders.chain(h_borders) {
|
||||
seq!(N in 0..8 {
|
||||
let mut acc~N = min;
|
||||
});
|
||||
let ch = block * ch_block;
|
||||
let ch_end = ch + ch_block;
|
||||
let mut out = out.slice_mut(s![oh, ow, ch..ch_end]);
|
||||
|
||||
for kh in 0..kernel_height {
|
||||
let ih = oh * stride_height + kh * dilation_height;
|
||||
if ih < pad_h || ih >= x_height + pad_h {
|
||||
continue;
|
||||
}
|
||||
let ih = ih - pad_h;
|
||||
|
||||
for kw in 0..kernel_width {
|
||||
let iw = ow * stride_width + kw * dilation_width;
|
||||
if iw < pad_w || iw >= x_width + pad_w {
|
||||
continue;
|
||||
}
|
||||
let iw = iw - pad_w;
|
||||
|
||||
let x = x.slice(s![ih, iw, ch..ch_end]);
|
||||
|
||||
seq!(N in 0..8 {
|
||||
// SAFETY:
|
||||
// Load a full vector from x[N * lanes]. This is bounds checked by the
|
||||
// slice above.
|
||||
acc~N = acc~N.max(unsafe { vload_unaligned(&x[N * lanes]) });
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
seq!(N in 0..8 {
|
||||
// SAFETY:
|
||||
// Store a full vector to out[N * lanes]. This is bounds checked by the
|
||||
// slice above.
|
||||
unsafe { vstore_unaligned(&mut out[N * lanes], acc~N) };
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute the unblocked (not unrolled) portion of the pool.
|
||||
///
|
||||
/// SAFETY: Safe as long as `ch + simd_lanes <= out_channels`.
|
||||
#[allow(clippy::too_many_arguments, unused_mut)]
|
||||
#[inline(always)]
|
||||
#[macerator::with_simd]
|
||||
unsafe fn loop_unblocked<'a, S: Simd, E: Element + VOrd + MinMax>(
|
||||
x: ArrayView3<'a, E>,
|
||||
mut out: ArrayViewMut3<'a, E>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ch: usize,
|
||||
) where
|
||||
'a: 'a,
|
||||
{
|
||||
let [kernel_height, kernel_width] = kernel_size;
|
||||
let [pad_h, pad_w] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
let [dilation_height, dilation_width] = dilation;
|
||||
|
||||
let (x_height, x_width, _) = x.dim();
|
||||
let (out_height, out_width, _) = out.dim();
|
||||
|
||||
for oh in pad_h..out_height.saturating_sub(pad_h) {
|
||||
for ow in pad_w..out_width.saturating_sub(pad_w) {
|
||||
let mut acc = E::MIN.splat::<S>();
|
||||
let out = &mut out[[oh, ow, ch]];
|
||||
|
||||
for kh in 0..kernel_height {
|
||||
let ih = oh * stride_height + kh * dilation_height - pad_h;
|
||||
|
||||
for kw in 0..kernel_width {
|
||||
let iw = ow * stride_width + kw * dilation_width - pad_w;
|
||||
// Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes`
|
||||
acc = acc.max(unsafe { vload_unaligned(&x[[ih, iw, ch]]) });
|
||||
}
|
||||
}
|
||||
// Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`.
|
||||
unsafe { vstore_unaligned(out, acc) };
|
||||
}
|
||||
}
|
||||
|
||||
// Border pixels need bounds checks
|
||||
if (pad_h, pad_w) != (0, 0) {
|
||||
let v_borders = (0..pad_h)
|
||||
.chain(out_height.saturating_sub(pad_h)..out_height)
|
||||
.cartesian_product(0..out_width);
|
||||
let h_borders = (0..out_height)
|
||||
.cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width));
|
||||
|
||||
for (oh, ow) in v_borders.chain(h_borders) {
|
||||
let mut acc = E::MIN.splat::<S>();
|
||||
let out = &mut out[[oh, ow, ch]];
|
||||
|
||||
for kh in 0..kernel_height {
|
||||
let ih = oh * stride_height + kh * dilation_height;
|
||||
if ih < pad_h || ih >= x_height + pad_h {
|
||||
continue;
|
||||
}
|
||||
let ih = ih - pad_h;
|
||||
|
||||
for kw in 0..kernel_width {
|
||||
let iw = ow * stride_width + kw * dilation_width;
|
||||
if iw < pad_w || iw >= x_width + pad_w {
|
||||
continue;
|
||||
}
|
||||
let iw = iw - pad_w;
|
||||
// Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes`
|
||||
acc = acc.max(unsafe { vload_unaligned(&x[[ih, iw, ch]]) });
|
||||
}
|
||||
}
|
||||
// Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`.
|
||||
unsafe { vstore_unaligned(out, acc) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn loop_scalar<E: Element + MinMax>(
|
||||
x: ArrayView3<'_, E>,
|
||||
mut out: ArrayViewMut3<'_, E>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ch: usize,
|
||||
) {
|
||||
let [kernel_height, kernel_width] = kernel_size;
|
||||
let [pad_h, pad_w] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
let [dilation_height, dilation_width] = dilation;
|
||||
|
||||
let (x_height, x_width, _) = x.dim();
|
||||
let (out_height, out_width, _) = out.dim();
|
||||
|
||||
for oh in 0..out_height {
|
||||
for ow in 0..out_width {
|
||||
let mut acc = E::MIN;
|
||||
|
||||
for kh in 0..kernel_height {
|
||||
let ih = oh * stride_height + kh * dilation_height;
|
||||
if ih < pad_h || ih >= x_height + pad_h {
|
||||
continue;
|
||||
}
|
||||
let ih = ih - pad_h;
|
||||
|
||||
for kw in 0..kernel_width {
|
||||
let iw = ow * stride_width + kw * dilation_width;
|
||||
if iw < pad_w || iw >= x_width + pad_w {
|
||||
continue;
|
||||
}
|
||||
let iw = iw - pad_w;
|
||||
acc = acc.max(x[[ih, iw, ch]]);
|
||||
}
|
||||
}
|
||||
|
||||
out[[oh, ow, ch]] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
pub(crate) mod avgpool;
|
||||
mod base;
|
||||
pub(crate) mod binary;
|
||||
pub(crate) mod binary_elemwise;
|
||||
pub(crate) mod cmp;
|
||||
pub(crate) mod conv;
|
||||
pub(crate) mod maxpool;
|
||||
pub(crate) mod unary;
|
||||
|
||||
pub use base::*;
|
||||
@@ -0,0 +1,234 @@
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use bytemuck::cast;
|
||||
use macerator::{
|
||||
Scalar, Simd, VAbs, VBitNot, VRecip, Vector, vload, vload_unaligned, vstore, vstore_unaligned,
|
||||
};
|
||||
use ndarray::ArrayD;
|
||||
use num_traits::Signed;
|
||||
use seq_macro::seq;
|
||||
|
||||
use crate::{NdArrayElement, SharedArray};
|
||||
|
||||
use super::should_use_simd;
|
||||
|
||||
pub trait SimdUnop<T: Scalar, Out: Scalar> {
|
||||
fn apply_vec<S: Simd>(input: Vector<S, T>) -> Vector<S, Out>;
|
||||
fn apply(input: T) -> Out;
|
||||
fn is_accelerated<S: Simd>() -> bool;
|
||||
}
|
||||
|
||||
pub struct RecipVec;
|
||||
|
||||
impl SimdUnop<f32, f32> for RecipVec {
|
||||
fn apply_vec<S: Simd>(input: Vector<S, f32>) -> Vector<S, f32> {
|
||||
input.recip()
|
||||
}
|
||||
|
||||
fn apply(input: f32) -> f32 {
|
||||
input.recip()
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<f32 as VRecip>::is_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VecAbs;
|
||||
|
||||
impl<T: VAbs + Signed> SimdUnop<T, T> for VecAbs {
|
||||
fn apply_vec<S: Simd>(input: Vector<S, T>) -> Vector<S, T> {
|
||||
input.abs()
|
||||
}
|
||||
|
||||
fn apply(input: T) -> T {
|
||||
input.abs()
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VAbs>::is_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VecBitNot;
|
||||
|
||||
impl<T: VBitNot> SimdUnop<T, T> for VecBitNot {
|
||||
fn apply_vec<S: Simd>(input: Vector<S, T>) -> Vector<S, T> {
|
||||
!input
|
||||
}
|
||||
|
||||
fn apply(input: T) -> T {
|
||||
input.not()
|
||||
}
|
||||
|
||||
fn is_accelerated<S: Simd>() -> bool {
|
||||
<T as VBitNot>::is_accelerated::<S>()
|
||||
}
|
||||
}
|
||||
|
||||
#[macerator::with_simd]
|
||||
fn is_accelerated<S: Simd, T: Scalar, Out: Scalar, Op: SimdUnop<T, Out>>(
|
||||
_x: PhantomData<(T, Out, Op)>,
|
||||
) -> bool {
|
||||
Op::is_accelerated::<S>()
|
||||
}
|
||||
|
||||
pub fn try_unary_simd<
|
||||
E: NdArrayElement,
|
||||
EOut: NdArrayElement,
|
||||
T: NdArrayElement + Scalar,
|
||||
Out: NdArrayElement + Scalar,
|
||||
Op: SimdUnop<T, Out>,
|
||||
>(
|
||||
input: SharedArray<E>,
|
||||
) -> Result<SharedArray<EOut>, SharedArray<E>> {
|
||||
if !should_use_simd(input.len())
|
||||
|| input.as_slice_memory_order().is_none()
|
||||
|| !is_accelerated::<T, Out, Op>(PhantomData)
|
||||
{
|
||||
return Err(input);
|
||||
}
|
||||
// Used to assert traits based on the dynamic `DType`.
|
||||
let input = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(input) };
|
||||
let out = if size_of::<T>() == size_of::<Out>()
|
||||
&& align_of::<T>() >= align_of::<Out>()
|
||||
&& input.is_unique()
|
||||
{
|
||||
unsafe { unary_scalar_simd_inplace::<T, Out, Op>(input) }
|
||||
} else {
|
||||
unary_scalar_simd_owned::<T, Out, Op>(input)
|
||||
};
|
||||
// Used to assert traits based on the dynamic `DType`.
|
||||
let out = unsafe { core::mem::transmute::<SharedArray<Out>, SharedArray<EOut>>(out) };
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Execute operation in line.
|
||||
/// SAFETY:
|
||||
/// Must ensure `size_of::<T> == size_of::<Out>` and `align_of::<T> >= align_of::<Out>`.
|
||||
unsafe fn unary_scalar_simd_inplace<
|
||||
T: NdArrayElement + Scalar,
|
||||
Out: NdArrayElement + Scalar,
|
||||
Op: SimdUnop<T, Out>,
|
||||
>(
|
||||
input: SharedArray<T>,
|
||||
) -> SharedArray<Out> {
|
||||
let mut buffer = input.into_owned();
|
||||
let slice = buffer.as_slice_memory_order_mut().unwrap();
|
||||
// This is only called when in and out have the same size, so it's safe
|
||||
unsafe { unary_slice_inplace::<T, Out, Op>(slice, PhantomData) };
|
||||
// Buffer has the same elem size and is filled with the operation output, so this is safe
|
||||
let out = unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<Out>>(buffer) };
|
||||
out.into_shared()
|
||||
}
|
||||
|
||||
fn unary_scalar_simd_owned<
|
||||
T: NdArrayElement + Scalar,
|
||||
Out: NdArrayElement + Scalar,
|
||||
Op: SimdUnop<T, Out>,
|
||||
>(
|
||||
input: SharedArray<T>,
|
||||
) -> SharedArray<Out> {
|
||||
let mut out = unsafe { ArrayD::uninit(input.shape()).assume_init() };
|
||||
let input = input.as_slice_memory_order().unwrap();
|
||||
let out_slice = out.as_slice_memory_order_mut().unwrap();
|
||||
unary_slice::<T, Out, Op>(input, out_slice, PhantomData);
|
||||
out.into_shared()
|
||||
}
|
||||
|
||||
#[allow(clippy::erasing_op, clippy::identity_op)]
|
||||
#[macerator::with_simd]
|
||||
fn unary_slice<
|
||||
'a,
|
||||
S: Simd,
|
||||
T: NdArrayElement + Scalar,
|
||||
Out: NdArrayElement + Scalar,
|
||||
Op: SimdUnop<T, Out>,
|
||||
>(
|
||||
input: &'a [T],
|
||||
out: &'a mut [Out],
|
||||
_op: PhantomData<Op>,
|
||||
) where
|
||||
'a: 'a,
|
||||
{
|
||||
let lanes = T::lanes::<S>();
|
||||
let mut chunks_input = input.chunks_exact(8 * lanes);
|
||||
let mut chunks_out = out.chunks_exact_mut(8 * lanes);
|
||||
while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {
|
||||
seq!(N in 0..8 {
|
||||
// Load one full vector from `input`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
|
||||
let s~N = unsafe { vload_unaligned(&input[N * lanes]) };
|
||||
let s~N = Op::apply_vec::<S>(s~N);
|
||||
// Store one full vector to `out`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`
|
||||
unsafe { vstore_unaligned(&mut out[N * lanes], s~N) };
|
||||
});
|
||||
}
|
||||
let mut chunks_input = chunks_input.remainder().chunks_exact(lanes);
|
||||
let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes);
|
||||
while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {
|
||||
// Load one full vector from `input`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == lanes`
|
||||
let s0 = unsafe { vload_unaligned(input.as_ptr()) };
|
||||
let s0 = Op::apply_vec::<S>(s0);
|
||||
// Store one full vector to `out`.
|
||||
// SAFETY: Guaranteed to be in bounds because `len == lanes`
|
||||
unsafe { vstore_unaligned(out.as_mut_ptr(), s0) };
|
||||
}
|
||||
|
||||
for (input, out) in chunks_input
|
||||
.remainder()
|
||||
.iter()
|
||||
.zip(chunks_out.into_remainder())
|
||||
{
|
||||
*out = Op::apply(*input)
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute operation in line.
|
||||
/// SAFETY:
|
||||
/// Must ensure `size_of::<T> == size_of::<Out>` and `align_of::<T> >= align_of::<Out>`.
|
||||
#[macerator::with_simd]
|
||||
unsafe fn unary_slice_inplace<
|
||||
'a,
|
||||
S: Simd,
|
||||
T: NdArrayElement + Scalar,
|
||||
Out: NdArrayElement + Scalar,
|
||||
Op: SimdUnop<T, Out>,
|
||||
>(
|
||||
buf: &'a mut [T],
|
||||
_op: PhantomData<(Out, Op)>,
|
||||
) where
|
||||
'a: 'a,
|
||||
{
|
||||
let (head, main, tail) = unsafe { buf.align_to_mut::<Vector<S, T>>() };
|
||||
for elem in head.iter_mut().chain(tail) {
|
||||
*elem = cast(Op::apply(*elem));
|
||||
}
|
||||
let mut chunks = main.chunks_exact_mut(8);
|
||||
for elem in chunks.by_ref() {
|
||||
seq!(N in 0..8 {
|
||||
// Load a full vector from the aligned portion of the buffer.
|
||||
// SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is
|
||||
// always a full vector in bounds.
|
||||
let s~N = unsafe { vload(&elem[N] as *const _ as *const T) };
|
||||
let s~N = Op::apply_vec::<S>(s~N);
|
||||
// Store a full vector at the same position as the input. Cast is safe because `Out` is
|
||||
// size and align compatible
|
||||
unsafe { vstore(&mut elem[N] as *mut _ as *mut Out, s~N) };
|
||||
});
|
||||
}
|
||||
|
||||
for elem in chunks.into_remainder() {
|
||||
// Load a full vector from the aligned portion of the buffer.
|
||||
// SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is
|
||||
// always a full vector in bounds.
|
||||
let s0 = unsafe { vload(elem as *const _ as *const T) };
|
||||
|
||||
let s0 = Op::apply_vec::<S>(s0);
|
||||
// Store a full vector at the same position as the input. Cast is safe because `Out` is
|
||||
// size and align compatible
|
||||
unsafe { vstore(elem as *mut _ as *mut Out, s0) };
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,688 @@
|
||||
// Language
|
||||
use alloc::vec::Vec;
|
||||
use burn_backend::backend::ExecutionError;
|
||||
use burn_backend::ops::GridSampleOptions;
|
||||
use burn_backend::tensor::FloatTensor;
|
||||
use burn_backend::{TensorMetadata, element::cast::ToElement};
|
||||
|
||||
// Current crate
|
||||
use super::{
|
||||
NdArrayMathOps, NdArrayOps,
|
||||
matmul::{cross, matmul},
|
||||
};
|
||||
use crate::{
|
||||
NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
|
||||
};
|
||||
use crate::{NdArrayDevice, SEED, slice};
|
||||
use crate::{
|
||||
SharedArray,
|
||||
element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
|
||||
};
|
||||
use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
|
||||
|
||||
// Workspace crates
|
||||
use crate::rand::get_seeded_rng;
|
||||
use burn_backend::{Distribution, FloatDType, Scalar};
|
||||
use burn_backend::{ElementConversion, Shape, TensorData, backend::Backend, ops::FloatTensorOps};
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[allow(unused_imports)]
|
||||
use num_traits::Float;
|
||||
|
||||
use libm::erf;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[allow(dead_code)]
|
||||
fn round_ties_even_wrapper(x: f64) -> f64 {
|
||||
x.round_ties_even()
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[allow(dead_code)]
|
||||
fn round_ties_even_wrapper(x: f64) -> f64 {
|
||||
if (x - x.floor()) == 0.5 {
|
||||
(x * 0.5).round() * 2.0
|
||||
} else {
|
||||
x.round()
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
|
||||
for NdArray<E, I, Q>
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<E>>,
|
||||
NdArrayTensor: From<SharedArray<I>>,
|
||||
{
|
||||
fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
|
||||
NdArrayTensor::from_data(data)
|
||||
}
|
||||
|
||||
fn float_random(
|
||||
shape: Shape,
|
||||
distribution: Distribution,
|
||||
device: &NdArrayDevice,
|
||||
) -> FloatTensor<Self> {
|
||||
let mut seed = SEED.lock().unwrap();
|
||||
let mut rng = seed.take().unwrap_or_else(get_seeded_rng);
|
||||
let tensor = Self::float_from_data(
|
||||
TensorData::random::<E, _, _>(shape, distribution, &mut rng),
|
||||
device,
|
||||
);
|
||||
*seed = Some(rng);
|
||||
tensor
|
||||
}
|
||||
|
||||
async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
|
||||
Ok(tensor.into_data())
|
||||
}
|
||||
|
||||
fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
|
||||
NdArrayDevice::Cpu
|
||||
}
|
||||
|
||||
fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn float_empty(
|
||||
shape: Shape,
|
||||
device: &<NdArray<E> as Backend>::Device,
|
||||
dtype: FloatDType,
|
||||
) -> FloatTensor<Self> {
|
||||
Self::float_zeros(shape, device, dtype)
|
||||
}
|
||||
|
||||
fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
|
||||
}
|
||||
|
||||
fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::add_scalar(array, rhs.elem())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
|
||||
}
|
||||
|
||||
fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::sub_scalar(array, rhs.elem())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
|
||||
}
|
||||
|
||||
fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::mul_scalar(array, rhs.elem())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
|
||||
}
|
||||
|
||||
fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::div_scalar(array, rhs.elem())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
|
||||
}
|
||||
|
||||
fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::remainder_scalar(array, rhs.elem())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!((lhs, rhs), matmul)
|
||||
}
|
||||
|
||||
fn float_cross(
|
||||
lhs: FloatTensor<Self>,
|
||||
rhs: FloatTensor<Self>,
|
||||
dim: usize,
|
||||
) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
|
||||
}
|
||||
|
||||
fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::recip(array)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayOps::swap_dims(array, dim1, dim2)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayOps::reshape(array, shape)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_gather(
|
||||
dim: usize,
|
||||
tensor: FloatTensor<Self>,
|
||||
indices: NdArrayTensor,
|
||||
) -> FloatTensor<Self> {
|
||||
execute_with_int_dtype!(
|
||||
indices,
|
||||
IntElem,
|
||||
|idx_array: SharedArray<IntElem>| -> NdArrayTensor {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayOps::gather(dim, array, idx_array)
|
||||
})
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn float_scatter_add(
|
||||
dim: usize,
|
||||
tensor: FloatTensor<Self>,
|
||||
indices: NdArrayTensor,
|
||||
value: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
execute_with_int_dtype!(
|
||||
indices,
|
||||
IntElem,
|
||||
|idx_array: SharedArray<IntElem>| -> NdArrayTensor {
|
||||
execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayOps::scatter(
|
||||
dim, tensor, idx_array, value
|
||||
))
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn float_select(
|
||||
tensor: FloatTensor<Self>,
|
||||
dim: usize,
|
||||
indices: NdArrayTensor,
|
||||
) -> FloatTensor<Self> {
|
||||
execute_with_int_dtype!(
|
||||
indices,
|
||||
IntElem,
|
||||
|idx_array: SharedArray<IntElem>| -> NdArrayTensor {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::select(array, dim, idx_array)
|
||||
})
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn float_select_add(
|
||||
tensor: FloatTensor<Self>,
|
||||
dim: usize,
|
||||
indices: NdArrayTensor,
|
||||
value: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
execute_with_int_dtype!(
|
||||
indices,
|
||||
IntElem,
|
||||
|idx_array: SharedArray<IntElem>| -> NdArrayTensor {
|
||||
execute_with_float_dtype!((tensor, value), |tensor, value| {
|
||||
NdArrayMathOps::select_assign(tensor, dim, idx_array, value)
|
||||
})
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_backend::Slice]) -> FloatTensor<Self> {
|
||||
slice!(tensor, slices)
|
||||
}
|
||||
|
||||
fn float_slice_assign(
|
||||
tensor: FloatTensor<Self>,
|
||||
slices: &[burn_backend::Slice],
|
||||
value: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!((tensor, value), |tensor, value| {
|
||||
NdArrayOps::slice_assign(tensor, slices, value)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_mask_where(
|
||||
tensor: FloatTensor<Self>,
|
||||
mask: NdArrayTensor,
|
||||
value: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!((tensor, value), |tensor, value| {
|
||||
NdArrayOps::mask_where(tensor, mask.bool(), value)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_mask_fill(
|
||||
tensor: FloatTensor<Self>,
|
||||
mask: NdArrayTensor,
|
||||
value: Scalar,
|
||||
) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayOps::mask_fill(array, mask.bool(), value.elem())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
|
||||
execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
|
||||
}
|
||||
|
||||
fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::equal_elem(array, rhs.elem())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
|
||||
execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
|
||||
}
|
||||
|
||||
fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::greater_elem(array, rhs.elem())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
|
||||
execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
|
||||
NdArrayMathOps::greater_equal(lhs, rhs)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::greater_equal_elem(array, rhs.elem())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
|
||||
execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
|
||||
}
|
||||
|
||||
fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::lower_elem(array, rhs.elem())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
|
||||
execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
|
||||
NdArrayMathOps::lower_equal(lhs, rhs)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
|
||||
execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::lower_equal_elem(array, rhs.elem())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
// Use view() for zero-copy on borrowed storage
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::mean_view(array.view())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
// Use view() for zero-copy on borrowed storage
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::sum_view(array.view())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::mean_dim(array, dim)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::cumsum(array, dim)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::cumprod(array, dim)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::cummin(array, dim)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::cummax(array, dim)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::sum_dim(array, dim)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
|
||||
// Use view() for zero-copy on borrowed storage
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::argmax_view::<I>(array.view(), dim)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
|
||||
// Use view() for zero-copy on borrowed storage
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::argmin_view::<I>(array.view(), dim)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array.mapv_into(|a: FloatElem| a.log_elem()).into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
// Use view() for zero-copy on borrowed storage
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::prod_view(array.view())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::prod_dim(array, dim)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
// Use view() for zero-copy on borrowed storage
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::max_view(array.view())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
// Use view() for zero-copy on borrowed storage
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::min_view(array.view())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array.mapv_into(|a: FloatElem| a.log1p_elem()).into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array
|
||||
.mapv_into(|a: FloatElem| a.powf_elem(value.elem()))
|
||||
.into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::abs(array)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array
|
||||
.mapv_into(|a: FloatElem| (a.to_f64()).cos().elem())
|
||||
.into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array
|
||||
.mapv_into(|a: FloatElem| (a.to_f64()).cosh().elem())
|
||||
.into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array
|
||||
.mapv_into(|a: FloatElem| (a.to_f64()).sin().elem())
|
||||
.into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array
|
||||
.mapv_into(|a: FloatElem| (a.to_f64()).sinh().elem())
|
||||
.into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array
|
||||
.mapv_into(|a: FloatElem| (a.to_f64()).tan().elem())
|
||||
.into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array
|
||||
.mapv_into(|a: FloatElem| (a.to_f64()).tanh().elem())
|
||||
.into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array
|
||||
.mapv_into(|a: FloatElem| (a.to_f64()).acos().elem())
|
||||
.into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array
|
||||
.mapv_into(|a: FloatElem| (a.to_f64()).acosh().elem())
|
||||
.into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array
|
||||
.mapv_into(|a: FloatElem| (a.to_f64()).asin().elem())
|
||||
.into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array
|
||||
.mapv_into(|a: FloatElem| (a.to_f64()).asinh().elem())
|
||||
.into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array
|
||||
.mapv_into(|a: FloatElem| (a.to_f64()).atan().elem())
|
||||
.into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array
|
||||
.mapv_into(|a: FloatElem| (a.to_f64()).atanh().elem())
|
||||
.into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
|
||||
NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.atan2(*b))
|
||||
})
|
||||
}
|
||||
|
||||
fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array
|
||||
.mapv_into(|a: FloatElem| round_ties_even_wrapper(a.to_f64()).elem())
|
||||
.into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array
|
||||
.mapv_into(|a: FloatElem| (a.to_f64()).floor().elem())
|
||||
.into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array
|
||||
.mapv_into(|a: FloatElem| (a.to_f64()).ceil().elem())
|
||||
.into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array
|
||||
.mapv_into(|a: FloatElem| (a.to_f64()).trunc().elem())
|
||||
.into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array
|
||||
.mapv_into(|a: FloatElem| erf(a.to_f64()).elem())
|
||||
.into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
|
||||
cat_with_dtype!(tensors, dim, [F64, F32])
|
||||
}
|
||||
|
||||
fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::clamp_min(array, min.elem())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::clamp_max(array, max.elem())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::clamp(array, min.elem(), max.elem())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_into_int(tensor: FloatTensor<Self>) -> NdArrayTensor {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
array.mapv(|a: FloatElem| a.elem::<I>()).into_shared()
|
||||
})
|
||||
}
|
||||
|
||||
fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
|
||||
NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.powf(*b))
|
||||
})
|
||||
}
|
||||
|
||||
fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayOps::permute(array, axes)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayOps::flip(array, axes)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayMathOps::sign_op(array)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayOps::expand(array, shape)
|
||||
})
|
||||
}
|
||||
|
||||
fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
cast_to_dtype(array, dtype.into())
|
||||
})
|
||||
}
|
||||
|
||||
fn float_grid_sample_2d(
|
||||
tensor: FloatTensor<Self>,
|
||||
grid: FloatTensor<Self>,
|
||||
options: GridSampleOptions,
|
||||
) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
|
||||
tensor, grid, options
|
||||
))
|
||||
}
|
||||
|
||||
fn float_unfold(
|
||||
tensor: FloatTensor<Self>,
|
||||
dim: usize,
|
||||
size: usize,
|
||||
step: usize,
|
||||
) -> FloatTensor<Self> {
|
||||
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
|
||||
NdArrayOps::unfold(array, dim, size, step)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
use crate::{
|
||||
FloatNdArrayElement, NdArray, NdArrayTensor, SharedArray,
|
||||
element::{IntNdArrayElement, QuantElement},
|
||||
};
|
||||
use burn_backend::ops::TransactionOps;
|
||||
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> TransactionOps<Self>
|
||||
for NdArray<E, I, Q>
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<E>>,
|
||||
NdArrayTensor: From<SharedArray<I>>,
|
||||
{
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
/// Macro for running a function in parallel.
|
||||
#[cfg(feature = "multi-threads")]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! run_par {
|
||||
(
|
||||
$func:expr
|
||||
) => {{
|
||||
use rayon::prelude::*;
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
rayon::scope(|_| $func())
|
||||
}};
|
||||
}
|
||||
|
||||
/// Macro for running a function in parallel.
|
||||
#[cfg(not(feature = "multi-threads"))]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! run_par {
|
||||
(
|
||||
$func:expr
|
||||
) => {{ $func() }};
|
||||
}
|
||||
|
||||
/// Macro for iterating in parallel.
|
||||
#[cfg(not(feature = "multi-threads"))]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! iter_par {
|
||||
(
|
||||
$iter:expr
|
||||
) => {{ $iter }};
|
||||
}
|
||||
|
||||
/// Macro for iterating in parallel.
|
||||
#[cfg(feature = "multi-threads")]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! iter_par {
|
||||
(
|
||||
$iter:expr
|
||||
) => {{ $iter.into_par_iter() }};
|
||||
}
|
||||
|
||||
/// Macro for iterating in parallel.
|
||||
#[cfg(feature = "multi-threads")]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! iter_slice_par {
|
||||
(
|
||||
$slice:expr
|
||||
) => {{ $slice.into_par_iter() }};
|
||||
}
|
||||
|
||||
/// Macro for iterating in parallel.
|
||||
#[cfg(not(feature = "multi-threads"))]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! iter_slice_par {
|
||||
(
|
||||
$slice:expr
|
||||
) => {{ $slice.iter() }};
|
||||
}
|
||||
|
||||
/// Macro for iterating over a range in parallel.
|
||||
#[cfg(feature = "multi-threads")]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! iter_range_par {
|
||||
(
|
||||
$start:expr, $end:expr
|
||||
) => {{ ($start..$end).into_par_iter() }};
|
||||
}
|
||||
|
||||
/// Macro for iterating over a range in parallel.
|
||||
#[cfg(not(feature = "multi-threads"))]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! iter_range_par {
|
||||
(
|
||||
$start:expr, $end:expr
|
||||
) => {{ ($start..$end) }};
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
//! Random number generation utilities for burn-ndarray
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use rand::rngs::SmallRng;
|
||||
#[cfg(feature = "std")]
|
||||
use rand::rngs::StdRng;
|
||||
|
||||
/// Type alias for the RNG used by burn-ndarray
|
||||
#[cfg(feature = "std")]
|
||||
pub type NdArrayRng = StdRng;
|
||||
#[cfg(not(feature = "std"))]
|
||||
pub type NdArrayRng = SmallRng;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use rand::SeedableRng;
|
||||
|
||||
/// Get a seeded random number generator
|
||||
///
|
||||
/// For std builds, uses OS entropy.
|
||||
/// For no_std builds, uses a compile-time random seed.
|
||||
#[cfg(feature = "std")]
|
||||
pub fn get_seeded_rng() -> NdArrayRng {
|
||||
// Use the standard implementation from burn-std
|
||||
burn_std::rand::get_seeded_rng()
|
||||
}
|
||||
|
||||
/// Get a seeded random number generator
|
||||
///
|
||||
/// For std builds, uses OS entropy.
|
||||
/// For no_std builds, uses a compile-time random seed.
|
||||
#[cfg(not(feature = "std"))]
|
||||
pub fn get_seeded_rng() -> NdArrayRng {
|
||||
// Use compile-time random seed for no_std
|
||||
const SEED: u64 = const_random::const_random!(u64);
|
||||
SmallRng::seed_from_u64(SEED)
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
use core::cell::UnsafeCell;
|
||||
|
||||
/// Similar to `SyncUnsafeCell` see [Rust issues](https://github.com/rust-lang/rust/issues/95439).
|
||||
pub(crate) struct UnsafeSharedRef<'a, T> {
|
||||
cell: UnsafeCell<&'a mut T>,
|
||||
}
|
||||
|
||||
unsafe impl<T> Sync for UnsafeSharedRef<'_, T> {}
|
||||
|
||||
impl<'a, T> UnsafeSharedRef<'a, T> {
|
||||
pub fn new(data: &'a mut T) -> Self {
|
||||
Self {
|
||||
cell: UnsafeCell::new(data),
|
||||
}
|
||||
}
|
||||
pub unsafe fn get(&self) -> &'a mut T {
|
||||
unsafe { core::ptr::read(self.cell.get()) }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,514 @@
|
||||
//! Copy-on-write storage for zero-copy tensor loading.
|
||||
//!
|
||||
//! This module provides `NdArrayStorage<E>`, which enables true zero-copy loading
|
||||
//! from burnpack files. When data is borrowed from external memory (like mmap'd files
|
||||
//! or static data), it remains zero-copy until a mutating operation is performed,
|
||||
//! at which point it's copied (copy-on-write semantics).
|
||||
//!
|
||||
//! This integrates with ndarray's existing COW patterns - operations that check
|
||||
//! `is_unique()` will see borrowed data as non-unique, triggering the allocation path.
|
||||
|
||||
use alloc::vec::Vec;
|
||||
use burn_backend::Element;
|
||||
use burn_std::Bytes;
|
||||
use core::mem;
|
||||
use ndarray::{ArcArray, ArrayView, IxDyn};
|
||||
|
||||
/// Storage that supports both owned data and borrowed (zero-copy) data.
|
||||
///
|
||||
/// # Copy-on-Write Semantics
|
||||
///
|
||||
/// - **Borrowed**: Data from external source (burnpack, mmap, static).
|
||||
/// Reports `is_unique() == false` to trigger copy on mutation.
|
||||
/// - **Owned**: Standard `ArcArray` with built-in COW via Arc refcount.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// // Zero-copy load
|
||||
/// let storage = NdArrayStorage::from_borrowed(bytes, shape);
|
||||
/// storage.is_unique(); // false - will copy on mutation
|
||||
///
|
||||
/// // Read operations use view() - zero-copy
|
||||
/// let view = storage.view();
|
||||
///
|
||||
/// // Mutation converts to owned
|
||||
/// let owned = storage.into_owned(); // Copies here
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub enum NdArrayStorage<E: Element> {
|
||||
/// Borrowed from external source (e.g., burnpack zero-copy load).
|
||||
/// Keeps `Bytes` alive to ensure the referenced memory is valid.
|
||||
Borrowed {
|
||||
/// Source bytes - keeps external memory alive via reference counting
|
||||
bytes: Bytes,
|
||||
/// Shape of the tensor
|
||||
shape: Vec<usize>,
|
||||
},
|
||||
|
||||
/// Standard owned storage with ArcArray COW semantics.
|
||||
Owned(ArcArray<E, IxDyn>),
|
||||
}
|
||||
|
||||
impl<E: Element> Clone for NdArrayStorage<E> {
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
// For borrowed data, clone the Bytes (cheap Arc clone) and shape
|
||||
Self::Borrowed { bytes, shape } => Self::Borrowed {
|
||||
bytes: bytes.clone(),
|
||||
shape: shape.clone(),
|
||||
},
|
||||
// For owned data, clone the ArcArray (cheap Arc clone)
|
||||
Self::Owned(arr) => Self::Owned(arr.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Element> NdArrayStorage<E> {
|
||||
/// Create borrowed storage from external bytes.
|
||||
///
|
||||
/// Returns the bytes and shape back on failure (misaligned or too small),
|
||||
/// enabling zero-copy even for native allocations by avoiding defensive cloning.
|
||||
///
|
||||
/// # Requirements
|
||||
///
|
||||
/// The caller must ensure that:
|
||||
/// - The `Bytes` contain valid data for the element type `E`
|
||||
/// - The data is contiguous in row-major (C) order matching the provided shape
|
||||
///
|
||||
/// These requirements are upheld when loading from `TensorData` (burnpack, etc.)
|
||||
/// which always stores data contiguously in row-major order.
|
||||
pub fn from_borrowed(bytes: Bytes, shape: Vec<usize>) -> Result<Self, (Bytes, Vec<usize>)> {
|
||||
// Validate alignment
|
||||
let ptr = bytes.as_ptr();
|
||||
if !(ptr as usize).is_multiple_of(mem::align_of::<E>()) {
|
||||
return Err((bytes, shape));
|
||||
}
|
||||
|
||||
// Validate size (using checked arithmetic to prevent overflow)
|
||||
let num_elements = match shape
|
||||
.iter()
|
||||
.try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
|
||||
{
|
||||
Some(n) => n,
|
||||
None => return Err((bytes, shape)),
|
||||
};
|
||||
let expected_size = match num_elements.checked_mul(mem::size_of::<E>()) {
|
||||
Some(s) => s,
|
||||
None => return Err((bytes, shape)),
|
||||
};
|
||||
if bytes.len() < expected_size {
|
||||
return Err((bytes, shape));
|
||||
}
|
||||
|
||||
Ok(Self::Borrowed { bytes, shape })
|
||||
}
|
||||
|
||||
/// Create owned storage from an ArcArray.
|
||||
#[inline]
|
||||
pub fn from_owned(array: ArcArray<E, IxDyn>) -> Self {
|
||||
Self::Owned(array)
|
||||
}
|
||||
|
||||
/// Returns whether this storage is uniquely owned and can be mutated in-place.
|
||||
///
|
||||
/// - **Borrowed**: Always returns `false` to trigger copy-on-write.
|
||||
/// - **Owned**: Delegates to `ArcArray::is_unique()`.
|
||||
///
|
||||
/// This integrates with existing SIMD code patterns like:
|
||||
/// ```ignore
|
||||
/// if tensor.is_unique() {
|
||||
/// // mutate in place
|
||||
/// } else {
|
||||
/// // allocate new
|
||||
/// }
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn is_unique(&self) -> bool {
|
||||
match self {
|
||||
Self::Borrowed { .. } => false, // Force copy path
|
||||
Self::Owned(arr) => arr.is_unique(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a read-only view of the data.
|
||||
///
|
||||
/// This is zero-copy for both borrowed and owned variants.
|
||||
#[inline]
|
||||
pub fn view(&self) -> ArrayView<'_, E, IxDyn> {
|
||||
match self {
|
||||
Self::Borrowed { bytes, shape } => {
|
||||
let ptr = bytes.as_ptr() as *const E;
|
||||
let dim = IxDyn(shape);
|
||||
// SAFETY:
|
||||
// - `bytes` is kept alive for the lifetime of `self`
|
||||
// - Alignment was validated in `from_borrowed`
|
||||
// - Size was validated in `from_borrowed`
|
||||
unsafe { ArrayView::from_shape_ptr(dim, ptr) }
|
||||
}
|
||||
Self::Owned(arr) => arr.view(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to owned ArcArray.
|
||||
///
|
||||
/// - **Borrowed**: Copies the data into a new ArcArray.
|
||||
/// - **Owned + unique**: Returns the array without copying.
|
||||
/// - **Owned + shared**: Clones the data.
|
||||
pub fn into_owned(self) -> ArcArray<E, IxDyn> {
|
||||
match self {
|
||||
Self::Borrowed { bytes, shape } => {
|
||||
let ptr = bytes.as_ptr() as *const E;
|
||||
let dim = IxDyn(&shape);
|
||||
// SAFETY: Same as view() - bytes is valid for this scope
|
||||
let view = unsafe { ArrayView::from_shape_ptr(dim, ptr) };
|
||||
view.to_owned().into_shared()
|
||||
}
|
||||
Self::Owned(arr) => arr,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to shared ArcArray, suitable for returning from operations.
|
||||
///
|
||||
/// This is equivalent to `into_owned()` but named for clarity.
|
||||
#[inline]
|
||||
pub fn into_shared(self) -> ArcArray<E, IxDyn> {
|
||||
self.into_owned()
|
||||
}
|
||||
|
||||
/// Get the shape of the tensor.
|
||||
pub fn shape(&self) -> &[usize] {
|
||||
match self {
|
||||
Self::Borrowed { shape, .. } => shape,
|
||||
Self::Owned(arr) => arr.shape(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of dimensions.
|
||||
#[inline]
|
||||
pub fn ndim(&self) -> usize {
|
||||
self.shape().len()
|
||||
}
|
||||
|
||||
/// Get the total number of elements.
|
||||
#[inline]
|
||||
pub fn len(&self) -> usize {
|
||||
self.shape().iter().product()
|
||||
}
|
||||
|
||||
/// Check if the tensor is empty.
|
||||
#[inline]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
/// Returns `true` if this is borrowed (zero-copy) storage.
|
||||
#[inline]
|
||||
pub fn is_borrowed(&self) -> bool {
|
||||
matches!(self, Self::Borrowed { .. })
|
||||
}
|
||||
|
||||
/// Returns `true` if this is owned storage.
|
||||
#[inline]
|
||||
pub fn is_owned(&self) -> bool {
|
||||
matches!(self, Self::Owned(_))
|
||||
}
|
||||
|
||||
/// Ensure owned and return mutable reference to the ArcArray.
|
||||
///
|
||||
/// Converts borrowed to owned if necessary.
|
||||
pub fn ensure_owned(&mut self) -> &mut ArcArray<E, IxDyn> {
|
||||
if let Self::Borrowed { bytes, shape } = self {
|
||||
let ptr = bytes.as_ptr() as *const E;
|
||||
let dim = IxDyn(shape);
|
||||
// SAFETY: Same as view()
|
||||
let view = unsafe { ArrayView::from_shape_ptr(dim, ptr) };
|
||||
*self = Self::Owned(view.to_owned().into_shared());
|
||||
}
|
||||
match self {
|
||||
Self::Owned(arr) => arr,
|
||||
Self::Borrowed { .. } => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert from ArcArray to NdArrayStorage.
|
||||
impl<E: Element> From<ArcArray<E, IxDyn>> for NdArrayStorage<E> {
|
||||
fn from(array: ArcArray<E, IxDyn>) -> Self {
|
||||
Self::Owned(array)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use alloc::vec;
|
||||
use burn_std::Bytes;
|
||||
|
||||
#[test]
|
||||
fn test_borrowed_is_not_unique() {
|
||||
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let bytes = Bytes::from_elems(data);
|
||||
let storage =
|
||||
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
|
||||
|
||||
assert!(!storage.is_unique());
|
||||
assert!(storage.is_borrowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_owned_unique_when_single_ref() {
|
||||
let array = ndarray::ArrayD::from_elem(IxDyn(&[2, 2]), 1.0f32).into_shared();
|
||||
let storage = NdArrayStorage::from_owned(array);
|
||||
|
||||
assert!(storage.is_unique());
|
||||
assert!(storage.is_owned());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_owned_not_unique_when_cloned() {
|
||||
let array = ndarray::ArrayD::from_elem(IxDyn(&[2, 2]), 1.0f32).into_shared();
|
||||
let storage = NdArrayStorage::from_owned(array);
|
||||
let _clone = storage.clone();
|
||||
|
||||
assert!(!storage.is_unique());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_view_zero_copy() {
|
||||
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let bytes = Bytes::from_elems(data);
|
||||
let storage =
|
||||
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
|
||||
|
||||
let view = storage.view();
|
||||
assert_eq!(view[[0, 0]], 1.0);
|
||||
assert_eq!(view[[1, 1]], 4.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_into_owned_copies_borrowed() {
|
||||
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let bytes = Bytes::from_elems(data);
|
||||
let storage =
|
||||
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
|
||||
|
||||
let owned = storage.into_owned();
|
||||
assert_eq!(owned[[0, 0]], 1.0);
|
||||
assert_eq!(owned[[1, 1]], 4.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_borrowed_validates_alignment() {
|
||||
use burn_std::AllocationProperty;
|
||||
|
||||
// Test 1: Properly aligned data should succeed
|
||||
let aligned_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let aligned_bytes = Bytes::from_elems(aligned_data);
|
||||
|
||||
// Verify test setup - should be 4-byte aligned for f32
|
||||
assert_eq!(
|
||||
(aligned_bytes.as_ptr() as usize) % core::mem::align_of::<f32>(),
|
||||
0,
|
||||
"Test setup: f32 data should be properly aligned"
|
||||
);
|
||||
|
||||
let result = NdArrayStorage::<f32>::from_borrowed(aligned_bytes, vec![2, 2]);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"from_borrowed should succeed for properly aligned data"
|
||||
);
|
||||
|
||||
// Test 2: Misaligned data should fail
|
||||
// Create a buffer large enough to find a misaligned offset
|
||||
// (static data placement varies by platform, so we find an offset dynamically)
|
||||
let buffer: &[u8] = &[0u8; 32];
|
||||
let shared = bytes::Bytes::from_static(buffer);
|
||||
let base = shared.as_ptr() as usize;
|
||||
let align = core::mem::align_of::<f32>();
|
||||
|
||||
// Find an offset in 1..align that produces misalignment (at least one must exist)
|
||||
let misalign_offset = (1..align)
|
||||
.find(|&off| !(base + off).is_multiple_of(align))
|
||||
.expect("Should find a misaligned offset");
|
||||
|
||||
let sliced = shared.slice(misalign_offset..(misalign_offset + 16));
|
||||
let misaligned_bytes = Bytes::from_shared(sliced, AllocationProperty::Other);
|
||||
|
||||
// Verify test setup - should NOT be 4-byte aligned
|
||||
assert_ne!(
|
||||
(misaligned_bytes.as_ptr() as usize) % align,
|
||||
0,
|
||||
"Test setup: sliced data should be misaligned for f32"
|
||||
);
|
||||
|
||||
let result = NdArrayStorage::<f32>::from_borrowed(misaligned_bytes, vec![4]);
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"from_borrowed should return Err for misaligned data"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insufficient_size_returns_err() {
|
||||
// Create bytes that are too small for the requested shape
|
||||
let data: Vec<f32> = vec![1.0, 2.0]; // 8 bytes
|
||||
let bytes = Bytes::from_elems(data);
|
||||
|
||||
// Try to create storage for 4 elements (needs 16 bytes)
|
||||
let result = NdArrayStorage::<f32>::from_borrowed(bytes, vec![4]);
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"from_borrowed should return Err when bytes are too small"
|
||||
);
|
||||
}
|
||||
|
||||
// ==========================================================================
|
||||
// Zero-copy hardening tests
|
||||
// These tests verify the zero-copy guarantee is maintained. If any of these
|
||||
// fail, it indicates a regression in zero-copy functionality.
|
||||
// ==========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_zero_copy_native_allocation() {
|
||||
// CRITICAL: Verify that native allocations (Bytes::from_elems) are zero-copy
|
||||
// on initial load. The view() must return a pointer to the SAME memory.
|
||||
//
|
||||
// Note: Native allocations copy on clone (this is expected), but the initial
|
||||
// load is still zero-copy, avoiding an extra copy in the common case where
|
||||
// the tensor is used without cloning.
|
||||
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let bytes = Bytes::from_elems(data);
|
||||
let original_ptr = bytes.as_ptr();
|
||||
|
||||
let storage =
|
||||
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
|
||||
|
||||
// Initial load must be zero-copy
|
||||
let view = storage.view();
|
||||
let view_ptr = view.as_ptr() as *const u8;
|
||||
|
||||
assert_eq!(
|
||||
original_ptr, view_ptr,
|
||||
"ZERO-COPY REGRESSION: native allocation view() must return pointer to original bytes"
|
||||
);
|
||||
|
||||
// Verify data integrity
|
||||
assert_eq!(view[[0, 0]], 1.0);
|
||||
assert_eq!(view[[0, 1]], 2.0);
|
||||
assert_eq!(view[[1, 0]], 3.0);
|
||||
assert_eq!(view[[1, 1]], 4.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_copy_shared_bytes_pointer_identity() {
|
||||
// CRITICAL: Test with SharedBytesAllocationController for true zero-copy.
|
||||
// This simulates the actual burnpack/mmap loading path.
|
||||
use burn_std::AllocationProperty;
|
||||
|
||||
// Create static-like data using bytes::Bytes
|
||||
let data: &[u8] = &[
|
||||
0, 0, 128, 63, // 1.0f32 in little-endian
|
||||
0, 0, 0, 64, // 2.0f32
|
||||
0, 0, 64, 64, // 3.0f32
|
||||
0, 0, 128, 64, // 4.0f32
|
||||
];
|
||||
let shared = bytes::Bytes::from_static(data);
|
||||
let original_ptr = shared.as_ptr();
|
||||
|
||||
// Create Bytes with SharedBytesAllocationController
|
||||
let bytes = Bytes::from_shared(shared, AllocationProperty::Other);
|
||||
|
||||
let storage =
|
||||
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
|
||||
|
||||
// Verify pointer identity
|
||||
let view_ptr = storage.view().as_ptr() as *const u8;
|
||||
assert_eq!(
|
||||
original_ptr, view_ptr,
|
||||
"ZERO-COPY REGRESSION: SharedBytes view must point to original static data"
|
||||
);
|
||||
|
||||
// Clone should also share the same memory
|
||||
let cloned = storage.clone();
|
||||
let cloned_ptr = cloned.view().as_ptr() as *const u8;
|
||||
assert_eq!(
|
||||
original_ptr, cloned_ptr,
|
||||
"ZERO-COPY REGRESSION: SharedBytes clone must share memory"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clone_borrowed_stays_borrowed() {
|
||||
// Verify that cloning borrowed storage produces another borrowed storage.
|
||||
// Note: The underlying Bytes may or may not share memory depending on
|
||||
// the allocation controller (native allocations copy, file-backed may share).
|
||||
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let bytes = Bytes::from_elems(data);
|
||||
|
||||
let storage =
|
||||
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
|
||||
let cloned = storage.clone();
|
||||
|
||||
// Both should still be borrowed (the storage type is preserved)
|
||||
assert!(
|
||||
storage.is_borrowed(),
|
||||
"ZERO-COPY REGRESSION: original should remain borrowed after clone"
|
||||
);
|
||||
assert!(
|
||||
cloned.is_borrowed(),
|
||||
"ZERO-COPY REGRESSION: clone should be borrowed type"
|
||||
);
|
||||
|
||||
// Both should report not unique (important for COW behavior)
|
||||
assert!(
|
||||
!storage.is_unique(),
|
||||
"ZERO-COPY REGRESSION: original should not be unique after clone"
|
||||
);
|
||||
assert!(
|
||||
!cloned.is_unique(),
|
||||
"ZERO-COPY REGRESSION: clone should not be unique"
|
||||
);
|
||||
|
||||
// Data should be identical
|
||||
assert_eq!(storage.view(), cloned.view(), "Clone should have same data");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_copy_triggers_copy_on_mutation() {
|
||||
// Verify that into_owned() on borrowed data creates a NEW allocation
|
||||
// (this is the "copy" in copy-on-write)
|
||||
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let bytes = Bytes::from_elems(data);
|
||||
let original_ptr = bytes.as_ptr();
|
||||
|
||||
let storage =
|
||||
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
|
||||
|
||||
assert!(storage.is_borrowed(), "should start as borrowed");
|
||||
|
||||
let owned = storage.into_owned();
|
||||
let owned_ptr = owned.as_ptr() as *const u8;
|
||||
|
||||
assert_ne!(
|
||||
original_ptr, owned_ptr,
|
||||
"into_owned() on borrowed data MUST allocate new memory (copy-on-write)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_borrowed_reports_not_unique() {
|
||||
// CRITICAL: Borrowed storage must report is_unique() == false
|
||||
// This is what triggers copy-on-write in mutation operations
|
||||
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let bytes = Bytes::from_elems(data);
|
||||
let storage =
|
||||
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
|
||||
|
||||
assert!(
|
||||
!storage.is_unique(),
|
||||
"ZERO-COPY REGRESSION: borrowed storage MUST report is_unique() == false \
|
||||
to trigger copy-on-write. If this is true, mutations will corrupt shared data!"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,864 @@
|
||||
use core::mem;
|
||||
|
||||
use burn_backend::{
|
||||
DType, Element, QTensorPrimitive, Shape, TensorData, TensorMetadata,
|
||||
quantization::{QParams, QuantLevel, QuantMode, QuantScheme, QuantValue},
|
||||
};
|
||||
|
||||
use crate::NdArrayStorage;
|
||||
use crate::ops::quantization::{QuantizationStrategy, SymmetricQuantization};
|
||||
use alloc::vec::Vec;
|
||||
use ndarray::{ArcArray, ArrayD, IxDyn};
|
||||
|
||||
/// Concrete storage type for ndarray (owned with COW semantics via Arc)
|
||||
pub type SharedArray<E> = ArcArray<E, IxDyn>;
|
||||
|
||||
/// Tensor primitive used by the [ndarray backend](crate::NdArray).
|
||||
///
|
||||
/// Supports both owned and borrowed (zero-copy) data via `NdArrayStorage`.
|
||||
/// When data is borrowed from external sources (like burnpack files),
|
||||
/// it remains zero-copy until a mutating operation is performed.
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(missing_docs)]
|
||||
pub enum NdArrayTensor {
|
||||
F64(NdArrayStorage<f64>),
|
||||
F32(NdArrayStorage<f32>),
|
||||
I64(NdArrayStorage<i64>),
|
||||
I32(NdArrayStorage<i32>),
|
||||
I16(NdArrayStorage<i16>),
|
||||
I8(NdArrayStorage<i8>),
|
||||
U64(NdArrayStorage<u64>),
|
||||
U32(NdArrayStorage<u32>),
|
||||
U16(NdArrayStorage<u16>),
|
||||
U8(NdArrayStorage<u8>),
|
||||
Bool(NdArrayStorage<bool>),
|
||||
}
|
||||
|
||||
impl NdArrayTensor {
|
||||
/// Extract bool array, converting to owned if necessary.
|
||||
pub(crate) fn bool(self) -> SharedArray<bool> {
|
||||
match self {
|
||||
NdArrayTensor::Bool(storage) => storage.into_shared(),
|
||||
_ => unimplemented!("Expected bool tensor, got {:?}", self.dtype()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if this tensor uses borrowed (zero-copy) storage.
|
||||
#[inline]
|
||||
pub fn is_borrowed(&self) -> bool {
|
||||
macro_rules! check {
|
||||
($($variant:ident),*) => {
|
||||
match self {
|
||||
$(NdArrayTensor::$variant(s) => s.is_borrowed(),)*
|
||||
}
|
||||
};
|
||||
}
|
||||
check!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn cast_to_dtype<E1: Element>(array: SharedArray<E1>, dtype: DType) -> NdArrayTensor
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<E1>>,
|
||||
{
|
||||
fn cast<E1: Element, E2: Element>(array: SharedArray<E1>) -> SharedArray<E2> {
|
||||
array.mapv(|a| a.elem()).into_shared()
|
||||
}
|
||||
|
||||
if E1::dtype() == dtype {
|
||||
return array.into();
|
||||
}
|
||||
|
||||
match dtype {
|
||||
DType::F64 => cast::<E1, f64>(array).into(),
|
||||
DType::F32 => cast::<E1, f32>(array).into(),
|
||||
DType::Flex32 => cast::<E1, f32>(array).into(),
|
||||
DType::I64 => cast::<E1, i64>(array).into(),
|
||||
DType::I32 => cast::<E1, i32>(array).into(),
|
||||
DType::I16 => cast::<E1, i16>(array).into(),
|
||||
DType::I8 => cast::<E1, i8>(array).into(),
|
||||
DType::U64 => cast::<E1, u64>(array).into(),
|
||||
DType::U32 => cast::<E1, u32>(array).into(),
|
||||
DType::U16 => cast::<E1, u16>(array).into(),
|
||||
DType::U8 => cast::<E1, u8>(array).into(),
|
||||
DType::Bool => cast::<E1, bool>(array).into(),
|
||||
dtype => panic!("Unsupported dtype: {dtype:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_from {
|
||||
($($ty: ty => $dtype: ident),*) => {
|
||||
// From SharedArray (owned) -> NdArrayTensor
|
||||
$(impl From<SharedArray<$ty>> for NdArrayTensor {
|
||||
fn from(value: SharedArray<$ty>) -> NdArrayTensor {
|
||||
NdArrayTensor::$dtype(NdArrayStorage::from_owned(value))
|
||||
}
|
||||
})*
|
||||
|
||||
// From NdArrayStorage -> NdArrayTensor
|
||||
$(impl From<NdArrayStorage<$ty>> for NdArrayTensor {
|
||||
fn from(value: NdArrayStorage<$ty>) -> NdArrayTensor {
|
||||
NdArrayTensor::$dtype(value)
|
||||
}
|
||||
})*
|
||||
};
|
||||
}
|
||||
|
||||
impl_from!(
|
||||
f64 => F64, f32 => F32,
|
||||
i64 => I64, i32 => I32, i16 => I16, i8 => I8,
|
||||
u64 => U64, u32 => U32, u16 => U16, u8 => U8,
|
||||
bool => Bool
|
||||
);
|
||||
|
||||
/// Macro to execute an operation on a given element type.
|
||||
///
|
||||
/// Extracts the storage from NdArrayTensor, converts to SharedArray, and passes to operation.
|
||||
///
|
||||
/// # Panics
|
||||
/// Since there is no automatic type cast at this time, binary operations for different
|
||||
/// floating point precision data types will panic with a data type mismatch.
|
||||
#[macro_export]
|
||||
macro_rules! execute_with_dtype {
|
||||
(($lhs:expr, $rhs:expr),$element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
|
||||
let lhs_dtype = burn_backend::TensorMetadata::dtype(&$lhs);
|
||||
let rhs_dtype = burn_backend::TensorMetadata::dtype(&$rhs);
|
||||
match ($lhs, $rhs) {
|
||||
$(
|
||||
($crate::NdArrayTensor::$dtype(lhs), $crate::NdArrayTensor::$dtype(rhs)) => {
|
||||
#[allow(unused)]
|
||||
type $element = $ty;
|
||||
// Convert storage to SharedArray for compatibility with existing operations
|
||||
$op(lhs.into_shared(), rhs.into_shared()).into()
|
||||
}
|
||||
)*
|
||||
_ => panic!(
|
||||
"Data type mismatch (lhs: {:?}, rhs: {:?})",
|
||||
lhs_dtype, rhs_dtype
|
||||
),
|
||||
}
|
||||
}};
|
||||
// Binary op: type automatically inferred by the compiler
|
||||
(($lhs:expr, $rhs:expr), $op:expr) => {{
|
||||
$crate::execute_with_dtype!(($lhs, $rhs), E, $op)
|
||||
}};
|
||||
|
||||
// Binary op: generic type cannot be inferred for an operation
|
||||
(($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
|
||||
$crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
|
||||
F64 => f64, F32 => f32,
|
||||
I64 => i64, I32 => i32, I16 => i16, I8 => i8,
|
||||
U64 => u64, U32 => u32, U16 => u16, U8 => u8,
|
||||
Bool => bool
|
||||
])
|
||||
}};
|
||||
|
||||
($tensor:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
|
||||
match $tensor {
|
||||
$(
|
||||
$crate::NdArrayTensor::$dtype(storage) => {
|
||||
#[allow(unused)]
|
||||
type $element = $ty;
|
||||
// Convert to SharedArray for compatibility with most operations
|
||||
$op(storage.into_shared()).into()
|
||||
}
|
||||
)*
|
||||
#[allow(unreachable_patterns)]
|
||||
other => unimplemented!("unsupported dtype: {:?}", other.dtype())
|
||||
}
|
||||
}};
|
||||
// Unary op: type automatically inferred by the compiler
|
||||
($tensor:expr, $op:expr) => {{
|
||||
$crate::execute_with_dtype!($tensor, E, $op)
|
||||
}};
|
||||
|
||||
// Unary op: generic type cannot be inferred for an operation
|
||||
($tensor:expr, $element:ident, $op:expr) => {{
|
||||
$crate::execute_with_dtype!($tensor, $element, $op, [
|
||||
F64 => f64, F32 => f32,
|
||||
I64 => i64, I32 => i32, I16 => i16, I8 => i8,
|
||||
U64 => u64, U32 => u32, U16 => u16, U8 => u8,
|
||||
Bool => bool
|
||||
])
|
||||
}};
|
||||
}
|
||||
|
||||
/// Macro to execute an operation a given element type.
|
||||
/// Only handles float types.
|
||||
///
|
||||
/// # Panics
|
||||
/// Since there is no automatic type cast at this time, binary operations for different
|
||||
/// floating point precision data types will panic with a data type mismatch.
|
||||
#[macro_export]
|
||||
macro_rules! execute_with_float_dtype {
|
||||
// Binary op: type automatically inferred by the compiler
|
||||
(($lhs:expr, $rhs:expr), $op:expr) => {{
|
||||
$crate::execute_with_float_dtype!(($lhs, $rhs), E, $op)
|
||||
}};
|
||||
|
||||
// Binary op: generic type cannot be inferred for an operation
|
||||
(($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
|
||||
$crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
|
||||
F64 => f64, F32 => f32
|
||||
])
|
||||
}};
|
||||
|
||||
// Unary op: type automatically inferred by the compiler
|
||||
($tensor:expr, $op:expr) => {{
|
||||
$crate::execute_with_float_dtype!($tensor, E, $op)
|
||||
}};
|
||||
|
||||
// Unary op: generic type cannot be inferred for an operation
|
||||
($tensor:expr, $element:ident, $op:expr) => {{
|
||||
$crate::execute_with_dtype!($tensor, $element, $op, [
|
||||
F64 => f64, F32 => f32
|
||||
])
|
||||
}};
|
||||
}
|
||||
|
||||
/// Macro to execute an operation a given element type.
|
||||
/// Only handles int types.
|
||||
///
|
||||
/// # Panics
|
||||
/// Since there is no automatic type cast at this time, binary operations for different
|
||||
/// floating point precision data types will panic with a data type mismatch.
|
||||
#[macro_export]
|
||||
macro_rules! execute_with_int_dtype {
|
||||
// Binary op: type automatically inferred by the compiler
|
||||
(($lhs:expr, $rhs:expr), $op:expr) => {{
|
||||
$crate::execute_with_int_dtype!(($lhs, $rhs), E, $op)
|
||||
}};
|
||||
|
||||
// Binary op: generic type cannot be inferred for an operation
|
||||
(($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
|
||||
$crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
|
||||
I64 => i64, I32 => i32, I16 => i16, I8 => i8,
|
||||
U64 => u64, U32 => u32, U16 => u16, U8 => u8
|
||||
])
|
||||
}};
|
||||
|
||||
// Unary op: type automatically inferred by the compiler
|
||||
($tensor:expr, $op:expr) => {{
|
||||
$crate::execute_with_int_dtype!($tensor, E, $op)
|
||||
}};
|
||||
|
||||
// Unary op: generic type cannot be inferred for an operation
|
||||
($tensor:expr, $element:ident, $op:expr) => {{
|
||||
$crate::execute_with_dtype!($tensor, $element, $op, [
|
||||
I64 => i64, I32 => i32, I16 => i16, I8 => i8,
|
||||
U64 => u64, U32 => u32, U16 => u16, U8 => u8
|
||||
])
|
||||
}};
|
||||
}
|
||||
|
||||
/// Macro to execute an operation a given element type.
|
||||
/// Only handles numeric types
|
||||
///
|
||||
/// # Panics
|
||||
/// Since there is no automatic type cast at this time, binary operations for different
|
||||
/// floating point precision data types will panic with a data type mismatch.
|
||||
#[macro_export]
|
||||
macro_rules! execute_with_numeric_dtype {
|
||||
// Binary op: type automatically inferred by the compiler
|
||||
(($lhs:expr, $rhs:expr), $op:expr) => {{
|
||||
$crate::execute_with_numeric_dtype!(($lhs, $rhs), E, $op)
|
||||
}};
|
||||
|
||||
// Binary op: generic type cannot be inferred for an operation
|
||||
(($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
|
||||
$crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
|
||||
F64 => f64, F32 => f32,
|
||||
I64 => i64, I32 => i32, I16 => i16, I8 => i8,
|
||||
U64 => u64, U32 => u32, U16 => u16, U8 => u8
|
||||
])
|
||||
}};
|
||||
|
||||
// Unary op: type automatically inferred by the compiler
|
||||
($tensor:expr, $op:expr) => {{
|
||||
$crate::execute_with_numeric_dtype!($tensor, E, $op)
|
||||
}};
|
||||
|
||||
// Unary op: generic type cannot be inferred for an operation
|
||||
($tensor:expr, $element:ident, $op:expr) => {{
|
||||
$crate::execute_with_dtype!($tensor, $element, $op, [
|
||||
F64 => f64, F32 => f32,
|
||||
I64 => i64, I32 => i32, I16 => i16, I8 => i8,
|
||||
U64 => u64, U32 => u32, U16 => u16, U8 => u8
|
||||
])
|
||||
}};
|
||||
}
|
||||
|
||||
/// Macro to execute a cat operation on a given set of element types.
|
||||
///
|
||||
/// Uses zero-copy views from storage for concatenation.
|
||||
///
|
||||
/// # Panics
|
||||
/// Since there is no automatic type cast at this time, binary operations for different
|
||||
/// floating point precision data types will panic with a data type mismatch.
|
||||
#[macro_export]
|
||||
macro_rules! cat_with_dtype {
|
||||
($tensors: expr, $dim: expr, [$($dtype: ident),*]) => {
|
||||
match &$tensors[0] {
|
||||
$(NdArrayTensor::$dtype(_) => {
|
||||
let tensors = $tensors
|
||||
.iter()
|
||||
.map(|t| {
|
||||
if let NdArrayTensor::$dtype(storage) = t {
|
||||
// Use storage.view() for zero-copy access
|
||||
storage.view()
|
||||
} else {
|
||||
panic!("Concatenate data type mismatch (expected {:?}, got {:?})", $tensors[0].dtype(), t.dtype())
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
NdArrayOps::concatenate(&tensors, $dim).into()
|
||||
})*
|
||||
_ => panic!("Unsupported dtype: {:?}", $tensors[0].dtype())
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl TensorMetadata for NdArrayTensor {
|
||||
fn dtype(&self) -> DType {
|
||||
match self {
|
||||
NdArrayTensor::F64(_) => DType::F64,
|
||||
NdArrayTensor::F32(_) => DType::F32,
|
||||
NdArrayTensor::I64(_) => DType::I64,
|
||||
NdArrayTensor::I32(_) => DType::I32,
|
||||
NdArrayTensor::I16(_) => DType::I16,
|
||||
NdArrayTensor::I8(_) => DType::I8,
|
||||
NdArrayTensor::U64(_) => DType::U64,
|
||||
NdArrayTensor::U32(_) => DType::U32,
|
||||
NdArrayTensor::U16(_) => DType::U16,
|
||||
NdArrayTensor::U8(_) => DType::U8,
|
||||
NdArrayTensor::Bool(_) => DType::Bool,
|
||||
}
|
||||
}
|
||||
|
||||
fn shape(&self) -> Shape {
|
||||
// Use storage's shape method (works for both borrowed and owned)
|
||||
macro_rules! get_shape {
|
||||
($($variant:ident),*) => {
|
||||
match self {
|
||||
$(NdArrayTensor::$variant(storage) => Shape::from(storage.shape().to_vec()),)*
|
||||
}
|
||||
};
|
||||
}
|
||||
get_shape!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
|
||||
}
|
||||
|
||||
fn rank(&self) -> usize {
|
||||
self.shape().num_dims()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait ShapeOps {
|
||||
fn num_dims(self) -> usize;
|
||||
fn num_elements(self) -> usize;
|
||||
fn dims<const N: usize>(self) -> [usize; N];
|
||||
fn into_shape(self) -> Shape;
|
||||
}
|
||||
|
||||
impl ShapeOps for &[usize] {
|
||||
fn num_dims(self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
|
||||
fn num_elements(self) -> usize {
|
||||
self.iter().product()
|
||||
}
|
||||
|
||||
fn dims<const N: usize>(self) -> [usize; N] {
|
||||
self.try_into().unwrap()
|
||||
}
|
||||
|
||||
fn into_shape(self) -> Shape {
|
||||
Shape::from(self)
|
||||
}
|
||||
}
|
||||
|
||||
mod utils {
|
||||
use burn_std::tensor::is_contiguous;
|
||||
|
||||
use super::*;
|
||||
|
||||
impl NdArrayTensor {
|
||||
pub(crate) fn into_data(self) -> TensorData {
|
||||
let shape = self.shape();
|
||||
let contiguous = self.is_contiguous();
|
||||
|
||||
fn inner<E: Element>(
|
||||
shape: Shape,
|
||||
is_contiguous: bool,
|
||||
array: ArcArray<E, IxDyn>,
|
||||
) -> TensorData {
|
||||
let vec = if is_contiguous {
|
||||
match array.try_into_owned_nocopy() {
|
||||
Ok(owned) => {
|
||||
let (mut vec, offset) = owned.into_raw_vec_and_offset();
|
||||
if let Some(offset) = offset {
|
||||
vec.drain(..offset);
|
||||
}
|
||||
if vec.len() > shape.num_elements() {
|
||||
vec.drain(shape.num_elements()..vec.len());
|
||||
}
|
||||
vec
|
||||
}
|
||||
Err(array) => array.into_iter().collect(),
|
||||
}
|
||||
} else {
|
||||
array.into_iter().collect()
|
||||
};
|
||||
|
||||
TensorData::new(vec, shape)
|
||||
}
|
||||
|
||||
// Convert storage to owned array before extracting data
|
||||
execute_with_dtype!(self, |arr| inner(shape, contiguous, arr))
|
||||
}
|
||||
|
||||
pub(crate) fn is_contiguous(&self) -> bool {
|
||||
// For borrowed data, we assume it's contiguous (it came from TensorData which is contiguous)
|
||||
// For owned data, we check the strides
|
||||
macro_rules! check_contiguous {
|
||||
($($variant:ident),*) => {
|
||||
match self {
|
||||
$(NdArrayTensor::$variant(storage) => {
|
||||
match storage {
|
||||
NdArrayStorage::Borrowed { .. } => {
|
||||
// Borrowed storage requires contiguous row-major data
|
||||
// (see NdArrayStorage::from_borrowed documentation)
|
||||
true
|
||||
}
|
||||
NdArrayStorage::Owned(array) => {
|
||||
let shape = array.shape();
|
||||
let mut strides = Vec::with_capacity(array.strides().len());
|
||||
for &stride in array.strides() {
|
||||
if stride <= 0 {
|
||||
return false;
|
||||
}
|
||||
strides.push(stride as usize);
|
||||
}
|
||||
is_contiguous(shape, &strides)
|
||||
}
|
||||
}
|
||||
})*
|
||||
}
|
||||
};
|
||||
}
|
||||
check_contiguous!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a slice of usize to a typed dimension.
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! to_typed_dims {
|
||||
(
|
||||
$n:expr,
|
||||
$dims:expr,
|
||||
justdim
|
||||
) => {{
|
||||
let mut dims = [0; $n];
|
||||
for i in 0..$n {
|
||||
dims[i] = $dims[i];
|
||||
}
|
||||
let dim: Dim<[usize; $n]> = Dim(dims);
|
||||
dim
|
||||
}};
|
||||
}
|
||||
|
||||
/// Reshapes an array into a tensor.
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! reshape {
|
||||
(
|
||||
ty $ty:ty,
|
||||
n $n:expr,
|
||||
shape $shape:expr,
|
||||
array $array:expr
|
||||
) => {{
|
||||
let dim = $crate::to_typed_dims!($n, $shape, justdim);
|
||||
let array = match $array.is_standard_layout() {
|
||||
true => {
|
||||
match $array.to_shape(dim) {
|
||||
Ok(val) => val.into_shared(),
|
||||
Err(err) => {
|
||||
core::panic!("Shape should be compatible shape={dim:?}: {err:?}");
|
||||
}
|
||||
}
|
||||
},
|
||||
false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(),
|
||||
};
|
||||
array.into_dyn()
|
||||
}};
|
||||
(
|
||||
ty $ty:ty,
|
||||
shape $shape:expr,
|
||||
array $array:expr,
|
||||
d $D:expr
|
||||
) => {{
|
||||
match $D {
|
||||
1 => reshape!(ty $ty, n 1, shape $shape, array $array),
|
||||
2 => reshape!(ty $ty, n 2, shape $shape, array $array),
|
||||
3 => reshape!(ty $ty, n 3, shape $shape, array $array),
|
||||
4 => reshape!(ty $ty, n 4, shape $shape, array $array),
|
||||
5 => reshape!(ty $ty, n 5, shape $shape, array $array),
|
||||
6 => reshape!(ty $ty, n 6, shape $shape, array $array),
|
||||
_ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D),
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
/// Slice a tensor
|
||||
#[macro_export]
|
||||
macro_rules! slice {
|
||||
($tensor:expr, $slices:expr) => {
|
||||
slice!($tensor, $slices, F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
|
||||
};
|
||||
($tensor:expr, $slices:expr, $($variant:ident),*) => {
|
||||
match $tensor {
|
||||
$(NdArrayTensor::$variant(s) => { NdArrayOps::slice(s.view(), $slices).into() })*
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl NdArrayTensor {
|
||||
/// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData).
|
||||
///
|
||||
/// This method attempts zero-copy loading when possible. If the data has properly
|
||||
/// aligned bytes that can be borrowed, it creates a borrowed tensor. Otherwise,
|
||||
/// it falls back to copying the data.
|
||||
///
|
||||
/// Zero-copy loading works when:
|
||||
/// - The data's bytes are properly aligned for the element type
|
||||
/// - The bytes can be borrowed (e.g., from mmap'd file or static data)
|
||||
pub fn from_data(data: TensorData) -> NdArrayTensor {
|
||||
// Try borrowed storage first, fall back to owned if not possible
|
||||
match Self::try_from_data_borrowed(data) {
|
||||
Ok(tensor) => tensor,
|
||||
Err(data) => Self::from_data_owned(data),
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to create a tensor with borrowed storage (zero-copy).
|
||||
///
|
||||
/// Takes ownership of TensorData and returns it back on failure.
|
||||
/// No cloning occurs - bytes are moved into storage or returned on failure.
|
||||
///
|
||||
/// Returns `Err(data)` if borrowing is not possible (e.g., misaligned data).
|
||||
fn try_from_data_borrowed(data: TensorData) -> Result<NdArrayTensor, TensorData> {
|
||||
let TensorData {
|
||||
bytes,
|
||||
shape,
|
||||
dtype,
|
||||
} = data;
|
||||
|
||||
macro_rules! try_borrow {
|
||||
($ty:ty, $variant:ident, $bytes:expr, $shape:expr) => {
|
||||
match NdArrayStorage::<$ty>::from_borrowed($bytes, $shape) {
|
||||
Ok(storage) => return Ok(NdArrayTensor::$variant(storage)),
|
||||
Err((bytes, shape)) => (bytes, shape),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Try to create borrowed storage; get bytes back on failure
|
||||
let (bytes, shape) = match dtype {
|
||||
DType::F64 => try_borrow!(f64, F64, bytes, shape),
|
||||
DType::F32 => try_borrow!(f32, F32, bytes, shape),
|
||||
DType::I64 => try_borrow!(i64, I64, bytes, shape),
|
||||
DType::I32 => try_borrow!(i32, I32, bytes, shape),
|
||||
DType::I16 => try_borrow!(i16, I16, bytes, shape),
|
||||
DType::I8 => try_borrow!(i8, I8, bytes, shape),
|
||||
DType::U64 => try_borrow!(u64, U64, bytes, shape),
|
||||
DType::U32 => try_borrow!(u32, U32, bytes, shape),
|
||||
DType::U16 => try_borrow!(u16, U16, bytes, shape),
|
||||
DType::U8 => try_borrow!(u8, U8, bytes, shape),
|
||||
DType::Bool => try_borrow!(bool, Bool, bytes, shape),
|
||||
_ => (bytes, shape), // QFloat not supported for zero-copy
|
||||
};
|
||||
|
||||
Err(TensorData {
|
||||
bytes,
|
||||
shape,
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a tensor with owned storage.
|
||||
///
|
||||
/// This may or may not copy data depending on whether the underlying bytes
|
||||
/// can be reclaimed (via `try_into_vec`). If bytes are uniquely owned,
|
||||
/// no copy occurs; otherwise data is copied to a new allocation.
|
||||
fn from_data_owned(mut data: TensorData) -> NdArrayTensor {
|
||||
let shape = mem::take(&mut data.shape);
|
||||
|
||||
macro_rules! execute {
|
||||
($data: expr, [$($dtype: ident => $ty: ty),*]) => {
|
||||
match $data.dtype {
|
||||
$(DType::$dtype => {
|
||||
match data.into_vec::<$ty>() {
|
||||
// Safety: TensorData checks shape validity on creation
|
||||
Ok(vec) => unsafe { ArrayD::from_shape_vec_unchecked(shape, vec) }.into_shared(),
|
||||
Err(err) => panic!("Data should have the same element type as the tensor {err:?}"),
|
||||
}.into()
|
||||
},)*
|
||||
other => unimplemented!("Unsupported dtype {other:?}"),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
execute!(data, [
|
||||
F64 => f64, F32 => f32,
|
||||
I64 => i64, I32 => i32, I16 => i16, I8 => i8,
|
||||
U64 => u64, U32 => u32, U16 => u16, U8 => u8,
|
||||
Bool => bool
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
/// A quantized tensor for the ndarray backend.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NdArrayQTensor {
|
||||
/// The quantized tensor.
|
||||
pub qtensor: NdArrayTensor,
|
||||
/// The quantization scheme.
|
||||
pub scheme: QuantScheme,
|
||||
/// The quantization parameters.
|
||||
pub qparams: Vec<QParams<f32>>,
|
||||
}
|
||||
|
||||
impl NdArrayQTensor {
|
||||
/// Returns the quantization strategy, including quantization parameters, for the given tensor.
|
||||
pub fn strategy(&self) -> QuantizationStrategy {
|
||||
match self.scheme {
|
||||
QuantScheme {
|
||||
level: QuantLevel::Tensor,
|
||||
mode: QuantMode::Symmetric,
|
||||
value:
|
||||
QuantValue::Q8F
|
||||
| QuantValue::Q8S
|
||||
| QuantValue::E4M3
|
||||
| QuantValue::E5M2
|
||||
| QuantValue::Q4F
|
||||
| QuantValue::Q4S
|
||||
| QuantValue::E2M1
|
||||
| QuantValue::Q2F
|
||||
| QuantValue::Q2S,
|
||||
..
|
||||
} => QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
|
||||
self.qparams[0].scales,
|
||||
self.scheme.value,
|
||||
)),
|
||||
QuantScheme {
|
||||
level: QuantLevel::Block(block_size),
|
||||
mode: QuantMode::Symmetric,
|
||||
value:
|
||||
QuantValue::Q8F
|
||||
| QuantValue::Q8S
|
||||
| QuantValue::E4M3
|
||||
| QuantValue::E5M2
|
||||
| QuantValue::Q4F
|
||||
| QuantValue::Q4S
|
||||
| QuantValue::E2M1
|
||||
| QuantValue::Q2F
|
||||
| QuantValue::Q2S,
|
||||
..
|
||||
} => QuantizationStrategy::PerBlockSymmetric(
|
||||
self.qparams
|
||||
.iter()
|
||||
.map(|q| SymmetricQuantization::init(q.scales, self.scheme.value))
|
||||
.collect(),
|
||||
block_size,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QTensorPrimitive for NdArrayQTensor {
|
||||
fn scheme(&self) -> &QuantScheme {
|
||||
&self.scheme
|
||||
}
|
||||
|
||||
fn default_scheme() -> QuantScheme {
|
||||
QuantScheme::default().with_store(burn_backend::quantization::QuantStore::Native)
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorMetadata for NdArrayQTensor {
|
||||
fn dtype(&self) -> DType {
|
||||
DType::QFloat(self.scheme)
|
||||
}
|
||||
|
||||
fn shape(&self) -> Shape {
|
||||
self.qtensor.shape()
|
||||
}
|
||||
|
||||
fn rank(&self) -> usize {
|
||||
self.shape().num_dims()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::NdArray;
|
||||
use alloc::vec;
|
||||
|
||||
use super::*;
|
||||
use burn_backend::{
|
||||
Distribution,
|
||||
ops::{FloatTensorOps, QTensorOps},
|
||||
quantization::{QuantStore, QuantizationParametersPrimitive},
|
||||
};
|
||||
use burn_std::rand::get_seeded_rng;
|
||||
|
||||
#[test]
|
||||
fn should_support_into_and_from_data_1d() {
|
||||
let data_expected = TensorData::random::<f32, _, _>(
|
||||
Shape::new([3]),
|
||||
Distribution::Default,
|
||||
&mut get_seeded_rng(),
|
||||
);
|
||||
let tensor = NdArrayTensor::from_data(data_expected.clone());
|
||||
|
||||
let data_actual = tensor.into_data();
|
||||
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_into_and_from_data_2d() {
|
||||
let data_expected = TensorData::random::<f32, _, _>(
|
||||
Shape::new([2, 3]),
|
||||
Distribution::Default,
|
||||
&mut get_seeded_rng(),
|
||||
);
|
||||
let tensor = NdArrayTensor::from_data(data_expected.clone());
|
||||
|
||||
let data_actual = tensor.into_data();
|
||||
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_into_and_from_data_3d() {
|
||||
let data_expected = TensorData::random::<f32, _, _>(
|
||||
Shape::new([2, 3, 4]),
|
||||
Distribution::Default,
|
||||
&mut get_seeded_rng(),
|
||||
);
|
||||
let tensor = NdArrayTensor::from_data(data_expected.clone());
|
||||
|
||||
let data_actual = tensor.into_data();
|
||||
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_into_and_from_data_4d() {
|
||||
let data_expected = TensorData::random::<f32, _, _>(
|
||||
Shape::new([2, 3, 4, 2]),
|
||||
Distribution::Default,
|
||||
&mut get_seeded_rng(),
|
||||
);
|
||||
let tensor = NdArrayTensor::from_data(data_expected.clone());
|
||||
|
||||
let data_actual = tensor.into_data();
|
||||
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_qtensor_strategy() {
|
||||
type B = NdArray<f32, i64, i8>;
|
||||
let scale: f32 = 0.009_019_608;
|
||||
let device = Default::default();
|
||||
|
||||
let tensor = B::float_from_data(TensorData::from([-1.8f32, -1.0, 0.0, 0.5]), &device);
|
||||
let scheme = QuantScheme::default()
|
||||
.with_value(QuantValue::Q8S)
|
||||
.with_store(QuantStore::Native);
|
||||
let qparams = QuantizationParametersPrimitive {
|
||||
scales: B::float_from_data(TensorData::from([scale]), &device),
|
||||
};
|
||||
let qtensor: NdArrayQTensor = B::quantize(tensor, &scheme, qparams);
|
||||
|
||||
assert_eq!(qtensor.scheme(), &scheme);
|
||||
assert_eq!(
|
||||
qtensor.strategy(),
|
||||
QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
|
||||
scale,
|
||||
QuantValue::Q8S
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
// ==========================================================================
|
||||
// Zero-copy integration tests
|
||||
// These tests verify end-to-end zero-copy behavior through NdArrayTensor.
|
||||
// ==========================================================================
|
||||
|
||||
#[test]
|
||||
fn zero_copy_creates_borrowed_storage() {
|
||||
// Verify that from_data creates borrowed storage when possible.
|
||||
// Note: For native allocations, Bytes::clone() copies data internally,
|
||||
// but the storage type (Borrowed) is preserved, which is important for
|
||||
// the is_unique() behavior that triggers copy-on-write.
|
||||
use burn_std::Bytes;
|
||||
|
||||
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let bytes = Bytes::from_elems(data);
|
||||
let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32);
|
||||
|
||||
let tensor = NdArrayTensor::from_data(tensor_data);
|
||||
|
||||
match &tensor {
|
||||
NdArrayTensor::F32(storage) => {
|
||||
assert!(
|
||||
storage.is_borrowed(),
|
||||
"ZERO-COPY REGRESSION: from_data should create borrowed storage \
|
||||
for properly aligned TensorData with Bytes"
|
||||
);
|
||||
assert!(
|
||||
!storage.is_unique(),
|
||||
"ZERO-COPY REGRESSION: borrowed storage must report is_unique() == false"
|
||||
);
|
||||
}
|
||||
_ => panic!("Expected F32 tensor"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_copy_data_integrity() {
|
||||
// Verify data is correctly accessible through borrowed storage
|
||||
use burn_std::Bytes;
|
||||
|
||||
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let bytes = Bytes::from_elems(data);
|
||||
let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32);
|
||||
|
||||
let tensor = NdArrayTensor::from_data(tensor_data);
|
||||
|
||||
match &tensor {
|
||||
NdArrayTensor::F32(storage) => {
|
||||
let view = storage.view();
|
||||
assert_eq!(view[[0, 0]], 1.0);
|
||||
assert_eq!(view[[0, 1]], 2.0);
|
||||
assert_eq!(view[[1, 0]], 3.0);
|
||||
assert_eq!(view[[1, 1]], 4.0);
|
||||
}
|
||||
_ => panic!("Expected F32 tensor"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_copy_fallback_when_bytes_owned() {
|
||||
// When TensorData owns bytes exclusively, it may use the copy path
|
||||
// This is expected behavior - verify it still works correctly
|
||||
let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
|
||||
let tensor = NdArrayTensor::from_data(data.clone());
|
||||
let result = tensor.into_data();
|
||||
|
||||
assert_eq!(data, result, "Data should round-trip correctly");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user