feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,88 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science"]
description = "Generic backend that can be compiled just-in-time to any shader language target"
documentation = "https://docs.rs/burn-cubecl"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "gpu"]
license.workspace = true
name = "burn-cubecl"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cubecl"
version.workspace = true
[lints]
workspace = true
[features]
default = [
"autotune",
"std",
"fusion",
"cubecl/default",
"burn-fusion?/default",
"burn-cubecl-fusion?/default",
]
std = [
"cubecl/std",
"burn-backend/std",
"burn-fusion?/std",
"burn-cubecl-fusion?/std",
]
doc = ["default"]
memory-checks = ["burn-fusion?/memory-checks"]
tracing = [
"dep:tracing",
"cubecl/tracing",
"burn-std/tracing",
"burn-backend/tracing",
"burn-fusion?/tracing",
"burn-cubecl-fusion?/tracing",
]
autotune = ["burn-cubecl-fusion?/autotune"]
autotune-checks = [
"autotune",
"cubecl/autotune-checks",
"burn-cubecl-fusion?/autotune-checks",
]
fusion = ["burn-fusion", "burn-cubecl-fusion"]
fusion-experimental = ["fusion"]
template = []
[dependencies]
burn-cubecl-fusion = { path = "../burn-cubecl-fusion", version = "=0.21.0-pre.2", default-features = false, optional = true }
burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", default-features = false, optional = true }
burn-ir = { path = "../burn-ir", version = "=0.21.0-pre.2", default-features = false }
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false, features = [
"cubecl",
] }
burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false, features = [
"cubecl",
] }
cubecl = { workspace = true, features = ["stdlib"] }
cubek = { workspace = true, features = [
"attention",
"matmul",
"convolution",
"reduce",
"random",
"quantization",
] }
tracing = { workspace = true, features = ["attributes"], optional = true }
derive-new = { workspace = true }
log = { workspace = true }
# Async
futures-lite = { workspace = true, features = ["std"] }
# Template
serde = { workspace = true }
text_placeholder = { workspace = true, features = ["struct_context"] }
[package.metadata.docs.rs]
features = ["doc"]
rustdoc-args = ["--cfg", "docsrs"]

View File

@@ -0,0 +1 @@
../../LICENSE-APACHE

View File

@@ -0,0 +1 @@
../../LICENSE-MIT

View File

@@ -0,0 +1,3 @@
# Burn CubeCL Backend
Generic backend that can be compiled just-in-time (JIT) to any shader language target.

View File

@@ -0,0 +1,196 @@
use crate::{CubeRuntime, FloatElement, IntElement, element::BoolElement, tensor::CubeTensor};
use burn_backend::{Backend, DTypeUsage, DTypeUsageSet, DeviceOps, ExecutionError, TensorData};
use burn_std::DType;
use cubecl::{
features::{MmaConfig, TypeUsage},
server::ComputeServer,
};
use std::marker::PhantomData;
#[cfg(not(feature = "fusion"))]
use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
#[cfg(not(feature = "fusion"))]
use burn_ir::{BackendIr, TensorHandle};
/// Generic tensor backend that can be compiled just-in-time to any shader runtime
#[derive(new)]
pub struct CubeBackend<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> {
_runtime: PhantomData<R>,
_float_elem: PhantomData<F>,
_int_elem: PhantomData<I>,
_bool_elem: PhantomData<BT>,
}
impl<R, F, I, BT> Backend for CubeBackend<R, F, I, BT>
where
R: CubeRuntime,
R::Server: ComputeServer,
R::Device: DeviceOps,
F: FloatElement,
I: IntElement,
BT: BoolElement,
{
type Device = R::Device;
type FloatElem = F;
type IntElem = I;
type BoolElem = BT;
type FloatTensorPrimitive = CubeTensor<R>;
type IntTensorPrimitive = CubeTensor<R>;
type BoolTensorPrimitive = CubeTensor<R>;
type QuantizedTensorPrimitive = CubeTensor<R>;
fn name(device: &Self::Device) -> String {
let client = R::client(device);
format!("cubecl<{}>", R::name(&client))
}
fn seed(_device: &Self::Device, seed: u64) {
cubek::random::seed(seed);
}
fn ad_enabled(_device: &Self::Device) -> bool {
false
}
fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
let client = R::client(device);
futures_lite::future::block_on(client.sync()).map_err(|err| ExecutionError::WithContext {
reason: format!("{err}"),
})
}
fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
device: &Self::Device,
input: Input,
func: Func,
) -> Output {
let client = R::client(device);
client.memory_persistent_allocation(input, func)
}
fn memory_cleanup(device: &Self::Device) {
let client = R::client(device);
client.memory_cleanup();
}
fn staging<'a, Iter>(data: Iter, device: &Self::Device)
where
Iter: Iterator<Item = &'a mut TensorData>,
{
let client = R::client(device);
client.staging(data.map(|td| &mut td.bytes), false);
}
fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
let client = R::client(device);
let type_usage = client.properties().type_usage(dtype.into());
// Same as `TypeUsage::all_scalar()`, but we make the usage explicit here
type_usage.is_superset(
TypeUsage::Buffer
| TypeUsage::Conversion
| TypeUsage::Arithmetic
| TypeUsage::DotProduct,
)
}
fn dtype_usage(device: &Self::Device, dtype: DType) -> DTypeUsageSet {
let client = R::client(device);
let props = client.properties();
let storage = dtype.into();
let usage = props.type_usage(storage);
let mut out = DTypeUsageSet::new();
if usage.is_superset(TypeUsage::Buffer | TypeUsage::Conversion) {
out |= DTypeUsage::Storage;
}
if usage.contains(TypeUsage::Arithmetic) {
out |= DTypeUsage::Arithmetic;
}
let has_mma = |cfg: &MmaConfig| {
cfg.a_type == storage || cfg.b_type == storage || cfg.cd_type == storage
};
if props.features.cmma.iter().any(has_mma) || props.features.mma.iter().any(has_mma) {
out |= DTypeUsage::Accelerated;
}
out
}
}
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> core::fmt::Debug
for CubeBackend<R, F, I, BT>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("CubeCLBackend")
}
}
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Clone
for CubeBackend<R, F, I, BT>
{
fn clone(&self) -> Self {
Self::new()
}
}
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Default
for CubeBackend<R, F, I, BT>
{
fn default() -> Self {
Self::new()
}
}
impl<R: cubecl::Runtime> CubeRuntime for R
where
R::Device: DeviceOps,
{
type CubeDevice = R::Device;
type CubeServer = R::Server;
}
#[cfg(not(feature = "fusion"))]
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
for CubeBackend<R, F, I, BT>
{
type Handle = CubeTensor<R>;
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
handle.handle
}
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
handle.handle
}
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
handle.handle
}
fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
handle.handle
}
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
tensor
}
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
tensor
}
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
tensor
}
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
tensor
}
}

View File

@@ -0,0 +1,94 @@
use burn_backend::{Element, bf16, f16};
use cubecl::{
CubeElement as CubeElem, flex32,
prelude::{Float, Int, Numeric},
};
use cubek::{
matmul::definition::{MatmulPrecision, MatrixPrecision},
reduce::ReducePrecision,
};
/// The base element trait for the jit backend.
pub trait CubeElement: Element + CubeElem + PartialEq + Numeric {}
/// Element that can be used for matrix multiplication. Includes ints and floats.
pub trait MatmulElement:
CubeElement + MatmulPrecision<Acc: MatrixPrecision<Global: CubeElement>>
{
}
/// The float element type for the jit backend.
pub trait FloatElement: MatmulElement + Float {}
/// The int element type for the jit backend.
pub trait IntElement:
MatmulElement + Int + ReducePrecision<EI: CubeElement, EA: CubeElement>
{
}
/// The element type for booleans for the jit backend.
pub trait BoolElement: CubeElement + Int {
/// The true value for the boolean element.
fn true_val() -> Self {
Self::from_int(1)
}
/// The false value for the boolean element.
fn false_val() -> Self {
Self::from_int(0)
}
/// New bool element from Rust bool.
fn new_bool(val: bool) -> Self {
match val {
true => Self::true_val(),
false => Self::false_val(),
}
}
}
impl CubeElement for u64 {}
impl CubeElement for u32 {}
impl CubeElement for u16 {}
impl CubeElement for u8 {}
impl CubeElement for i64 {}
impl CubeElement for i32 {}
impl CubeElement for i16 {}
impl CubeElement for i8 {}
impl CubeElement for f64 {}
impl CubeElement for f32 {}
impl CubeElement for flex32 {}
impl CubeElement for f16 {}
impl CubeElement for bf16 {}
impl FloatElement for f64 {}
impl FloatElement for f32 {}
impl FloatElement for flex32 {}
impl FloatElement for bf16 {}
impl FloatElement for f16 {}
impl IntElement for i64 {}
impl IntElement for i32 {}
impl IntElement for i16 {}
impl IntElement for i8 {}
impl IntElement for u64 {}
impl IntElement for u32 {}
impl IntElement for u16 {}
impl IntElement for u8 {}
impl BoolElement for u8 {}
impl BoolElement for u32 {}
impl MatmulElement for f64 {}
impl MatmulElement for f32 {}
impl MatmulElement for flex32 {}
impl MatmulElement for bf16 {}
impl MatmulElement for f16 {}
impl MatmulElement for i64 {}
impl MatmulElement for i32 {}
impl MatmulElement for i16 {}
impl MatmulElement for i8 {}
impl MatmulElement for u64 {}
impl MatmulElement for u32 {}
impl MatmulElement for u16 {}
impl MatmulElement for u8 {}

View File

@@ -0,0 +1,205 @@
use crate::BoolElement;
use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, kernel, tensor::CubeTensor};
use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
use burn_backend::{DType, Shape};
use burn_cubecl_fusion::optim::reduce::ReduceSettings;
use burn_cubecl_fusion::optim::reduce_broadcasted::ReduceBroadcastedFuser;
use burn_cubecl_fusion::{
CubeFusionHandle, FallbackOperation,
optim::{
CubeOptimization, CubeOptimizationState,
elemwise::{ElementWiseFuser, ElemwiseOptimization},
matmul::{MatmulFuser, MatmulOptimization},
reduce::{ReduceFuser, ReduceOptimization},
reduce_broadcasted::ReduceBroadcastedOptimization,
},
};
use burn_fusion::{
FusionBackend, FusionRuntime,
stream::{Operation, OrderedExecution},
};
use burn_ir::{BackendIr, TensorHandle};
use burn_std::Metadata;
use core::marker::PhantomData;
use std::sync::Arc;
impl<R, BT> burn_fusion::Optimization<FusionCubeRuntime<R, BT>> for CubeOptimization<R>
where
R: CubeRuntime,
BT: BoolElement,
{
fn execute(
&mut self,
context: &mut burn_fusion::stream::Context<
'_,
<FusionCubeRuntime<R, BT> as FusionRuntime>::FusionHandle,
>,
execution: &OrderedExecution<FusionCubeRuntime<R, BT>>,
) {
match self {
Self::ElementWise(op) => op.execute::<BT>(context),
Self::Matmul(op) => op.execute::<BT>(context, |index| {
let operation = execution.operation_within_optimization(index);
Box::new(FallbackOperationWrapper::new(operation))
}),
Self::Reduce(op) => op.execute::<BT>(context, |index| {
let operation = execution.operation_within_optimization(index);
Box::new(FallbackOperationWrapper::new(operation))
}),
Self::ReduceBroadcasted(op) => op.execute::<BT>(context, |index| {
let operation = execution.operation_within_optimization(index);
Box::new(FallbackOperationWrapper::new(operation))
}),
}
}
fn to_state(&self) -> CubeOptimizationState {
self.to_opt_state()
}
fn from_state(device: &R::Device, state: CubeOptimizationState) -> Self {
match state {
CubeOptimizationState::ElementWise(state) => {
Self::ElementWise(ElemwiseOptimization::from_state(device, state))
}
CubeOptimizationState::Matmul(state) => {
Self::Matmul(MatmulOptimization::from_state(device, state))
}
CubeOptimizationState::Reduce(state) => {
Self::Reduce(ReduceOptimization::from_state(device, state))
}
CubeOptimizationState::ReduceBroadcasted(state) => {
Self::ReduceBroadcasted(ReduceBroadcastedOptimization::from_state(device, state))
}
}
}
}
struct FallbackOperationWrapper<O: Clone> {
operation: O,
}
impl<O: Clone> FallbackOperationWrapper<O> {
fn new(op: O) -> Self {
Self { operation: op }
}
}
impl<R: CubeRuntime, BT: BoolElement> FallbackOperation<R>
for FallbackOperationWrapper<Arc<dyn Operation<FusionCubeRuntime<R, BT>>>>
{
fn run(&self, context: &mut burn_fusion::stream::Context<'_, CubeFusionHandle<R>>) {
self.operation.as_ref().execute(context.handles);
}
}
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
for CubeBackend<R, F, I, BT>
{
type Handle = CubeFusionHandle<R>;
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
into_tensor(handle.handle, handle.shape)
}
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
into_tensor(handle.handle, handle.shape)
}
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
into_tensor(handle.handle, handle.shape)
}
fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
into_tensor(handle.handle, handle.shape)
}
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
tensor.into()
}
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
tensor.into()
}
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
tensor.into()
}
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
tensor.into()
}
}
impl<R: CubeRuntime, BT: BoolElement> FusionRuntime for FusionCubeRuntime<R, BT> {
type OptimizationState = CubeOptimizationState;
type Optimization = CubeOptimization<R>;
type FusionHandle = CubeFusionHandle<R>;
type FusionDevice = R::CubeDevice;
type BoolRepr = BT;
fn fusers(device: R::Device) -> Vec<Box<dyn burn_fusion::OperationFuser<Self::Optimization>>> {
vec![
Box::new(ElementWiseFuser::new(
device.clone(),
BT::as_type_native_unchecked().into(),
)),
Box::new(MatmulFuser::new(
device.clone(),
BT::as_type_native_unchecked().into(),
)),
Box::new(ReduceFuser::new(
device.clone(),
BT::as_type_native_unchecked().into(),
ReduceSettings::Always,
)),
Box::new(ReduceBroadcastedFuser::new(
device.clone(),
BT::as_type_native_unchecked().into(),
)),
]
}
}
/// Fusion runtime for JIT runtimes.
#[derive(Debug)]
pub struct FusionCubeRuntime<R: CubeRuntime, BT: BoolElement> {
_b: PhantomData<R>,
_bool: PhantomData<BT>,
}
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> FusionBackend
for CubeBackend<R, F, I, BT>
{
type FusionRuntime = FusionCubeRuntime<R, BT>;
type FullPrecisionBackend = CubeBackend<R, f32, i32, BT>;
fn cast_float(tensor: FloatTensor<Self>, dtype: DType) -> Self::Handle {
kernel::cast(tensor, dtype).into()
}
}
fn into_tensor<R: CubeRuntime>(handle: CubeFusionHandle<R>, shape: Shape) -> CubeTensor<R> {
CubeTensor {
client: handle.client,
handle: handle.handle,
device: handle.device,
meta: Box::new(Metadata::new(shape, handle.strides)),
dtype: handle.dtype,
qparams: handle.qparams,
}
}
impl<R: CubeRuntime> From<CubeTensor<R>> for CubeFusionHandle<R> {
fn from(value: CubeTensor<R>) -> Self {
Self {
client: value.client,
handle: value.handle,
device: value.device,
strides: value.meta.strides,
dtype: value.dtype,
qparams: value.qparams,
}
}
}

View File

@@ -0,0 +1,150 @@
use crate::{
CubeBackend, CubeRuntime, kernel::attention::attention_autotune,
ops::numeric::empty_device_dtype, tensor::CubeTensor,
};
use burn_backend::{
DType, Shape,
ops::{AttentionModuleOptions, attention::attention_fallback},
};
use cubek::attention::definition::{
AccumulatorPrecision, AttentionGlobalTypes, AttentionOptions, AttentionSetupError,
};
use cubek::attention::launch;
#[derive(Debug)]
/// Strategy used to select which attention implementation to run.
pub enum AttentionStrategy {
/// Flash Attention using accelerated inner matmuls.
FlashBlackboxAccelerated,
/// Flash Attention using unit inner matmuls.
FlashUnit,
/// Fallback implementation using multiple separate kernels.
Fallback,
/// Automatically benchmark and select the best strategy at runtime.
#[cfg(feature = "autotune")]
Autotune,
}
impl Default for AttentionStrategy {
fn default() -> Self {
// if autotune is enabled, default to autotune
#[cfg(feature = "autotune")]
return AttentionStrategy::Autotune;
// if autotune is disabled, default to fallback to make sure it runs
#[cfg(not(feature = "autotune"))]
AttentionStrategy::Fallback
}
}
#[allow(clippy::too_many_arguments)]
/// Launch an attention kernel with given strategy
pub fn attention<R: CubeRuntime>(
query: CubeTensor<R>,
key: CubeTensor<R>,
value: CubeTensor<R>,
mask: Option<CubeTensor<R>>,
attn_bias: Option<CubeTensor<R>>,
options: AttentionModuleOptions,
strategy: &AttentionStrategy,
out: Option<CubeTensor<R>>,
) -> Result<CubeTensor<R>, AttentionSetupError> {
let mut out = out.unwrap_or_else(|| init_attention_output(&query, &value));
match strategy {
AttentionStrategy::FlashBlackboxAccelerated => flash_attention(
query,
key,
value,
mask,
attn_bias,
options,
out,
launch::Strategy::BlackboxAccelerated(
cubek::attention::launch::BlueprintStrategy::Inferred(()),
),
),
AttentionStrategy::FlashUnit => flash_attention(
query,
key,
value,
mask,
attn_bias,
options,
out,
launch::Strategy::Unit(cubek::attention::launch::BlueprintStrategy::Inferred(())),
),
AttentionStrategy::Fallback => {
out = attention_fallback::<CubeBackend<R, f32, i32, u8>>(
query, key, value, mask, attn_bias, options,
);
Ok(out)
}
#[cfg(feature = "autotune")]
AttentionStrategy::Autotune => {
attention_autotune(query, key, value, mask, attn_bias, options, out)
}
}
}
#[allow(clippy::too_many_arguments)]
/// Launch a flash attention kernel
pub fn flash_attention<R: CubeRuntime>(
query: CubeTensor<R>,
key: CubeTensor<R>,
value: CubeTensor<R>,
mask: Option<CubeTensor<R>>,
_attn_bias: Option<CubeTensor<R>>,
options: AttentionModuleOptions,
out: CubeTensor<R>,
strategy: launch::Strategy,
) -> Result<CubeTensor<R>, AttentionSetupError> {
let client = &query.client;
let dtypes = AttentionGlobalTypes {
query: query.dtype.into(),
key: key.dtype.into(),
value: value.dtype.into(),
mask: mask.as_ref().map(|m| m.dtype).unwrap_or(DType::U8).into(),
out: out.dtype.into(),
};
cubek::attention::launch::launch_ref::<R>(
strategy,
client,
&query.as_handle_ref(),
&key.as_handle_ref(),
&value.as_handle_ref(),
&mask.as_ref().map(|mask| mask.as_handle_ref()),
&out.as_handle_ref(),
&dtypes,
AttentionOptions {
causal: options.is_causal,
accumulator_precision: AccumulatorPrecision::Strict(cubecl::ir::StorageType::Scalar(
cubecl::ir::ElemType::Float(cubecl::ir::FloatKind::F32),
)),
},
)?;
Ok(out)
}
pub(crate) fn init_attention_output<R: CubeRuntime>(
query: &CubeTensor<R>,
value: &CubeTensor<R>,
) -> CubeTensor<R> {
let num_batches = query.meta.shape[0];
let num_heads = query.meta.shape[1];
let seq_q = query.meta.shape[2];
let val_dim = value.meta.shape[3];
let out_shape = Shape::new([num_batches, num_heads, seq_q, val_dim]);
empty_device_dtype::<R>(
query.client.clone(),
query.device.clone(),
out_shape,
query.dtype,
)
}

View File

@@ -0,0 +1,5 @@
mod base;
mod tune;
pub use base::*;
pub use tune::*;

View File

@@ -0,0 +1,166 @@
use crate::{
CubeRuntime, CubeTuneId,
kernel::attention::{AttentionStrategy, attention},
tensor::CubeTensor,
};
use burn_backend::ops::AttentionModuleOptions;
use cubecl::tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner};
use cubek::attention::{definition::AttentionSetupError, launch::AttentionAutotuneKey};
/// Executes autotune on attention operations
pub fn attention_autotune<R: CubeRuntime>(
query: CubeTensor<R>,
key: CubeTensor<R>,
value: CubeTensor<R>,
mask: Option<CubeTensor<R>>,
attn_bias: Option<CubeTensor<R>>,
options: AttentionModuleOptions,
out: CubeTensor<R>,
) -> Result<CubeTensor<R>, AttentionSetupError> {
let client = query.client.clone();
static TUNER: LocalTuner<AttentionAutotuneKey, CubeTuneId> = local_tuner!();
let tunables = TUNER.init(|| {
const PRIORITY_MAX: i8 = 3;
const PRIORITY_MIN: i8 = 0;
let flash_attention =
TuneGroup::<AttentionAutotuneKey>::new("flash_attention", |_key| PRIORITY_MAX);
let fallback = TuneGroup::<AttentionAutotuneKey>::new("fallback", |_key| PRIORITY_MIN);
let mut set = TunableSet::new(create_key::<R>, input_gen::<R>);
// First entry should always work, since it is considered the fallback.
set = set.with(
Tunable::new(
"fallback",
|query, key, value, mask, attn_bias, out, options| {
attention::<R>(
query,
key,
value,
mask,
attn_bias,
options,
&AttentionStrategy::Fallback,
Some(out),
)
.map_err(|err| std::format!("{err:?}"))
},
)
.group(&fallback, |_key| PRIORITY_MAX),
);
set = set.with(
Tunable::new(
"blackbox_accelerated",
|query, key, value, mask, attn_bias, out, options| {
attention::<R>(
query,
key,
value,
mask,
attn_bias,
options,
&AttentionStrategy::FlashBlackboxAccelerated,
Some(out),
)
.map_err(|err| std::format!("{err:?}"))
},
)
.group(&flash_attention, |_key| PRIORITY_MAX),
);
set = set.with(
Tunable::new(
"unit",
|query, key, value, mask, attn_bias, out, options| {
attention::<R>(
query,
key,
value,
mask,
attn_bias,
options,
&AttentionStrategy::FlashUnit,
Some(out),
)
.map_err(|err| std::format!("{err:?}"))
},
)
.group(&flash_attention, |_key| PRIORITY_MIN),
);
set
});
TUNER.execute(
&CubeTuneId::new(&client, &query.device),
&client,
tunables,
(query, key, value, mask, attn_bias, out.clone(), options),
);
Ok(out)
}
fn create_key<R: CubeRuntime>(
query: &CubeTensor<R>,
key: &CubeTensor<R>,
value: &CubeTensor<R>,
mask: &Option<CubeTensor<R>>,
_attn_bias: &Option<CubeTensor<R>>,
out: &CubeTensor<R>,
_options: &AttentionModuleOptions,
) -> AttentionAutotuneKey {
let total_batches = query.meta.shape[0] * query.meta.shape[1];
let seq_q = query.meta.shape[2];
let head_dim = query.meta.shape[3];
let seq_kv = value.meta.shape[2];
let val_dim = value.meta.shape[3];
AttentionAutotuneKey::generate(
query.dtype.into(),
key.dtype.into(),
value.dtype.into(),
out.dtype.into(),
total_batches,
seq_q,
head_dim,
seq_kv,
val_dim,
mask.is_some(),
)
}
#[allow(clippy::type_complexity)]
#[allow(clippy::too_many_arguments)]
fn input_gen<R: CubeRuntime>(
_key: &AttentionAutotuneKey,
query: &CubeTensor<R>,
key: &CubeTensor<R>,
value: &CubeTensor<R>,
mask: &Option<CubeTensor<R>>,
attn_bias: &Option<CubeTensor<R>>,
out: &CubeTensor<R>,
options: &AttentionModuleOptions,
) -> (
CubeTensor<R>,
CubeTensor<R>,
CubeTensor<R>,
Option<CubeTensor<R>>,
Option<CubeTensor<R>>,
CubeTensor<R>,
AttentionModuleOptions,
) {
(
query.clone(),
key.clone(),
value.clone(),
mask.clone(),
attn_bias.clone(),
out.copy(),
*options,
)
}

View File

@@ -0,0 +1,302 @@
use crate::{
CubeRuntime,
kernel::utils::{
address_type, broadcast_shape, linear_view, linear_view_alias, linear_view_ref,
},
ops::{max_line_size, numeric::empty_device_dtype},
tensor::CubeTensor,
};
use burn_backend::{TensorMetadata, bf16, f16};
use cubecl::{
calculate_cube_count_elemwise, intrinsic, prelude::*, std::tensor::layout::linear::LinearView,
};
pub(crate) trait BinaryOpFamily: Send + Sync + 'static {
type BinaryOp<C: Numeric>: BinaryOp<C>;
}
#[cube]
pub(crate) trait BinaryOp<C: Numeric>: 'static + Send + Sync {
/// Execute a binary operation.
fn execute(lhs: Line<C>, rhs: Line<C>) -> Line<C>;
}
pub(crate) struct AddOp;
pub(crate) struct SubOp;
pub(crate) struct MulOp;
pub(crate) struct DivOp;
pub(crate) struct RemainderOp;
pub(crate) struct AndOp;
pub(crate) struct OrOp;
pub(crate) struct PowOp;
impl BinaryOpFamily for AddOp {
type BinaryOp<C: Numeric> = Self;
}
impl BinaryOpFamily for SubOp {
type BinaryOp<C: Numeric> = Self;
}
impl BinaryOpFamily for MulOp {
type BinaryOp<C: Numeric> = Self;
}
impl BinaryOpFamily for DivOp {
type BinaryOp<C: Numeric> = Self;
}
impl BinaryOpFamily for RemainderOp {
type BinaryOp<C: Numeric> = Self;
}
impl BinaryOpFamily for PowOp {
type BinaryOp<C: Numeric> = Self;
}
impl BinaryOpFamily for AndOp {
type BinaryOp<C: Numeric> = Self;
}
impl BinaryOpFamily for OrOp {
type BinaryOp<C: Numeric> = Self;
}
#[cube]
impl<N: Numeric> BinaryOp<N> for AddOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
lhs + rhs
}
}
#[cube]
impl<N: Numeric> BinaryOp<N> for SubOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
lhs - rhs
}
}
#[cube]
impl<N: Numeric> BinaryOp<N> for MulOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
lhs * rhs
}
}
#[cube]
impl<N: Numeric> BinaryOp<N> for DivOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
lhs / rhs
}
}
#[cube]
impl<N: Numeric> BinaryOp<N> for RemainderOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
Line::rem(lhs, rhs)
}
}
#[cube]
impl<N: Numeric> BinaryOp<N> for PowOp {
#[allow(unused)]
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
intrinsic!(|scope| {
let elem = N::as_type(scope).elem_type();
if let cubecl::ir::ElemType::Float(kind) = elem {
match kind {
cubecl::ir::FloatKind::F16 => {
let lhs = <Line<f16> as Cast>::__expand_cast_from(scope, lhs);
let rhs = <Line<f16> as Cast>::__expand_cast_from(scope, rhs);
let out = Line::__expand_powf(scope, lhs, rhs);
return <Line<N> as Cast>::__expand_cast_from(scope, out);
}
cubecl::ir::FloatKind::BF16 => {
let lhs = <Line<bf16> as Cast>::__expand_cast_from(scope, lhs);
let rhs = <Line<bf16> as Cast>::__expand_cast_from(scope, rhs);
let out = Line::__expand_powf(scope, lhs, rhs);
return <Line<N> as Cast>::__expand_cast_from(scope, out);
}
cubecl::ir::FloatKind::F64 => {
let lhs = <Line<f64> as Cast>::__expand_cast_from(scope, lhs);
let rhs = <Line<f64> as Cast>::__expand_cast_from(scope, rhs);
let out = Line::__expand_powf(scope, lhs, rhs);
return <Line<N> as Cast>::__expand_cast_from(scope, out);
}
_ => {}
}
};
let lhs = <Line<f32> as Cast>::__expand_cast_from(scope, lhs);
let rhs = <Line<f32> as Cast>::__expand_cast_from(scope, rhs);
let out = Line::__expand_powf(scope, lhs, rhs);
return <Line<N> as Cast>::__expand_cast_from(scope, out);
})
}
}
#[cube]
impl<N: Numeric> BinaryOp<N> for AndOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
Line::cast_from(Line::<bool>::cast_from(lhs).and(Line::<bool>::cast_from(rhs)))
}
}
#[cube]
impl<N: Numeric> BinaryOp<N> for OrOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
Line::cast_from(Line::<bool>::cast_from(lhs).or(Line::<bool>::cast_from(rhs)))
}
}
#[cube(launch_unchecked, address_type = "dynamic")]
pub(crate) fn kernel_scalar_binop<C: Numeric, O: BinaryOpFamily>(
input: &LinearView<Line<C>>,
scalar: InputScalar,
output: &mut LinearView<Line<C>, ReadWrite>,
#[define(C)] _dtype: StorageType,
) {
if !output.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
output[ABSOLUTE_POS] =
O::BinaryOp::<C>::execute(input[ABSOLUTE_POS], Line::new(scalar.get::<C>()));
}
#[cube(launch_unchecked, address_type = "dynamic")]
pub(crate) fn kernel_binop<C: Numeric, O: BinaryOpFamily>(
lhs: &LinearView<Line<C>>,
rhs: &LinearView<Line<C>>,
out: &mut LinearView<Line<C>, ReadWrite>,
#[define(C)] _dtype: StorageType,
) {
if !out.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
out[ABSOLUTE_POS] = O::BinaryOp::<C>::execute(lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS]);
}
pub(crate) fn launch_binop<R: CubeRuntime, O: BinaryOpFamily>(
lhs: CubeTensor<R>,
rhs: CubeTensor<R>,
) -> CubeTensor<R> {
let line_size_lhs = max_line_size(&lhs);
let line_size_rhs = max_line_size(&rhs);
let line_size = Ord::min(line_size_lhs, line_size_rhs);
let shape_out = broadcast_shape(&[&lhs, &rhs]);
let dtype = lhs.dtype;
let client = lhs.client.clone();
let num_elems = shape_out.num_elements();
let working_units = num_elems / line_size as usize;
let cube_dim = CubeDim::new(&lhs.client, working_units);
let cube_count = calculate_cube_count_elemwise(&lhs.client, working_units, cube_dim);
unsafe {
if lhs.can_mut_broadcast(&rhs) {
kernel_binop::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(lhs, rhs),
linear_view(&lhs, line_size),
linear_view_ref(&rhs, &lhs, line_size),
linear_view_alias(&lhs, line_size, 0),
dtype.into(),
)
.expect("Kernel to never fail");
lhs
} else if rhs.can_mut_broadcast(&lhs) {
kernel_binop::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(lhs, rhs),
linear_view_ref(&lhs, &rhs, line_size),
linear_view(&rhs, line_size),
linear_view_alias(&rhs, line_size, 1),
dtype.into(),
)
.expect("Kernel to never fail");
rhs
} else {
let output =
empty_device_dtype(lhs.client.clone(), lhs.device.clone(), shape_out, dtype);
kernel_binop::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(lhs, rhs, output),
linear_view_ref(&lhs, &output, line_size),
linear_view_ref(&rhs, &output, line_size),
linear_view(&output, line_size),
dtype.into(),
)
.expect("Kernel to never fail");
output
}
}
}
pub(crate) fn launch_scalar_binop<R: CubeRuntime, O: BinaryOpFamily>(
tensor: CubeTensor<R>,
scalar: InputScalar,
) -> CubeTensor<R> {
// Vectorization is only enabled when the last dimension is contiguous.
let line_size = max_line_size(&tensor);
let client = tensor.client.clone();
let num_elems = tensor.meta.num_elements();
let dtype = tensor.dtype;
let working_units = num_elems / line_size as usize;
let cube_dim = CubeDim::new(&tensor.client, working_units);
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
unsafe {
if tensor.can_mut() && tensor.is_nonoverlapping() {
kernel_scalar_binop::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(tensor),
linear_view(&tensor, line_size),
scalar,
linear_view_alias(&tensor, line_size, 0),
dtype.into(),
)
.expect("Kernel to never fail");
tensor
} else {
let output = empty_device_dtype(
tensor.client.clone(),
tensor.device.clone(),
tensor.shape(),
dtype,
);
kernel_scalar_binop::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(tensor, output),
linear_view(&tensor, line_size),
scalar,
linear_view(&output, line_size),
dtype.into(),
)
.expect("Kernel to never fail");
output
}
}
}

View File

@@ -0,0 +1,119 @@
use crate::{
CubeRuntime,
kernel::utils::{
address_type, broadcast_shape, linear_view, linear_view_alias, linear_view_ref,
},
ops::{max_line_size, numeric::empty_device_dtype},
tensor::CubeTensor,
};
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
pub(crate) trait BinaryOpFloatFamily: Send + Sync + 'static {
type BinaryOp<C: Float>: BinaryOpFloat<C>;
}
#[cube]
pub(crate) trait BinaryOpFloat<C: Float>: 'static + Send + Sync {
/// Execute a binary operation.
fn execute(lhs: Line<C>, rhs: Line<C>) -> Line<C>;
}
pub(crate) struct ArcTan2Op;
impl BinaryOpFloatFamily for ArcTan2Op {
type BinaryOp<C: Float> = Self;
}
#[cube]
impl<N: Float> BinaryOpFloat<N> for ArcTan2Op {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
Line::atan2(lhs, rhs)
}
}
#[cube(launch_unchecked, address_type = "dynamic")]
pub(crate) fn kernel_binop<C: Float, O: BinaryOpFloatFamily>(
lhs: &LinearView<Line<C>>,
rhs: &LinearView<Line<C>>,
out: &mut LinearView<Line<C>, ReadWrite>,
#[define(C)] _dtype: StorageType,
) {
if !out.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
out[ABSOLUTE_POS] = O::BinaryOp::<C>::execute(lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS]);
}
pub(crate) fn launch_binop_float<R: CubeRuntime, O: BinaryOpFloatFamily>(
lhs: CubeTensor<R>,
rhs: CubeTensor<R>,
) -> CubeTensor<R> {
let line_size_lhs = max_line_size(&lhs);
let line_size_rhs = max_line_size(&rhs);
let line_size = Ord::min(line_size_lhs, line_size_rhs);
let shape_out = broadcast_shape(&[&lhs, &rhs]);
let dtype = lhs.dtype;
let client = lhs.client.clone();
let num_elems = shape_out.num_elements();
let working_units = num_elems / line_size as usize;
let cube_dim = CubeDim::new(&lhs.client, working_units);
let cube_count = calculate_cube_count_elemwise(&lhs.client, working_units, cube_dim);
unsafe {
if lhs.can_mut_broadcast(&rhs) {
kernel_binop::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(lhs, rhs),
linear_view(&lhs, line_size),
linear_view_ref(&rhs, &lhs, line_size),
linear_view_alias(&lhs, line_size, 0),
dtype.into(),
)
.expect("Kernel to never fail");
lhs
} else if rhs.can_mut_broadcast(&lhs) {
kernel_binop::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(lhs, rhs),
linear_view_ref(&lhs, &rhs, line_size),
linear_view(&rhs, line_size),
linear_view_alias(&rhs, line_size, 1),
dtype.into(),
)
.expect("Kernel to never fail");
rhs
} else {
let output =
empty_device_dtype(lhs.client.clone(), lhs.device.clone(), shape_out, dtype);
kernel_binop::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(lhs, rhs, output),
linear_view_ref(&lhs, &output, line_size),
linear_view_ref(&rhs, &output, line_size),
linear_view(&output, line_size),
dtype.into(),
)
.expect("Kernel to never fail");
output
}
}
}
/// Calculate the four-quadrant inverse tangent of `lhs / rhs`.
pub fn atan2<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
launch_binop_float::<R, ArcTan2Op>(lhs, rhs)
}

View File

@@ -0,0 +1,229 @@
use crate::{
CubeRuntime,
kernel::utils::{
address_type, broadcast_shape, linear_view, linear_view_alias, linear_view_ref,
},
ops::{max_line_size, numeric::empty_device_dtype},
tensor::CubeTensor,
};
use burn_backend::TensorMetadata;
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
pub(crate) trait BinaryOpIntFamily: Send + Sync + 'static {
type BinaryOp<C: Int>: BinaryOpInt<C>;
}
#[cube]
pub(crate) trait BinaryOpInt<C: Int>: 'static + Send + Sync {
/// Execute a binary operation.
fn execute(lhs: Line<C>, rhs: Line<C>) -> Line<C>;
}
pub(crate) struct BitwiseAndOp;
pub(crate) struct BitwiseOrOp;
pub(crate) struct BitwiseXorOp;
pub(crate) struct BitwiseShrOp;
pub(crate) struct BitwiseShlOp;
impl BinaryOpIntFamily for BitwiseAndOp {
type BinaryOp<C: Int> = Self;
}
impl BinaryOpIntFamily for BitwiseOrOp {
type BinaryOp<C: Int> = Self;
}
impl BinaryOpIntFamily for BitwiseXorOp {
type BinaryOp<C: Int> = Self;
}
impl BinaryOpIntFamily for BitwiseShrOp {
type BinaryOp<C: Int> = Self;
}
impl BinaryOpIntFamily for BitwiseShlOp {
type BinaryOp<C: Int> = Self;
}
#[cube]
impl<N: Int> BinaryOpInt<N> for BitwiseAndOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
lhs & rhs
}
}
#[cube]
impl<N: Int> BinaryOpInt<N> for BitwiseOrOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
lhs | rhs
}
}
#[cube]
impl<N: Int> BinaryOpInt<N> for BitwiseXorOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
lhs ^ rhs
}
}
#[cube]
impl<N: Int> BinaryOpInt<N> for BitwiseShrOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
lhs >> rhs
}
}
#[cube]
impl<N: Int> BinaryOpInt<N> for BitwiseShlOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
lhs << rhs
}
}
#[cube(launch_unchecked, address_type = "dynamic")]
pub(crate) fn kernel_scalar_binop_int<C: Int, O: BinaryOpIntFamily>(
input: &LinearView<Line<C>>,
scalar: InputScalar,
output: &mut LinearView<Line<C>, ReadWrite>,
#[define(C)] _dtype: StorageType,
) {
if !output.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
output[ABSOLUTE_POS] =
O::BinaryOp::<C>::execute(input[ABSOLUTE_POS], Line::new(scalar.get::<C>()));
}
#[cube(launch_unchecked, address_type = "dynamic")]
pub(crate) fn kernel_binop_int<C: Int, O: BinaryOpIntFamily>(
lhs: &LinearView<Line<C>>,
rhs: &LinearView<Line<C>>,
out: &mut LinearView<Line<C>, ReadWrite>,
#[define(C)] _dtype: StorageType,
) {
if !out.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
out[ABSOLUTE_POS] = O::BinaryOp::<C>::execute(lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS]);
}
pub(crate) fn launch_binop_int<R: CubeRuntime, O: BinaryOpIntFamily>(
lhs: CubeTensor<R>,
rhs: CubeTensor<R>,
) -> CubeTensor<R> {
let line_size_lhs = max_line_size(&lhs);
let line_size_rhs = max_line_size(&rhs);
let line_size = Ord::min(line_size_lhs, line_size_rhs);
let shape_out = broadcast_shape(&[&lhs, &rhs]);
let client = lhs.client.clone();
let num_elems = shape_out.num_elements();
let working_units = num_elems / line_size as usize;
let cube_dim = CubeDim::new(&lhs.client, working_units);
let cube_count = calculate_cube_count_elemwise(&lhs.client, working_units, cube_dim);
unsafe {
if lhs.can_mut_broadcast(&rhs) {
kernel_binop_int::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(lhs, rhs),
linear_view(&lhs, line_size),
linear_view_ref(&rhs, &lhs, line_size),
linear_view_alias(&lhs, line_size, 0),
lhs.dtype.into(),
)
.expect("Kernel to never fail");
lhs
} else if rhs.can_mut_broadcast(&lhs) {
kernel_binop_int::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(lhs, rhs),
linear_view_ref(&lhs, &rhs, line_size),
linear_view(&rhs, line_size),
linear_view_alias(&rhs, line_size, 1),
lhs.dtype.into(),
)
.expect("Kernel to never fail");
rhs
} else {
let output =
empty_device_dtype(lhs.client.clone(), lhs.device.clone(), shape_out, lhs.dtype);
kernel_binop_int::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(lhs, rhs, output),
linear_view_ref(&lhs, &output, line_size),
linear_view_ref(&rhs, &output, line_size),
linear_view(&output, line_size),
lhs.dtype.into(),
)
.expect("Kernel to never fail");
output
}
}
}
pub(crate) fn launch_scalar_binop_int<R: CubeRuntime, O: BinaryOpIntFamily>(
tensor: CubeTensor<R>,
scalar: InputScalar,
) -> CubeTensor<R> {
let line_size = max_line_size(&tensor);
let client = tensor.client.clone();
let num_elems = tensor.meta.shape.num_elements();
let working_units = num_elems / line_size as usize;
let cube_dim = CubeDim::new(&tensor.client, working_units);
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
unsafe {
if tensor.can_mut() && tensor.is_nonoverlapping() {
kernel_scalar_binop_int::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(tensor),
linear_view(&tensor, line_size),
scalar,
linear_view_alias(&tensor, line_size, 0),
tensor.dtype.into(),
)
.expect("Kernel to never fail");
tensor
} else {
let output = empty_device_dtype(
tensor.client.clone(),
tensor.device.clone(),
tensor.shape(),
tensor.dtype,
);
kernel_scalar_binop_int::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(tensor, output),
linear_view(&tensor, line_size),
scalar,
linear_view(&output, line_size),
tensor.dtype.into(),
)
.expect("Kernel to never fail");
output
}
}
}

View File

@@ -0,0 +1,70 @@
use crate::{
CubeRuntime,
kernel::utils::{address_type, linear_view},
ops::{max_line_size, numeric::empty_device_dtype},
tensor::CubeTensor,
};
use burn_backend::{DType, TensorMetadata};
use cubecl::std::tensor::layout::linear::LinearView;
use cubecl::{calculate_cube_count_elemwise, prelude::*};
#[cube(launch, address_type = "dynamic")]
pub(crate) fn cast_element<I: Numeric, O: Numeric>(
input: &LinearView<Line<I>>,
output: &mut LinearView<Line<O>, ReadWrite>,
#[define(I, O)] _dtypes: [StorageType; 2],
) {
if !output.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
output[ABSOLUTE_POS] = Line::cast_from(input[ABSOLUTE_POS]);
}
/// Cast a tensor to the given element type.
///
/// Note: When input element is semantically a boolean, prefer bool_cast function.
pub fn cast<R: CubeRuntime>(input: CubeTensor<R>, dtype: DType) -> CubeTensor<R> {
let dtype_output = match dtype {
DType::Flex32 => DType::F32,
_ => dtype,
};
let dtype_input = match input.dtype {
DType::Flex32 => DType::F32,
_ => input.dtype,
};
if dtype_input == dtype_output {
return input;
}
let client = input.client.clone();
let line_size = max_line_size(&input);
let num_elems: usize = input.meta.num_elements();
let working_units = num_elems / line_size as usize;
let cube_dim = CubeDim::new(&client, working_units);
let cube_count = calculate_cube_count_elemwise(&client, working_units, cube_dim);
let output = empty_device_dtype(
client.clone(),
input.device.clone(),
input.shape(),
dtype, // We take the same dtype as passed as input (Flex32 not F32)
);
cast_element::launch(
&client,
cube_count,
cube_dim,
address_type!(input, output),
linear_view(&input, line_size),
linear_view(&output, line_size),
[dtype_input.into(), dtype_output.into()],
)
.expect("Kernel to never fail");
output
}

View File

@@ -0,0 +1,55 @@
use crate::{
CubeElement, CubeRuntime,
kernel::utils::{address_type, linear_view},
ops::{max_line_size, numeric::empty_device},
tensor::CubeTensor,
};
use burn_backend::TensorMetadata;
use cubecl::{
CubeDim, calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView,
};
#[cube(launch_unchecked, address_type = "dynamic")]
fn bool_cast_kernel<B: Int, T: Numeric>(
input: &LinearView<Line<B>>,
output: &mut LinearView<Line<T>, ReadWrite>,
#[define(B)] _input_ty: StorageType,
) {
if !output.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
output[ABSOLUTE_POS] = Line::cast_from(input[ABSOLUTE_POS] & Line::cast_from(1u32));
}
/// Cast a bool tensor to the given element type.
///
/// This alternative to cast is necessary because bool are represented as u32 or u8
/// where any non-zero value means true. Depending how it was created
/// it may hold an uncanny bit combination. Naively casting it would not
/// necessarily yield 0 or 1.
pub fn bool_cast<R: CubeRuntime, EO: CubeElement>(tensor: CubeTensor<R>) -> CubeTensor<R> {
let output =
empty_device::<R, EO>(tensor.client.clone(), tensor.device.clone(), tensor.shape());
let line_size = max_line_size(&tensor);
let num_elems = tensor.meta.num_elements();
let working_units = num_elems / line_size as usize;
let cube_dim = CubeDim::new(&tensor.client, working_units);
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
unsafe {
bool_cast_kernel::launch_unchecked::<EO, R>(
&tensor.client,
cube_count,
cube_dim,
address_type!(tensor, output),
linear_view(&tensor, line_size),
linear_view(&output, line_size),
tensor.dtype.into(),
)
.expect("Kernel to never fail");
}
output
}

View File

@@ -0,0 +1,5 @@
mod base;
mod bool_cast;
pub use base::*;
pub use bool_cast::*;

View File

@@ -0,0 +1,42 @@
use cubecl::prelude::*;
use crate::{
CubeRuntime,
kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric},
tensor::CubeTensor,
};
#[derive(CubeLaunch, CubeType)]
struct Options {
min_value: InputScalar,
max_value: InputScalar,
}
pub(crate) fn clamp<R: CubeRuntime>(
input: CubeTensor<R>,
min_value: InputScalar,
max_value: InputScalar,
) -> CubeTensor<R> {
struct ClampOp;
#[cube]
impl<N: Numeric> NumericUnaryOp<N> for ClampOp {
type Options = Options;
fn execute(input: Line<N>, options: &Self::Options) -> Line<N> {
let line_size = input.size();
cubecl::prelude::clamp(
input,
Line::empty(line_size).fill(options.min_value.get::<N>()),
Line::empty(line_size).fill(options.max_value.get::<N>()),
)
}
}
impl NumericUnaryOpFamily for ClampOp {
type Options = Options;
type Unary<N: Numeric> = Self;
}
launch_unary_numeric::<R, ClampOp, _>(input, |_| OptionsLaunch::new(min_value, max_value))
}

View File

@@ -0,0 +1,432 @@
use crate::{
CubeRuntime,
kernel::utils::{
address_type, broadcast_shape, linear_view, linear_view_alias, linear_view_ref,
},
ops::{max_line_size, numeric::empty_device_dtype},
tensor::CubeTensor,
};
use burn_backend::{DType, TensorMetadata};
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
#[cube]
pub(crate) trait ComparisonOpFamily: 'static + Send + Sync {
type Operation<N: Numeric>: ComparisonOp<N>;
}
#[cube]
pub(crate) trait ComparisonOp<C: Numeric>: 'static + Send + Sync {
/// Execute a comparison operation.
fn execute(lhs: Line<C>, rhs: Line<C>) -> bool;
}
struct EqualOp;
struct GreaterEqualOp;
struct LowerEqualOp;
struct GreaterOp;
struct LowerOp;
impl ComparisonOpFamily for EqualOp {
type Operation<N: Numeric> = Self;
}
#[cube]
impl<N: Numeric> ComparisonOp<N> for EqualOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> bool {
lhs == rhs
}
}
impl ComparisonOpFamily for GreaterEqualOp {
type Operation<N: Numeric> = Self;
}
#[cube]
impl<N: Numeric> ComparisonOp<N> for GreaterEqualOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> bool {
lhs >= rhs
}
}
impl ComparisonOpFamily for LowerEqualOp {
type Operation<N: Numeric> = Self;
}
#[cube]
impl<N: Numeric> ComparisonOp<N> for LowerEqualOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> bool {
lhs <= rhs
}
}
impl ComparisonOpFamily for GreaterOp {
type Operation<N: Numeric> = Self;
}
#[cube]
impl<N: Numeric> ComparisonOp<N> for GreaterOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> bool {
lhs > rhs
}
}
impl ComparisonOpFamily for LowerOp {
type Operation<N: Numeric> = Self;
}
#[cube]
impl<N: Numeric> ComparisonOp<N> for LowerOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> bool {
lhs < rhs
}
}
#[cube(launch_unchecked, address_type = "dynamic")]
pub(crate) fn kernel_scalar_cmp<N: Numeric, Bool: Numeric, O: ComparisonOpFamily>(
input: &LinearView<Line<N>>,
scalar: InputScalar,
output: &mut LinearView<Line<Bool>, ReadWrite>,
#[define(N, Bool)] _dtypes: [StorageType; 2],
) {
if !output.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
output[ABSOLUTE_POS] = Line::cast_from(O::Operation::<N>::execute(
input[ABSOLUTE_POS],
Line::new(scalar.get::<N>()),
));
}
#[cube(launch_unchecked, address_type = "dynamic")]
pub(crate) fn kernel_cmp<N: Numeric, Bool: Numeric, O: ComparisonOpFamily>(
lhs: &LinearView<Line<N>>,
rhs: &LinearView<Line<N>>,
out: &mut LinearView<Line<Bool>, ReadWrite>,
#[define(N, Bool)] _dtype: [StorageType; 2],
) {
if !out.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
out[ABSOLUTE_POS] = Line::cast_from(O::Operation::<N>::execute(
lhs[ABSOLUTE_POS],
rhs[ABSOLUTE_POS],
));
}
pub(crate) fn launch_cmp<R: CubeRuntime, O: ComparisonOpFamily>(
lhs: CubeTensor<R>,
rhs: CubeTensor<R>,
dtype_bool: DType,
) -> CubeTensor<R> {
let line_size_lhs = max_line_size(&lhs);
let line_size_rhs = max_line_size(&rhs);
let line_size = Ord::min(line_size_lhs, line_size_rhs);
let shape_out = broadcast_shape(&[&lhs, &rhs]);
let client = lhs.client.clone();
let num_elems = shape_out.num_elements();
let working_units = num_elems / line_size as usize;
let cube_dim = CubeDim::new(&lhs.client, working_units);
let cube_count = calculate_cube_count_elemwise(&lhs.client, working_units, cube_dim);
let dtypes = [lhs.dtype.into(), dtype_bool.into()];
let same_tensor_type = dtypes[0] == dtypes[1];
if same_tensor_type && lhs.can_mut_broadcast(&rhs) {
unsafe {
kernel_cmp::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(lhs, rhs),
linear_view(&lhs, line_size),
linear_view_ref(&rhs, &lhs, line_size),
linear_view_alias(&lhs, line_size, 0),
dtypes,
)
.expect("Kernel to never fail");
}
CubeTensor::new(lhs.client, lhs.handle, *lhs.meta, lhs.device, dtype_bool)
} else if same_tensor_type && rhs.can_mut_broadcast(&lhs) {
unsafe {
kernel_cmp::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(lhs, rhs),
linear_view_ref(&lhs, &rhs, line_size),
linear_view(&rhs, line_size),
linear_view_alias(&rhs, line_size, 1),
dtypes,
)
.expect("Kernel to never fail");
};
CubeTensor::new(rhs.client, rhs.handle, *rhs.meta, rhs.device, dtype_bool)
} else {
let output = empty_device_dtype(
lhs.client.clone(),
lhs.device.clone(),
shape_out,
dtype_bool,
);
unsafe {
kernel_cmp::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(lhs, rhs, output),
linear_view_ref(&lhs, &output, line_size),
linear_view_ref(&rhs, &output, line_size),
linear_view(&output, line_size),
dtypes,
)
.expect("Kernel to never fail");
};
output
}
}
pub(crate) fn launch_scalar_cmp<R: CubeRuntime, O: ComparisonOpFamily>(
tensor: CubeTensor<R>,
scalar: InputScalar,
dtype_bool: DType,
) -> CubeTensor<R> {
let line_size = max_line_size(&tensor);
let client = tensor.client.clone();
let num_elems = tensor.meta.num_elements();
let working_units = num_elems / line_size as usize;
let cube_dim = CubeDim::new(&tensor.client, working_units);
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
let dtypes = [tensor.dtype.into(), dtype_bool.into()];
let same_tensor_type = dtypes[0] == dtypes[1];
if same_tensor_type && tensor.can_mut() && tensor.is_nonoverlapping() {
unsafe {
kernel_scalar_cmp::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(tensor),
linear_view(&tensor, line_size),
scalar,
linear_view_alias(&tensor, line_size, 0),
dtypes,
)
.expect("Kernel to never fail");
}
CubeTensor::new(
tensor.client,
tensor.handle,
*tensor.meta,
tensor.device,
dtype_bool,
)
} else {
let output = empty_device_dtype(
tensor.client.clone(),
tensor.device.clone(),
tensor.shape(),
dtype_bool,
);
unsafe {
kernel_scalar_cmp::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(tensor, output),
linear_view(&tensor, line_size),
scalar,
linear_view(&output, line_size),
dtypes,
)
.expect("Kernel to never fail");
}
output
}
}
pub fn equal<R: CubeRuntime>(
lhs: CubeTensor<R>,
rhs: CubeTensor<R>,
dtype_bool: DType,
) -> CubeTensor<R> {
launch_cmp::<R, EqualOp>(lhs, rhs, dtype_bool)
}
pub fn greater<R: CubeRuntime>(
lhs: CubeTensor<R>,
rhs: CubeTensor<R>,
dtype_bool: DType,
) -> CubeTensor<R> {
launch_cmp::<R, GreaterOp>(lhs, rhs, dtype_bool)
}
pub fn greater_equal<R: CubeRuntime>(
lhs: CubeTensor<R>,
rhs: CubeTensor<R>,
dtype_bool: DType,
) -> CubeTensor<R> {
launch_cmp::<R, GreaterEqualOp>(lhs, rhs, dtype_bool)
}
pub fn lower<R: CubeRuntime>(
lhs: CubeTensor<R>,
rhs: CubeTensor<R>,
dtype_bool: DType,
) -> CubeTensor<R> {
launch_cmp::<R, LowerOp>(lhs, rhs, dtype_bool)
}
pub fn lower_equal<R: CubeRuntime>(
lhs: CubeTensor<R>,
rhs: CubeTensor<R>,
dtype_bool: DType,
) -> CubeTensor<R> {
launch_cmp::<R, LowerEqualOp>(lhs, rhs, dtype_bool)
}
pub fn equal_elem<R: CubeRuntime>(
lhs: CubeTensor<R>,
rhs: InputScalar,
dtype_bool: DType,
) -> CubeTensor<R> {
launch_scalar_cmp::<R, EqualOp>(lhs, rhs, dtype_bool)
}
pub fn greater_elem<R: CubeRuntime>(
lhs: CubeTensor<R>,
rhs: InputScalar,
dtype_bool: DType,
) -> CubeTensor<R> {
launch_scalar_cmp::<R, GreaterOp>(lhs, rhs, dtype_bool)
}
pub fn lower_elem<R: CubeRuntime>(
lhs: CubeTensor<R>,
rhs: InputScalar,
dtype_bool: DType,
) -> CubeTensor<R> {
launch_scalar_cmp::<R, LowerOp>(lhs, rhs, dtype_bool)
}
pub fn greater_equal_elem<R: CubeRuntime>(
lhs: CubeTensor<R>,
rhs: InputScalar,
dtype_bool: DType,
) -> CubeTensor<R> {
launch_scalar_cmp::<R, GreaterEqualOp>(lhs, rhs, dtype_bool)
}
pub fn lower_equal_elem<R: CubeRuntime>(
lhs: CubeTensor<R>,
rhs: InputScalar,
dtype_bool: DType,
) -> CubeTensor<R> {
launch_scalar_cmp::<R, LowerEqualOp>(lhs, rhs, dtype_bool)
}
// Unary comparison / predicate / relational ops
#[cube]
pub(crate) trait PredicateOp<F: Float>: 'static + Send + Sync {
/// Execute a predicate operation.
fn execute(input: Line<F>) -> bool;
}
pub(crate) trait PredicateOpFamily: 'static + Send + Sync {
type Operation<F: Float>: PredicateOp<F>;
}
struct IsNanOp;
struct IsInfOp;
impl PredicateOpFamily for IsNanOp {
type Operation<F: Float> = Self;
}
#[cube]
impl<F: Float> PredicateOp<F> for IsNanOp {
fn execute(input: Line<F>) -> bool {
Line::is_nan(input)
}
}
impl PredicateOpFamily for IsInfOp {
type Operation<F: Float> = Self;
}
#[cube]
impl<F: Float> PredicateOp<F> for IsInfOp {
fn execute(input: Line<F>) -> bool {
Line::is_inf(input)
}
}
#[cube(launch_unchecked, address_type = "dynamic")]
pub(crate) fn kernel_predicate<F: Float, Bool: Numeric, O: PredicateOpFamily>(
input: &LinearView<Line<F>>,
output: &mut LinearView<Line<Bool>, ReadWrite>,
#[define(F, Bool)] _dtypes: [StorageType; 2],
) {
if !output.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
output[ABSOLUTE_POS] = Line::cast_from(O::Operation::<F>::execute(input[ABSOLUTE_POS]));
}
pub(crate) fn launch_predicate<R: CubeRuntime, O: PredicateOpFamily>(
tensor: CubeTensor<R>,
dtype_bool: DType,
) -> CubeTensor<R> {
let line_size = max_line_size(&tensor);
let client = tensor.client.clone();
let num_elems = tensor.meta.num_elements();
let dtypes = [tensor.dtype.into(), dtype_bool.into()];
let working_units = num_elems / line_size as usize;
let cube_dim = CubeDim::new(&tensor.client, working_units);
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
let output = empty_device_dtype(
tensor.client.clone(),
tensor.device.clone(),
tensor.shape(),
dtype_bool,
);
unsafe {
kernel_predicate::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(tensor, output),
linear_view_ref(&tensor, &output, line_size),
linear_view(&output, line_size),
dtypes,
)
.expect("Kernel to never fail");
}
output
}
pub fn is_nan<R: CubeRuntime>(tensor: CubeTensor<R>, dtype_bool: DType) -> CubeTensor<R> {
launch_predicate::<R, IsNanOp>(tensor, dtype_bool)
}
pub fn is_inf<R: CubeRuntime>(tensor: CubeTensor<R>, dtype_bool: DType) -> CubeTensor<R> {
launch_predicate::<R, IsInfOp>(tensor, dtype_bool)
}

View File

@@ -0,0 +1,124 @@
use burn_backend::{DType, QTensorPrimitive, TensorMetadata};
use cubecl::quant::scheme::{QuantStore, QuantValue};
use cubecl::server::AllocationKind;
use crate::{CubeRuntime, ops::empty_qtensor, tensor::CubeTensor};
/// Make a jit tensor contiguous.
pub fn into_contiguous<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
if tensor.is_contiguous() {
return tensor;
}
if tensor.qparams.is_some() {
return into_contiguous_quantized(tensor, AllocationKind::Contiguous);
}
let output = cubecl::std::tensor::into_contiguous_ref(
&tensor.client,
&tensor.as_handle_ref(),
tensor.dtype.into(),
)
.expect("Kernel to never fail");
CubeTensor::new(
tensor.client,
output.handle,
*output.metadata,
tensor.device,
tensor.dtype,
)
}
/// Make a jit tensor contiguous with an aligned last stride. Tensor is considered already contiguous
/// if runtime can read it as is. This is equivalent in practice.
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(tensor))
)]
pub fn into_contiguous_aligned<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
if R::can_read_tensor(tensor.meta.shape(), tensor.meta.strides()) {
return tensor;
}
if tensor.qparams.is_some() {
return into_contiguous_quantized(tensor, AllocationKind::Optimized);
}
let output = cubecl::std::tensor::into_contiguous_pitched_ref(
&tensor.client,
&tensor.as_handle_ref(),
tensor.dtype.into(),
)
.expect("Kernel to never fail");
CubeTensor::new(
tensor.client,
output.handle,
*output.metadata,
tensor.device,
tensor.dtype,
)
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(tensor))
)]
fn into_contiguous_quantized<R: CubeRuntime>(
tensor: CubeTensor<R>,
kind: AllocationKind,
) -> CubeTensor<R> {
let scheme = tensor.scheme();
let output = empty_qtensor(tensor.shape(), *tensor.scheme(), &tensor.device, kind);
let (values, scales) = tensor.quantized_handles().unwrap();
let (out_values, out_scales) = output.quantized_handles().unwrap();
match scheme.store {
QuantStore::PackedU32(packed_dim) => {
cubecl::std::tensor::into_contiguous_packed_ref(
&values.client,
&values.as_handle_ref(),
&out_values.as_handle_ref(),
packed_dim,
tensor.meta.shape(),
scheme.num_quants(),
DType::U32.into(),
)
.expect("Kernel to never fail");
}
// e2m1 is special because it has a native packed representation, `e2m1x2`.
// It's internally stored as `u8` with a packing factor of 2.
QuantStore::PackedNative(packed_dim) if scheme.value == QuantValue::E2M1 => {
cubecl::std::tensor::into_contiguous_packed_ref(
&values.client,
&values.as_handle_ref(),
&out_values.as_handle_ref(),
packed_dim,
tensor.meta.shape(),
scheme.num_quants(),
DType::U8.into(),
)
.expect("Kernel to never fail");
}
_ => {
cubecl::std::tensor::copy_into(
&values.client,
&values.as_handle_ref(),
&out_values.as_handle_ref(),
values.dtype.into(),
)
.expect("Kernel to never fail");
}
}
cubecl::std::tensor::copy_into(
&scales.client,
&scales.as_handle_ref(),
&out_scales.as_handle_ref(),
scales.dtype.into(),
)
.expect("Kernel to never fail");
output
}

View File

@@ -0,0 +1,142 @@
use burn_backend::{
TensorMetadata,
ops::{ConvOptions, ConvTransposeOptions},
};
use burn_std::Shape;
use cubek::convolution::components::ConvSetupError;
use crate::{
CubeRuntime,
kernel::conv::{conv_transpose2d, conv_transpose3d},
ops::{permute_nchw_to_nhwc, permute_nhwc_to_nchw, reshape},
tensor::CubeTensor,
};
pub(crate) fn conv_data_backward_fallback<R: CubeRuntime, const N_DIM: usize>(
out_grad: CubeTensor<R>,
weights: CubeTensor<R>,
in_shape: Shape,
options: ConvOptions<N_DIM>,
) -> Result<CubeTensor<R>, ConvSetupError> {
let dim_c = out_grad.rank();
let kernel_size = &weights.meta.shape()[1..dim_c];
let in_shape = &in_shape[1..dim_c];
let out_shape = &out_grad.meta.shape()[1..dim_c];
let mut padding_out = [0; N_DIM];
for i in 0..N_DIM {
padding_out[i] = calculate_padding_out(
kernel_size[i],
options.stride[i],
options.padding[i],
options.dilation[i],
in_shape[i],
out_shape[i],
);
}
// We don't yet have NHWC kernels for conv_transpose so need to do this.
// Should eventually use NHWC kernels instead
let out_grad = permute_nhwc_to_nchw(out_grad);
let weights = permute_nhwc_to_nchw(weights);
let in_grad = match N_DIM {
1 => conv_transpose1d_from_conv_transpose2d(
out_grad,
weights,
ConvTransposeOptions::new(
[options.stride[0]],
[options.padding[0]],
[padding_out[0]],
[options.dilation[0]],
options.groups,
),
),
2 => conv_transpose2d(
out_grad,
weights,
None,
ConvTransposeOptions::new(
[options.stride[0], options.stride[1]],
[options.padding[0], options.padding[1]],
[padding_out[0], padding_out[1]],
[options.dilation[0], options.dilation[1]],
options.groups,
),
Default::default(),
),
3 => Ok(conv_transpose3d(
out_grad,
weights,
None,
ConvTransposeOptions::new(
[options.stride[0], options.stride[1], options.stride[2]],
[options.padding[0], options.padding[1], options.padding[2]],
[padding_out[0], padding_out[1], padding_out[2]],
[
options.dilation[0],
options.dilation[1],
options.dilation[2],
],
options.groups,
),
)
.unwrap()),
_ => unimplemented!("Invalid dimensionality"),
}?;
Ok(permute_nchw_to_nhwc(in_grad))
}
fn calculate_padding_out(
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
size_in: usize,
size_out: usize,
) -> usize {
if stride <= 1 {
return 0;
}
let out = 1
+ ((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) as f64 / stride as f64).ceil()
as usize;
i64::max(0, out as i64 - size_out as i64) as usize
}
fn conv_transpose1d_from_conv_transpose2d<R: CubeRuntime>(
x: CubeTensor<R>,
weight: CubeTensor<R>,
options: ConvTransposeOptions<1>,
) -> Result<CubeTensor<R>, ConvSetupError> {
let [channels_in, channels_out, kernel_size] = weight.shape().dims();
let [batch_size, _channels_in, length_in] = x.shape().dims();
let weight = reshape(
weight,
Shape::new([channels_in, channels_out, kernel_size, 1]),
);
let x = reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));
let tensor = conv_transpose2d(
x,
weight,
None,
ConvTransposeOptions::new(
[options.stride[0], 1],
[options.padding[0], 0],
[options.padding_out[0], 0],
[options.dilation[0], 1],
options.groups,
),
Default::default(),
)?;
let [batch_size, channels_out, height_out, _weight_out] = tensor.shape().dims();
Ok(reshape(
tensor,
Shape::from([batch_size, channels_out, height_out]),
))
}

View File

@@ -0,0 +1,132 @@
use burn_backend::ops::ConvOptions;
use burn_std::Shape;
use cubek::{
convolution::{
AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy, backward_data,
components::ConvSetupError,
},
matmul::{
definition::{MatmulElems, MatmulGlobalElems},
launch::MatmulInputHandleRef,
},
};
use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
pub fn dgrad_gemm_simple_sync<R: CubeRuntime, const N: usize>(
out_grad: CubeTensor<R>,
weights: CubeTensor<R>,
input_shape: Shape,
options: ConvOptions<N>,
tile_kind: AcceleratedTileKind,
) -> Result<CubeTensor<R>, ConvSetupError> {
let read_strategy = match tile_kind {
AcceleratedTileKind::Cmma => ReadingStrategy::Cyclic,
AcceleratedTileKind::Mma => ReadingStrategy::Strided,
};
launch_backwards_data::<R, N>(
&Strategy::Simple {
read_strategy,
tile_kind,
},
out_grad,
weights,
input_shape,
options,
)
}
pub fn dgrad_gemm_simple_async<R: CubeRuntime, const N: usize>(
out_grad: CubeTensor<R>,
weights: CubeTensor<R>,
input_shape: Shape,
options: ConvOptions<N>,
tile_kind: AcceleratedTileKind,
) -> Result<CubeTensor<R>, ConvSetupError> {
let read_strategy = match tile_kind {
AcceleratedTileKind::Cmma => ReadingStrategy::AsyncCyclic,
AcceleratedTileKind::Mma => ReadingStrategy::AsyncStrided,
};
launch_backwards_data::<R, N>(
&Strategy::Simple {
read_strategy,
tile_kind,
},
out_grad,
weights,
input_shape,
options,
)
}
pub fn dgrad_gemm_simple_tma<R: CubeRuntime, const N: usize>(
out_grad: CubeTensor<R>,
weights: CubeTensor<R>,
input_shape: Shape,
options: ConvOptions<N>,
tile_kind: AcceleratedTileKind,
) -> Result<CubeTensor<R>, ConvSetupError> {
launch_backwards_data::<R, N>(
&Strategy::Simple {
read_strategy: ReadingStrategy::Tma,
tile_kind,
},
out_grad,
weights,
input_shape,
options,
)
}
/// Perform a convolution backwards data pass using the implicit GEMM (im2col) algorithm, using
/// cubecl tiling matmul components.
///
/// * `input` - The input feature map
/// * `out_grad` - The output gradients
/// * `weight_shape` - The shape of the weights/weight gradients
/// * `options` - The options to use for the convolution
pub fn launch_backwards_data<R: CubeRuntime, const N: usize>(
strategy: &Strategy,
out_grad: CubeTensor<R>,
weights: CubeTensor<R>,
input_shape: Shape,
options: ConvOptions<N>,
) -> Result<CubeTensor<R>, ConvSetupError> {
if options.groups != 1 || options.stride.iter().any(|&s| s != 1) {
return Err(ConvSetupError::Groups(options.groups));
}
let out_dtype = out_grad.dtype;
let in_grad = empty_device_dtype(
out_grad.client.clone(),
out_grad.device.clone(),
input_shape,
out_dtype,
);
let client = out_grad.client.clone();
let dtypes = MatmulElems::from_globals(&MatmulGlobalElems {
lhs: out_grad.dtype.into(),
rhs: weights.dtype.into(),
out: out_dtype.into(),
});
let out_grad = MatmulInputHandleRef::new(out_grad.as_handle_ref(), out_grad.dtype.into());
let weights = MatmulInputHandleRef::new(weights.as_handle_ref(), weights.dtype.into());
backward_data::launch_ref::<R, N>(
strategy,
&client,
&out_grad,
&weights,
&in_grad.as_handle_ref(),
ConvolutionArgs {
stride: options.stride,
padding: options.padding,
dilation: options.dilation,
},
dtypes,
)?;
Ok(in_grad)
}

View File

@@ -0,0 +1,2 @@
pub mod launch;
pub use launch::*;

View File

@@ -0,0 +1,8 @@
pub mod fallback;
pub mod implicit_gemm;
#[cfg(feature = "autotune")]
pub mod tune;
#[cfg(feature = "autotune")]
pub(crate) use tune::*;

View File

@@ -0,0 +1,172 @@
use burn_backend::ops::ConvOptions;
use burn_std::Shape;
use cubecl::{
ir::StorageType,
tune::{LocalTuner, Tunable, TunableSet, anchor, local_tuner},
};
use cubek::convolution::AcceleratedTileKind;
use crate::{
CubeAutotuneKey, CubeRuntime, CubeTuneId,
kernel::conv::{
ConvAutotuneKey,
backward_data::{fallback::conv_data_backward_fallback, implicit_gemm::*},
},
tensor::CubeTensor,
};
/// Executes autotune on conv2d operations
pub fn dgrad_autotune<R: CubeRuntime, const N: usize>(
out_grad: CubeTensor<R>,
weights: CubeTensor<R>,
input_shape: Shape,
options: ConvOptions<N>,
) -> CubeTensor<R> {
let client = out_grad.client.clone();
static TUNER: LocalTuner<CubeAutotuneKey, CubeTuneId> = local_tuner!();
// Note: TMA isn't currently implemented properly, and will always error.
// It's kept here so it gets automatically enabled as soon as cubek updates.
// No CMMA for TMA because swizzling will be mandatory for good performance on dgrad.
let tunables = TUNER.init(|| {
TunableSet::new(create_key::<R, N>, create_wgrad_input::<R, N>)
.with(Tunable::new(
"wgrad_fallback",
conv_data_backward_fallback::<R, N>,
))
.with(Tunable::new(
"simple_sync_cmma",
|input, grad, shape, options| {
dgrad_gemm_simple_sync(input, grad, shape, options, AcceleratedTileKind::Cmma)
},
))
.with(Tunable::new(
"simple_sync_mma",
|input, grad, shape, options| {
dgrad_gemm_simple_sync(input, grad, shape, options, AcceleratedTileKind::Mma)
},
))
.with(Tunable::new(
"simple_async_cmma",
|input, grad, shape, options| {
dgrad_gemm_simple_async(input, grad, shape, options, AcceleratedTileKind::Cmma)
},
))
.with(Tunable::new(
"simple_async_mma",
|input, grad, shape, options| {
dgrad_gemm_simple_async(input, grad, shape, options, AcceleratedTileKind::Mma)
},
))
.with(Tunable::new(
"simple_tma_mma",
|input, grad, shape, options| {
dgrad_gemm_simple_tma(input, grad, shape, options, AcceleratedTileKind::Mma)
},
))
});
TUNER.execute(
&CubeTuneId::new(&out_grad.client, &out_grad.device),
&client,
tunables,
(out_grad, weights, input_shape, options),
)
}
pub fn create_wgrad_input<R: CubeRuntime, const N: usize>(
_key: &CubeAutotuneKey,
out_grad: &CubeTensor<R>,
weights: &CubeTensor<R>,
input_shape: &Shape,
options: &ConvOptions<N>,
) -> (CubeTensor<R>, CubeTensor<R>, Shape, ConvOptions<N>) {
(
out_grad.clone(),
weights.clone(),
input_shape.clone(),
options.clone(),
)
}
fn create_key<R: CubeRuntime, const N: usize>(
out_grad: &CubeTensor<R>,
weights: &CubeTensor<R>,
input_shape: &Shape,
options: &ConvOptions<N>,
) -> CubeAutotuneKey {
let dtype = out_grad.dtype;
let rank = out_grad.meta.num_dims();
let dim_c = rank - 1;
let batch_size = out_grad.meta.shape()[0];
let in_channels = input_shape[dim_c];
let out_channels = out_grad.meta.shape()[dim_c];
let kernel_size = weights.meta.shape()[1..dim_c].to_vec();
let in_shape = input_shape[1..dim_c]
.iter()
.map(|shape| anchor(*shape, None, None, None))
.collect();
let ConvOptions {
stride,
padding,
dilation,
groups,
} = options.clone();
let lhs_stride_align = if out_grad.meta.strides()[dim_c] == 1 {
stride_align(out_grad.meta.strides(), out_grad.dtype.into())
} else {
0
};
let lhs_shape_align = pow2_factor(out_channels).min(lhs_stride_align);
let rhs_stride_align = if weights.meta.strides()[dim_c] == 1 {
stride_align(weights.meta.strides(), weights.dtype.into())
} else {
0
};
let rhs_shape_align = pow2_factor(in_channels).min(rhs_stride_align);
CubeAutotuneKey::Conv(ConvAutotuneKey::new(
kernel_size,
stride.to_vec(),
padding.to_vec(),
dilation.to_vec(),
groups,
in_channels,
out_channels,
in_shape,
batch_size,
false,
dtype,
lhs_shape_align,
lhs_stride_align,
rhs_shape_align,
rhs_stride_align,
))
}
/// Maximum factor relevant for strides. Currently set to 2^10 because that's 128-byte swizzle's
/// repeat number, so it's the largest align that can have performance impacts.
const MAX_STRIDE_FACTOR: u32 = 10;
/// Defines the non-contiguous stride alignment in terms of powers of two
fn stride_align(strides: &[usize], elem: StorageType) -> u8 {
let max = MAX_STRIDE_FACTOR;
let dim_c = strides.len() - 1;
let factor = strides[..dim_c]
.iter()
.map(|it| (*it * elem.size_bits()) / 8)
.map(|it| it.trailing_zeros())
.min()
.unwrap_or(max);
factor.min(max) as u8
}
/// Defines the potential vectorization.
fn pow2_factor(axis: usize) -> u8 {
axis.trailing_zeros().min(4) as u8
}

View File

@@ -0,0 +1,112 @@
use burn_backend::{TensorMetadata, ops::ConvOptions};
use burn_std::{Shape, Slice};
use cubek::convolution::components::ConvSetupError;
use crate::{
CubeRuntime,
kernel::{conv::base::conv_forward_nhwc, slice, slice_assign},
ops::{numeric::empty_device_dtype, swap_dims},
tensor::CubeTensor,
};
/// Calculate the convolution backward pass with regard to the weight gradients.
pub fn conv_weight_backward_fallback<R: CubeRuntime, const N_DIM: usize>(
input: CubeTensor<R>,
output_grad: CubeTensor<R>,
weight_shape: Shape,
options: ConvOptions<N_DIM>,
) -> Result<CubeTensor<R>, ConvSetupError> {
match options.groups == 1 {
true => conv_weight_grad_no_groups::<R, N_DIM>(input, output_grad, weight_shape, options),
false => conv_weight_grad_groups::<R, N_DIM>(input, output_grad, weight_shape, options),
}
}
fn conv_weight_grad_no_groups<R: CubeRuntime, const N_DIM: usize>(
input: CubeTensor<R>,
output_grad: CubeTensor<R>,
weight_shape: Shape,
options: ConvOptions<N_DIM>,
) -> Result<CubeTensor<R>, ConvSetupError> {
let dim_c = input.rank() - 1;
let input_swapped = swap_dims(input, 0, dim_c);
let out_grad_swapped = swap_dims(output_grad, 0, dim_c);
let weight_grad_swapped = conv_forward_nhwc(
input_swapped,
out_grad_swapped,
None,
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
Default::default(),
)?;
let mut weight_grad = swap_dims(weight_grad_swapped, 0, dim_c);
if weight_grad.shape() != weight_shape {
let ranges = weight_shape.iter().map(|&s| 0..s).collect::<Vec<_>>();
weight_grad = slice(weight_grad, &ranges);
}
Ok(weight_grad)
}
#[allow(clippy::single_range_in_vec_init, reason = "False positive")]
fn conv_weight_grad_groups<R: CubeRuntime, const N_DIM: usize>(
input: CubeTensor<R>,
output_grad: CubeTensor<R>,
weight_shape: Shape,
options: ConvOptions<N_DIM>,
) -> Result<CubeTensor<R>, ConvSetupError> {
let mut weight_grad = empty_device_dtype(
input.client.clone(),
input.device.clone(),
weight_shape.clone(),
input.dtype,
);
let dim_c = input.rank() - 1;
let channels_out = weight_shape[0];
let increment_co = channels_out / options.groups;
let input_swapped = swap_dims(input, 0, dim_c);
let output_grad_swapped = swap_dims(output_grad, 0, dim_c);
let kernel_size = &weight_shape[1..dim_c];
let kernel_size_slice = kernel_size.iter().map(|&s| 0..s).collect::<Vec<_>>();
let increment_ci = weight_grad.meta.shape()[dim_c];
for g in 0..options.groups {
let start_idx_ci = g * increment_ci;
let end_idx_ci = (g + 1) * increment_ci;
let start_idx_co = g * increment_co;
let end_idx_co = (g + 1) * increment_co;
let input = slice(input_swapped.clone(), &[start_idx_ci..end_idx_ci]);
let grad = slice(output_grad_swapped.clone(), &[start_idx_co..end_idx_co]);
let weight_grad_tmp = conv_forward_nhwc(
input,
grad,
None,
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
Default::default(),
)?;
let mut weight_grad_tmp = swap_dims(weight_grad_tmp, 0, dim_c);
let kernel_size_tmp = &weight_grad_tmp.meta.shape()[1..dim_c];
if kernel_size != kernel_size_tmp {
let mut slices = vec![0..increment_co];
slices.extend(kernel_size_slice.clone());
slices.push(0..increment_ci);
weight_grad_tmp = slice(weight_grad_tmp, &slices);
}
let mut slices = vec![start_idx_co..end_idx_co];
slices.extend(kernel_size_slice.clone());
slices.push(0..increment_ci);
let slices = slices.into_iter().map(Slice::from).collect::<Vec<_>>();
weight_grad = slice_assign(weight_grad, &slices, weight_grad_tmp);
}
Ok(weight_grad)
}

View File

@@ -0,0 +1,132 @@
use burn_backend::ops::ConvOptions;
use burn_std::Shape;
use cubek::{
convolution::{
AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy, backward_weight,
components::ConvSetupError,
},
matmul::{
definition::{MatmulElems, MatmulGlobalElems},
launch::MatmulInputHandleRef,
},
};
use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
pub(crate) fn wgrad_gemm_simple_sync<R: CubeRuntime, const N: usize>(
input: CubeTensor<R>,
out_grad: CubeTensor<R>,
weight_shape: Shape,
options: ConvOptions<N>,
tile_kind: AcceleratedTileKind,
) -> Result<CubeTensor<R>, ConvSetupError> {
let read_strategy = match tile_kind {
AcceleratedTileKind::Cmma => ReadingStrategy::Cyclic,
AcceleratedTileKind::Mma => ReadingStrategy::Strided,
};
launch_backwards_weight::<R, N>(
&Strategy::Simple {
read_strategy,
tile_kind,
},
input,
out_grad,
weight_shape,
options,
)
}
pub(crate) fn wgrad_gemm_simple_async<R: CubeRuntime, const N: usize>(
input: CubeTensor<R>,
out_grad: CubeTensor<R>,
weight_shape: Shape,
options: ConvOptions<N>,
tile_kind: AcceleratedTileKind,
) -> Result<CubeTensor<R>, ConvSetupError> {
let read_strategy = match tile_kind {
AcceleratedTileKind::Cmma => ReadingStrategy::AsyncCyclic,
AcceleratedTileKind::Mma => ReadingStrategy::AsyncStrided,
};
launch_backwards_weight::<R, N>(
&Strategy::Simple {
read_strategy,
tile_kind,
},
input,
out_grad,
weight_shape,
options,
)
}
pub(crate) fn wgrad_gemm_simple_tma<R: CubeRuntime, const N: usize>(
input: CubeTensor<R>,
out_grad: CubeTensor<R>,
weight_shape: Shape,
options: ConvOptions<N>,
tile_kind: AcceleratedTileKind,
) -> Result<CubeTensor<R>, ConvSetupError> {
launch_backwards_weight::<R, N>(
&Strategy::Simple {
read_strategy: ReadingStrategy::Tma,
tile_kind,
},
input,
out_grad,
weight_shape,
options,
)
}
/// Perform a convolution backwards weight pass using the implicit GEMM (im2col) algorithm, using
/// cubecl tiling matmul components.
///
/// * `input` - The input feature map
/// * `out_grad` - The output gradients
/// * `weight_shape` - The shape of the weights/weight gradients
/// * `options` - The options to use for the convolution
pub fn launch_backwards_weight<R: CubeRuntime, const N: usize>(
strategy: &Strategy,
input: CubeTensor<R>,
out_grad: CubeTensor<R>,
weight_shape: Shape,
options: ConvOptions<N>,
) -> Result<CubeTensor<R>, ConvSetupError> {
if options.groups != 1 {
return Err(ConvSetupError::Groups(options.groups));
}
let out_dtype = out_grad.dtype;
let weight_grad = empty_device_dtype(
input.client.clone(),
input.device.clone(),
weight_shape,
out_dtype,
);
let client = input.client.clone();
let dtypes = MatmulElems::from_globals(&MatmulGlobalElems {
lhs: input.dtype.into(),
rhs: out_grad.dtype.into(),
out: out_dtype.into(),
});
let input = MatmulInputHandleRef::new(input.as_handle_ref(), input.dtype.into());
let out_grad = MatmulInputHandleRef::new(out_grad.as_handle_ref(), out_grad.dtype.into());
backward_weight::launch_ref::<R, N>(
strategy,
&client,
&input,
&out_grad,
&weight_grad.as_handle_ref(),
ConvolutionArgs {
stride: options.stride,
padding: options.padding,
dilation: options.dilation,
},
dtypes,
)?;
Ok(weight_grad)
}

View File

@@ -0,0 +1,2 @@
pub mod launch;
pub use launch::*;

View File

@@ -0,0 +1,8 @@
pub mod fallback;
pub mod implicit_gemm;
#[cfg(feature = "autotune")]
pub mod tune;
#[cfg(feature = "autotune")]
pub(crate) use tune::*;

View File

@@ -0,0 +1,175 @@
use burn_backend::ops::ConvOptions;
use burn_std::Shape;
use cubecl::{
ir::StorageType,
tune::{LocalTuner, Tunable, TunableSet, anchor, local_tuner},
};
use cubek::convolution::AcceleratedTileKind;
use crate::{
CubeAutotuneKey, CubeRuntime, CubeTuneId,
kernel::conv::{
ConvAutotuneKey,
backward_weight::{fallback::conv_weight_backward_fallback, implicit_gemm::*},
},
tensor::CubeTensor,
};
/// Executes autotune on the weight gradients pass for convolution
pub fn wgrad_autotune<R: CubeRuntime, const N: usize>(
input: CubeTensor<R>,
out_grad: CubeTensor<R>,
weight_shape: Shape,
options: ConvOptions<N>,
) -> CubeTensor<R> {
let client = input.client.clone();
static TUNER: LocalTuner<CubeAutotuneKey, CubeTuneId> = local_tuner!();
let tunables = TUNER.init(|| {
TunableSet::new(create_key::<R, N>, create_wgrad_input::<R, N>)
.with(Tunable::new(
"wgrad_fallback",
conv_weight_backward_fallback::<R, N>,
))
.with(Tunable::new(
"simple_sync_cmma",
|input, grad, shape, options| {
wgrad_gemm_simple_sync(input, grad, shape, options, AcceleratedTileKind::Cmma)
},
))
.with(Tunable::new(
"simple_sync_mma",
|input, grad, shape, options| {
wgrad_gemm_simple_sync(input, grad, shape, options, AcceleratedTileKind::Mma)
},
))
.with(Tunable::new(
"simple_async_cmma",
|input, grad, shape, options| {
wgrad_gemm_simple_async(input, grad, shape, options, AcceleratedTileKind::Cmma)
},
))
.with(Tunable::new(
"simple_async_mma",
|input, grad, shape, options| {
wgrad_gemm_simple_async(input, grad, shape, options, AcceleratedTileKind::Mma)
},
))
.with(Tunable::new(
"simple_tma_cmma",
|input, grad, shape, options| {
wgrad_gemm_simple_tma(input, grad, shape, options, AcceleratedTileKind::Cmma)
},
))
.with(Tunable::new(
"simple_tma_mma",
|input, grad, shape, options| {
wgrad_gemm_simple_tma(input, grad, shape, options, AcceleratedTileKind::Mma)
},
))
});
TUNER.execute(
&CubeTuneId::new(&input.client, &input.device),
&client,
tunables,
(input, out_grad, weight_shape, options),
)
}
pub fn create_wgrad_input<R: CubeRuntime, const N: usize>(
_key: &CubeAutotuneKey,
input: &CubeTensor<R>,
out_grad: &CubeTensor<R>,
weight_shape: &Shape,
options: &ConvOptions<N>,
) -> (CubeTensor<R>, CubeTensor<R>, Shape, ConvOptions<N>) {
(
input.clone(),
out_grad.clone(),
weight_shape.clone(),
options.clone(),
)
}
fn create_key<R: CubeRuntime, const N: usize>(
input: &CubeTensor<R>,
out_grad: &CubeTensor<R>,
weight_shape: &Shape,
options: &ConvOptions<N>,
) -> CubeAutotuneKey {
let dtype = input.dtype;
let rank = input.meta.num_dims();
let dim_c = rank - 1;
let batch_size = input.meta.shape()[0];
let in_channels = input.meta.shape()[dim_c];
let out_channels = weight_shape[0];
let kernel_size = weight_shape[1..dim_c].to_vec();
let in_shape = input.meta.shape()[1..dim_c]
.iter()
.map(|shape| anchor(*shape, None, None, None))
.collect();
let ConvOptions {
stride,
padding,
dilation,
groups,
} = options.clone();
let lhs_stride_align = if out_grad.meta.strides()[dim_c] == 1 {
stride_align(out_grad.meta.strides(), out_grad.dtype.into())
} else {
0
};
let lhs_shape_align = pow2_factor(out_channels).min(lhs_stride_align);
let rhs_stride_align = if input.meta.strides()[dim_c] == 1 {
stride_align(input.meta.strides(), input.dtype.into())
} else {
0
};
let rhs_shape_align = pow2_factor(in_channels).min(rhs_stride_align);
CubeAutotuneKey::Conv(ConvAutotuneKey::new(
kernel_size,
stride.to_vec(),
padding.to_vec(),
dilation.to_vec(),
groups,
in_channels,
out_channels,
in_shape,
batch_size,
false,
dtype,
lhs_shape_align,
lhs_stride_align,
rhs_shape_align,
rhs_stride_align,
))
}
/// Maximum factor relevant for strides. Currently set to 2^10 because that's 128-byte swizzle's
/// repeat number, so it's the largest align that can have performance impacts.
const MAX_STRIDE_FACTOR: u32 = 10;
/// Defines the non-contiguous stride alignment in terms of powers of two
fn stride_align(strides: &[usize], elem: StorageType) -> u8 {
let max = MAX_STRIDE_FACTOR;
let dim_c = strides.len() - 1;
let factor = strides[..dim_c]
.iter()
.map(|it| (*it * elem.size_bits()) / 8)
.map(|it| it.trailing_zeros())
.min()
.unwrap_or(max);
factor.min(max) as u8
}
/// Defines the potential vectorization.
fn pow2_factor(axis: usize) -> u8 {
axis.trailing_zeros().min(4) as u8
}

View File

@@ -0,0 +1,189 @@
use burn_backend::ops::ConvOptions;
use burn_std::Shape;
use cubek::convolution::{AcceleratedTileKind, components::ConvSetupError};
#[cfg(feature = "autotune")]
use crate::kernel::conv::{backward_weight::wgrad_autotune, dgrad_autotune};
use crate::{
CubeRuntime,
kernel::conv::{
backward_data::{fallback::conv_data_backward_fallback, implicit_gemm::*},
backward_weight::{fallback::conv_weight_backward_fallback, implicit_gemm::*},
forward::implicit_gemm::conv_gemm_simple_sync,
},
ops::{permute_nchw_to_nhwc, permute_nchw_to_nhwc_shape, permute_nhwc_to_nchw},
tensor::CubeTensor,
};
use super::conv_direct;
#[cfg(feature = "autotune")]
use super::forward::conv_autotune;
/// The strategy to be used when launching a convolution kernel.
pub enum ConvStrategy {
/// A simple direct convolution.
Direct,
#[cfg(feature = "autotune")]
/// Using autotune to choose the best kernel based on runtime information.
Autotune,
/// Implicit GEMM implementation of convolution. Lower memory usage but requires CMMA and
/// has constraints on tensor shape.
ImplicitGemm,
}
impl Default for ConvStrategy {
fn default() -> Self {
// if autotune is enabled, default to autotune
#[cfg(feature = "autotune")]
return ConvStrategy::Autotune;
// if autotune is disabled, default to the more memory-conservative algorithm
#[cfg(not(feature = "autotune"))]
ConvStrategy::Direct
}
}
/// Performs an N-dimensional convolution with the given strategy
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
pub fn conv_forward<R: CubeRuntime, const N: usize>(
input: CubeTensor<R>,
weight: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
options: ConvOptions<N>,
strategy: ConvStrategy,
) -> Result<CubeTensor<R>, ConvSetupError> {
let input = permute_nchw_to_nhwc(input);
let weight = permute_nchw_to_nhwc(weight);
let out = conv_forward_nhwc(input, weight, bias, options, strategy)?;
Ok(permute_nhwc_to_nchw(out))
}
/// Performs an N-dimensional convolution with the given strategy on NHWC inputs/outputs
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
pub fn conv_forward_nhwc<R: CubeRuntime, const N: usize>(
input: CubeTensor<R>,
weight: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
options: ConvOptions<N>,
strategy: ConvStrategy,
) -> Result<CubeTensor<R>, ConvSetupError> {
match strategy {
ConvStrategy::Direct => conv_direct::<R, N>(input, weight, bias, options),
#[cfg(feature = "autotune")]
ConvStrategy::Autotune => Ok(conv_autotune::<R, N>(input, weight, bias, options)),
ConvStrategy::ImplicitGemm => {
if options.groups != 1 {
conv_direct::<R, N>(input, weight, bias, options)
} else {
conv_gemm_simple_sync::<R, N>(
input,
weight,
bias,
options,
AcceleratedTileKind::Cmma,
)
}
}
}
}
/// Performs an N-dimensional convolution backwards pass with regard to weight, with the given strategy
///
/// * `input` - The input feature map
/// * `out_grad` - The output gradients
/// * `weight_shape` - The shape of the weights/weight gradients
/// * `options` - The options used for the convolution
/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
pub fn conv_weight_backward<R: CubeRuntime, const N: usize>(
input: CubeTensor<R>,
out_grad: CubeTensor<R>,
weight_shape: Shape,
options: ConvOptions<N>,
strategy: ConvStrategy,
) -> Result<CubeTensor<R>, ConvSetupError> {
let input = permute_nchw_to_nhwc(input);
let out_grad = permute_nchw_to_nhwc(out_grad);
let weight_shape = permute_nchw_to_nhwc_shape(weight_shape);
let weight_grad = match strategy {
ConvStrategy::Direct => {
conv_weight_backward_fallback::<R, N>(input, out_grad, weight_shape, options)
}
#[cfg(feature = "autotune")]
ConvStrategy::Autotune => Ok(wgrad_autotune::<R, N>(
input,
out_grad,
weight_shape,
options,
)),
ConvStrategy::ImplicitGemm => {
if options.groups != 1 {
conv_weight_backward_fallback::<R, N>(input, out_grad, weight_shape, options)
} else {
wgrad_gemm_simple_sync::<R, N>(
input,
out_grad,
weight_shape,
options,
AcceleratedTileKind::Cmma,
)
}
}
}?;
Ok(permute_nhwc_to_nchw(weight_grad))
}
/// Performs an N-dimensional convolution backwards data pass with the given strategy
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `in_shape` - The shape of the input to the layer
/// * `options` - The options to use for the convolution
/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
pub fn conv_data_backward<R: CubeRuntime, const N: usize>(
out_grad: CubeTensor<R>,
weights: CubeTensor<R>,
in_shape: Shape,
options: ConvOptions<N>,
strategy: ConvStrategy,
) -> Result<CubeTensor<R>, ConvSetupError> {
let out_grad = permute_nchw_to_nhwc(out_grad);
let weights = permute_nchw_to_nhwc(weights);
let in_shape = permute_nchw_to_nhwc_shape(in_shape);
let weight_grad = match strategy {
ConvStrategy::Direct => {
conv_data_backward_fallback::<R, N>(out_grad, weights, in_shape, options)?
}
#[cfg(feature = "autotune")]
ConvStrategy::Autotune => dgrad_autotune::<R, N>(out_grad, weights, in_shape, options),
ConvStrategy::ImplicitGemm => {
if options.groups != 1 || options.stride.iter().any(|&s| s != 1) {
conv_data_backward_fallback::<R, N>(out_grad, weights, in_shape, options)?
} else {
dgrad_gemm_simple_sync::<R, N>(
out_grad,
weights,
in_shape,
options,
AcceleratedTileKind::Cmma,
)?
}
}
};
Ok(permute_nhwc_to_nchw(weight_grad))
}

View File

@@ -0,0 +1,54 @@
use crate::{CubeRuntime, tensor::CubeTensor};
use burn_backend::ops::ConvTransposeOptions;
use cubek::convolution::components::ConvSetupError;
#[cfg(feature = "autotune")]
use super::conv_transpose2d_autotune;
use super::{conv_transpose2d_col2im, conv_transpose2d_direct};
/// The strategy to be used when launching a conv_transpose kernel.
pub enum ConvTranspose2dStrategy {
/// A simple direct convolution.
Direct,
#[cfg(feature = "autotune")]
/// Using autotune to choose the best kernel based on runtime information.
Autotune,
/// GEMM (im2col) based implementation of convolution. Significantly increased memory usage.
Gemm,
}
impl Default for ConvTranspose2dStrategy {
fn default() -> Self {
// if autotune is enabled, default to autotune
#[cfg(feature = "autotune")]
return ConvTranspose2dStrategy::Autotune;
// if autotune is disabled, default to the more memory-conservative algorithm
#[cfg(not(feature = "autotune"))]
ConvTranspose2dStrategy::Direct
}
}
/// Performs a 2D convolution with the given strategy
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
pub fn conv_transpose2d<R: CubeRuntime>(
input: CubeTensor<R>,
weight: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
options: ConvTransposeOptions<2>,
strategy: ConvTranspose2dStrategy,
) -> Result<CubeTensor<R>, ConvSetupError> {
match strategy {
ConvTranspose2dStrategy::Direct => conv_transpose2d_direct(input, weight, bias, options),
#[cfg(feature = "autotune")]
ConvTranspose2dStrategy::Autotune => {
Ok(conv_transpose2d_autotune(input, weight, bias, options))
}
ConvTranspose2dStrategy::Gemm => conv_transpose2d_col2im(input, weight, bias, options),
}
}

View File

@@ -0,0 +1,302 @@
use crate::{
CubeRuntime,
kernel::{
conv::batches_per_run,
into_contiguous_aligned,
matmul::{MatmulStrategy, matmul},
slice,
utils::{address_type, decompose_linear, linear_view, shape_divmod},
},
ops::{numeric::empty_device_dtype, reshape, swap_dims},
tensor::CubeTensor,
};
use burn_backend::{
Shape,
ops::{ConvTransposeOptions, conv::calculate_conv_transpose_output_size},
};
use cubecl::{
calculate_cube_count_elemwise,
prelude::*,
std::{FastDivmod, tensor::layout::linear::LinearView},
};
use cubek::convolution::components::ConvSetupError;
/// Perform a 2D convolution transposition using the GEMM (col2im) algorithm.
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
pub fn conv_transpose2d_col2im<R: CubeRuntime>(
input: CubeTensor<R>,
weight: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
options: ConvTransposeOptions<2>,
) -> Result<CubeTensor<R>, ConvSetupError> {
let [input_channels, im_ch_per_group, kernel_h, kernel_w] = weight.meta.shape().dims();
let [batch_size, _, input_h, input_w] = input.meta.shape().dims();
let groups = options.groups;
let input_ch_per_group = input_channels / groups;
let ConvTransposeOptions {
padding: [padding_h, padding_w],
padding_out: [padding_out_h, padding_out_w],
dilation: [dilation_h, dilation_w],
stride: [stride_h, stride_w],
..
} = options.clone();
let im_h = calculate_conv_transpose_output_size(
kernel_h,
stride_h,
padding_h,
padding_out_h,
dilation_h,
input_h,
);
let im_w = calculate_conv_transpose_output_size(
kernel_w,
stride_w,
padding_w,
padding_out_w,
dilation_w,
input_w,
);
let im_channels = im_ch_per_group * groups;
let batches_per_run = batches_per_run(
batch_size,
input_h * input_w,
input.client.properties().hardware.plane_size_max as usize,
)?;
let col_shape_0 = im_ch_per_group * kernel_h * kernel_w;
let weight = reshape(
weight.clone(),
Shape::new([groups, input_ch_per_group, col_shape_0]),
);
let weight = into_contiguous_aligned(swap_dims(weight, 1, 2));
if batches_per_run != batch_size {
let runs = batch_size / batches_per_run;
let im_shape = Shape::new([runs, batches_per_run, im_channels, im_h, im_w]);
let image = empty_device_dtype(
input.client.clone(),
input.device.clone(),
im_shape,
input.dtype,
);
let input_shape = Shape::new([runs, batches_per_run, input_channels, input_h, input_w]);
let input = reshape(input, input_shape);
let input_shape_run = Shape::new([batches_per_run, input_channels, input_h, input_w]);
for run in 0..runs {
let input = index(input.clone(), run);
let input = reshape(input, input_shape_run.clone());
let im_shape = Shape::new([batches_per_run, im_channels, im_h, im_w]);
let image_slice = index(image.clone(), run);
let image_slice = reshape(image_slice, im_shape);
execute(
input,
weight.clone(),
bias.clone(),
image_slice,
options.clone(),
kernel_h,
kernel_w,
)?;
}
Ok(reshape(
image,
Shape::new([batch_size, im_channels, im_h, im_w]),
))
} else {
let im_shape = Shape::new([batches_per_run, im_channels, im_h, im_w]);
let image = empty_device_dtype(
input.client.clone(),
input.device.clone(),
im_shape,
input.dtype,
);
execute(
input,
weight,
bias,
image.clone(),
options,
kernel_h,
kernel_w,
)?;
Ok(image)
}
}
pub(crate) fn index<R: CubeRuntime>(tensor: CubeTensor<R>, i: usize) -> CubeTensor<R> {
#[allow(clippy::single_range_in_vec_init)]
let mut indices = vec![i..i + 1];
for dim in tensor.meta.shape()[1..].iter() {
indices.push(0..*dim);
}
let mut tensor = slice(tensor, &indices);
tensor.meta.remove(0);
tensor
}
#[allow(clippy::too_many_arguments)]
fn execute<R: CubeRuntime>(
input: CubeTensor<R>,
weight: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
image: CubeTensor<R>,
options: ConvTransposeOptions<2>,
kernel_h: usize,
kernel_w: usize,
) -> Result<(), ConvSetupError> {
let [batch_size, _, input_h, input_w] = input.meta.shape().dims();
let [groups, col_shape_0, input_ch_per_group] = weight.meta.shape().dims();
let col_shape_1 = batch_size * input_h * input_w;
let input = swap_dims(input, 0, 1);
let input_shape = Shape::new([groups, input_ch_per_group, col_shape_1]);
let input = reshape(input, input_shape);
let dtype = input.dtype;
let columns = matmul(weight, input, None, MatmulStrategy::default(), dtype)?;
let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1]));
col2im(
columns, bias, image, kernel_h, kernel_w, input_h, input_w, options,
)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn col2im<R: CubeRuntime>(
columns: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
out: CubeTensor<R>,
kernel_h: usize,
kernel_w: usize,
out_h: usize,
out_w: usize,
options: ConvTransposeOptions<2>,
) -> Result<(), LaunchError> {
let dtype = columns.dtype;
let columns = into_contiguous_aligned(columns);
let bias = bias.map(into_contiguous_aligned);
let num_elems = out.meta.num_elements();
let cube_dim = CubeDim::new(&columns.client, num_elems);
let cube_count = calculate_cube_count_elemwise(&columns.client, num_elems, cube_dim);
unsafe {
col2im_kernel::launch_unchecked(
&columns.client,
cube_count,
cube_dim,
address_type!(columns, bias, out),
columns.as_tensor_arg(1),
bias.as_ref().map(|bias| bias.as_tensor_arg(1)).into(),
linear_view(&out, 1),
shape_divmod(&out),
Col2ImArgsLaunch::new(
ScalarArg::new(out_h),
ScalarArg::new(out_w),
ScalarArg::new(kernel_h),
ScalarArg::new(kernel_w),
ScalarArg::new(options.padding[0]),
ScalarArg::new(options.padding[1]),
ScalarArg::new(options.dilation[0]),
ScalarArg::new(options.dilation[1]),
ScalarArg::new(options.stride[0]),
ScalarArg::new(options.stride[1]),
),
dtype.into(),
)
}
}
#[derive(CubeLaunch, CubeType)]
struct Col2ImArgs {
out_h: usize,
out_w: usize,
kernel_h: usize,
kernel_w: usize,
pad_h: usize,
pad_w: usize,
dilation_h: usize,
dilation_w: usize,
stride_h: usize,
stride_w: usize,
}
#[cube(launch_unchecked, address_type = "dynamic")]
fn col2im_kernel<E: Numeric>(
columns: &Tensor<E>,
bias: &Option<Tensor<E>>,
image: &mut LinearView<E, ReadWrite>,
image_shape: Sequence<FastDivmod<usize>>,
args: &Col2ImArgs,
#[define(E)] _dtype: StorageType,
) {
if ABSOLUTE_POS >= image.shape() {
terminate!();
}
let (_, pos) = decompose_linear(ABSOLUTE_POS, &image_shape);
let [batch, ch_im, im_y, im_x] = *pos else {
unreachable!()
};
let im_x = im_x + args.pad_w;
let im_y = im_y + args.pad_h;
let kernel_extent_w = (args.kernel_w - 1) * args.dilation_w + 1;
let kernel_extent_h = (args.kernel_h - 1) * args.dilation_h + 1;
let mut val = E::from_int(0);
let x_col_start = if im_x >= kernel_extent_w {
(im_x - kernel_extent_w) / args.stride_w + 1
} else {
0usize.runtime()
};
let x_col_end = clamp_max(im_x / args.stride_w + 1, args.out_w);
let y_col_start = if im_y >= kernel_extent_h {
(im_y - kernel_extent_h) / args.stride_h + 1
} else {
0usize.runtime()
};
let y_col_end = clamp_max(im_y / args.stride_h + 1, args.out_h);
for col_y in y_col_start..y_col_end {
let kernel_y = im_y - col_y * args.stride_h;
for col_x in x_col_start..x_col_end {
let kernel_x = im_x - col_x * args.stride_w;
if kernel_y.is_multiple_of(args.dilation_h) && kernel_x.is_multiple_of(args.dilation_w)
{
let kernel_y = kernel_y / args.dilation_h;
let kernel_x = kernel_x / args.dilation_w;
let col_k =
ch_im * args.kernel_h * args.kernel_w + kernel_y * args.kernel_w + kernel_x;
let col_n = batch * args.out_h * args.out_w + col_y * args.out_w + col_x;
let col_pos = col_k * columns.stride(0) + col_n * columns.stride(1);
val += columns[col_pos];
}
}
}
match bias {
Some(bias) => image[ABSOLUTE_POS] = val + bias[ch_im],
None => image[ABSOLUTE_POS] = val,
}
}

View File

@@ -0,0 +1,15 @@
mod base;
mod col2im;
mod transpose_direct;
#[cfg(feature = "autotune")]
mod tune;
pub use base::*;
pub use col2im::*;
pub use transpose_direct::*;
#[cfg(feature = "autotune")]
pub use tune::*;

View File

@@ -0,0 +1,185 @@
use crate::{
CubeRuntime,
kernel::utils::{address_type, decompose_linear, linear_view, shape_divmod},
ops::numeric::empty_device_dtype,
tensor::CubeTensor,
};
use burn_backend::{Shape, ops::ConvTransposeOptions};
use cubecl::{
calculate_cube_count_elemwise,
prelude::*,
std::{FastDivmod, tensor::layout::linear::LinearView},
};
use cubek::convolution::components::ConvSetupError;
#[derive(CubeLaunch, CubeType)]
struct ConvArgs {
conv_stride_0: usize,
conv_stride_1: usize,
dilation_0: usize,
dilation_1: usize,
padding_0: usize,
padding_1: usize,
groups: usize,
}
#[cube(launch, address_type = "dynamic")]
fn conv_transpose2d_direct_kernel<E: Numeric>(
input: &Tensor<E>,
weight: &Tensor<E>,
bias: &Option<Tensor<E>>,
output: &mut LinearView<E, ReadWrite>,
out_shape: Sequence<FastDivmod<usize>>,
args: ConvArgs,
#[define(E)] _dtype: StorageType,
) {
if ABSOLUTE_POS >= output.shape() {
terminate!();
}
let in_c_per_group = weight.shape(0) / args.groups;
let out_c_per_group = weight.shape(1);
let kernel_h = weight.shape(2);
let kernel_w = weight.shape(3);
let (_, pos) = decompose_linear(ABSOLUTE_POS, &out_shape);
let [batch, oc_out, out_y, out_x] = *pos else {
unreachable!()
};
let k = oc_out / out_c_per_group;
let group = k % args.groups;
let out_c = oc_out - out_c_per_group * group;
let in_c_start = group * in_c_per_group;
let in_c_end = in_c_start + in_c_per_group;
let stride_0_i = args.conv_stride_0 as i32;
let stride_1_i = args.conv_stride_1 as i32;
let kms_h = (kernel_h * args.dilation_0) as i32 - stride_0_i;
let kms_w = (kernel_w * args.dilation_1) as i32 - stride_1_i;
let y_start = ((out_y + args.padding_0) as i32 - kms_h) / stride_0_i;
let x_start = ((out_x + args.padding_1) as i32 - kms_w) / stride_1_i;
let y_end = clamp(kms_h + y_start + 1, 0, input.shape(2) as i32) as usize;
let x_end = clamp(kms_w + x_start + 1, 0, input.shape(3) as i32) as usize;
let y_start = clamp_min(y_start, 0) as usize;
let x_start = clamp_min(x_start, 0) as usize;
let idx_input_batch = batch * input.stride(0);
let idx_weight_oc = out_c * weight.stride(1);
let bias: Option<E> = bias.map(|bias| bias[oc_out]);
let mut sum = bias.unwrap_or_default();
let numerator_h_base = out_y + args.padding_0;
let numerator_w_base = out_x + args.padding_1;
for in_c in in_c_start..in_c_end {
let idx_input_ic = in_c * input.stride(1);
let idx_weight_ic = in_c * weight.stride(0);
for in_y in y_start..y_end {
let numerator_tmp = in_y * args.conv_stride_0;
let numerator_h = numerator_h_base - numerator_tmp;
if numerator_h_base >= numerator_tmp && numerator_h.is_multiple_of(args.dilation_0) {
let kernel_y = numerator_h / args.dilation_0;
let idx_input_y = in_y * input.stride(2);
let idx_weight_ky = kernel_y * weight.stride(2);
for in_x in x_start..x_end {
let numerator_tmp = in_x * args.conv_stride_1;
let numerator_w = numerator_w_base - numerator_tmp;
if numerator_w_base >= numerator_tmp
&& numerator_w.is_multiple_of(args.dilation_1)
{
let kernel_x = numerator_w / args.dilation_1;
let idx_input_x = in_x * input.stride(3);
let idx_weight_kx = kernel_x * weight.stride(3);
let index_input =
idx_input_batch + idx_input_ic + idx_input_y + idx_input_x;
let index_weight =
idx_weight_ic + idx_weight_oc + idx_weight_ky + idx_weight_kx;
let value = input[index_input];
let weight = weight[index_weight];
sum += value * weight;
}
}
}
}
}
output[ABSOLUTE_POS] = sum;
}
/// Perform a 2D convolution transposition using the direct algorithm.
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
///
pub fn conv_transpose2d_direct<R: CubeRuntime>(
input: CubeTensor<R>,
weight: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
options: ConvTransposeOptions<2>,
) -> Result<CubeTensor<R>, ConvSetupError> {
let [batch_size, _, in_height, in_width] = input.meta.shape().dims();
let [_, out_channels, kernel_0, kernel_1] = weight.meta.shape().dims();
let out_0 = (in_height - 1) * options.stride[0]
+ options.dilation[0] * (kernel_0 - 1)
+ options.padding_out[0]
- 2 * options.padding[0]
+ 1;
let out_1 = (in_width - 1) * options.stride[1]
+ options.dilation[1] * (kernel_1 - 1)
+ options.padding_out[1]
- 2 * options.padding[1]
+ 1;
let shape_out = Shape::new([batch_size, out_channels * options.groups, out_0, out_1]);
let output = empty_device_dtype(
input.client.clone(),
input.device.clone(),
shape_out.clone(),
input.dtype,
);
let num_elems = output.meta.num_elements();
let cube_dim = CubeDim::new(&input.client, num_elems);
let cube_count = calculate_cube_count_elemwise(&input.client, num_elems, cube_dim);
conv_transpose2d_direct_kernel::launch(
&input.client,
cube_count,
cube_dim,
address_type!(input, weight, bias, output),
input.as_tensor_arg(1),
weight.as_tensor_arg(1),
bias.as_ref().map(|bias| bias.as_tensor_arg(1)).into(),
linear_view(&output, 1),
shape_divmod(&output),
ConvArgsLaunch::new(
ScalarArg::new(options.stride[0]),
ScalarArg::new(options.stride[1]),
ScalarArg::new(options.dilation[0]),
ScalarArg::new(options.dilation[1]),
ScalarArg::new(options.padding[0]),
ScalarArg::new(options.padding[1]),
ScalarArg::new(options.groups),
),
input.dtype.into(),
)?;
Ok(output)
}

View File

@@ -0,0 +1,91 @@
use burn_backend::ops::ConvTransposeOptions;
use cubecl::tune::{LocalTuner, Tunable, TunableSet, local_tuner};
use crate::{
CubeAutotuneKey, CubeRuntime, CubeTuneId,
kernel::conv::{ConvTranspose2dAutotuneKey, conv_transpose2d_col2im, conv_transpose2d_direct},
tensor::CubeTensor,
};
/// Executes autotune on conv2d operations
pub fn conv_transpose2d_autotune<R: CubeRuntime>(
input: CubeTensor<R>,
weights: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
options: ConvTransposeOptions<2>,
) -> CubeTensor<R> {
let client = input.client.clone();
static TUNER: LocalTuner<CubeAutotuneKey, CubeTuneId> = local_tuner!();
let tune_set = TUNER.init(|| {
TunableSet::new(create_key::<R>, create_transpose2d_input::<R>)
.with(Tunable::new(
"conv_transpose2d_direct",
conv_transpose2d_direct::<R>,
))
.with(Tunable::new(
"conv_transpose2d_col2im",
conv_transpose2d_col2im::<R>,
))
});
TUNER.execute(
&CubeTuneId::new(&input.client, &input.device),
&client,
tune_set,
(input, weights, bias, options),
)
}
pub fn create_transpose2d_input<R: CubeRuntime>(
_key: &CubeAutotuneKey,
input: &CubeTensor<R>,
weights: &CubeTensor<R>,
bias: &Option<CubeTensor<R>>,
options: &ConvTransposeOptions<2>,
) -> (
CubeTensor<R>,
CubeTensor<R>,
Option<CubeTensor<R>>,
ConvTransposeOptions<2>,
) {
(
input.clone(),
weights.clone(),
bias.clone(),
options.clone(),
)
}
fn create_key<R: CubeRuntime>(
input: &CubeTensor<R>,
weights: &CubeTensor<R>,
bias: &Option<CubeTensor<R>>,
options: &ConvTransposeOptions<2>,
) -> CubeAutotuneKey {
let [batch_size, in_channels, height, width] = input.meta.shape().dims();
let [out_channels, _, kernel_h, kernel_w] = weights.meta.shape().dims();
let ConvTransposeOptions {
stride,
padding,
dilation,
groups,
padding_out,
} = options.clone();
CubeAutotuneKey::ConvTranspose(ConvTranspose2dAutotuneKey::new(
[kernel_h, kernel_w],
stride,
padding,
padding_out,
dilation,
groups,
in_channels,
out_channels,
height,
width,
batch_size,
bias.is_some(),
input.dtype,
))
}

View File

@@ -0,0 +1,222 @@
use cubecl::{
calculate_cube_count_elemwise,
prelude::*,
std::{FastDivmod, tensor::layout::linear::LinearView},
};
use crate::{
CubeRuntime,
kernel::utils::{address_type, decompose_linear, linear_view, shape_divmod},
ops::numeric::empty_device_dtype,
tensor::CubeTensor,
};
use burn_backend::{Shape, ops::ConvTransposeOptions};
#[derive(CubeLaunch, CubeType)]
struct ConvArgs {
conv_stride_0: usize,
conv_stride_1: usize,
conv_stride_2: usize,
dilation_0: usize,
dilation_1: usize,
dilation_2: usize,
padding_0: usize,
padding_1: usize,
padding_2: usize,
groups: usize,
}
#[cube(launch, address_type = "dynamic")]
fn conv_transpose3d_kernel<E: Numeric>(
input: &Tensor<E>,
weight: &Tensor<E>,
bias: &Option<Tensor<E>>,
output: &mut LinearView<E, ReadWrite>,
out_shape: Sequence<FastDivmod<usize>>,
args: ConvArgs,
#[define(E)] _dtype: StorageType,
) {
let in_channels = weight.shape(0);
let out_c_per_group = weight.shape(1);
let kernel_size_0 = weight.shape(2);
let kernel_size_1 = weight.shape(3);
let kernel_size_2 = weight.shape(4);
let stride_0_i = args.conv_stride_0 as i32;
let stride_1_i = args.conv_stride_1 as i32;
let stride_2_i = args.conv_stride_2 as i32;
let (_, pos) = decompose_linear(ABSOLUTE_POS, &out_shape);
let [batch, out_c_out, out_z, out_y, out_x] = *pos else {
unreachable!()
};
let groups = args.groups;
let in_c_per_group = in_channels / groups;
let k = out_c_out / out_c_per_group;
let group = k % groups;
let out_channel = out_c_out - out_c_per_group * group;
let in_c_start = group * in_c_per_group;
let in_c_end = in_c_start + in_c_per_group;
let kernel_d = (kernel_size_0 * args.dilation_0 - args.conv_stride_0) as i32;
let kernel_h = (kernel_size_1 * args.dilation_1 - args.conv_stride_1) as i32;
let kernel_w = (kernel_size_2 * args.dilation_2 - args.conv_stride_2) as i32;
let z_start = ((out_z + args.padding_0) as i32 - kernel_d) / stride_0_i;
let y_start = ((out_y + args.padding_1) as i32 - kernel_h) / stride_1_i;
let x_start = ((out_x + args.padding_2) as i32 - kernel_w) / stride_2_i;
let z_end = clamp(kernel_d + z_start + 1, 0, input.shape(2) as i32) as usize;
let y_end = clamp(kernel_h + y_start + 1, 0, input.shape(3) as i32) as usize;
let x_end = clamp(kernel_w + x_start + 1, 0, input.shape(4) as i32) as usize;
let z_start = clamp_min(z_start, 0) as usize;
let y_start = clamp_min(y_start, 0) as usize;
let x_start = clamp_min(x_start, 0) as usize;
let index_input_batch = batch * input.stride(0);
let index_weight_out_c = out_channel * weight.stride(1);
let bias: Option<E> = bias.map(|bias| bias[out_c_out]);
let mut sum = bias.unwrap_or_default();
let numerator_d_base = out_z + args.padding_0;
let numerator_h_base = out_y + args.padding_1;
let numerator_w_base = out_x + args.padding_2;
for in_c in in_c_start..in_c_end {
let index_input_in_c = in_c * input.stride(1);
let index_weight_in_c = in_c * weight.stride(0);
for in_z in z_start..z_end {
let numerator_tmp = in_z * args.conv_stride_0;
let numerator_d = numerator_d_base - numerator_tmp;
if numerator_d_base >= numerator_tmp && numerator_d.is_multiple_of(args.dilation_0) {
let kernel_z = numerator_d / args.dilation_0;
let index_input_z = in_z * input.stride(2);
let index_weight_kz = kernel_z * weight.stride(2);
for in_y in y_start..y_end {
let numerator_tmp = in_y * args.conv_stride_1;
let numerator_h = numerator_h_base - numerator_tmp;
if numerator_h_base >= numerator_tmp
&& numerator_h.is_multiple_of(args.dilation_1)
{
let kernel_y = numerator_h / args.dilation_1;
let index_input_y = in_y * input.stride(3);
let index_weight_ky = kernel_y * weight.stride(3);
for in_x in x_start..x_end {
let numerator_tmp = in_x * args.conv_stride_2;
let numerator_w = numerator_w_base - numerator_tmp;
if numerator_w_base >= numerator_tmp
&& numerator_w.is_multiple_of(args.dilation_2)
{
let kernel_x = numerator_w / args.dilation_2;
let index_input_x = in_x * input.stride(4);
let index_weight_kx = kernel_x * weight.stride(4);
let index_input = index_input_batch
+ index_input_in_c
+ index_input_z
+ index_input_y
+ index_input_x;
let index_weight = index_weight_in_c
+ index_weight_out_c
+ index_weight_kz
+ index_weight_ky
+ index_weight_kx;
let value = input[index_input];
let weight = weight[index_weight];
sum += value * weight;
}
}
}
}
}
}
}
output[ABSOLUTE_POS] = sum;
}
pub(crate) fn conv_transpose3d<R: CubeRuntime>(
input: CubeTensor<R>,
weight: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
options: ConvTransposeOptions<3>,
) -> Result<CubeTensor<R>, LaunchError> {
let [batch_size, _, in_depth, in_height, in_width] = input.meta.shape().dims();
let [_, out_channels, kernel_0, kernel_1, kernel_2] = weight.meta.shape().dims();
let out_0 = (in_depth - 1) * options.stride[0]
+ options.dilation[0] * (kernel_0 - 1)
+ options.padding_out[0]
- 2 * options.padding[0]
+ 1;
let out_1 = (in_height - 1) * options.stride[1]
+ options.dilation[1] * (kernel_1 - 1)
+ options.padding_out[1]
- 2 * options.padding[1]
+ 1;
let out_2 = (in_width - 1) * options.stride[2]
+ options.dilation[2] * (kernel_2 - 1)
+ options.padding_out[2]
- 2 * options.padding[2]
+ 1;
let shape_out = Shape::new([
batch_size,
out_channels * options.groups,
out_0,
out_1,
out_2,
]);
let output = empty_device_dtype(
input.client.clone(),
input.device.clone(),
shape_out.clone(),
input.dtype,
);
let num_elems = output.meta.num_elements();
let cube_dim = CubeDim::new(&input.client, num_elems);
let cube_count = calculate_cube_count_elemwise(&input.client, num_elems, cube_dim);
conv_transpose3d_kernel::launch(
&input.client,
cube_count,
cube_dim,
address_type!(input, weight, bias, output),
input.as_tensor_arg(1),
weight.as_tensor_arg(1),
bias.as_ref().map(|bias| bias.as_tensor_arg(1)).into(),
linear_view(&output, 1),
shape_divmod(&output),
ConvArgsLaunch::new(
ScalarArg::new(options.stride[0]),
ScalarArg::new(options.stride[1]),
ScalarArg::new(options.stride[2]),
ScalarArg::new(options.dilation[0]),
ScalarArg::new(options.dilation[1]),
ScalarArg::new(options.dilation[2]),
ScalarArg::new(options.padding[0]),
ScalarArg::new(options.padding[1]),
ScalarArg::new(options.padding[2]),
ScalarArg::new(options.groups),
),
input.dtype.into(),
)?;
Ok(output)
}

View File

@@ -0,0 +1,314 @@
use cubecl::{
calculate_cube_count_elemwise,
prelude::*,
std::{FastDivmod, FastDivmodArgs},
};
use cubek::convolution::components::ConvSetupError;
use burn_backend::{
Shape,
ops::{DeformConvOptions, conv::calculate_conv_output_size},
};
use crate::{
CubeRuntime,
kernel::{
AddOp, into_contiguous_aligned, launch_binop,
matmul::{MatmulStrategy, matmul},
utils::address_type,
},
ops::{numeric::zeros_client, reshape, swap_dims},
tensor::CubeTensor,
};
#[derive(CubeLaunch, CubeType)]
struct DeformConv2dArgs {
conv_stride_h: usize,
conv_stride_w: usize,
dilation_h: usize,
dilation_w: usize,
padding_h: InputScalar,
padding_w: InputScalar,
offset_groups: usize,
kernel_height: usize,
kernel_width: usize,
out_h: usize,
out_w: usize,
}
#[cube(launch, address_type = "dynamic")]
fn deform_im2col_kernel<F: Float>(
input: &Tensor<F>,
offset: &Tensor<F>,
mask: &Option<Tensor<F>>,
columns: &mut Tensor<F>,
pos_shape: Sequence<FastDivmod<usize>>,
args: &DeformConv2dArgs,
#[comptime] kernel_h_unroll: Option<usize>,
#[comptime] kernel_w_unroll: Option<usize>,
#[define(F)] _dtype: StorageType,
) {
// position shape: [in_channels, batch_size, out_h, out_w]
// columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]
let kernel_height = kernel_h_unroll.unwrap_or(args.kernel_height);
let unroll_h = kernel_h_unroll.is_some();
let kernel_width = kernel_w_unroll.unwrap_or(args.kernel_width);
let unroll_w = kernel_w_unroll.is_some();
let out_h = args.out_h;
let out_w = args.out_w;
let in_channels = input.shape(1);
let height = input.shape(2);
let width = input.shape(3);
let col_stride_0 = columns.stride(0);
let (rem, out_x) = pos_shape[3].div_mod(ABSOLUTE_POS);
let (rem, out_y) = pos_shape[2].div_mod(rem);
let (in_channel, batch) = pos_shape[1].div_mod(rem);
if in_channel >= in_channels {
terminate!()
}
let out_k_base = in_channel * kernel_height * kernel_width;
let out_n = batch * out_h * out_w + out_y * out_w + out_x;
let channels_per_offset_group = in_channels / args.offset_groups;
let group_index = in_channel / channels_per_offset_group;
let mut col_base_idx = out_k_base * columns.stride(0) + out_n * columns.stride(1);
let input_base_idx = batch * input.stride(0) + in_channel * input.stride(1);
let offset_base_idx = batch * offset.stride(0)
+ group_index * kernel_height * kernel_width * 2 * offset.stride(1);
let mask_base_idx = mask.as_ref().map(|mask| {
batch * mask.stride(0) + group_index * kernel_height * kernel_width * mask.stride(1)
});
#[unroll(unroll_h)]
for kernel_y in 0..kernel_height {
#[unroll(unroll_w)]
for kernel_x in 0..kernel_width {
let mask_index = kernel_y * kernel_width + kernel_x;
let offset_index = mask_index * 2;
let offset_y = offset[offset_base_idx
+ offset_index * offset.stride(1)
+ out_y * offset.stride(2)
+ out_x * offset.stride(3)];
let offset_x = offset[offset_base_idx
+ (offset_index + 1) * offset.stride(1)
+ out_y * offset.stride(2)
+ out_x * offset.stride(3)];
let y = F::cast_from(out_y * args.conv_stride_h + kernel_y * args.dilation_h)
- args.padding_h.get::<F>()
+ offset_y;
let x = F::cast_from(out_x * args.conv_stride_w + kernel_x * args.dilation_w)
- args.padding_w.get::<F>()
+ offset_x;
let interpolated = bilinear_interpolate(input, height, width, y, x, input_base_idx);
let value = match mask.zip::<usize>(mask_base_idx) {
Some((mask, base_idx)) => {
let mask_value = mask[base_idx
+ mask_index * mask.stride(1)
+ out_y * mask.stride(2)
+ out_x * mask.stride(3)];
mask_value * interpolated
}
None => interpolated,
};
columns[col_base_idx] = value;
col_base_idx += col_stride_0;
}
}
}
#[cube]
pub(crate) fn bilinear_interpolate<F: Float>(
input: &Tensor<F>,
height: usize,
width: usize,
y: F,
x: F,
offset: usize,
) -> F {
// To simplify code
let y = f32::cast_from(y);
let x = f32::cast_from(x);
let stride_y = input.stride(2);
let stride_x = input.stride(3);
let mut result = F::new(0.0);
if y > -1.0 && height as f32 > y && x > -1.0 && width as f32 > x {
let y_low = y.floor();
let x_low = x.floor();
let y_high = (y_low + 1.) as usize;
let x_high = (x_low + 1.) as usize;
let zero = F::new(0.0);
let v1: F = if y_low >= 0. && x_low >= 0. {
input[offset + y_low as usize * stride_y + x_low as usize * stride_x]
} else {
zero
};
let v2: F = if y_low >= 0. && x_high < width {
input[offset + y_low as usize * stride_y + x_high * stride_x]
} else {
zero
};
let v3: F = if y_high < height && x_low >= 0. {
input[offset + y_high * stride_y + x_low as usize * stride_x]
} else {
zero
};
let v4: F = if y_high < height && x_high < width {
input[offset + y_high * stride_y + x_high * stride_x]
} else {
zero
};
let l_y = y - y_low;
let l_x = x - x_low;
let h_y = 1.0 - l_y;
let h_x = 1.0 - l_x;
let w1 = F::cast_from(h_y * h_x);
let w2 = F::cast_from(h_y * l_x);
let w3 = F::cast_from(l_y * h_x);
let w4 = F::cast_from(l_y * l_x);
result = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
}
result
}
pub(crate) fn deform_im2col<R: CubeRuntime>(
input: CubeTensor<R>,
offset: CubeTensor<R>,
mask: Option<CubeTensor<R>>,
options: DeformConvOptions<2>,
out_dims: (usize, usize),
kernel_dims: (usize, usize),
) -> Result<CubeTensor<R>, LaunchError> {
let client = input.client.clone();
let device = input.device.clone();
let dtype = input.dtype;
let [batch_size, in_channels, _, _] = input.meta.shape().dims();
let (out_height, out_width) = out_dims;
let (kernel_height, kernel_width) = kernel_dims;
let shape_out = Shape::new([
in_channels * kernel_height * kernel_width,
batch_size * out_height * out_width,
]);
let pos_shape = [in_channels, batch_size, out_height, out_width]
.into_iter()
.map(|s| FastDivmodArgs::new(&client, s))
.collect();
let output = zeros_client(client.clone(), device.clone(), shape_out.clone(), dtype);
let num_kernels = in_channels * batch_size * out_height * out_width;
let cube_dim = CubeDim::new(&input.client, num_kernels);
let cube_count = calculate_cube_count_elemwise(&input.client, num_kernels, cube_dim);
deform_im2col_kernel::launch(
&input.client,
cube_count,
cube_dim,
address_type!(input, offset, mask, output),
input.as_tensor_arg(1),
offset.as_tensor_arg(1),
mask.as_ref().map(|mask| mask.as_tensor_arg(1)).into(),
output.as_handle_ref().as_tensor_arg(1),
pos_shape,
DeformConv2dArgsLaunch::new(
ScalarArg::new(options.stride[0]),
ScalarArg::new(options.stride[1]),
ScalarArg::new(options.dilation[0]),
ScalarArg::new(options.dilation[1]),
{
let val = options.padding[0] as f32;
InputScalar::new(val, dtype)
},
{
let val = options.padding[1] as f32;
InputScalar::new(val, dtype)
},
ScalarArg::new(options.offset_groups),
ScalarArg::new(kernel_height),
ScalarArg::new(kernel_width),
ScalarArg::new(out_height),
ScalarArg::new(out_width),
),
Some(kernel_height),
Some(kernel_width),
dtype.into(),
)?;
Ok(output)
}
pub(crate) fn deform_conv2d<R: CubeRuntime>(
input: CubeTensor<R>,
offset: CubeTensor<R>,
weight: CubeTensor<R>,
mask: Option<CubeTensor<R>>,
bias: Option<CubeTensor<R>>,
options: DeformConvOptions<2>,
) -> Result<CubeTensor<R>, ConvSetupError> {
let input = into_contiguous_aligned(input);
let offset = into_contiguous_aligned(offset);
let weight = into_contiguous_aligned(weight);
let mask = mask.map(|it| into_contiguous_aligned(it));
let bias = bias.map(|it| into_contiguous_aligned(it));
let [batch_size, _, in_height, in_width] = input.meta.shape().dims();
let [out_channels, _, kernel_h, kernel_w] = weight.meta.shape().dims();
let groups = options.weight_groups;
let out_h = calculate_conv_output_size(
kernel_h,
options.stride[0],
options.padding[0],
options.dilation[0],
in_height,
);
let out_w = calculate_conv_output_size(
kernel_w,
options.stride[1],
options.padding[1],
options.dilation[1],
in_width,
);
let out_dims = (out_h, out_w);
let columns = deform_im2col(input, offset, mask, options, out_dims, (kernel_h, kernel_w))?;
let [col_size_0, col_size_1] = columns.meta.shape().dims();
let col_size_0 = col_size_0 / groups;
let out_c_per_group = out_channels / groups;
let dtype = weight.dtype;
let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_size_0]));
let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1]));
let out = matmul(weight, columns, None, MatmulStrategy::default(), dtype)?;
let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w]));
let out = swap_dims(out, 0, 1);
if let Some(bias) = bias {
let bias = reshape(bias, Shape::new([1, out_channels, 1, 1]));
Ok(launch_binop::<R, AddOp>(out, bias))
} else {
Ok(out)
}
}

View File

@@ -0,0 +1,720 @@
use super::{bilinear_interpolate, deform_im2col, index};
use crate::{
CubeRuntime,
kernel::{
cast, into_contiguous_aligned,
matmul::{MatmulStrategy, matmul},
reduce::reduce_dim,
slice_assign,
utils::{address_type, decompose_linear, linear_view},
},
ops::{
numeric::{empty_device_dtype, zeros_client},
reshape, swap_dims,
},
tensor::CubeTensor,
};
use burn_backend::{DType, Shape, TensorMetadata, ops::DeformConvOptions};
use cubecl::{
CubeDim, CubeLaunch, calculate_cube_count_elemwise, cube,
features::TypeUsage,
ir::FloatKind,
prelude::*,
std::{FastDivmod, FastDivmodArgs, tensor::layout::linear::LinearView},
};
use cubek::{
convolution::components::ConvSetupError,
reduce::components::instructions::ReduceOperationConfig,
};
use std::marker::PhantomData;
/// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions.
#[allow(
clippy::single_range_in_vec_init,
clippy::type_complexity,
clippy::too_many_arguments
)]
pub(crate) fn deform_conv2d_backward<R: CubeRuntime>(
input: CubeTensor<R>,
offset: CubeTensor<R>,
weight: CubeTensor<R>,
mask: Option<CubeTensor<R>>,
bias: Option<CubeTensor<R>>,
out_grad: CubeTensor<R>,
options: DeformConvOptions<2>,
) -> Result<
(
CubeTensor<R>,
CubeTensor<R>,
CubeTensor<R>,
Option<CubeTensor<R>>,
Option<CubeTensor<R>>,
),
ConvSetupError,
> {
let [_, _, out_h, out_w] = out_grad.meta.shape().dims();
let [_, _, kernel_h, kernel_w] = weight.meta.shape().dims();
let gradient_bias = bias.map(|bias| {
let grad = reduce_dim(
out_grad.clone(),
None,
0,
Default::default(),
ReduceOperationConfig::Sum,
)
.unwrap();
let grad = reduce_dim(
grad,
None,
2,
Default::default(),
ReduceOperationConfig::Sum,
)
.unwrap();
let grad = reduce_dim(
grad,
None,
3,
Default::default(),
ReduceOperationConfig::Sum,
)
.unwrap();
reshape(grad, bias.meta.shape)
});
let input = into_contiguous_aligned(input);
let offset = into_contiguous_aligned(offset);
let weight = into_contiguous_aligned(weight);
let mask = mask.map(|it| into_contiguous_aligned(it));
let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs(
input.clone(),
weight.clone(),
offset.clone(),
mask.clone(),
out_grad.clone(),
&options,
(kernel_h, kernel_w),
)?;
let weight_grad = compute_weight_grad(
input,
offset,
mask,
out_grad,
options,
(kernel_h, kernel_w),
(out_h, out_w),
)?;
Ok((
input_gradient,
offset_gradient,
weight_grad,
mask_gradient,
gradient_bias,
))
}
fn compute_weight_grad<R: CubeRuntime>(
input: CubeTensor<R>,
offset: CubeTensor<R>,
mask: Option<CubeTensor<R>>,
out_grad: CubeTensor<R>,
options: DeformConvOptions<2>,
kernel_dims: (usize, usize),
out_dims: (usize, usize),
) -> Result<CubeTensor<R>, ConvSetupError> {
let [_, in_channels, _, _] = input.meta.shape().dims();
let [_, out_channels, _, _] = out_grad.meta.shape().dims();
let (kernel_h, kernel_w) = kernel_dims;
let groups = options.weight_groups;
let dtype = input.dtype;
let in_c_per_group = in_channels / groups;
let out_c_per_group = out_channels / groups;
let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims)?;
let [col_size_0, col_size_1] = columns.meta.shape().dims();
let col_size_0 = col_size_0 / groups;
let out_grad = swap_dims(out_grad, 0, 1);
let out_grad = reshape(out_grad, Shape::new([groups, out_c_per_group, col_size_1]));
let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1]));
let columns = swap_dims(columns, 1, 2);
let grad_weight = matmul(out_grad, columns, None, MatmulStrategy::default(), dtype)?;
Ok(reshape(
grad_weight,
Shape::new([out_channels, in_c_per_group, kernel_h, kernel_w]),
))
}
type InputGradients<R> = (CubeTensor<R>, CubeTensor<R>, Option<CubeTensor<R>>);
fn backward_gradient_inputs<R: CubeRuntime>(
image: CubeTensor<R>,
weight: CubeTensor<R>,
offset: CubeTensor<R>,
mask: Option<CubeTensor<R>>,
out_grad: CubeTensor<R>,
options: &DeformConvOptions<2>,
kernel_dims: (usize, usize),
) -> Result<InputGradients<R>, ConvSetupError> {
let client = out_grad.client.clone();
let device = out_grad.device.clone();
let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.meta.shape().dims();
let [batch_size, _, out_h, out_w] = out_grad.meta.shape().dims();
let groups = options.weight_groups;
let out_c_per_group = out_channels / groups;
let col_shape_0 = in_c_per_group * kernel_h * kernel_w;
let col_shape_1 = batch_size * out_h * out_w;
let col_shape = Shape::new([groups, col_shape_0, col_shape_1]);
let mut columns = empty_device_dtype(client, device, col_shape, weight.dtype);
let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0]));
let out_grad = swap_dims(out_grad, 0, 1);
let out_grad_shape = Shape::new([groups, out_c_per_group, col_shape_1]);
let out_grad = reshape(out_grad, out_grad_shape);
for group in 0..groups {
let dtype = weight.dtype;
let weight = swap_dims(index(weight.clone(), group), 0, 1);
let out_grad = index(out_grad.clone(), group);
let values = matmul(weight, out_grad, None, MatmulStrategy::default(), dtype)?;
let values = reshape(values, Shape::new([1, col_shape_0, col_shape_1]));
columns = slice_assign(
columns,
&[
burn_backend::Slice::from(group..group + 1),
burn_backend::Slice::from(0..col_shape_0),
burn_backend::Slice::from(0..col_shape_1),
],
values,
);
}
let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1]));
let input_shape = image.shape();
let (offset_gradient, mask_gradient) = compute_offset_and_mask_gradient(
columns.clone(),
image,
offset.clone(),
mask.clone(),
options,
kernel_dims,
)?;
let input_gradient =
compute_input_grad(columns, offset, mask, options, kernel_dims, input_shape)?;
Ok((input_gradient, offset_gradient, mask_gradient))
}
fn compute_offset_and_mask_gradient<R: CubeRuntime>(
columns: CubeTensor<R>,
image: CubeTensor<R>,
offset: CubeTensor<R>,
mask: Option<CubeTensor<R>>,
options: &DeformConvOptions<2>,
kernel_dims: (usize, usize),
) -> Result<(CubeTensor<R>, Option<CubeTensor<R>>), ConvSetupError> {
let client = offset.client.clone();
let device = offset.device.clone();
let (kernel_h, kernel_w) = kernel_dims;
let [batches, _, out_h, out_w] = offset.meta.shape().dims();
let offset_groups = options.offset_groups;
let pos_shape = [batches, offset_groups, kernel_h, kernel_w, 2, out_h, out_w];
let pos_shape = pos_shape
.into_iter()
.map(|s| FastDivmodArgs::new(&client, s))
.collect();
let grad_offset =
empty_device_dtype(client.clone(), device.clone(), offset.shape(), offset.dtype);
let grad_mask = mask
.as_ref()
.map(|mask| empty_device_dtype(client.clone(), device.clone(), mask.shape(), mask.dtype));
let num_elements_offset = offset.meta.num_elements();
let cube_dim = CubeDim::new(&image.client, num_elements_offset);
let cube_count = calculate_cube_count_elemwise(&image.client, num_elements_offset, cube_dim);
let dtype: StorageType = image.dtype.into();
unsafe {
deform_col2img_coord_kernel::launch_unchecked(
&image.client,
cube_count,
cube_dim,
address_type!(image, offset, mask, grad_offset, grad_mask),
image.as_tensor_arg(1),
offset.as_tensor_arg(1),
mask.as_ref().map(|mask| mask.as_tensor_arg(1)).into(),
columns.as_tensor_arg(1),
linear_view(&grad_offset, 1),
grad_mask
.as_ref()
.map(|grad_mask| grad_mask.as_tensor_arg(1))
.into(),
pos_shape,
DeformConv2dCol2ImgCoordArgsLaunch::new(
ScalarArg::new(options.stride[0]),
ScalarArg::new(options.stride[1]),
ScalarArg::new(options.dilation[0]),
ScalarArg::new(options.dilation[1]),
InputScalar::new(options.padding[0] as f32, dtype.elem_type()),
InputScalar::new(options.padding[1] as f32, dtype.elem_type()),
ScalarArg::new(offset_groups),
ScalarArg::new(kernel_h),
ScalarArg::new(kernel_w),
),
dtype,
)
}?;
Ok((grad_offset, grad_mask))
}
#[derive(CubeLaunch, CubeType)]
struct DeformConv2dCol2ImgCoordArgs {
stride_h: usize,
stride_w: usize,
dilation_h: usize,
dilation_w: usize,
pad_h: InputScalar,
pad_w: InputScalar,
offset_groups: usize,
kernel_height: usize,
kernel_width: usize,
}
#[allow(clippy::collapsible_if)]
#[cube(launch_unchecked, address_type = "dynamic")]
fn deform_col2img_coord_kernel<F: Float>(
image: &Tensor<F>,
offset: &Tensor<F>,
mask: &Option<Tensor<F>>,
columns: &Tensor<F>,
grad_offset: &mut LinearView<F, ReadWrite>,
grad_mask: &mut Option<Tensor<F>>,
pos_shape: Sequence<FastDivmod<usize>>,
args: &DeformConv2dCol2ImgCoordArgs,
#[define(F)] _dtype: StorageType,
) {
// Position format: [batch, [offset_groups, kernel_h, kernel_w, 2], out_h, out_w]
// Columns format: [[in_channel, kernel_h, kernel_w], [batch, out_h, out_w]]
// Alternatively : [batch, offset_channels, out_h, out_w]
if ABSOLUTE_POS >= grad_offset.shape() {
terminate!();
}
let out_h = offset.shape(2);
let out_w = offset.shape(3);
let in_channels = image.shape(1);
let height = image.shape(2);
let width = image.shape(3);
let kernel_w = args.kernel_width;
let kernel_h = args.kernel_height;
let mut grad_offset_val = F::new(0.0);
let mut grad_mask_val = F::new(0.0);
let (_, pos) = decompose_linear(ABSOLUTE_POS, &pos_shape);
let [batch, offset_group, kernel_y, kernel_x, dir, out_y, out_x] = *pos else {
unreachable!()
};
let channels_per_offset_group = in_channels / args.offset_groups;
let col_n = batch * out_h * out_w + out_y * out_w + out_x;
let col_base_idx =
offset_group * channels_per_offset_group * kernel_h * kernel_w * columns.stride(0)
+ col_n * columns.stride(1);
let mut image_base_idx =
batch * image.stride(0) + offset_group * channels_per_offset_group * image.stride(1);
let offset_pos_1 =
offset_group * kernel_h * kernel_w * 2 + kernel_y * kernel_w * 2 + kernel_x * 2;
let offset_base_idx = batch * offset.stride(0)
+ offset_pos_1 * offset.stride(1)
+ out_y * offset.stride(2)
+ out_x * offset.stride(3);
let offset_y_idx = offset_base_idx;
let offset_x_idx = offset_base_idx + offset.stride(1);
let offset_y = offset[offset_y_idx];
let offset_x = offset[offset_x_idx];
let mask_pos_1 = offset_group * kernel_h * kernel_w + kernel_y * kernel_w + kernel_x;
let mask_value = match &mask {
Some(mask) => {
let mask_idx = batch * mask.stride(0)
+ mask_pos_1 * mask.stride(1)
+ out_y * mask.stride(2)
+ out_x * mask.stride(3);
mask[mask_idx]
}
None => F::new(1.0),
};
let is_y_direction = dir == 0;
for col_c in 0..channels_per_offset_group {
let col_pos = col_base_idx + col_c * kernel_h * kernel_w * columns.stride(0);
let y = F::cast_from(out_y * args.stride_h + kernel_y * args.dilation_h)
- args.pad_h.get::<F>()
+ offset_y;
let x = F::cast_from(out_x * args.stride_w + kernel_x * args.dilation_w)
- args.pad_w.get::<F>()
+ offset_x;
let weight =
get_coordinate_weight(image, image_base_idx, height, width, y, x, is_y_direction);
let columns_value = columns[col_pos];
grad_offset_val += mask_value * weight * columns_value;
if grad_mask.is_some() && is_y_direction {
grad_mask_val +=
columns_value * bilinear_interpolate(image, height, width, y, x, image_base_idx);
}
image_base_idx += image.stride(1);
}
grad_offset[ABSOLUTE_POS] = grad_offset_val;
if let Some(grad_mask) = grad_mask {
if is_y_direction {
let idx = batch * grad_mask.stride(0)
+ mask_pos_1 * grad_mask.stride(1)
+ out_y * grad_mask.stride(2)
+ out_x * grad_mask.stride(3);
grad_mask[idx] = grad_mask_val
}
}
}
#[cube]
fn get_coordinate_weight<F: Float>(
input: &Tensor<F>,
offset: usize,
height: usize,
width: usize,
y: F,
x: F,
is_y_direction: bool,
) -> F {
let stride_y = input.stride(2);
let stride_x = input.stride(3);
let y = f32::cast_from(y);
let x = f32::cast_from(x);
let y_low = f32::floor(y);
let x_low = f32::floor(x);
let y_high = y_low + 1.;
let x_high = x_low + 1.;
let valid_y_low = y_low >= 0. && y_low < height as f32;
let valid_y_high = y_high >= 0. && y_high < height as f32;
let valid_x_low = x_low >= 0. && x_low < width as f32;
let valid_x_high = x_high >= 0. && x_high < width as f32;
let bottom_left = if valid_y_low && valid_x_low {
input[offset + y_low as usize * stride_y + x_low as usize * stride_x]
} else {
F::new(0.0)
};
let bottom_right = if valid_y_low && valid_x_high {
input[offset + y_low as usize * stride_y + x_high as usize * stride_x]
} else {
F::new(0.0)
};
let top_left = if valid_y_high && valid_x_low {
input[offset + y_high as usize * stride_y + x_low as usize * stride_x]
} else {
F::new(0.0)
};
let top_right = if valid_y_high && valid_x_high {
input[offset + y_high as usize * stride_y + x_high as usize * stride_x]
} else {
F::new(0.0)
};
if is_y_direction {
let delta_x = F::cast_from(x - x_low);
delta_x * (top_right - bottom_right) + (F::new(1.0) - delta_x) * (top_left - bottom_left)
} else {
let delta_y = F::cast_from(y - y_low);
delta_y * (top_right - top_left) + (F::new(1.0) - delta_y) * (bottom_right - bottom_left)
}
}
fn compute_input_grad<R: CubeRuntime>(
columns: CubeTensor<R>,
offset: CubeTensor<R>,
mask: Option<CubeTensor<R>>,
options: &DeformConvOptions<2>,
kernel_dims: (usize, usize),
input_shape: Shape,
) -> Result<CubeTensor<R>, LaunchError> {
let client = offset.client.clone();
let device = offset.device.clone();
let supports_fadd = client
.properties()
.type_usage(StorageType::Atomic(FloatKind::F32.into()))
.contains(TypeUsage::AtomicAdd);
let supports_same_type = client
.properties()
.type_usage(StorageType::Atomic(columns.dtype.into()))
.contains(TypeUsage::AtomicAdd);
let [batches, in_channels, height, width] = input_shape.dims();
let [_, _, out_h, out_w] = offset.meta.shape().dims();
let (kernel_h, kernel_w) = kernel_dims;
let pos_shape = [in_channels, kernel_h, kernel_w, batches, out_h, out_w];
let pos_shape = pos_shape
.into_iter()
.map(|s| FastDivmodArgs::new(&client, s))
.collect();
let shape = Shape::new([batches, in_channels, height, width]);
let grad_in = match supports_fadd && supports_same_type {
// Use type as is to save a cast
true => zeros_client(client.clone(), device.clone(), shape, columns.dtype),
// Force `f32` to enable bitcasting as `u32`, or use intrinsic when supported
false => zeros_client(client.clone(), device.clone(), shape, DType::F32),
};
let grad_arg = grad_in.as_tensor_arg(1);
let num_elements = columns.meta.num_elements();
let cube_dim = CubeDim::new(&offset.client, num_elements);
let cube_count = calculate_cube_count_elemwise(&offset.client, num_elements, cube_dim);
let launch = match supports_fadd {
true => deform_col2img_kernel::launch_unchecked::<IntrinsicFloatAtomicAddFamily, R>,
false => deform_col2img_kernel::launch_unchecked::<CASFloatAtomicAdd, R>,
};
let dtype = offset.dtype;
let dtypes: [StorageType; 2] = match supports_same_type {
true => [dtype.into(), dtype.into()],
false => [dtype.into(), DType::F32.into()],
};
unsafe {
launch(
&offset.client,
cube_count,
cube_dim,
address_type!(offset, mask, columns, grad_in),
offset.as_tensor_arg(1),
mask.as_ref().map(|mask| mask.as_tensor_arg(1)).into(),
linear_view(&columns, 1),
grad_arg,
pos_shape,
DeformConv2dCol2ImgArgsLaunch::new(
ScalarArg::new(options.stride[0]),
ScalarArg::new(options.stride[1]),
ScalarArg::new(options.dilation[0]),
ScalarArg::new(options.dilation[1]),
InputScalar::new(options.padding[0] as f32, dtypes[0].elem_type()),
InputScalar::new(options.padding[1] as f32, dtypes[0].elem_type()),
ScalarArg::new(options.offset_groups),
ScalarArg::new(kernel_h),
ScalarArg::new(kernel_w),
),
dtypes,
)
}?;
Ok(if !supports_same_type || !supports_fadd {
cast(grad_in, dtype)
} else {
grad_in
})
}
#[derive(CubeLaunch, CubeType)]
struct DeformConv2dCol2ImgArgs {
stride_h: usize,
stride_w: usize,
dilation_h: usize,
dilation_w: usize,
pad_h: InputScalar,
pad_w: InputScalar,
offset_groups: usize,
kernel_height: usize,
kernel_width: usize,
}
#[cube(launch_unchecked, address_type = "dynamic")]
fn deform_col2img_kernel<F: Float, FP: Float, FAdd: FloatAtomicAddFamily>(
offset: &Tensor<F>,
mask: &Option<Tensor<F>>,
columns: &LinearView<F>,
grad_input: &mut Tensor<Atomic<ProxyType<FAdd, FP>>>,
pos_shape: Sequence<FastDivmod<usize>>,
args: &DeformConv2dCol2ImgArgs,
#[define(F, FP)] _dtype: [StorageType; 2],
) {
// Position format: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]
if ABSOLUTE_POS >= columns.shape() {
terminate!();
}
let n_in_channels = grad_input.shape(1);
let height = grad_input.shape(2);
let width = grad_input.shape(3);
let kernel_h = args.kernel_height;
let kernel_w = args.kernel_width;
let n_offset_groups = args.offset_groups;
let (_, pos) = decompose_linear(ABSOLUTE_POS, &pos_shape);
let [in_channel, kernel_y, kernel_x, batch, out_y, out_x] = *pos else {
unreachable!()
};
let channels_per_offset_group = n_in_channels / n_offset_groups;
let offset_group = in_channel / channels_per_offset_group;
let offset_pos_1 =
offset_group * kernel_h * kernel_w * 2 + kernel_y * kernel_w * 2 + kernel_x * 2;
let offset_base_idx = batch * offset.stride(0)
+ offset_pos_1 * offset.stride(1)
+ out_y * offset.stride(2)
+ out_x * offset.stride(3);
let offset_y_idx = offset_base_idx;
let offset_x_idx = offset_base_idx + offset.stride(1);
let offset_y = offset[offset_y_idx];
let offset_x = offset[offset_x_idx];
let mask_value = match mask {
Some(mask) => {
let mask_pos_1 = offset_group * kernel_h * kernel_w + kernel_y * kernel_w + kernel_x;
mask[batch * mask.stride(0)
+ mask_pos_1 * mask.stride(1)
+ out_y * mask.stride(2)
+ out_x * mask.stride(3)]
}
None => F::new(1.0),
};
let y = F::cast_from(out_y * args.stride_h + kernel_y * args.dilation_h)
- args.pad_h.get::<F>()
+ offset_y;
let x = F::cast_from(out_x * args.stride_w + kernel_x * args.dilation_w)
- args.pad_w.get::<F>()
+ offset_x;
for dy in -1..=1i32 {
#[unroll]
for dx in -1..=1i32 {
let yp = y.floor() + F::cast_from(dy);
let xp = x.floor() + F::cast_from(dx);
if yp >= F::new(0.0)
&& yp < F::cast_from(height)
&& xp >= F::new(0.0)
&& xp < F::cast_from(width)
&& F::abs(y - yp) < F::new(1.0)
&& F::abs(x - xp) < F::new(1.0)
{
let gradient_pos = batch * grad_input.stride(0)
+ in_channel * grad_input.stride(1)
+ usize::cast_from(yp) * grad_input.stride(2)
+ usize::cast_from(xp) * grad_input.stride(3);
let weight = (F::new(1.0) - F::abs(y - yp)) * (F::new(1.0) - F::abs(x - xp));
let value = mask_value * F::cast_from(weight) * columns[ABSOLUTE_POS];
FAdd::Op::<FP>::float_atomic_add::<F>(&mut grad_input[gradient_pos], value);
}
}
}
}
type ProxyType<FADF, FP> = <<FADF as FloatAtomicAddFamily>::Op<FP> as FloatAtomicAdd>::ProxyType;
#[cube]
trait FloatAtomicAddFamily: Send + Sync + 'static {
type Op<ProxyType: Float>: FloatAtomicAdd;
}
#[cube]
trait FloatAtomicAdd: Send + Sync + 'static {
type ProxyType: Numeric;
fn float_atomic_add<F: Float>(ptr: &mut Atomic<Self::ProxyType>, value: F);
}
#[derive(CubeType)]
struct IntrinsicFloatAtomicAdd<F: Float> {
#[cube(comptime)]
_ty: PhantomData<F>,
}
#[derive(CubeType)]
struct CASFloatAtomicAdd;
struct IntrinsicFloatAtomicAddFamily;
impl FloatAtomicAddFamily for IntrinsicFloatAtomicAddFamily {
type Op<ProxyType: Float> = IntrinsicFloatAtomicAdd<ProxyType>;
}
impl FloatAtomicAddFamily for CASFloatAtomicAdd {
type Op<ProxyType: Float> = Self;
}
#[cube]
impl<FAdd: Float> FloatAtomicAdd for IntrinsicFloatAtomicAdd<FAdd> {
type ProxyType = FAdd;
fn float_atomic_add<F: Float>(ptr: &mut Atomic<FAdd>, value: F) {
let value = FAdd::cast_from(value);
ptr.fetch_add(value);
}
}
#[cube]
impl FloatAtomicAdd for CASFloatAtomicAdd {
type ProxyType = u32;
fn float_atomic_add<F: Float>(ptr: &mut Atomic<Self::ProxyType>, value: F) {
let value = f32::cast_from(value);
if value != 0.0 {
let mut v = ptr.load();
loop {
let prev = v;
let v_float = f32::from_bits(v);
let new = (v_float + value).to_bits();
v = ptr.compare_exchange_weak(v, new);
if prev == v {
break;
}
}
}
}
}

View File

@@ -0,0 +1,320 @@
use crate::{
CubeRuntime,
kernel::{
into_contiguous_aligned,
utils::{address_type, linear_view},
},
ops::max_line_size,
tensor::CubeTensor,
};
use crate::{kernel::utils::decompose_linear, ops::numeric::empty_device_dtype};
use burn_backend::{
TensorMetadata,
ops::{ConvOptions, conv::calculate_conv_output_sizes},
};
use cubecl::std::{FastDivmod, FastDivmodArgs};
use cubecl::{
calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView,
tensor_line_size_parallel,
};
use cubek::convolution::components::ConvSetupError;
#[derive(CubeLaunch, CubeType, Clone)]
pub(crate) struct ConvParam {
pub stride: u32,
pub dilation: u32,
pub padding: i32,
}
#[derive(CubeLaunch, CubeType)]
struct Conv2dArgs {
conv_params: Sequence<ConvParam>,
channels_per_group: u32,
}
#[cube(launch_unchecked, address_type = "dynamic")]
fn direct_conv2d_kernel<E: Numeric>(
input: &Tensor<Line<E>>,
weight: &Tensor<Line<E>>,
bias: Option<Tensor<Line<E>>>,
output: &mut LinearView<Line<E>, ReadWrite>,
args: Conv2dArgs,
shape_out: Sequence<FastDivmod<u32>>,
shape_out_c: FastDivmod<u32>,
#[comptime] has_padding: bool,
#[define(E)] _dtype: StorageType,
) {
if !output.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
let n_spatial = comptime![shape_out.len()];
let line_size_out = output.line_size();
let pos = ABSOLUTE_POS * line_size_out;
let in_c_per_group = weight.shape(weight.rank() - 1) as u32;
let (rem, out_c) = shape_out_c.div_mod(pos as u32);
let (b, spatial_pos) = decompose_linear(rem, &shape_out);
let g = out_c / args.channels_per_group;
let ic_start = in_c_per_group * g;
let bias: Option<Line<E>> = bias.map(|bias| bias[out_c as usize / line_size_out]);
let mut sum = bias.unwrap_or_else(|| Line::empty(line_size_out).fill(E::from_int(0)));
let in_offs = b as usize * input.stride(0) + ic_start as usize;
let stride_oc = weight.stride(0);
let mut in_shape = Sequence::new();
let mut in_strides = Sequence::new();
let mut kernel_shape = Sequence::new();
let mut kernel_strides = Sequence::new();
#[unroll]
for i in 0..n_spatial {
in_shape.push(input.shape(i + 1) as u32);
in_strides.push(input.stride(i + 1));
kernel_shape.push(weight.shape(i + 1) as u32);
kernel_strides.push(weight.stride(i + 1));
}
let weight_offs = out_c as usize * stride_oc;
let loop_params = LoopParams {
out_pos: spatial_pos,
in_shape,
in_strides,
kernel_shape,
kernel_strides,
conv_params: args.conv_params,
in_c_per_group,
stride_oc,
};
kernel_loop(
input,
weight,
&mut sum,
in_offs,
true,
weight_offs,
&loop_params,
0usize,
has_padding,
);
output[ABSOLUTE_POS] = sum;
}
#[derive(CubeType, Clone)]
struct LoopParams {
out_pos: Sequence<u32>,
in_shape: Sequence<u32>,
in_strides: Sequence<usize>,
kernel_shape: Sequence<u32>,
kernel_strides: Sequence<usize>,
conv_params: Sequence<ConvParam>,
in_c_per_group: u32,
stride_oc: usize,
}
#[cube]
fn kernel_loop<E: Numeric>(
input: &Tensor<Line<E>>,
weight: &Tensor<Line<E>>,
sum: &mut Line<E>,
in_offs: usize,
in_bounds: bool,
weight_offs: usize,
params: &LoopParams,
#[comptime] kernel_dim: usize,
#[comptime] has_padding: bool,
) {
if comptime![kernel_dim < params.kernel_shape.len()] {
let out_idx = *params.out_pos.index(kernel_dim);
let conv = params.conv_params.index(kernel_dim);
let shape = *params.in_shape.index(kernel_dim);
let stride = *params.in_strides.index(kernel_dim);
let k_stride = *params.kernel_strides.index(kernel_dim);
for pos in 0..*params.kernel_shape.index(kernel_dim) {
let in_pos = (out_idx * conv.stride + pos * conv.dilation) as i32 - conv.padding;
let in_offs = in_offs + in_pos as usize * stride;
let weight_offs = weight_offs + pos as usize * k_stride;
let mut in_bounds = in_bounds;
if has_padding {
in_bounds &= in_pos >= 0 && (in_pos as u32) < shape;
}
kernel_loop(
input,
weight,
sum,
in_offs,
in_bounds,
weight_offs,
params,
comptime![kernel_dim + 1],
has_padding,
);
}
} else {
kernel_loop_inner(
input,
weight,
sum,
in_offs,
in_bounds,
weight_offs,
params.in_c_per_group,
params.stride_oc,
);
}
}
#[cube]
fn kernel_loop_inner<E: Numeric>(
input: &Tensor<Line<E>>,
weight: &Tensor<Line<E>>,
sum: &mut Line<E>,
in_offs: usize,
in_bounds: bool,
weight_offs: usize,
in_c_per_group: u32,
stride_oc: usize,
) {
let line_size_in = input.line_size();
let line_size_out = sum.size();
if in_bounds {
for in_c in range_stepped(0, in_c_per_group, line_size_in as u32) {
let in_pos = in_offs + in_c as usize;
let mut weight_pos = weight_offs + in_c as usize;
let val = input[in_pos / line_size_in];
#[unroll]
for v in 0..line_size_out {
let weight = weight[weight_pos / line_size_in];
let val = val * weight;
#[unroll]
for i in 0..line_size_in {
sum[v] += val[i];
}
weight_pos += stride_oc;
}
}
}
}
/// Perform a 2D convolution using the direct convolution algorithm.
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
///
pub fn conv_direct<R: CubeRuntime, const N: usize>(
mut input: CubeTensor<R>,
mut weight: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
options: ConvOptions<N>,
) -> Result<CubeTensor<R>, ConvSetupError> {
let client = input.client.clone();
let out_dtype = input.dtype;
let rank = input.meta.shape().num_dims();
let dim_c = rank - 1;
// We only care about the channels here, everything else can be permuted
if input.meta.strides()[dim_c] != 1 {
input = into_contiguous_aligned(input);
}
if weight.meta.strides()[dim_c] != 1 {
weight = into_contiguous_aligned(weight);
}
let batch_size = input.meta.shape()[0];
let in_shape = &input.meta.shape()[1..dim_c];
let out_channels = weight.meta.shape()[0];
let kernel_shape = &weight.meta.shape()[1..dim_c];
let channels_per_group = out_channels / options.groups;
let out_size = calculate_conv_output_sizes(
kernel_shape,
&options.stride,
&options.padding,
&options.dilation,
in_shape,
);
let mut shape_out = vec![batch_size];
shape_out.extend(out_size.iter().copied());
shape_out.push(out_channels);
let output = empty_device_dtype(
input.client.clone(),
input.device.clone(),
shape_out.into(),
out_dtype,
);
// Need custom line size calculation here to account for the groups division. Need to vectorize
// over `channels_per_group` instead.
let mut grouped_out_shape = output.shape();
grouped_out_shape[dim_c] = channels_per_group;
let line_size_out = tensor_line_size_parallel(
input.client.io_optimized_line_sizes(input.dtype.size()),
&grouped_out_shape,
output.meta.strides(),
dim_c,
);
// Use channels_per_group instead of in_channels to avoid issues here
let line_size_in = max_line_size(&weight);
let shape_out = output.meta.shape()[1..dim_c]
.iter()
.map(|s| FastDivmodArgs::<u32>::new(&client, *s as u32))
.collect();
let shape_out_c = FastDivmodArgs::<u32>::new(&client, out_channels as u32);
let mut conv_params = SequenceArg::new();
for i in 0..kernel_shape.len() {
conv_params.push(ConvParamLaunch::new(
ScalarArg::new(options.stride[i] as u32),
ScalarArg::new(options.dilation[i] as u32),
ScalarArg::new(options.padding[i] as i32),
));
}
let working_units = output.meta.num_elements() / line_size_out;
let cube_dim = CubeDim::new(&input.client, working_units);
let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
unsafe {
direct_conv2d_kernel::launch_unchecked(
&input.client,
cube_count,
cube_dim,
address_type!(input, weight, bias, output),
input.as_tensor_arg(line_size_in),
weight.as_tensor_arg(line_size_in),
bias.as_ref().map(|b| b.as_tensor_arg(line_size_out)).into(),
linear_view(&output, line_size_out),
Conv2dArgsLaunch::new(conv_params, ScalarArg::new(channels_per_group as u32)),
shape_out,
shape_out_c,
options.padding.iter().any(|it| *it != 0),
out_dtype.into(),
)
}?;
Ok(output)
}

View File

@@ -0,0 +1,167 @@
use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
use burn_backend::ops::{ConvOptions, conv::calculate_conv_output_sizes};
use cubek::{
convolution::{
AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy,
components::ConvSetupError, forward,
},
matmul::{
definition::{MatmulElems, MatmulGlobalElems},
launch::MatmulInputHandleRef,
},
};
/// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul
/// components. Uses [`CmmaLargeMAlgorithm`] for the stage size
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
pub fn conv_gemm_simple_sync<R: CubeRuntime, const N: usize>(
input: CubeTensor<R>,
weight: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
options: ConvOptions<N>,
tile_kind: AcceleratedTileKind,
) -> Result<CubeTensor<R>, ConvSetupError> {
let read_strategy = match tile_kind {
AcceleratedTileKind::Cmma => ReadingStrategy::Cyclic,
AcceleratedTileKind::Mma => ReadingStrategy::Strided,
};
launch_convolution_forward::<R, N>(
&Strategy::Simple {
read_strategy,
tile_kind,
},
input,
weight,
bias,
options,
)
}
pub fn conv_gemm_simple_async<R: CubeRuntime, const N: usize>(
input: CubeTensor<R>,
weight: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
options: ConvOptions<N>,
tile_kind: AcceleratedTileKind,
) -> Result<CubeTensor<R>, ConvSetupError> {
let read_strategy = match tile_kind {
AcceleratedTileKind::Cmma => ReadingStrategy::AsyncCyclic,
AcceleratedTileKind::Mma => ReadingStrategy::AsyncStrided,
};
launch_convolution_forward::<R, N>(
&Strategy::Simple {
read_strategy,
tile_kind,
},
input,
weight,
bias,
options,
)
}
/// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul
/// components. Uses [`CmmaLargeMAlgorithm`] for the stage size
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
pub fn conv_gemm_simple_tma<R: CubeRuntime, const N: usize>(
input: CubeTensor<R>,
weight: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
options: ConvOptions<N>,
tile_kind: AcceleratedTileKind,
) -> Result<CubeTensor<R>, ConvSetupError> {
launch_convolution_forward::<R, N>(
&Strategy::Simple {
read_strategy: ReadingStrategy::Tma,
tile_kind,
},
input,
weight,
bias,
options,
)
}
/// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul
/// components, using the specified algorithm.
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
pub fn launch_convolution_forward<R: CubeRuntime, const N: usize>(
strategy: &Strategy,
input: CubeTensor<R>,
weight: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
options: ConvOptions<N>,
) -> Result<CubeTensor<R>, ConvSetupError> {
if options.groups != 1 {
return Err(ConvSetupError::Groups(options.groups));
}
let out_dtype = input.dtype;
let rank = input.meta.shape().num_dims();
let batch_size = input.meta.shape()[0];
let dim_c = rank - 1;
let shape = &input.meta.shape()[1..dim_c];
let out_channels = weight.meta.shape()[0];
let weight_shape = &weight.meta.shape()[1..dim_c];
let mut out_shape = calculate_conv_output_sizes(
weight_shape,
&options.stride,
&options.padding,
&options.dilation,
shape,
);
out_shape.insert(0, batch_size);
out_shape.push(out_channels);
let out = empty_device_dtype(
input.client.clone(),
input.device.clone(),
out_shape.into(),
out_dtype,
);
let bias = bias
.as_ref()
.map(|bias| MatmulInputHandleRef::Normal(bias.as_handle_ref(), bias.dtype.into()));
let client = input.client.clone();
let dtypes = MatmulElems::from_globals(&MatmulGlobalElems {
lhs: input.dtype.into(),
rhs: weight.dtype.into(),
out: out_dtype.into(),
});
let input = MatmulInputHandleRef::new(input.as_handle_ref(), input.dtype.into());
let weight = MatmulInputHandleRef::new(weight.as_handle_ref(), weight.dtype.into());
forward::launch_ref::<R, N>(
strategy,
&client,
&input,
&weight,
&bias,
&out.as_handle_ref(),
ConvolutionArgs {
stride: options.stride,
padding: options.padding,
dilation: options.dilation,
},
dtypes,
)?;
Ok(out)
}

View File

@@ -0,0 +1,2 @@
pub mod launch;
pub use launch::*;

View File

@@ -0,0 +1,7 @@
pub mod implicit_gemm;
#[cfg(feature = "autotune")]
pub mod tune;
#[cfg(feature = "autotune")]
pub(crate) use tune::*;

View File

@@ -0,0 +1,174 @@
use burn_backend::ops::ConvOptions;
use cubecl::{
ir::StorageType,
tune::{LocalTuner, Tunable, TunableSet, anchor, local_tuner},
};
use cubek::convolution::AcceleratedTileKind;
use crate::{
CubeAutotuneKey, CubeRuntime, CubeTuneId,
kernel::conv::{ConvAutotuneKey, conv_direct, conv_im2col_1x1, forward::implicit_gemm::*},
tensor::CubeTensor,
};
/// Executes autotune on convolution operations
pub fn conv_autotune<R: CubeRuntime, const N: usize>(
input: CubeTensor<R>,
weight: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
options: ConvOptions<N>,
) -> CubeTensor<R> {
let client = input.client.clone();
static TUNER: LocalTuner<CubeAutotuneKey, CubeTuneId> = local_tuner!();
let tunables = TUNER.init(|| {
TunableSet::new(create_key::<R, N>, create_conv_input::<R, N>)
.with(Tunable::new("conv_direct", conv_direct::<R, N>))
.with(Tunable::new("conv_im2col_1x1", conv_im2col_1x1::<R, N>))
.with(Tunable::new(
"simple_sync_cmma",
|input, weight, bias, options| {
conv_gemm_simple_sync(input, weight, bias, options, AcceleratedTileKind::Cmma)
},
))
.with(Tunable::new(
"simple_sync_mma",
|input, weight, bias, options| {
conv_gemm_simple_sync(input, weight, bias, options, AcceleratedTileKind::Mma)
},
))
.with(Tunable::new(
"simple_async_cmma",
|input, weight, bias, options| {
conv_gemm_simple_async(input, weight, bias, options, AcceleratedTileKind::Cmma)
},
))
.with(Tunable::new(
"simple_async_mma",
|input, weight, bias, options| {
conv_gemm_simple_async(input, weight, bias, options, AcceleratedTileKind::Mma)
},
))
.with(Tunable::new(
"simple_tma_cmma",
|input, weight, bias, options| {
conv_gemm_simple_tma(input, weight, bias, options, AcceleratedTileKind::Cmma)
},
))
.with(Tunable::new(
"simple_tma_mma",
|input, weight, bias, options| {
conv_gemm_simple_tma(input, weight, bias, options, AcceleratedTileKind::Mma)
},
))
});
TUNER.execute(
&CubeTuneId::new(&input.client, &input.device),
&client,
tunables,
(input, weight, bias, options),
)
}
pub fn create_conv_input<R: CubeRuntime, const N: usize>(
_key: &CubeAutotuneKey,
input: &CubeTensor<R>,
weights: &CubeTensor<R>,
bias: &Option<CubeTensor<R>>,
options: &ConvOptions<N>,
) -> (
CubeTensor<R>,
CubeTensor<R>,
Option<CubeTensor<R>>,
ConvOptions<N>,
) {
(
input.clone(),
weights.clone(),
bias.clone(),
options.clone(),
)
}
fn create_key<R: CubeRuntime, const N: usize>(
input: &CubeTensor<R>,
weights: &CubeTensor<R>,
bias: &Option<CubeTensor<R>>,
options: &ConvOptions<N>,
) -> CubeAutotuneKey {
let dtype = input.dtype;
let rank = input.meta.shape().num_dims();
let dim_c = rank - 1;
let batch_size = input.meta.shape()[0];
let in_channels = input.meta.shape()[dim_c];
let out_channels = weights.meta.shape()[0];
let kernel_size = weights.meta.shape()[1..dim_c].to_vec();
let in_shape = input.meta.shape()[1..dim_c]
.iter()
.map(|shape| anchor(*shape, None, None, None))
.collect();
let ConvOptions {
stride,
padding,
dilation,
groups,
} = options.clone();
let lhs_stride_align = if input.meta.strides()[dim_c] == 1 {
stride_align(input.meta.strides(), input.dtype.into())
} else {
0
};
let lhs_shape_align = pow2_factor(in_channels).min(lhs_stride_align);
let rhs_stride_align = if weights.meta.strides()[dim_c] == 1 {
stride_align(weights.meta.strides(), weights.dtype.into())
} else {
0
};
let rhs_shape_align = pow2_factor(in_channels).min(rhs_stride_align);
CubeAutotuneKey::Conv(ConvAutotuneKey::new(
kernel_size,
stride.to_vec(),
padding.to_vec(),
dilation.to_vec(),
groups,
in_channels,
out_channels,
in_shape,
batch_size,
bias.is_some(),
dtype,
lhs_shape_align,
lhs_stride_align,
rhs_shape_align,
rhs_stride_align,
))
}
/// Maximum factor relevant for strides. Currently set to 2^10 because that's 128-byte swizzle's
/// repeat number, so it's the largest align that can have performance impacts.
const MAX_STRIDE_FACTOR: u32 = 10;
/// Defines the non-contiguous stride alignment in terms of powers of two
fn stride_align(strides: &[usize], elem: StorageType) -> u8 {
let max = MAX_STRIDE_FACTOR;
let dim_c = strides.len() - 1;
let factor = strides[..dim_c]
.iter()
.map(|it| (*it * elem.size_bits()) / 8)
.map(|it| it.trailing_zeros())
.min()
.unwrap_or(max);
factor.min(max) as u8
}
/// Defines the potential vectorization.
fn pow2_factor(axis: usize) -> u8 {
axis.trailing_zeros().min(4) as u8
}

View File

@@ -0,0 +1,187 @@
use burn_backend::{
DType,
ops::{ConvOptions, conv::calculate_conv_output_sizes},
};
use burn_std::{Metadata, Shape};
use core::iter;
use cubecl::{
prelude::*,
std::tensor::{TensorHandle, into_contiguous_pitched_ref},
};
use cubek::convolution::components::ConvSetupError;
use crate::{
CubeRuntime,
kernel::{
AddOp, into_contiguous_aligned, launch_binop,
matmul::{MatmulStrategy, matmul},
utils::split_dim,
},
ops::{reshape, swap_dims},
tensor::CubeTensor,
};
#[cfg(not(test))]
pub(crate) fn batches_per_run(
batch_size: usize,
out_shape: usize,
plane_size: usize,
) -> Result<usize, ConvSetupError> {
use cubek::matmul::definition::MatmulAvailabilityError;
let cube_count_per_batch = out_shape.div_ceil(plane_size);
let max_cube_count = u16::MAX as usize;
let max_simultaneous = Ord::min(max_cube_count / cube_count_per_batch, batch_size);
if max_simultaneous == 0 {
return Err(MatmulAvailabilityError::CubeCountTooBig(CubeCount::Static(
cube_count_per_batch as u32,
1,
1,
))
.into());
}
Ok((0..=max_simultaneous)
.rev()
.find(|per_run| batch_size.is_multiple_of(*per_run))
.expect("Logically not possible"))
}
#[cfg(test)]
#[allow(unused)]
pub(crate) fn batches_per_run(
batch_size: usize,
out_shape: usize,
plane_size: usize,
) -> Result<usize, ConvSetupError> {
Ok(1)
}
pub fn conv_im2col_1x1<R: CubeRuntime, const N: usize>(
input: CubeTensor<R>,
mut weight: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
options: ConvOptions<N>,
) -> Result<CubeTensor<R>, ConvSetupError> {
if options.groups != 1 {
return Err(ConvSetupError::Groups(options.groups));
}
let rank = input.meta.num_dims();
let dim_c = rank - 1;
let batch_size = input.meta.shape()[0];
let in_channels = input.meta.shape()[dim_c];
let in_shape = &input.meta.shape()[1..dim_c];
let out_channels = weight.meta.shape()[0];
let kernel_shape = &weight.meta.shape()[1..dim_c];
if kernel_shape.iter().any(|s| *s != 1) {
return Err(ConvSetupError::Unknown);
}
let out_shape = calculate_conv_output_sizes(
kernel_shape,
&options.stride,
&options.padding,
&options.dilation,
in_shape,
);
let mut split_m = vec![batch_size];
split_m.extend(out_shape.iter().copied());
if kernel_shape.iter().any(|it| *it != 1) || in_shape != out_shape {
return Err(ConvSetupError::Unknown);
}
let input = reshape_input(input); // [(NHW), C] : [M, K]
let dtype = input.dtype;
// Efficient permutation that takes the stride required for TMA into account
let weight = if weight.meta.strides()[dim_c] != 1 {
// Remove kernel dims so padded dim is channels
*weight.meta = Metadata::new(
[out_channels, in_channels], // [N, K]
[weight.meta.strides()[0], weight.meta.strides()[dim_c]],
);
// Pitched contiguous to skip running another kernel for TMA
into_contiguous_aligned(weight)
} else {
// Already compatible, skip initial reshape
*weight.meta = Metadata::new([out_channels, in_channels], [weight.meta.strides()[0], 1]);
weight
};
// Permute to N-major, while keeping memory layout K-major. K-major for both sides is the most
// efficient for matmul, and allows skipping a contiguous kernel
let weight = swap_dims(weight, 0, 1); // [K, N]
let out = matmul(input, weight, None, MatmulStrategy::default(), dtype)?; // [M, N]
// Skip reshape to avoid potential `into_contiguous`. We're only splitting dims so it's safe.
let mut out = split_dim(out, 0, &split_m); // [N, H, W, C]
if let Some(bias) = bias {
let mut bias_shape = iter::repeat_n(1, rank - 1).collect::<Vec<_>>();
bias_shape.push(out_channels);
let bias = reshape(bias, bias_shape.into());
out = launch_binop::<R, AddOp>(out, bias);
}
Ok(out)
}
/// Reshapes NHWC input to [(N, H, W), C]
fn reshape_input<R: CubeRuntime>(mut input: CubeTensor<R>) -> CubeTensor<R> {
let rank = input.meta.num_dims();
let dim_c = rank - 1;
let dtype = input.dtype;
let batch_size = input.meta.shape()[0];
let in_c: usize = input.meta.shape()[dim_c];
let in_shape: Shape = input.meta.shape()[1..dim_c].into();
if !is_spatial_contiguous(input.meta.shape(), input.meta.strides()) {
let contiguous =
into_contiguous_pitched_ref(&input.client, &input.as_handle_ref(), dtype.into())
.expect("Kernel to never fail");
input = from_handle(&input.client, &input.device, contiguous, dtype);
}
*input.meta = Metadata::new(
[batch_size * in_shape.num_elements(), in_c], // [M, K]
[input.meta.strides()[dim_c - 1], input.meta.strides()[dim_c]],
);
input
}
fn is_spatial_contiguous(shape: &[usize], strides: &[usize]) -> bool {
let rank = shape.len();
let mut ordered = strides.to_vec();
ordered.sort();
if ordered != strides {
return false;
}
for i in (1..rank - 2).rev() {
if strides[i + 1] * shape[i + 1] != strides[i] {
return false;
}
}
true
}
fn from_handle<R: CubeRuntime>(
client: &ComputeClient<R>,
device: &R::Device,
handle: TensorHandle<R>,
dtype: DType,
) -> CubeTensor<R> {
CubeTensor::new(
client.clone(),
handle.handle,
*handle.metadata,
device.clone(),
dtype,
)
}

View File

@@ -0,0 +1,25 @@
mod backward_data;
mod backward_weight;
mod base;
mod conv_transpose2d;
mod conv_transpose3d;
mod deform_conv2d;
mod deform_conv_transpose2d;
mod direct;
mod forward;
mod im2col;
mod tune_key;
pub(crate) use backward_data::*;
pub(crate) use conv_transpose2d::*;
pub(crate) use conv_transpose3d::*;
pub(crate) use deform_conv_transpose2d::*;
pub(crate) use deform_conv2d::*;
pub(crate) use direct::*;
pub(crate) use im2col::*;
pub use base::*;
pub use conv_transpose2d::{ConvTranspose2dStrategy, conv_transpose2d};
pub(crate) use tune_key::*;

View File

@@ -0,0 +1,50 @@
use burn_backend::DType;
use cubecl::AutotuneKey;
use serde::{Deserialize, Serialize};
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
/// Autotune key representative of matmul versions
pub struct ConvAutotuneKey {
pub kernel_size: Vec<usize>,
pub stride: Vec<usize>,
pub padding: Vec<usize>,
pub dilation: Vec<usize>,
pub groups: usize,
#[autotune(anchor)]
pub in_channels: usize,
#[autotune(anchor)]
pub out_channels: usize,
pub shape: Vec<usize>,
#[autotune(anchor)]
pub batch_size: usize,
pub has_bias: bool,
pub dtype: DType,
pub lhs_shape_align: u8,
pub lhs_stride_align: u8,
pub rhs_shape_align: u8,
pub rhs_stride_align: u8,
}
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
/// Autotune key representative of matmul versions
pub struct ConvTranspose2dAutotuneKey {
pub kernel_size: [usize; 2],
pub stride: [usize; 2],
pub padding: [usize; 2],
pub padding_out: [usize; 2],
pub dilation: [usize; 2],
pub groups: usize,
#[autotune(anchor)]
pub in_channels: usize,
#[autotune(anchor)]
pub out_channels: usize,
#[autotune(anchor)]
pub height: usize,
#[autotune(anchor)]
pub width: usize,
#[autotune(anchor)]
pub batch_size: usize,
pub has_bias: bool,
pub dtype: DType,
}

View File

@@ -0,0 +1,101 @@
use crate::{
CubeRuntime,
kernel::utils::{address_type, broadcast_shape, linear_view, linear_view_ref},
ops::numeric::empty_device_dtype,
tensor::CubeTensor,
};
use cubecl::std::tensor::layout::linear::LinearView;
use cubecl::{calculate_cube_count_elemwise, prelude::*};
#[cube(launch_unchecked, address_type = "dynamic")]
fn cross_kernel<E: Float>(
lhs: &LinearView<Line<E>>,
rhs: &LinearView<Line<E>>,
output: &mut LinearView<Line<E>, ReadWrite>,
#[define(E)] _dtype: StorageType,
) {
// Each thread processes one 3-element vector
let vector_idx = ABSOLUTE_POS;
let base_pos = vector_idx * 3;
if !output.is_in_bounds(base_pos) {
terminate!();
}
// Extract vectors
let a0 = lhs[base_pos];
let a1 = lhs[base_pos + 1];
let a2 = lhs[base_pos + 2];
let b0 = rhs[base_pos];
let b1 = rhs[base_pos + 1];
let b2 = rhs[base_pos + 2];
// Compute cross product: a × b
let x = a1 * b2 - a2 * b1;
let y = a2 * b0 - a0 * b2;
let z = a0 * b1 - a1 * b0;
// Store result
output[base_pos] = x;
output[base_pos + 1] = y;
output[base_pos + 2] = z;
}
pub(crate) fn cross<R: CubeRuntime>(
lhs: CubeTensor<R>,
rhs: CubeTensor<R>,
dim: usize,
) -> CubeTensor<R> {
let ndims = lhs.meta.num_dims();
// Validate that the cross dimension has size 3
if lhs.meta.shape()[dim] != 3 || rhs.meta.shape()[dim] != 3 {
panic!(
"Cross product requires dimension {} to have size 3, but got {} and {}",
dim,
lhs.meta.shape()[dim],
rhs.meta.shape()[dim]
);
}
// For now, only support cross on the last dimension
if dim != ndims - 1 {
unimplemented!(
"Cross product on non-last dimension not yet implemented for CubeCL backend"
);
}
let output_shape = broadcast_shape(&[&lhs, &rhs]);
// Since the cross dimension is forced to be size 3, line size would be restricted to 1 anyway
let line_size = 1;
let output = empty_device_dtype(
lhs.client.clone(),
lhs.device.clone(),
output_shape.clone(),
lhs.dtype,
);
// Number of vectors to process
let num_vectors = output_shape.num_elements() / 3;
let cube_dim = CubeDim::new(&lhs.client, num_vectors);
let cube_count = calculate_cube_count_elemwise(&lhs.client, num_vectors, cube_dim);
unsafe {
cross_kernel::launch_unchecked(
&lhs.client,
cube_count,
cube_dim,
address_type!(lhs, rhs, output),
linear_view_ref(&lhs, &output, line_size),
linear_view_ref(&rhs, &output, line_size),
linear_view(&output, line_size),
lhs.dtype.into(),
)
.expect("Kernel to never fail");
};
output
}

View File

@@ -0,0 +1,163 @@
use cubecl::prelude::*;
use crate::{CubeRuntime, tensor::CubeTensor};
use burn_backend::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
use super::bilinear::grid_sample_bilinear_launch;
/// Grid sample operation supporting bilinear interpolation
pub fn grid_sample<R: CubeRuntime>(
input: CubeTensor<R>,
grid: CubeTensor<R>,
options: GridSampleOptions,
) -> CubeTensor<R> {
match options.mode {
InterpolateMode::Bilinear => grid_sample_bilinear_launch(input, grid, options),
_ => panic!(
"Unsupported grid_sample interpolation mode: {:?}",
options.mode
),
}
}
/// Compile-time padding mode for kernel specialization
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum PaddingMode {
/// Fill with zeros for out-of-bounds coordinates.
Zeros,
/// Clamp coordinates to the border (use nearest edge value).
Border,
/// Reflect coordinates at the boundary.
Reflection,
}
impl From<GridSamplePaddingMode> for PaddingMode {
fn from(mode: GridSamplePaddingMode) -> Self {
match mode {
GridSamplePaddingMode::Zeros => PaddingMode::Zeros,
GridSamplePaddingMode::Border => PaddingMode::Border,
GridSamplePaddingMode::Reflection => PaddingMode::Reflection,
}
}
}
/// Fetch value based on padding mode (dispatch to appropriate handler)
#[cube]
pub(crate) fn fetch_value<F: Float>(
input: &Tensor<F>,
base: usize,
stride_h: usize,
stride_w: usize,
y: i32,
x: i32,
h: i32,
w: i32,
#[comptime] padding_mode: PaddingMode,
) -> F {
match padding_mode {
PaddingMode::Zeros => fetch_with_zeros(input, base, stride_h, stride_w, y, x, h, w),
PaddingMode::Border => fetch_with_border(input, base, stride_h, stride_w, y, x, h, w),
PaddingMode::Reflection => {
fetch_with_reflection(input, base, stride_h, stride_w, y, x, h, w)
}
}
}
/// Fetch value with zeros padding (return 0 for out-of-bounds).
#[cube]
pub(crate) fn fetch_with_zeros<F: Float>(
input: &Tensor<F>,
base: usize,
stride_h: usize,
stride_w: usize,
y: i32,
x: i32,
h: i32,
w: i32,
) -> F {
let in_bounds = x >= 0 && x < w && y >= 0 && y < h;
let x_clamped = clamp(x, 0, w - 1) as usize;
let y_clamped = clamp(y, 0, h - 1) as usize;
let idx = base + y_clamped * stride_h + x_clamped * stride_w;
select(in_bounds, input[idx], F::new(0.0))
}
/// Fetch value with border padding (clamp to edge).
#[cube]
pub(crate) fn fetch_with_border<F: Float>(
input: &Tensor<F>,
base: usize,
stride_h: usize,
stride_w: usize,
y: i32,
x: i32,
h: i32,
w: i32,
) -> F {
let x_clamped = clamp(x, 0, w - 1) as usize;
let y_clamped = clamp(y, 0, h - 1) as usize;
let idx = base + y_clamped * stride_h + x_clamped * stride_w;
input[idx]
}
/// Fetch value with reflection padding.
/// Assumes float reflection was applied to center, so indices are at most 2 steps out of bounds.
#[cube]
pub(crate) fn fetch_with_reflection<F: Float>(
input: &Tensor<F>,
base: usize,
stride_h: usize,
stride_w: usize,
y: i32,
x: i32,
h: i32,
w: i32,
) -> F {
let x_reflected = reflect_coord_bounded(x, w);
let y_reflected = reflect_coord_bounded(y, h);
let idx = base + y_reflected * stride_h + x_reflected * stride_w;
input[idx]
}
/// Reflect an integer index that may be out of bounds.
/// After float reflection, indices can be up to 2 steps out for bicubic (1 step for bilinear).
#[cube]
fn reflect_coord_bounded(idx: i32, size: i32) -> usize {
let max_idx = size - 1;
let neg_reflected = -idx - 1;
let pos_reflected = 2 * max_idx + 1 - idx;
let result = select(
idx < 0,
neg_reflected,
select(idx > max_idx, pos_reflected, idx),
);
clamp(result, 0, max_idx) as usize
}
/// Reflect a float coordinate into the valid sampling range.
#[cube]
pub(crate) fn reflect_coord<F: Float>(coord: F, size: u32, #[comptime] align_corners: bool) -> F {
let size_f = F::cast_from(size);
if align_corners {
reflect_float_impl::<F>(coord, F::new(0.0), size_f - F::new(1.0))
} else {
reflect_float_impl::<F>(coord, F::new(-0.5), size_f - F::new(0.5))
}
}
/// Reflect a float coordinate into [min_val, max_val] using a triangle wave pattern.
#[cube]
fn reflect_float_impl<F: Float>(coord: F, min_val: F, max_val: F) -> F {
let span = max_val - min_val;
let is_valid = span > F::new(0.0);
let safe_span = select(is_valid, span, F::new(1.0));
// Triangle wave formula: span - |((x mod 2*span) - span)| + min_val
let period = safe_span * F::new(2.0);
let x = (coord - min_val).abs();
let x_mod = x - (x / period).floor() * period;
let reflected = safe_span - (x_mod - safe_span).abs() + min_val;
select(is_valid, reflected, min_val)
}

View File

@@ -0,0 +1,177 @@
use cubecl::std::{FastDivmod, FastDivmodArgs};
use cubecl::{calculate_cube_count_elemwise, prelude::*};
use crate::{
CubeRuntime, kernel::utils::address_type, ops::numeric::empty_device_dtype, tensor::CubeTensor,
};
use burn_backend::{Shape, ops::GridSampleOptions};
use super::base::{PaddingMode, fetch_value, reflect_coord};
/// Grid sample with bilinear interpolation.
///
/// Each thread processes all channels for one spatial output position:
/// 1. Reading (x, y) coordinates from the grid tensor (once per spatial position)
/// 2. Converting normalized [-1, 1] coords to pixel coordinates (once)
/// 3. For each channel: fetch 4 corner values, interpolate, and write output
#[cube(launch, address_type = "dynamic")]
fn grid_sample_bilinear_kernel<F: Float>(
input: &Tensor<F>, // [N, C, H_in, W_in]
grid: &Tensor<F>, // [N, H_out, W_out, 2]
output: &mut Tensor<F>, // [N, C, H_out, W_out]
shape_spatial: Sequence<FastDivmod<usize>>, // [N, H_out, W_out] for thread decomposition
#[comptime] align_corners: bool,
#[comptime] pad_mode: PaddingMode,
#[define(F)] _dtype: StorageType,
) {
// Thread index maps to spatial position (n, h_out, w_out) only
let spatial_idx = ABSOLUTE_POS;
let num_spatial = output.shape(0) * output.shape(2) * output.shape(3);
if spatial_idx >= num_spatial {
terminate!();
}
// Decompose spatial index into (n, h_out, w_out)
let (rem, w_out) = shape_spatial[2].div_mod(spatial_idx);
let (n, h_out) = shape_spatial[1].div_mod(rem);
let channels = input.shape(1) as u32;
let h_in = input.shape(2) as u32;
let w_in = input.shape(3) as u32;
// Read grid coordinates once per spatial position
let grid_offset = n * grid.stride(0) + h_out * grid.stride(1) + w_out * grid.stride(2);
let gx = grid[grid_offset]; // x coordinate in [-1, 1]
let gy = grid[grid_offset + 1]; // y coordinate in [-1, 1]
// Convert normalized coordinates to pixel coordinates
let (px, py) = if align_corners {
let px = (gx + F::new(1.0)) * F::cast_from((w_in - 1) as f32) / F::new(2.0);
let py = (gy + F::new(1.0)) * F::cast_from((h_in - 1) as f32) / F::new(2.0);
(px, py)
} else {
let px = (gx + F::new(1.0)) * F::cast_from(w_in as f32) / F::new(2.0) - F::new(0.5);
let py = (gy + F::new(1.0)) * F::cast_from(h_in as f32) / F::new(2.0) - F::new(0.5);
(px, py)
};
// For reflection padding, reflect the coordinate into the valid sampling range.
// This ensures integer indices are at most 1 step out of bounds.
let (px, py) = if comptime!(pad_mode == PaddingMode::Reflection) {
let px = reflect_coord::<F>(px, w_in, align_corners);
let py = reflect_coord::<F>(py, h_in, align_corners);
(px, py)
} else {
(px, py)
};
// Compute floor and ceil indices
let x0_f = px.floor();
let y0_f = py.floor();
let x1_f = x0_f + F::new(1.0);
let y1_f = y0_f + F::new(1.0);
// Compute interpolation weights
let wx = px - x0_f;
let wy = py - y0_f;
let wx_ = F::new(1.0) - wx;
let wy_ = F::new(1.0) - wy;
// Convert to integers for indexing
let x0 = i32::cast_from(x0_f);
let y0 = i32::cast_from(y0_f);
let x1 = i32::cast_from(x1_f);
let y1 = i32::cast_from(y1_f);
let w_in = w_in as i32;
let h_in = h_in as i32;
// Pre-compute strides
let stride_n = input.stride(0);
let stride_c = input.stride(1);
let stride_h = input.stride(2);
let stride_w = input.stride(3);
let out_stride_n = output.stride(0);
let out_stride_c = output.stride(1);
let out_stride_h = output.stride(2);
let out_stride_w = output.stride(3);
// Base offsets for this spatial position
let in_base_n = n * stride_n;
let out_base_spatial = n * out_stride_n + h_out * out_stride_h + w_out * out_stride_w;
// Loop over all channels - grid coords and weights are reused
for c in 0..channels {
let in_base = in_base_n + c as usize * stride_c;
let v00 = fetch_value(
input, in_base, stride_h, stride_w, y0, x0, h_in, w_in, pad_mode,
);
let v01 = fetch_value(
input, in_base, stride_h, stride_w, y1, x0, h_in, w_in, pad_mode,
);
let v10 = fetch_value(
input, in_base, stride_h, stride_w, y0, x1, h_in, w_in, pad_mode,
);
let v11 = fetch_value(
input, in_base, stride_h, stride_w, y1, x1, h_in, w_in, pad_mode,
);
// Bilinear interpolation
let result = wx_ * wy_ * v00 + wx_ * wy * v01 + wx * wy_ * v10 + wx * wy * v11;
let out_idx = out_base_spatial + c as usize * out_stride_c;
output[out_idx] = result;
}
}
/// Launch the grid sample bilinear kernel
pub(crate) fn grid_sample_bilinear_launch<R: CubeRuntime>(
input: CubeTensor<R>,
grid: CubeTensor<R>,
options: GridSampleOptions,
) -> CubeTensor<R> {
let [batch_size, channels, _h_in, _w_in] = input.meta.shape().dims();
let [_n, h_out, w_out, two] = grid.meta.shape().dims();
assert_eq!(two, 2, "Grid last dimension must be 2");
// Create output tensor [N, C, H_out, W_out]
let output_shape = Shape::new([batch_size, channels, h_out, w_out]);
let output = empty_device_dtype(
input.client.clone(),
input.device.clone(),
output_shape,
input.dtype,
);
// Spatial threading: one thread per (n, h_out, w_out)
let spatial_shape = Shape::new([batch_size, h_out, w_out]);
let num_spatial = spatial_shape.num_elements();
let mut shape_spatial = SequenceArg::new();
for dim in spatial_shape.iter() {
shape_spatial.push(FastDivmodArgs::new(&input.client, *dim));
}
let cube_dim = CubeDim::new(&input.client, num_spatial);
let cube_count = calculate_cube_count_elemwise(&input.client, num_spatial, cube_dim);
let padding_mode: PaddingMode = options.padding_mode.into();
grid_sample_bilinear_kernel::launch(
&input.client,
cube_count,
cube_dim,
address_type!(input, grid, output),
input.as_tensor_arg(1),
grid.as_tensor_arg(1),
output.as_tensor_arg(1),
shape_spatial,
options.align_corners,
padding_mode,
input.dtype.into(),
)
.expect("Grid sample kernel failed");
output
}

View File

@@ -0,0 +1,4 @@
mod base;
mod bilinear;
pub use base::*;

View File

@@ -0,0 +1,99 @@
use crate::{
CubeRuntime,
kernel::utils::{address_type, linear_view, shape_divmod},
ops::numeric::empty_device_dtype,
tensor::CubeTensor,
};
use burn_backend::{DType, TensorMetadata};
use cubecl::{
calculate_cube_count_elemwise,
prelude::*,
std::{FastDivmod, tensor::layout::linear::LinearView},
};
#[cube(launch_unchecked, address_type = "dynamic")]
fn flip_kernel<E: Numeric, Bool: Int>(
input: &Tensor<E>,
output: &mut LinearView<E, ReadWrite>,
in_shape: Sequence<FastDivmod<usize>>,
indices: Sequence<InputScalar>,
#[define(E, Bool)] _dtypes: [StorageType; 2],
) {
if !output.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
let rank = in_shape.len().comptime();
let mut offset = ABSOLUTE_POS;
let mut offset_input = 0;
#[unroll]
for i in 0..rank {
let dim = rank - i - 1;
let shape = input.shape(dim);
let (rem, offset_local) = in_shape[dim].div_mod(offset);
offset = rem;
let flip = indices.index(dim).get::<Bool>() == Bool::from_int(1);
let offset_local = select(flip, shape - offset_local - 1, offset_local);
offset_input += offset_local * input.stride(dim);
}
output[ABSOLUTE_POS] = input[offset_input];
}
pub(crate) fn flip<R: CubeRuntime>(
tensor: CubeTensor<R>,
indices: &[usize],
dtype_bool: DType,
) -> CubeTensor<R> {
let output = empty_device_dtype(
tensor.client.clone(),
tensor.device.clone(),
tensor.shape(),
tensor.dtype,
);
flip_on_output(tensor, output, indices, dtype_bool)
}
pub(crate) fn flip_on_output<R: CubeRuntime>(
tensor: CubeTensor<R>,
output: CubeTensor<R>,
indices: &[usize],
dtype_bool: DType,
) -> CubeTensor<R> {
let dtype_input = tensor.dtype;
let ndims = tensor.meta.num_dims();
let mut indices_sequence = SequenceArg::<'_, R, InputScalar>::new();
for i in 0..ndims {
indices_sequence.push({
let val = indices.contains(&i) as u8;
InputScalar::new(val, dtype_bool)
});
}
let num_elements = output.meta.num_elements();
let cube_dim = CubeDim::new(&tensor.client, num_elements);
let cube_count = calculate_cube_count_elemwise(&tensor.client, num_elements, cube_dim);
unsafe {
flip_kernel::launch_unchecked(
&tensor.client,
cube_count,
cube_dim,
address_type!(tensor, output),
tensor.as_tensor_arg(1),
linear_view(&output, 1),
shape_divmod(&tensor),
indices_sequence,
[dtype_input.into(), dtype_bool.into()],
)
.expect("Kernel to never fail");
}
output
}

View File

@@ -0,0 +1,76 @@
use crate::{
CubeRuntime,
kernel::utils::{address_type, broadcast_strides, linear_view, shape_divmod},
ops::numeric::empty_device_dtype,
tensor::CubeTensor,
};
use burn_backend::TensorMetadata;
use cubecl::frontend::{ABSOLUTE_POS, Numeric, Tensor};
use cubecl::std::{FastDivmod, tensor::index_offset_contiguous_fastdivmod};
use cubecl::{CubeDim, std::tensor::layout::linear::LinearView};
use cubecl::{calculate_cube_count_elemwise, prelude::*};
#[cube(launch_unchecked, address_type = "dynamic")]
fn gather_kernel<T: Numeric, I: Numeric>(
input: &Tensor<Line<T>>,
indices: &LinearView<Line<I>>,
output: &mut LinearView<Line<T>, ReadWrite>,
in_strides: Sequence<usize>, // zeroed out for broadcast dims and `dim`
out_shape: Sequence<FastDivmod<usize>>,
dim: usize,
#[define(T, I)] _dtypes: [StorageType; 2],
) {
if !indices.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
let mut offset = index_offset_contiguous_fastdivmod(
ABSOLUTE_POS,
&out_shape,
&in_strides,
input.line_size(),
);
offset += usize::cast_from(indices[ABSOLUTE_POS]) * input.stride(dim);
output[ABSOLUTE_POS] = input[offset];
}
pub(crate) fn gather<R: CubeRuntime>(
dim: usize,
tensor: CubeTensor<R>,
indices: CubeTensor<R>,
) -> CubeTensor<R> {
let shape_output = indices.shape();
let total_elem = shape_output.num_elements();
let output = empty_device_dtype(
tensor.client.clone(),
tensor.device.clone(),
shape_output,
tensor.dtype,
);
let cube_dim = CubeDim::new(&tensor.client, total_elem);
let cube_count = calculate_cube_count_elemwise(&tensor.client, total_elem, cube_dim);
let mut in_strides = broadcast_strides(&output, &tensor);
in_strides.values[dim] = ScalarArg::new(0); // Zero `dim` to exclude it from the indexing
unsafe {
gather_kernel::launch_unchecked(
&tensor.client,
cube_count,
cube_dim,
address_type!(tensor, indices, output),
tensor.as_tensor_arg(1),
linear_view(&indices, 1),
linear_view(&output, 1),
in_strides,
shape_divmod(&output),
ScalarArg::new(dim),
[tensor.dtype.into(), indices.dtype.into()],
)
.expect("Kernel to never fail");
}
output
}

View File

@@ -0,0 +1,18 @@
mod flip;
mod gather;
mod repeat_dim;
mod scatter;
mod select;
mod select_assign;
mod slice;
mod slice_assign;
pub(crate) use flip::*;
pub(crate) use repeat_dim::*;
pub(crate) use select::*;
pub(crate) use select_assign::*;
pub use slice::*;
pub(crate) use slice_assign::*;
pub(crate) use gather::*;
pub(crate) use scatter::*;

View File

@@ -0,0 +1,93 @@
use crate::{
CubeRuntime,
kernel::utils::{address_type, shape_divmod},
ops::numeric::empty_device_dtype,
tensor::CubeTensor,
};
use cubecl::{
calculate_cube_count_elemwise,
prelude::*,
std::{FastDivmod, FastDivmodArgs},
};
#[cube(launch_unchecked, address_type = "dynamic")]
fn repeat_dim_kernel<E: Numeric>(
input: &Tensor<E>,
output: &mut Tensor<E>,
out_shape: Sequence<FastDivmod<usize>>,
in_shape: FastDivmod<usize>,
#[comptime] dim: usize,
#[define(E)] _dtype: StorageType,
) {
if ABSOLUTE_POS >= output.len() {
terminate!();
}
let rank = out_shape.len().comptime();
let mut pos = ABSOLUTE_POS;
let mut offset_input = 0;
let mut offset_output = 0;
#[unroll]
for i in 0..rank {
let i = rank - i - 1;
let (rem, mut local_pos) = out_shape[i].div_mod(pos);
pos = rem;
offset_output += local_pos * output.stride(i);
if i == dim {
local_pos = in_shape.modulo(local_pos);
}
offset_input += local_pos * input.stride(i);
}
output[offset_output] = input[offset_input];
}
pub(crate) fn repeat_dim<R: CubeRuntime>(
mut input: CubeTensor<R>,
dim: usize,
times: usize,
) -> CubeTensor<R> {
if input.meta.shape()[dim] == 1 {
input.meta.strides[dim] = 0;
input.meta.shape = input.meta.shape.repeat(dim, times).unwrap();
return input;
}
let shape = input.meta.shape.clone().repeat(dim, times).unwrap();
// Create output handle
let output = empty_device_dtype(
input.client.clone(),
input.device.clone(),
shape,
input.dtype,
);
let working_units = output.meta.num_elements();
let cube_dim = CubeDim::new(&input.client, working_units);
let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
unsafe {
repeat_dim_kernel::launch_unchecked(
&input.client,
cube_count,
cube_dim,
address_type!(input, output),
input.as_tensor_arg(1),
output.as_tensor_arg(1),
shape_divmod(&output),
FastDivmodArgs::new(&input.client, input.meta.shape()[dim]),
dim,
output.dtype.into(),
)
.expect("Kernel to never fail");
};
output
}

View File

@@ -0,0 +1,109 @@
use crate::{
CubeRuntime,
kernel::{
AddOp, BinaryOp, BinaryOpFamily, OrOp,
utils::{address_type, shape_divmod},
},
tensor::CubeTensor,
};
use cubecl::{CubeDim, calculate_cube_count_elemwise};
use cubecl::{prelude::*, std::FastDivmod};
#[cube(launch_unchecked, address_type = "dynamic")]
fn scatter_kernel<T: Numeric, I: Int, Op: BinaryOpFamily>(
input: &mut Tensor<T>,
indices: &Tensor<I>,
value: &Tensor<T>,
in_shape: Sequence<FastDivmod<usize>>,
#[comptime] dim: usize,
#[define(T, I)] _dtypes: [StorageType; 2],
) {
let rank = in_shape.len().comptime();
let stride_input = input.stride(dim);
let stride_value = value.stride(dim);
let stride_indices = indices.stride(dim);
let shape_value = value.shape(dim);
let mut offset = ABSOLUTE_POS;
let mut offset_input = 0;
let mut offset_indices = 0;
let mut offset_value = 0;
let mut num_elems = 1;
#[unroll]
for i in 0..rank {
let i = rank - i - 1;
if i != dim {
let shape_input_loop = input.shape(i);
let (rem, local_pos) = in_shape[i].div_mod(offset);
offset = rem;
offset_input += local_pos * input.stride(i);
offset_indices += local_pos * indices.stride(i);
offset_value += local_pos * value.stride(i);
num_elems *= shape_input_loop;
}
}
let should_stop = ABSOLUTE_POS >= num_elems;
if should_stop {
terminate!();
}
for i in 0..shape_value {
let value_idx = (stride_value * i) + offset_value;
let index_idx = (stride_indices * i) + offset_indices;
let value = value[value_idx];
let index = usize::cast_from(indices[index_idx]);
let input_idx = (stride_input * index) + offset_input;
let value =
Op::BinaryOp::<T>::execute(Line::cast_from(input[input_idx]), Line::cast_from(value));
input[input_idx] = value[0];
}
}
pub(crate) fn scatter<R: CubeRuntime>(
dim: usize,
tensor: CubeTensor<R>,
indices: CubeTensor<R>,
value: CubeTensor<R>,
is_bool: bool,
) -> CubeTensor<R> {
let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() {
true => tensor,
false => tensor.copy(),
};
let num_elems = tensor.meta.num_elements() / tensor.meta.shape()[dim];
let working_units = num_elems;
let cube_dim = CubeDim::new(&indices.client, working_units);
let cube_count = calculate_cube_count_elemwise(&indices.client, working_units, cube_dim);
let launch = match is_bool {
true => scatter_kernel::launch_unchecked::<OrOp, R>,
false => scatter_kernel::launch_unchecked::<AddOp, R>,
};
unsafe {
launch(
&indices.client.clone(),
cube_count,
cube_dim,
address_type!(tensor, indices, value),
tensor.as_tensor_arg(1),
indices.as_tensor_arg(1),
value.as_tensor_arg(1),
shape_divmod(&tensor),
dim,
[tensor.dtype.into(), indices.dtype.into()],
)
.expect("Kernel to never fail");
}
tensor
}

View File

@@ -0,0 +1,82 @@
use crate::{CubeRuntime, kernel::utils::address_type, tensor::CubeTensor};
use crate::{
kernel::utils::{linear_view, shape_divmod},
ops::numeric::empty_device_dtype,
};
use burn_backend::TensorMetadata;
use cubecl::{CubeDim, calculate_cube_count_elemwise, std::tensor::layout::linear::LinearView};
use cubecl::{prelude::*, std::FastDivmod};
#[cube(launch_unchecked, address_type = "dynamic")]
fn select_kernel<T: Numeric, I: Numeric>(
input: &Tensor<T>,
indices: &LinearView<I>,
output: &mut LinearView<T, ReadWrite>,
out_shape: Sequence<FastDivmod<usize>>,
dim: usize,
#[define(T, I)] _dtypes: [StorageType; 2],
) {
if ABSOLUTE_POS >= output.shape() {
terminate!();
}
let rank = out_shape.len().comptime();
let mut offset = ABSOLUTE_POS;
let mut offset_input = 0;
#[unroll]
for i in 0..rank {
let i = rank - i - 1;
let (rem, offset_local) = out_shape[i].div_mod(offset);
offset = rem;
let offset_local = cubecl::prelude::select(
i == dim,
usize::cast_from(indices[offset_local]),
offset_local,
);
offset_input += offset_local * input.stride(i);
}
output[ABSOLUTE_POS] = input[offset_input];
}
pub(crate) fn select<R: CubeRuntime>(
tensor: CubeTensor<R>,
dim: usize,
indices: CubeTensor<R>,
) -> CubeTensor<R> {
let mut shape_output = tensor.shape();
shape_output[dim] = indices.meta.shape()[0];
let total_elem = shape_output.num_elements();
let output = empty_device_dtype(
tensor.client.clone(),
tensor.device.clone(),
shape_output,
tensor.dtype,
);
let working_units = total_elem;
let cube_dim = CubeDim::new(&indices.client, working_units);
let cube_count = calculate_cube_count_elemwise(&indices.client, working_units, cube_dim);
unsafe {
select_kernel::launch_unchecked(
&tensor.client,
cube_count,
cube_dim,
address_type!(tensor, indices, output),
tensor.as_tensor_arg(1),
linear_view(&indices, 1),
linear_view(&output, 1),
shape_divmod(&output),
ScalarArg::new(dim),
[tensor.dtype.into(), indices.dtype.into()],
)
.expect("Kernel to never fail");
};
output
}

View File

@@ -0,0 +1,98 @@
use crate::kernel::{
AddOp, BinaryOp, BinaryOpFamily, OrOp,
utils::{address_type, linear_view, shape_divmod},
};
use crate::{CubeRuntime, tensor::CubeTensor};
use cubecl::{CubeDim, calculate_cube_count_elemwise, std::tensor::layout::linear::LinearView};
use cubecl::{prelude::*, std::FastDivmod};
#[cube(launch_unchecked, address_type = "dynamic")]
fn select_assign_kernel<F: Numeric, I: Numeric, Op: BinaryOpFamily>(
tensor: &mut Tensor<F>,
indices: &LinearView<I>,
value: &Tensor<F>,
value_shape: Sequence<FastDivmod<usize>>,
num_elems: usize,
#[comptime] dim: usize,
#[define(F, I)] _dtypes: [StorageType; 2],
) {
if ABSOLUTE_POS >= num_elems {
terminate!();
}
let rank = value_shape.len().comptime();
let mut offset = ABSOLUTE_POS;
let mut offset_tensor = 0;
let mut offset_value = 0;
// Calculate offsets and num_elems
#[unroll]
for i in 0..rank {
let i = rank - i - 1;
if i != dim {
let (rem, local_pos) = value_shape[i].div_mod(offset);
offset = rem;
offset_tensor += local_pos * tensor.stride(i);
offset_value += local_pos * value.stride(i);
}
}
let strides_tensor_dim = tensor.stride(dim);
let strides_value_dim = value.stride(dim);
// Main operation
for i in 0..value.shape(dim) {
let index_tensor = usize::cast_from(indices[i]) * strides_tensor_dim + offset_tensor;
let index_value = i * strides_value_dim + offset_value;
let value = Op::BinaryOp::<F>::execute(
Line::cast_from(tensor[index_tensor]),
Line::cast_from(value[index_value]),
);
tensor[index_tensor] = F::cast_from(value);
}
}
pub(crate) fn select_assign<R: CubeRuntime>(
tensor: CubeTensor<R>,
dim: usize,
indices: CubeTensor<R>,
value: CubeTensor<R>,
is_bool: bool,
) -> CubeTensor<R> {
let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() {
true => tensor,
false => tensor.copy(),
};
let num_elems = tensor.meta.num_elements() / tensor.meta.shape()[dim];
let working_units = num_elems;
let cube_dim = CubeDim::new(&indices.client, working_units);
let cube_count = calculate_cube_count_elemwise(&indices.client, working_units, cube_dim);
let launch = match is_bool {
true => select_assign_kernel::launch_unchecked::<OrOp, R>,
false => select_assign_kernel::launch_unchecked::<AddOp, R>,
};
unsafe {
launch(
&tensor.client,
cube_count,
cube_dim,
address_type!(tensor, indices, value),
tensor.as_tensor_arg(1),
linear_view(&indices, 1),
value.as_tensor_arg(1),
shape_divmod(&value),
ScalarArg::new(num_elems),
dim,
[tensor.dtype.into(), indices.dtype.into()],
)
.expect("Kernel to never fail");
};
tensor
}

View File

@@ -0,0 +1,247 @@
use crate::{
CubeRuntime,
kernel::utils::{address_type, linear_view, shape_divmod},
ops::numeric::empty_device_dtype,
tensor::CubeTensor,
};
use burn_backend::{Slice, TensorMetadata};
use burn_std::{Metadata, SliceOps};
use cubecl::{
calculate_cube_count_elemwise, intrinsic,
prelude::*,
std::{FastDivmod, tensor::layout::linear::LinearView},
};
use std::ops::Range;
/// Slice a jit tensor with a set of ranges
pub fn slice<R: CubeRuntime>(tensor: CubeTensor<R>, indices: &[Range<usize>]) -> CubeTensor<R> {
let mut dims = tensor.shape();
let mut offset_start = 0u64;
let mut offset_end = 0u64;
for i in 0..indices.len() {
offset_start += (tensor.meta.strides()[i] * indices[i].start) as u64;
offset_end += (tensor.meta.strides()[i] * (dims[i] - indices[i].end)) as u64;
dims[i] = indices[i].end - indices[i].start;
}
let offset_start = offset_start * tensor.dtype.size() as u64;
let offset_end = offset_end * tensor.dtype.size() as u64;
let memory_offset_alignment = tensor.client.properties().memory.alignment;
if offset_start.is_multiple_of(memory_offset_alignment)
&& offset_end.is_multiple_of(memory_offset_alignment)
{
CubeTensor::new(
tensor.client,
tensor
.handle
.offset_start(offset_start)
.offset_end(offset_end),
Metadata::new(dims, tensor.meta.strides),
tensor.device,
tensor.dtype,
)
} else {
let output = empty_device_dtype(
tensor.client.clone(),
tensor.device.clone(),
dims,
tensor.dtype,
);
slice_on_output(tensor, output, indices)
}
}
#[cube(launch_unchecked, address_type = "dynamic")]
fn slice_kernel<E: Numeric>(
input: &Tensor<E>,
output: &mut LinearView<E, ReadWrite>,
out_shape: Sequence<FastDivmod<usize>>,
indices: Sequence<usize>,
#[define(E)] _dtype: StorageType,
) {
if !output.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
let rank = comptime![out_shape.len()];
let mut offset_output = ABSOLUTE_POS;
let mut offset_input = 0;
#[unroll]
for i in 0..rank {
// Iterate in reverse to use divmod
let dim = rank - i - 1;
let range_start = indices[dim];
let (rem, offset_local) = out_shape[dim].div_mod(offset_output);
offset_output = rem;
let offset_local = offset_local + range_start;
offset_input += offset_local * input.stride(dim);
}
output[ABSOLUTE_POS] = input[offset_input];
}
pub(crate) fn slice_on_output<R: CubeRuntime>(
tensor: CubeTensor<R>,
output: CubeTensor<R>,
indices: &[Range<usize>],
) -> CubeTensor<R> {
let ndims = tensor.meta.num_dims();
let mut indices_sequence = SequenceArg::<R, usize>::new();
for i in 0..ndims {
let start = indices.get(i).map(|index| index.start).unwrap_or(0);
indices_sequence.push(ScalarArg::new(start));
}
let working_units = output.meta.num_elements();
let cube_dim = CubeDim::new(&tensor.client, working_units);
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
unsafe {
slice_kernel::launch_unchecked(
&tensor.client,
cube_count,
cube_dim,
address_type!(tensor, output),
tensor.as_tensor_arg(1),
linear_view(&output, 1),
shape_divmod(&output),
indices_sequence,
tensor.dtype.into(),
)
.expect("Kernel to never fail");
};
output
}
/// Kernel for slicing with steps
#[cube(launch_unchecked, address_type = "dynamic")]
fn slice_with_steps_kernel<E: Numeric>(
input: &Tensor<E>,
output: &mut LinearView<E, ReadWrite>,
out_shape: Sequence<FastDivmod<usize>>,
starts: Sequence<usize>,
ends: Sequence<usize>,
steps: Sequence<i32>,
#[define(E)] _dtype: StorageType,
) {
if !output.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
let rank = comptime![out_shape.len()];
let mut output_offset = ABSOLUTE_POS;
let mut input_offset = 0;
// Calculate the input offset based on output position and slice info
#[unroll]
for i in 0..rank {
// Iterate in reverse to use divmod
let dim = rank - i - 1;
let start = starts[dim];
let end = ends[dim];
let step = steps[dim];
let (rem, output_idx) = out_shape[dim].div_mod(output_offset);
output_offset = rem;
let input_idx = if step > 0 {
// Forward stepping
start + output_idx * (step as usize)
} else {
// Backward stepping - start from end-1
let abs_step = (-step) as usize;
let end_minus_1 = end - 1;
end_minus_1 - output_idx * abs_step
};
input_offset += input_idx * input.stride(dim);
}
output[ABSOLUTE_POS] = input[input_offset];
}
/// Slice a tensor with steps
pub fn slice_with_steps<R: CubeRuntime>(tensor: CubeTensor<R>, slices: &[Slice]) -> CubeTensor<R> {
// Check if all steps are 1 - if so, use the optimized regular slice
let all_steps_one = slices.iter().all(|info| info.step == 1);
if all_steps_one {
// Convert Slice to Range for step=1
let simple_ranges: Vec<Range<usize>> = slices
.iter()
.enumerate()
.map(|(i, slice)| slice.to_range(tensor.meta.shape()[i]))
.collect();
return slice(tensor, &simple_ranges);
}
// Calculate output shape
let shape_output = tensor.shape().slice(slices).unwrap();
// Create output tensor
let output = empty_device_dtype(
tensor.client.clone(),
tensor.device.clone(),
shape_output.clone(),
tensor.dtype,
);
// Prepare three separate sequences for kernel
let mut starts = SequenceArg::<R, usize>::new();
let mut ends = SequenceArg::<R, usize>::new();
let mut steps = SequenceArg::<R, i32>::new();
for (dim, slice) in slices.iter().enumerate() {
let range = slice.to_range(tensor.meta.shape()[dim]);
starts.push(ScalarArg::new(range.start));
ends.push(ScalarArg::new(range.end));
steps.push(ScalarArg::new(slice.step as i32));
}
// Pad with default values if needed to match tensor dimensions
for dim in slices.len()..tensor.meta.num_dims() {
starts.push(ScalarArg::new(0));
ends.push(ScalarArg::new(tensor.meta.shape()[dim]));
steps.push(ScalarArg::new(1));
}
// Launch kernel
let working_units = shape_output.num_elements();
let cube_dim = CubeDim::new(&tensor.client, working_units);
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
unsafe {
slice_with_steps_kernel::launch_unchecked(
&tensor.client,
cube_count,
cube_dim,
address_type!(tensor, output),
tensor.as_tensor_arg(1),
linear_view(&output, 1),
shape_divmod(&output),
starts,
ends,
steps,
tensor.dtype.into(),
)
.expect("Kernel to never fail");
}
output
}
/// This is annoying and we need to find a way to do this automatically at some point
#[allow(unused)]
#[cube]
fn unwrap(value: u32) -> comptime_type!(u32) {
intrinsic!(|_| value.constant().unwrap().as_u32())
}

View File

@@ -0,0 +1,258 @@
use crate::{
CubeRuntime,
kernel::utils::{address_type, linear_view, shape_divmod},
tensor::CubeTensor,
};
use cubecl::{
calculate_cube_count_elemwise, intrinsic,
prelude::*,
std::{FastDivmod, FastDivmodArgs, tensor::layout::linear::LinearView},
};
#[cube(launch_unchecked, address_type = "dynamic")]
fn slice_assign_kernel<E: Numeric>(
input: &mut Tensor<Line<E>>,
value: &LinearView<Line<E>>,
slice_shape: Sequence<FastDivmod<usize>>,
slice_offsets: Sequence<usize>,
#[define(E)] _dtype: StorageType,
) {
if !value.is_in_bounds(ABSOLUTE_POS) {
terminate!()
}
let rank = comptime!(slice_shape.len());
let line_size = input.line_size();
let mut offset_remainder = ABSOLUTE_POS * line_size;
let mut offset_input = 0;
#[allow(clippy::explicit_counter_loop)]
#[unroll]
for i in 0..rank {
let dim = rank - i - 1;
let (rem, offset_local) = slice_shape[dim].div_mod(offset_remainder);
let range_start = slice_offsets[dim];
let offset_local_input = offset_local + range_start;
offset_input += offset_local_input * input.stride(dim);
offset_remainder = rem;
}
// Value tensor is accessed linearly since it's a LinearView
input[offset_input / line_size] = value[ABSOLUTE_POS];
}
/// Kernel for slice assign with steps
#[cube(launch_unchecked, address_type = "dynamic")]
fn slice_assign_with_steps_kernel<E: Numeric>(
input: &mut Tensor<E>,
value: &LinearView<E>,
value_shape: Sequence<FastDivmod<usize>>,
starts: Sequence<usize>,
ends: Sequence<usize>,
steps: Sequence<i32>,
#[define(E)] _dtype: StorageType,
) {
if !value.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
let rank = comptime![value_shape.len()];
let mut value_offset = ABSOLUTE_POS;
let mut input_offset = 0;
// Calculate the input offset based on value position and slice info
#[unroll]
for i in 0..rank {
// Iterate in reverse to use divmod
let dim = rank - i - 1;
let start = starts[dim];
let end = ends[dim];
let step = steps[dim];
let (rem, value_idx) = value_shape[dim].div_mod(value_offset);
value_offset = rem;
let input_idx = if step > 0 {
// Forward stepping
start + value_idx * (step as usize)
} else if step < 0 {
// Backward stepping - start from end-1
// For negative steps, we iterate backwards through the selected indices
let abs_step = (-step) as usize;
let end_minus_1 = end - 1;
end_minus_1 - value_idx * abs_step
} else {
// step == 0, shouldn't happen
value_idx
};
input_offset += input_idx * input.stride(dim);
}
input[input_offset] = value[ABSOLUTE_POS];
}
pub(crate) fn slice_assign<R: CubeRuntime>(
tensor: CubeTensor<R>,
indices: &[burn_backend::Slice],
value: CubeTensor<R>,
) -> CubeTensor<R> {
// Check if any slice has non-unit step
let has_non_unit_step = indices.iter().any(|s| s.step != 1 && s.step != 0);
if has_non_unit_step {
// Use slice_assign_with_steps
return slice_assign_with_steps(tensor, indices, value);
}
let client = tensor.client.clone();
let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() {
true => tensor,
false => tensor.copy(),
};
let ndims = tensor.meta.num_dims();
let line_size = if tensor.meta.strides()[ndims - 1] == 1 && value.meta.strides()[ndims - 1] == 1
{
let last = indices
.get(ndims - 1)
.cloned()
.unwrap_or(burn_backend::Slice {
start: 0,
end: Some(tensor.meta.shape()[ndims - 1] as isize),
step: 1,
});
let end = last.end.unwrap_or(tensor.meta.shape()[ndims - 1] as isize);
let shape = (end - last.start) as usize;
let offset = last.start as usize;
client
.io_optimized_line_sizes(tensor.dtype.size())
.filter(|&it| {
shape.is_multiple_of(it)
&& strides_compatible(tensor.meta.strides(), it)
&& strides_compatible(value.meta.strides(), it)
&& offset.is_multiple_of(it)
})
.max()
.unwrap_or(1)
} else {
1
};
let mut shape = SequenceArg::<R, FastDivmod<usize>>::new();
let mut offsets = SequenceArg::<R, usize>::new();
for i in 0..ndims {
let slice = indices.get(i).cloned().unwrap_or(burn_backend::Slice {
start: 0,
end: Some(tensor.meta.shape()[i] as isize),
step: 1,
});
let start = slice.start as usize;
let end = slice.end.unwrap_or(tensor.meta.shape()[i] as isize);
let length = (end - slice.start) as usize;
shape.push(FastDivmodArgs::<usize>::new(&client, length));
offsets.push(ScalarArg::new(start));
}
let working_units = value.meta.num_elements() / line_size;
let cube_dim = CubeDim::new(&tensor.client, working_units);
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
unsafe {
slice_assign_kernel::launch_unchecked(
&tensor.client,
cube_count,
cube_dim,
address_type!(tensor, value),
tensor.as_tensor_arg(line_size),
linear_view(&value, line_size),
shape,
offsets,
tensor.dtype.into(),
)
.expect("Kernel to never fail");
}
tensor
}
/// Slice assign with steps support
///
/// This function handles slice assignment with arbitrary step values, including negative steps.
/// It follows NumPy/PyTorch semantics where values[i] is assigned to selected_indices[i].
///
/// For example, with s![0..6;-1] which selects indices [5,4,3,2,1,0]:
/// - values[0] goes to index 5
/// - values[1] goes to index 4
/// - etc.
pub(crate) fn slice_assign_with_steps<R: CubeRuntime>(
tensor: CubeTensor<R>,
slices: &[burn_backend::Slice],
value: CubeTensor<R>,
) -> CubeTensor<R> {
let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() {
true => tensor,
false => tensor.copy(),
};
// Prepare sequences for kernel
let mut starts = SequenceArg::<R, usize>::new();
let mut ends = SequenceArg::<R, usize>::new();
let mut steps = SequenceArg::<R, i32>::new();
for (dim, slice) in slices.iter().enumerate() {
let range = slice.to_range(tensor.meta.shape()[dim]);
starts.push(ScalarArg::new(range.start));
ends.push(ScalarArg::new(range.end));
steps.push(ScalarArg::new(slice.step as i32));
}
// Pad with default values if needed to match tensor dimensions
for dim in slices.len()..tensor.meta.num_dims() {
starts.push(ScalarArg::new(0));
ends.push(ScalarArg::new(tensor.meta.shape()[dim]));
steps.push(ScalarArg::new(1));
}
// Launch kernel
let working_units = value.meta.num_elements();
let cube_dim = CubeDim::new(&tensor.client, working_units);
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
unsafe {
slice_assign_with_steps_kernel::launch_unchecked(
&tensor.client,
cube_count,
cube_dim,
address_type!(tensor, value),
tensor.as_tensor_arg(1),
linear_view(&value, 1),
shape_divmod(&value),
starts,
ends,
steps,
tensor.dtype.into(),
)
.expect("Kernel to never fail");
}
tensor
}
fn strides_compatible(strides: &[usize], vec: usize) -> bool {
strides
.iter()
.all(|stride| *stride % vec == 0 || *stride == 1)
}
/// Helper function for unwrap
#[allow(unused)]
#[cube]
fn unwrap(value: u32) -> comptime_type!(u32) {
intrinsic!(|_| value.constant().unwrap().as_u32())
}

View File

@@ -0,0 +1,79 @@
use crate::{
CubeRuntime,
kernel::into_contiguous,
ops::{numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
tensor::CubeTensor,
};
use burn_backend::{
Shape, TensorMetadata,
ops::{InterpolateMode, InterpolateOptions},
};
use super::{
bicubic::interpolate_bicubic_launch, bilinear::interpolate_bilinear_launch,
nearest::interpolate_nearest_launch, nearest_backward::interpolate_nearest_backward_launch,
};
/// Interpolate operation
///
/// Supports nearest, bilinear and bicubic modes
pub fn interpolate<R: CubeRuntime>(
input: CubeTensor<R>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> CubeTensor<R> {
let [batch_size, channels, _, _] = input.meta.shape().dims();
let [out_height, out_width] = output_size;
let input = into_contiguous(permute_nchw_to_nhwc(input));
let shape_out = Shape::new([batch_size, out_height, out_width, channels]);
let output = empty_device_dtype(
input.client.clone(),
input.device.clone(),
shape_out,
input.dtype,
);
let align_corners = options.align_corners;
let output = match options.mode {
InterpolateMode::Nearest => interpolate_nearest_launch(input, output),
InterpolateMode::Bilinear => interpolate_bilinear_launch(input, output, align_corners),
InterpolateMode::Bicubic => interpolate_bicubic_launch(input, output, align_corners),
};
permute_nhwc_to_nchw(output)
}
/// Backward interpolate operation
///
/// Note: only nearest mode is supported
pub fn interpolate_backward<R: CubeRuntime>(
input: CubeTensor<R>,
out_grad: CubeTensor<R>,
_output_size: [usize; 2],
options: InterpolateOptions,
) -> CubeTensor<R> {
let input = permute_nchw_to_nhwc(input);
let out_grad = permute_nchw_to_nhwc(out_grad);
let output_shape = input.shape();
let output = empty_device_dtype(
input.client.clone(),
input.device.clone(),
output_shape,
input.dtype,
);
let output = match options.mode {
InterpolateMode::Nearest => interpolate_nearest_backward_launch(out_grad, output),
InterpolateMode::Bilinear => {
panic!("bilinear interpolation backward is not supported by JIT backend")
}
InterpolateMode::Bicubic => {
panic!("bicubic interpolation backward is not supported by JIT backend")
}
};
permute_nhwc_to_nchw(output)
}

View File

@@ -0,0 +1,194 @@
use cubecl::std::{
FastDivmod,
tensor::layout::{linear::LinearLayout, *},
};
use cubecl::{calculate_cube_count_elemwise, prelude::*};
use crate::{
CubeRuntime,
kernel::utils::{address_type, linear_layout, shape_divmod},
ops::max_line_size,
tensor::CubeTensor,
};
#[cube(launch, address_type = "dynamic")]
fn interpolate_bicubic_kernel<F: Float>(
input: &Tensor<Line<F>>,
output: &mut Tensor<Line<F>>,
shape_out: Sequence<FastDivmod<usize>>,
out_layout: LinearLayout,
#[comptime] align_corners: bool,
#[define(F)] _dtype: StorageType,
) {
if ABSOLUTE_POS >= output.len() {
terminate!();
}
let line_size = input.line_size();
let out_idx = out_layout.to_source_pos(ABSOLUTE_POS);
let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * line_size);
let (rem, x) = shape_out[2].div_mod(rem);
let (b, y) = shape_out[1].div_mod(rem);
let input_height = input.shape(1) - 1;
let input_height_f = input_height as f32;
let frac = if align_corners {
let output_height = clamp_min(output.shape(1) - 1, 1) as f32;
(y * input_height) as f32 / output_height
} else {
let in_size = (input_height + 1) as f32;
let out_size = output.shape(1) as f32;
(y as f32 + 0.5) * (in_size / out_size) - 0.5
};
let y_in_f = frac.floor();
let yw = Line::empty(line_size).fill(F::cast_from(frac - y_in_f));
// Clamp indices in float space to handle negative coordinates from half_pixel
let y0 = clamp(y_in_f - 1.0, 0.0, input_height_f) as usize;
let y1 = clamp(y_in_f, 0.0, input_height_f) as usize;
let y2 = clamp(y_in_f + 1.0, 0.0, input_height_f) as usize;
let y3 = clamp(y_in_f + 2.0, 0.0, input_height_f) as usize;
let input_width = input.shape(2) - 1;
let input_width_f = input_width as f32;
let frac = if align_corners {
let output_width = clamp_min(output.shape(2) - 1, 1) as f32;
(x * input_width) as f32 / output_width
} else {
let in_size = (input_width + 1) as f32;
let out_size = output.shape(2) as f32;
(x as f32 + 0.5) * (in_size / out_size) - 0.5
};
let x_in_f = frac.floor();
let xw = Line::empty(line_size).fill(F::cast_from(frac - x_in_f));
// Clamp indices in float space to handle negative coordinates from half_pixel
let x0 = clamp(x_in_f - 1.0, 0.0, input_width_f) as usize;
let x1 = clamp(x_in_f, 0.0, input_width_f) as usize;
let x2 = clamp(x_in_f + 1.0, 0.0, input_width_f) as usize;
let x3 = clamp(x_in_f + 2.0, 0.0, input_width_f) as usize;
let index_base = b * input.stride(0) + c * input.stride(3);
let in_stride_y = input.stride(1);
let in_stride_x = input.stride(2);
let y0_stride = y0 * in_stride_y;
let y1_stride = y1 * in_stride_y;
let y2_stride = y2 * in_stride_y;
let y3_stride = y3 * in_stride_y;
let x0_stride = x0 * in_stride_x;
let x1_stride = x1 * in_stride_x;
let x2_stride = x2 * in_stride_x;
let x3_stride = x3 * in_stride_x;
let inp_0 = input[(index_base + y0_stride + x0_stride) / line_size];
let inp_1 = input[(index_base + y0_stride + x1_stride) / line_size];
let inp_2 = input[(index_base + y0_stride + x2_stride) / line_size];
let inp_3 = input[(index_base + y0_stride + x3_stride) / line_size];
let coefficients0 = cubic_interp_1d::<F>(inp_0, inp_1, inp_2, inp_3, xw);
let inp_0 = input[(index_base + y1_stride + x0_stride) / line_size];
let inp_1 = input[(index_base + y1_stride + x1_stride) / line_size];
let inp_2 = input[(index_base + y1_stride + x2_stride) / line_size];
let inp_3 = input[(index_base + y1_stride + x3_stride) / line_size];
let coefficients1 = cubic_interp_1d::<F>(inp_0, inp_1, inp_2, inp_3, xw);
let inp_0 = input[(index_base + y2_stride + x0_stride) / line_size];
let inp_1 = input[(index_base + y2_stride + x1_stride) / line_size];
let inp_2 = input[(index_base + y2_stride + x2_stride) / line_size];
let inp_3 = input[(index_base + y2_stride + x3_stride) / line_size];
let coefficients2 = cubic_interp_1d::<F>(inp_0, inp_1, inp_2, inp_3, xw);
let inp_0 = input[(index_base + y3_stride + x0_stride) / line_size];
let inp_1 = input[(index_base + y3_stride + x1_stride) / line_size];
let inp_2 = input[(index_base + y3_stride + x2_stride) / line_size];
let inp_3 = input[(index_base + y3_stride + x3_stride) / line_size];
let coefficients3 = cubic_interp_1d::<F>(inp_0, inp_1, inp_2, inp_3, xw);
let val = cubic_interp_1d::<F>(
coefficients0,
coefficients1,
coefficients2,
coefficients3,
yw,
);
output[out_idx] = val;
}
#[cube]
fn cubic_interp_1d<F: Float>(
x0: Line<F>,
x1: Line<F>,
x2: Line<F>,
x3: Line<F>,
t: Line<F>,
) -> Line<F> {
let a = lined(&x0, -0.75);
let coeffs0 = cubic_convolution_2::<F>(t + lined(&x0, 1.0), a);
let coeffs1 = cubic_convolution_1::<F>(t, a);
let coeffs2 = cubic_convolution_1::<F>(lined(&x0, 1.0) - t, a);
let coeffs3 = cubic_convolution_2::<F>(lined(&x0, 2.0) - t, a);
x0 * coeffs0 + x1 * coeffs1 + x2 * coeffs2 + x3 * coeffs3
}
#[cube]
fn cubic_convolution_1<F: Float>(x: Line<F>, a: Line<F>) -> Line<F> {
let conv = (a + lined(&x, 2.0)) * x;
let tmp = a + lined(&x, 3.0);
(conv - tmp) * x * x + lined(&x, 1.0)
}
#[cube]
fn cubic_convolution_2<F: Float>(x: Line<F>, a: Line<F>) -> Line<F> {
let conv = a * x;
let conv = (conv - lined(&x, 5.0) * a) * x;
let tmp = lined(&x, 8.0) * a;
let conv = (conv + tmp) * x;
conv - lined(&x, 4.0) * a
}
#[cube]
fn lined<F: Float>(x: &Line<F>, #[comptime] v: f32) -> Line<F> {
Line::empty(x.size()).fill(F::new(v))
}
pub(crate) fn interpolate_bicubic_launch<R: CubeRuntime>(
input: CubeTensor<R>,
output: CubeTensor<R>,
align_corners: bool,
) -> CubeTensor<R> {
let line_size = max_line_size(&input);
let out_shape = shape_divmod(&output);
let out_layout = linear_layout(&output, line_size);
let working_units = output.meta.num_elements() / line_size as usize;
let cube_dim = CubeDim::new(&input.client, working_units);
let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
interpolate_bicubic_kernel::launch(
&input.client,
cube_count,
cube_dim,
address_type!(input, output),
input.as_tensor_arg(line_size),
output.as_tensor_arg(line_size),
out_shape,
out_layout,
align_corners,
output.dtype.into(),
)
.expect("Kernel to never fail");
output
}

View File

@@ -0,0 +1,149 @@
use cubecl::std::{
FastDivmod,
tensor::layout::{linear::LinearLayout, *},
};
use cubecl::{calculate_cube_count_elemwise, prelude::*};
use crate::{
CubeRuntime,
kernel::utils::{address_type, linear_layout, shape_divmod},
ops::max_line_size,
tensor::CubeTensor,
};
#[cube(launch, address_type = "dynamic")]
fn interpolate_bilinear_kernel<F: Float>(
input: &Tensor<Line<F>>,
output: &mut Tensor<Line<F>>,
shape_out: Sequence<FastDivmod<usize>>,
out_layout: LinearLayout,
#[comptime] align_corners: bool,
#[define(F)] _dtype: StorageType,
) {
if ABSOLUTE_POS >= output.len() {
terminate!();
}
let line_size = input.line_size();
let out_idx = out_layout.to_source_pos(ABSOLUTE_POS);
let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * line_size);
let (rem, x) = shape_out[2].div_mod(rem);
let (b, y) = shape_out[1].div_mod(rem);
let frac = if align_corners {
let numerator = (input.shape(1) - 1) as f32;
let denominator = clamp_min(output.shape(1) - 1, 1) as f32;
y as f32 * (numerator / denominator)
} else {
let in_size = input.shape(1) as f32;
let out_size = output.shape(1) as f32;
clamp(
(y as f32 + 0.5) * (in_size / out_size) - 0.5,
0.0,
in_size - 1.0,
)
};
let v0 = frac.floor();
let v1 = frac.ceil();
let yw = F::cast_from(frac - v0);
let yw_ = Line::empty(line_size).fill(F::new(1.0) - yw);
let yw = Line::empty(line_size).fill(yw);
let y0_ok = v0 >= 0.0;
let y0 = v0 as usize;
let y1 = v1 as usize;
let frac = if align_corners {
let numerator = (input.shape(2) - 1) as f32;
let denominator = clamp_min(output.shape(2) - 1, 1) as f32;
x as f32 * (numerator / denominator)
} else {
let in_size = input.shape(2) as f32;
let out_size = output.shape(2) as f32;
clamp(
(x as f32 + 0.5) * (in_size / out_size) - 0.5,
0.0,
in_size - 1.0,
)
};
let v0 = frac.floor();
let v1 = frac.ceil();
let xw = F::cast_from(frac - v0);
let xw_ = Line::empty(line_size).fill(F::new(1.0) - xw);
let xw = Line::empty(line_size).fill(xw);
let x0_ok = v0 >= 0.0;
let x0 = v0 as usize;
let x1 = v1 as usize;
let index_base = b * input.stride(0) + c * input.stride(3);
let in_stride_y = input.stride(1);
let in_stride_x = input.stride(2);
let y0_stride = y0 * in_stride_y;
let y1_stride = y1 * in_stride_y;
let x0_stride = x0 * in_stride_x;
let x1_stride = x1 * in_stride_x;
let height = input.shape(1);
let width = input.shape(2);
let y1_ok = y1 < height;
let x1_ok = x1 < width;
let zero = Line::empty(line_size).fill(F::new(0.0));
let p_a = select(
x0_ok && y0_ok,
input[(index_base + y0_stride + x0_stride) / line_size] * xw_ * yw_,
zero,
);
let p_b = select(
x1_ok && y0_ok,
input[(index_base + y0_stride + x1_stride) / line_size] * xw * yw_,
zero,
);
let p_c = select(
x0_ok && y1_ok,
input[(index_base + y1_stride + x0_stride) / line_size] * xw_ * yw,
zero,
);
let p_d = select(
x1_ok && y1_ok,
input[(index_base + y1_stride + x1_stride) / line_size] * xw * yw,
zero,
);
output[out_idx] = p_a + p_b + p_c + p_d;
}
pub(crate) fn interpolate_bilinear_launch<R: CubeRuntime>(
input: CubeTensor<R>,
output: CubeTensor<R>,
align_corners: bool,
) -> CubeTensor<R> {
let line_size = max_line_size(&input);
let out_shape = shape_divmod(&output);
let out_layout = linear_layout(&output, line_size);
let working_units = output.meta.num_elements() / line_size as usize;
let cube_dim = CubeDim::new(&input.client, working_units);
let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
interpolate_bilinear_kernel::launch(
&input.client,
cube_count,
cube_dim,
address_type!(input, output),
input.as_tensor_arg(line_size),
output.as_tensor_arg(line_size),
out_shape,
out_layout,
align_corners,
output.dtype.into(),
)
.expect("Kernel to never fail");
output
}

View File

@@ -0,0 +1,7 @@
mod base;
mod bicubic;
mod bilinear;
mod nearest;
mod nearest_backward;
pub use base::*;

View File

@@ -0,0 +1,80 @@
use cubecl::std::{
FastDivmod,
tensor::layout::{linear::LinearLayout, *},
};
use cubecl::{calculate_cube_count_elemwise, prelude::*};
use crate::{
CubeRuntime,
kernel::utils::{address_type, linear_layout, shape_divmod},
ops::max_line_size,
tensor::CubeTensor,
};
#[cube(launch_unchecked, address_type = "dynamic")]
fn interpolate_nearest_kernel<F: Float>(
input: &Tensor<Line<F>>,
output: &mut Tensor<Line<F>>,
shape_out: Sequence<FastDivmod<usize>>,
out_layout: LinearLayout,
#[define(F)] _dtype: StorageType,
) {
if ABSOLUTE_POS >= output.len() {
terminate!();
}
let line_size = input.line_size();
let out_idx = out_layout.to_source_pos(ABSOLUTE_POS);
let out_pos = ABSOLUTE_POS * line_size;
let (h_in, w_in) = (input.shape(1) as f32, input.shape(2) as f32);
let (h_out, w_out) = (output.shape(1) as f32, output.shape(2) as f32);
let (rem, c) = shape_out[3].div_mod(out_pos);
let (rem, x) = shape_out[2].div_mod(rem);
let (b, y) = shape_out[1].div_mod(rem);
let y = y as f32 * (h_in / h_out);
let x = x as f32 * (w_in / w_out);
let in_idx = b * input.stride(0)
+ y as usize * input.stride(1)
+ x as usize * input.stride(2)
+ c * input.stride(3);
output[out_idx] = input[in_idx / line_size];
}
pub(crate) fn interpolate_nearest_launch<R: CubeRuntime>(
input: CubeTensor<R>,
output: CubeTensor<R>,
) -> CubeTensor<R> {
let client = input.client.clone();
let line_size = max_line_size(&input);
let working_units = output.meta.num_elements() / line_size as usize;
let cube_dim = CubeDim::new(&input.client, working_units);
let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
let shape_out = shape_divmod(&output);
let out_layout = linear_layout(&output, line_size);
unsafe {
interpolate_nearest_kernel::launch_unchecked(
&client,
cube_count,
cube_dim,
address_type!(input, output),
input.as_tensor_arg(line_size),
output.as_tensor_arg(line_size),
shape_out,
out_layout,
output.dtype.into(),
)
.expect("Kernel to never fail");
};
output
}

View File

@@ -0,0 +1,103 @@
use cubecl::std::{
FastDivmod,
tensor::layout::{linear::LinearLayout, *},
};
use cubecl::{calculate_cube_count_elemwise, prelude::*};
use crate::{
CubeRuntime,
kernel::utils::{address_type, linear_layout, shape_divmod},
ops::max_line_size,
tensor::CubeTensor,
};
#[cube(launch_unchecked, address_type = "dynamic")]
fn interpolate_nearest_backward_kernel<F: Float>(
grad: &Tensor<Line<F>>,
output: &mut Tensor<Line<F>>,
shape_out: Sequence<FastDivmod<usize>>,
out_layout: LinearLayout,
#[define(F)] _dtype: StorageType,
) {
if ABSOLUTE_POS >= output.len() {
terminate!();
}
let line_size = grad.line_size();
let out_idx = out_layout.to_source_pos(ABSOLUTE_POS);
let out_h = output.shape(1);
let out_w = output.shape(2);
let grad_h = grad.shape(1);
let grad_w = grad.shape(2);
let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * line_size);
let (rem, out_x) = shape_out[2].div_mod(rem);
let (b, out_y) = shape_out[1].div_mod(rem);
let grad_y_start = start_index::<F>(out_y, grad_h, out_h);
let grad_y_end = end_index::<F>(out_y, grad_h, out_h);
let grad_x_start = start_index::<F>(out_x, grad_w, out_w);
let grad_x_end = end_index::<F>(out_x, grad_w, out_w);
let index_grad_base = b * grad.stride(0) + c * grad.stride(3);
let mut sum = Line::empty(line_size).fill(F::new(0.0));
for grad_y in grad_y_start..grad_y_end {
for grad_x in grad_x_start..grad_x_end {
let index_grad = index_grad_base + grad_y * grad.stride(1) + grad_x * grad.stride(2);
sum += grad[index_grad];
}
}
output[out_idx] = sum;
}
#[cube]
fn start_index<F: Float>(input_index: usize, output_size: usize, input_size: usize) -> usize {
let numerator = F::cast_from(input_index * output_size);
let div = (numerator / F::cast_from(input_size)).ceil();
usize::cast_from(div)
}
#[cube]
fn end_index<F: Float>(input_index: usize, output_size: usize, input_size: usize) -> usize {
let numerator = F::cast_from((input_index + 1) * output_size);
let div = (numerator / F::cast_from(input_size)).ceil();
let index = usize::cast_from(div);
clamp_max(index, output_size)
}
pub(crate) fn interpolate_nearest_backward_launch<R: CubeRuntime>(
out_grad: CubeTensor<R>,
output: CubeTensor<R>,
) -> CubeTensor<R> {
let line_size = max_line_size(&out_grad);
let out_shape = shape_divmod(&output);
let out_layout = linear_layout(&output, line_size);
let working_units = output.meta.num_elements() / line_size as usize;
let cube_dim = CubeDim::new(&out_grad.client, working_units);
let cube_count = calculate_cube_count_elemwise(&out_grad.client, working_units, cube_dim);
unsafe {
interpolate_nearest_backward_kernel::launch_unchecked(
&out_grad.client,
cube_count,
cube_dim,
address_type!(out_grad, output),
out_grad.as_tensor_arg(line_size),
output.as_tensor_arg(line_size),
out_shape,
out_layout,
output.dtype.into(),
)
.expect("Kernel to never fail");
};
output
}

View File

@@ -0,0 +1,39 @@
use burn_backend::DType;
use cubecl::prelude::InputScalar;
use super::{MaskFillStrategy, mask_where::MaskWhereStrategy};
use crate::{CubeRuntime, tensor::CubeTensor};
/// Execute the mask fill kernel.
pub(crate) fn mask_fill_auto<R: CubeRuntime>(
tensor: CubeTensor<R>,
mask: CubeTensor<R>,
value: InputScalar,
dtype_bool: DType,
) -> CubeTensor<R> {
let strategy = if tensor.can_mut() && tensor.is_nonoverlapping() {
MaskFillStrategy::Inplace
} else {
MaskFillStrategy::Readonly
};
super::mask_fill(tensor, mask, value, strategy, dtype_bool)
}
/// Execute the mask where kernel.
pub(crate) fn mask_where_auto<R: CubeRuntime>(
tensor: CubeTensor<R>,
mask: CubeTensor<R>,
value: CubeTensor<R>,
dtype_bool: DType,
) -> CubeTensor<R> {
let strategy = if tensor.can_mut_broadcast(&value) {
MaskWhereStrategy::InplaceLhs
} else if value.can_mut_broadcast(&tensor) {
MaskWhereStrategy::InplaceRhs
} else {
MaskWhereStrategy::Readonly
};
super::mask_where(tensor, mask, value, strategy, dtype_bool)
}

View File

@@ -0,0 +1,88 @@
use burn_backend::{DType, TensorMetadata};
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
use crate::{
CubeRuntime,
kernel::utils::{address_type, linear_view, linear_view_alias, linear_view_ref},
ops::{max_line_size_many, numeric::empty_device_dtype},
tensor::CubeTensor,
};
#[cube(launch_unchecked, address_type = "dynamic")]
fn mask_fill_kernel<T: Numeric, B: Int>(
input: &LinearView<Line<T>>,
mask: &LinearView<Line<B>>,
output: &mut LinearView<Line<T>, ReadWrite>,
value: InputScalar,
#[define(T, B)] _dtypes: [StorageType; 2],
) {
if !output.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
let mask = Line::cast_from(mask[ABSOLUTE_POS]);
let input = input[ABSOLUTE_POS];
let value = Line::new(value.get::<T>());
output[ABSOLUTE_POS] = select_many(mask, value, input);
}
#[derive(Clone, Copy, Debug)]
/// Define how to run the mask fill kernel.
///
/// # Notes
///
/// All assertions should be done before choosing the strategy.
pub enum MaskFillStrategy {
/// Don't mutate any input.
Readonly,
/// Reuse the input tensor inplace.
Inplace,
}
/// Execute the mask fill kernel with the given strategy.
pub fn mask_fill<R: CubeRuntime>(
input: CubeTensor<R>,
mask: CubeTensor<R>,
value: InputScalar,
strategy: MaskFillStrategy,
dtype_bool: DType,
) -> CubeTensor<R> {
let ndims = input.meta.num_dims();
let output = match strategy {
MaskFillStrategy::Readonly => empty_device_dtype(
input.client.clone(),
input.device.clone(),
input.shape(),
input.dtype,
),
MaskFillStrategy::Inplace => input.clone(),
};
let line_size = max_line_size_many(&[&input, &mask], ndims - 1);
let working_units = input.meta.num_elements() / line_size as usize;
let cube_dim = CubeDim::new(&input.client, working_units);
let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
let out_arg = match strategy {
MaskFillStrategy::Readonly => linear_view(&output, line_size),
MaskFillStrategy::Inplace => linear_view_alias(&output, line_size, 0),
};
unsafe {
mask_fill_kernel::launch_unchecked(
&input.client,
cube_count,
cube_dim,
address_type!(input, mask, output),
linear_view(&input, line_size),
linear_view_ref(&mask, &input, line_size),
out_arg,
value,
[output.dtype.into(), dtype_bool.into()],
)
.expect("Kernel to never fail");
}
output
}

View File

@@ -0,0 +1,91 @@
use burn_backend::DType;
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
use crate::{
CubeRuntime,
kernel::utils::{
address_type, broadcast_shape, linear_view, linear_view_alias, linear_view_ref,
},
ops::{max_line_size_many, numeric::empty_device_dtype},
tensor::CubeTensor,
};
#[cube(launch, address_type = "dynamic")]
fn mask_where_kernel<T: Numeric, B: Int>(
input: &LinearView<Line<T>>,
value: &LinearView<Line<T>>,
mask: &LinearView<Line<B>>,
output: &mut LinearView<Line<T>, ReadWrite>,
#[define(T, B)] _dtypes: [StorageType; 2],
) {
let pos = ABSOLUTE_POS;
if !output.is_in_bounds(pos) {
terminate!();
}
output[pos] = select_many(Line::cast_from(mask[pos]), value[pos], input[pos]);
}
#[derive(Clone, Copy, Debug)]
/// Define how to run the mask where kernel.
///
/// # Notes
///
/// All assertions should be done before choosing the strategy.
pub enum MaskWhereStrategy {
/// Don't mutate any input.
Readonly,
/// Reuse the lhs tensor inplace.
InplaceLhs,
/// Reuse the rhs tensor inplace.
InplaceRhs,
}
/// Execute the mask where kernel with the given strategy.
pub fn mask_where<R: CubeRuntime>(
input: CubeTensor<R>,
mask: CubeTensor<R>,
value: CubeTensor<R>,
strategy: MaskWhereStrategy,
dtype_bool: DType,
) -> CubeTensor<R> {
let line_size = max_line_size_many(&[&input, &mask, &value], input.meta.num_dims() - 1);
let working_units = input.meta.num_elements() / line_size as usize;
let cube_dim = CubeDim::new(&input.client, working_units);
let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
let out_shape = broadcast_shape(&[&input, &mask, &value]);
let output = match strategy {
MaskWhereStrategy::Readonly => empty_device_dtype(
input.client.clone(),
input.device.clone(),
out_shape,
input.dtype,
),
MaskWhereStrategy::InplaceLhs => input.clone(),
MaskWhereStrategy::InplaceRhs => value.clone(),
};
let out = match strategy {
MaskWhereStrategy::Readonly => linear_view(&output, line_size),
MaskWhereStrategy::InplaceLhs => linear_view_alias(&output, line_size, 0),
MaskWhereStrategy::InplaceRhs => linear_view_alias(&output, line_size, 1),
};
mask_where_kernel::launch(
&input.client,
cube_count,
cube_dim,
address_type!(input, value, mask, output),
linear_view_ref(&input, &output, line_size),
linear_view_ref(&value, &output, line_size),
linear_view_ref(&mask, &output, line_size),
out,
[output.dtype.into(), dtype_bool.into()],
)
.expect("Kernel to never fail");
output
}

View File

@@ -0,0 +1,8 @@
mod base;
mod mask_fill;
mod mask_where;
pub(crate) use base::*;
pub use mask_fill::*;
pub use mask_where::*;

View File

@@ -0,0 +1,155 @@
use super::init_matmul_output;
use crate::{CubeRuntime, kernel::quantization::dequantize, tensor::CubeTensor};
use burn_backend::{DType, QTensorPrimitive};
use burn_std::QuantLevel;
use cubek::matmul::{
definition::{MatmulElems, MatmulGlobalElems, MatmulSetupError},
launch::{MatmulInputHandleRef, Strategy},
};
#[cfg(feature = "autotune")]
use super::matmul_autotune;
/// The strategy to be used when launching a matmul kernel.
pub enum MatmulStrategy {
#[cfg(feature = "autotune")]
/// Using autotune to choose the best kernel based on runtime information.
Autotune,
/// Cube implementation of matmul.
Cube,
}
impl Default for MatmulStrategy {
fn default() -> Self {
// if autotune is enabled, default to autotune
#[cfg(feature = "autotune")]
return MatmulStrategy::Autotune;
#[cfg(not(feature = "autotune"))]
MatmulStrategy::Cube
}
}
/// Launch a matmul kernel using the given strategy.
pub fn matmul<R: CubeRuntime>(
lhs: CubeTensor<R>,
rhs: CubeTensor<R>,
out: Option<CubeTensor<R>>,
strategy: MatmulStrategy,
out_dtype: DType,
) -> Result<CubeTensor<R>, MatmulSetupError> {
match strategy {
MatmulStrategy::Cube => {
let out = out.unwrap_or_else(|| init_matmul_output(&lhs, &rhs, out_dtype));
launch_matmul(&Default::default(), lhs, rhs, out.clone())?;
Ok(out)
}
#[cfg(feature = "autotune")]
MatmulStrategy::Autotune => Ok(matmul_autotune(lhs, rhs, out, out_dtype)),
}
}
pub(crate) fn launch_matmul_naive<R: CubeRuntime>(
strategy: &Strategy,
mut lhs: CubeTensor<R>,
mut rhs: CubeTensor<R>,
out: CubeTensor<R>,
) -> Result<(), MatmulSetupError> {
// Naive has very specific layout requirements for block scaled tensors, so we need to manually
// dequantize if it fails to launch normally. This is because naive is assumed to always work.
if lhs.qparams.is_some() || rhs.qparams.is_some() {
match launch_matmul(strategy, lhs.clone(), rhs.clone(), out.clone()) {
Err(_) => {
if lhs.qparams.is_some() {
lhs = dequantize(lhs, out.dtype);
}
if rhs.qparams.is_some() {
rhs = dequantize(rhs, out.dtype);
}
launch_matmul(strategy, lhs, rhs, out)
}
Ok(_) => Ok(()),
}
} else {
launch_matmul(strategy, lhs, rhs, out)
}
}
pub(crate) fn launch_matmul<R: CubeRuntime>(
strategy: &Strategy,
lhs: CubeTensor<R>,
mut rhs: CubeTensor<R>,
out: CubeTensor<R>,
) -> Result<(), MatmulSetupError> {
let client = &lhs.client;
let lhs_quant_handles = lhs.quantized_handles();
let out_dtype: DType = out.dtype;
let (lhs_dtype, lhs_handle) = match &lhs_quant_handles {
None => (
lhs.dtype,
MatmulInputHandleRef::new(lhs.as_handle_ref(), lhs.dtype.into()),
),
Some((data, scale)) => (
out_dtype,
MatmulInputHandleRef::quantized(
data.as_handle_ref(),
scale.as_handle_ref(),
lhs.meta.shape(),
lhs.scheme(),
data.dtype.into(),
scale.dtype.into(),
),
),
};
let rhs_quant_handles = rhs.quantized_handles();
let (rhs_dtype, rhs_handle) = match &rhs_quant_handles {
None => (
lhs.dtype,
MatmulInputHandleRef::new(rhs.as_handle_ref(), lhs.dtype.into()),
),
Some((data, scale)) => {
// Extremely hacky fix to ensure naive can run in every case
if matches!(strategy, Strategy::Naive)
&& matches!(rhs.scheme().level, QuantLevel::Block(_))
{
rhs = dequantize(rhs.clone(), lhs.dtype);
(
lhs.dtype,
MatmulInputHandleRef::new(rhs.as_handle_ref(), rhs.dtype.into()),
)
} else {
(
out_dtype,
MatmulInputHandleRef::quantized(
data.as_handle_ref(),
scale.as_handle_ref(),
rhs.meta.shape(),
rhs.scheme(),
data.dtype.into(),
scale.dtype.into(),
),
)
}
}
};
let mut dtypes = MatmulElems::from_globals(&MatmulGlobalElems {
lhs: lhs_dtype.into(),
rhs: rhs_dtype.into(),
out: out_dtype.into(),
});
cubek::matmul::launch::launch_ref(
strategy,
client,
&lhs_handle,
&rhs_handle,
&out.as_handle_ref(),
&mut dtypes,
)?;
Ok(())
}

View File

@@ -0,0 +1,10 @@
mod base;
mod tune;
/// Contains utilities for matmul operation
pub mod utils;
pub use base::*;
#[cfg(feature = "autotune")]
pub use tune::*;
pub use utils::*;

View File

@@ -0,0 +1,409 @@
use crate::{
CubeRuntime, CubeTuneId,
kernel::matmul::{launch_matmul, launch_matmul_naive, utils::init_matmul_output},
tensor::CubeTensor,
};
use burn_backend::DType;
use cubecl::tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner};
use cubek::matmul::{
definition::MatmulKind,
launch::{MatmulAutotuneKey, MatmulGlobalScale, Strategy, should_tune_double_buffering},
routines::{
BlueprintStrategy, TileSizeSelection, double_buffering::DoubleBufferingArgs,
double_unit::DoubleUnitSelectionArgs, ordered_double_buffering::OrderedSelectionArgs,
simple::SimpleArgs, simple_unit::SimpleUnitSelectionArgs,
},
};
fn matmul_input_gen<R: CubeRuntime>(
_key: &MatmulAutotuneKey,
lhs: &CubeTensor<R>,
rhs: &CubeTensor<R>,
out: &CubeTensor<R>,
) -> (CubeTensor<R>, CubeTensor<R>, CubeTensor<R>) {
(lhs.clone(), rhs.clone(), out.copy())
}
/// Executes autotune on matmul operations
pub fn matmul_autotune<R: CubeRuntime>(
lhs: CubeTensor<R>,
rhs: CubeTensor<R>,
out: Option<CubeTensor<R>>,
out_dtype: DType,
) -> CubeTensor<R> {
let output = out.unwrap_or_else(|| init_matmul_output(&lhs, &rhs, out_dtype));
let client = lhs.client.clone();
static TUNER: LocalTuner<MatmulAutotuneKey, CubeTuneId> = local_tuner!();
let tunables = TUNER.init(|| {
const PRIORITY_MAX: i8 = 3;
const PRIORITY_HIGH: i8 = 2;
const PRIORITY_MEDIUM: i8 = 1;
const PRIORITY_MIN: i8 = 0;
const PRIORITY_NEVER: i8 = -1;
let cmma = TuneGroup::<MatmulAutotuneKey>::new("cmma", |key| {
if matches!(
key.analysis.kind,
MatmulKind::General
// Those variants are just because the unit alternatives aren't very good yet.
| MatmulKind::VecMat | MatmulKind::MatVec
) {
PRIORITY_HIGH
} else {
PRIORITY_MEDIUM
}
});
let mma = TuneGroup::<MatmulAutotuneKey>::new("mma", |key| {
if matches!(
key.analysis.kind,
// General is usually bad, but I think shapes like 16x8196 would be classed as
// general and are very good with MMA
// Should highly degenerated matrices that aren't VecMat have their own class?
MatmulKind::General | MatmulKind::VecMat | MatmulKind::MatVec
) {
PRIORITY_HIGH
} else {
PRIORITY_MEDIUM
}
});
let unit = TuneGroup::<MatmulAutotuneKey>::new("unit", |key| {
if !matches!(key.analysis.kind, MatmulKind::General)
|| matches!(key.analysis.scale_global, MatmulGlobalScale::Small)
{
PRIORITY_HIGH
} else {
PRIORITY_MIN
}
});
let tma = TuneGroup::<MatmulAutotuneKey>::new("tma", |key| {
// For large matmul, we set the max priority to TMA kernels, higher than any other
// matmuls, since they are the best kernels no matter what.
//
// But only when all axis are large.
let max_axis = usize::max(key.definition.m, key.definition.n);
let max_axis = usize::max(key.definition.k, max_axis);
let min_axis = usize::min(key.definition.m, key.definition.n);
let min_axis = usize::min(key.definition.k, min_axis);
let skewed_factor = max_axis / min_axis;
let priority_max = if matches!(key.analysis.kind, MatmulKind::General)
&& matches!(key.analysis.scale_global, MatmulGlobalScale::Large)
&& skewed_factor < 4
{
PRIORITY_MAX
} else {
PRIORITY_HIGH
};
if key.definition.lhs_stride_factor >= 4 && key.definition.rhs_stride_factor >= 4 {
priority_max
} else {
PRIORITY_NEVER
}
});
fn double_buffering_priority(key: &MatmulAutotuneKey, max: i8, min: i8) -> i8 {
if should_tune_double_buffering(false, key) {
max
} else {
min
}
}
let mut set = TunableSet::new(create_key::<R>, matmul_input_gen::<R>);
// First entry should always work, since it is considered the fallback.
set = set.with(
Tunable::new("matmul_naive", |lhs, rhs, out| {
launch_matmul_naive::<R>(&Strategy::Naive, lhs, rhs, out)
.map_err(|err| std::format!("{err:?}"))
})
.group(&unit, |key| {
if matches!(key.analysis.scale_global, MatmulGlobalScale::Small)
|| matches!(key.analysis.kind, MatmulKind::InnerProduct)
{
PRIORITY_MAX
} else {
PRIORITY_MIN
}
}),
);
// Unit VecMat
for (strategy, double_buf) in [
(
Strategy::SimpleVecMat(BlueprintStrategy::Inferred(().into())),
false,
),
(
Strategy::DoubleVecMat(BlueprintStrategy::Inferred(().into())),
true,
),
] {
set = set.with(
Tunable::new(strategy.to_string(), move |lhs, rhs, out| {
launch_matmul::<R>(&strategy, lhs, rhs, out)
.map_err(|err| std::format!("{err:?}"))
})
.group(&unit, move |key| match double_buf {
false => PRIORITY_MAX,
true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH),
}),
);
}
// Unit matmuls
for tile_size in [
TileSizeSelection::MaxTileSize,
TileSizeSelection::MinTileSize,
] {
for (strategy, double_buf) in [
(
Strategy::SimpleUnit(BlueprintStrategy::Inferred(SimpleUnitSelectionArgs {
tile_size,
})),
false,
),
(
Strategy::DoubleUnit(BlueprintStrategy::Inferred(DoubleUnitSelectionArgs {
tile_size,
})),
true,
),
] {
set = set.with(
Tunable::new(strategy.to_string(), move |lhs, rhs, out| {
launch_matmul::<R>(&strategy, lhs, rhs, out)
.map_err(|err| format!("{err:?}"))
})
.group(&unit, move |key| match double_buf {
false => PRIORITY_MAX,
true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH),
}),
)
}
}
// Accelerated matmuls
for (strategy, double_buf, group_extra, tile_group) in [
(
Strategy::SimpleCyclicCmma(BlueprintStrategy::Inferred(SimpleArgs {
multi_rows: false,
})),
false,
None,
&cmma,
),
(
Strategy::SimpleCyclicMma(BlueprintStrategy::Inferred(SimpleArgs {
multi_rows: false,
})),
false,
None,
&mma,
),
(
Strategy::SimpleCyclicCmma(BlueprintStrategy::Inferred(SimpleArgs {
multi_rows: true,
})),
false,
None,
&cmma,
),
(
Strategy::SimpleCyclicMma(BlueprintStrategy::Inferred(SimpleArgs {
multi_rows: true,
})),
false,
None,
&mma,
),
(
Strategy::OrderedDoubleCmma(BlueprintStrategy::Inferred(OrderedSelectionArgs {
partition_k: Some(2),
row_count: Some(4),
rows_per_plane: Some(2),
})),
true,
None,
&cmma,
),
(
Strategy::OrderedDoubleMma(BlueprintStrategy::Inferred(OrderedSelectionArgs {
partition_k: Some(2),
row_count: Some(4),
rows_per_plane: Some(2),
})),
true,
None,
&mma,
),
(
Strategy::OrderedDoubleCmma(BlueprintStrategy::Inferred(OrderedSelectionArgs {
partition_k: Some(2),
row_count: Some(8),
rows_per_plane: Some(2),
})),
true,
None,
&cmma,
),
(
Strategy::OrderedDoubleMma(BlueprintStrategy::Inferred(OrderedSelectionArgs {
partition_k: Some(2),
row_count: Some(8),
rows_per_plane: Some(2),
})),
true,
None,
&mma,
),
(
Strategy::DoubleCyclicCmma(BlueprintStrategy::Inferred(DoubleBufferingArgs {
specialized: false,
})),
true,
None,
&cmma,
),
(
Strategy::DoubleCyclicMma(BlueprintStrategy::Inferred(DoubleBufferingArgs {
specialized: false,
})),
true,
None,
&mma,
),
(
Strategy::DoubleCyclicCmma(BlueprintStrategy::Inferred(DoubleBufferingArgs {
specialized: true,
})),
true,
None,
&cmma,
),
(
Strategy::DoubleCyclicMma(BlueprintStrategy::Inferred(DoubleBufferingArgs {
specialized: true,
})),
true,
None,
&mma,
),
(
Strategy::SpecializedCyclicCmma(BlueprintStrategy::Inferred(().into())),
true,
None,
&cmma,
),
(
Strategy::SpecializedCyclicMma(BlueprintStrategy::Inferred(().into())),
true,
None,
&mma,
),
(
Strategy::SimpleTmaCmma(BlueprintStrategy::Inferred(SimpleArgs {
multi_rows: false,
})),
false,
Some(&tma),
&cmma,
),
(
Strategy::SimpleTmaMma(BlueprintStrategy::Inferred(SimpleArgs {
multi_rows: false,
})),
false,
Some(&tma),
&mma,
),
(
Strategy::SimpleTmaCmma(BlueprintStrategy::Inferred(SimpleArgs {
multi_rows: true,
})),
false,
Some(&tma),
&cmma,
),
(
Strategy::SimpleTmaMma(BlueprintStrategy::Inferred(SimpleArgs {
multi_rows: true,
})),
false,
Some(&tma),
&mma,
),
(
Strategy::SpecializedTmaCmma(BlueprintStrategy::Inferred(().into())),
true,
Some(&tma),
&cmma,
),
(
Strategy::SpecializedTmaMma(BlueprintStrategy::Inferred(().into())),
true,
Some(&tma),
&mma,
),
] {
let priority_within_group = |key: &MatmulAutotuneKey, double_buf: bool| match double_buf
{
false => PRIORITY_MAX,
true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH),
};
let mut tunable = Tunable::new(strategy.to_string(), move |lhs, rhs, out| {
launch_matmul::<R>(&strategy, lhs, rhs, out).map_err(|err| format!("{err:?}"))
});
// tile group
tunable = tunable.group(tile_group, move |key| {
priority_within_group(key, double_buf)
});
// extra group
if let Some(group) = group_extra {
tunable = tunable.group(group, move |key| priority_within_group(key, double_buf));
}
set = set.with(tunable);
}
set
});
TUNER.execute(
&CubeTuneId::new(&lhs.client, &lhs.device),
&client,
tunables,
(lhs, rhs, output.clone()),
);
output
}
fn create_key<R: CubeRuntime>(
lhs: &CubeTensor<R>,
rhs: &CubeTensor<R>,
out: &CubeTensor<R>,
) -> MatmulAutotuneKey {
MatmulAutotuneKey::generate(
&lhs.client,
lhs.meta.shape(),
rhs.meta.shape(),
lhs.meta.strides(),
rhs.meta.strides(),
lhs.dtype.into(),
rhs.dtype.into(),
out.dtype.into(),
lhs.try_scheme(),
rhs.try_scheme(),
)
}

View File

@@ -0,0 +1,5 @@
#[cfg(feature = "autotune")]
mod base;
#[cfg(feature = "autotune")]
pub use base::matmul_autotune;

View File

@@ -0,0 +1,16 @@
use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
use burn_backend::{DType, calculate_matmul_output};
/// Creates an empty output tensor with matmul output shape
pub fn init_matmul_output<R: CubeRuntime>(
lhs: &CubeTensor<R>,
rhs: &CubeTensor<R>,
dtype: DType,
) -> CubeTensor<R> {
empty_device_dtype(
lhs.client.clone(),
lhs.device.clone(),
calculate_matmul_output(lhs.meta.shape(), rhs.meta.shape()).unwrap(),
dtype,
)
}

View File

@@ -0,0 +1,51 @@
mod binary;
mod binary_float;
mod binary_int;
mod cast;
mod clamp;
mod comparison;
mod contiguous;
mod cross;
mod index;
mod mask;
mod unary_float;
mod unary_int;
mod unary_numeric;
pub(crate) use binary::*;
pub(crate) use binary_float::*;
pub(crate) use binary_int::*;
pub use cast::*;
pub use contiguous::*;
pub(crate) use cross::*;
pub use mask::*;
pub(crate) use unary_float::*;
pub(crate) use unary_int::*;
pub(crate) use unary_numeric::*;
pub use crate::cubecl::prelude::KernelMetadata;
/// Attention kernels
pub mod attention;
/// Convolution kernels
pub mod conv;
/// Grid sampling kernels
pub mod grid_sample;
/// Interpolation kernels
pub mod interpolate;
/// Matmul kernels
pub mod matmul;
/// Pooling kernels
pub mod pool;
/// Pseudo-random number generator kernels
pub mod prng;
/// Quantization operations
pub mod quantization;
/// Reduction algorithms
pub mod reduce;
pub(crate) use clamp::*;
pub(crate) use comparison::*;
pub use index::*;
pub(crate) mod utils;

View File

@@ -0,0 +1,117 @@
use crate::{
CubeRuntime,
kernel::{
into_contiguous_aligned,
pool::pool2d::{Position, view4d},
utils::{address_type, decompose_linear, shape_divmod},
},
ops::{max_line_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
tensor::CubeTensor,
};
use burn_backend::Shape;
use cubecl::{
calculate_cube_count_elemwise,
prelude::*,
std::{FastDivmod, tensor::View},
};
#[cube(launch, address_type = "dynamic")]
fn adaptive_avg_pool2d_direct<E: Numeric>(
input: &Tensor<Line<E>>,
output: &mut View<Line<E>, Position, ReadWrite>,
out_shape: Sequence<FastDivmod<usize>>,
working_units: usize,
#[define(E)] _dtype: StorageType,
) {
if ABSOLUTE_POS >= working_units {
terminate!();
}
let (_, pos) = decompose_linear(ABSOLUTE_POS * output.line_size(), &out_shape);
let [b, oh, ow, c] = *pos else { unreachable!() };
let (_, out_h, out_w, _) = output.shape();
let (in_stride_h, in_stride_w) = (input.stride(1), input.stride(2));
let (in_h, in_w) = (input.shape(1), input.shape(2));
let ih_start = start_index(oh, out_h, in_h);
let ih_end = end_index(oh, out_h, in_h);
let iw_start = start_index(ow, out_w, in_w);
let iw_end = end_index(ow, out_w, in_w);
let mut sum = Line::empty(input.line_size()).fill(E::from_int(0));
let index_input_base = b * input.stride(0) + c * input.stride(3);
for ih in ih_start..ih_end {
let index_input_2 = ih * in_stride_h;
for iw in iw_start..iw_end {
let index_input_3 = iw * in_stride_w;
let index_input = index_input_base + index_input_2 + index_input_3;
sum += input[index_input / input.line_size()];
}
}
let num_ih = ih_end - ih_start;
let num_iw = iw_end - iw_start;
output[(b, oh, ow, c)] = sum / Line::cast_from(num_ih * num_iw);
}
#[cube]
fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
(output_size_index * input_size) / output_size
}
#[cube]
fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
let index = (output_size_index + 1) * input_size;
let index = index.div_ceil(output_size);
if input_size < index {
input_size
} else {
index
}
}
pub(crate) fn adaptive_avg_pool2d<R: CubeRuntime>(
input: CubeTensor<R>,
output_size: [usize; 2],
) -> CubeTensor<R> {
let [batch_size, channels, _, _] = input.meta.shape().dims();
let input = into_contiguous_aligned(permute_nchw_to_nhwc(input));
let line_size = max_line_size(&input);
let output_shape = Shape::new([batch_size, output_size[0], output_size[1], channels]);
let num_elems: usize = output_shape.num_elements();
let output = empty_device_dtype(
input.client.clone(),
input.device.clone(),
output_shape,
input.dtype,
);
let working_units = num_elems / line_size as usize;
let cube_dim = CubeDim::new(&input.client, working_units);
let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
adaptive_avg_pool2d_direct::launch(
&input.client,
cube_count,
cube_dim,
address_type!(input, output),
input.as_tensor_arg(line_size),
view4d(&output, line_size),
shape_divmod(&output),
ScalarArg::new(working_units),
output.dtype.into(),
)
.expect("Kernel to never fail");
permute_nhwc_to_nchw(output)
}

View File

@@ -0,0 +1,119 @@
use crate::{
CubeRuntime,
kernel::{
into_contiguous_aligned,
pool::pool2d::{Position, view4d},
utils::{address_type, decompose_linear, shape_divmod},
},
ops::{max_line_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
tensor::CubeTensor,
};
use burn_backend::Shape;
use cubecl::{
calculate_cube_count_elemwise,
prelude::*,
std::{FastDivmod, tensor::View},
};
#[cube(launch, address_type = "dynamic")]
fn adaptive_avg_pool2d_backward_direct<E: Numeric>(
grad: &Tensor<Line<E>>,
output: &mut View<Line<E>, Position, ReadWrite>,
out_shape: Sequence<FastDivmod<usize>>,
working_units: usize,
#[define(E)] _dtype: StorageType,
) {
if ABSOLUTE_POS >= working_units {
terminate!();
}
let (_, out_h, out_w, _) = output.shape();
let (grad_stride_h, grad_stride_w) = (grad.stride(1), grad.stride(2));
let (grad_h, grad_w) = (grad.shape(1), grad.shape(2));
let (_, pos) = decompose_linear(ABSOLUTE_POS * output.line_size(), &out_shape);
let [b, ih, iw, c] = *pos else { unreachable!() };
let oh_start = start_index(ih, out_h, grad_h);
let oh_end = end_index(ih, out_h, grad_h);
let ow_start = start_index(iw, out_w, grad_w);
let ow_end = end_index(iw, out_w, grad_w);
let mut grad_acc = Line::empty(grad.line_size()).fill(E::from_int(0));
let index_base = b * grad.stride(0) + (c * grad.stride(3));
for oh in oh_start..oh_end {
let ih_start = start_index(oh, grad_h, out_h);
let ih_end = end_index(oh, grad_h, out_h);
if ih >= ih_start && ih < ih_end {
for ow in ow_start..ow_end {
let iw_start = start_index(ow, grad_w, out_w);
let iw_end = end_index(ow, grad_w, out_w);
if iw >= iw_start && iw < iw_end {
let num_ih = ih_end - ih_start;
let num_iw = iw_end - iw_start;
let index = index_base + (oh * grad_stride_h) + (ow * grad_stride_w);
grad_acc += grad[index / grad.line_size()] / Line::cast_from(num_iw * num_ih);
}
}
}
}
output[(b, ih, iw, c)] = grad_acc;
}
#[cube]
fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
(output_size_index * input_size) / output_size
}
#[cube]
fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
let index = (output_size_index + 1) * input_size;
let index = index.div_ceil(output_size);
if input_size < index {
input_size
} else {
index
}
}
pub(crate) fn adaptive_avg_pool2d_backward<R: CubeRuntime>(
x: CubeTensor<R>,
out_grad: CubeTensor<R>,
) -> CubeTensor<R> {
let [batches, channels, height, width] = x.meta.shape().dims();
let out_grad = into_contiguous_aligned(permute_nchw_to_nhwc(out_grad));
let line_size = max_line_size(&out_grad);
let out_shape = Shape::new([batches, height, width, channels]);
let output = empty_device_dtype(x.client.clone(), x.device.clone(), out_shape, x.dtype);
let num_elems = output.meta.num_elements();
let working_units = num_elems / line_size as usize;
let cube_dim = CubeDim::new(&x.client, working_units);
let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);
adaptive_avg_pool2d_backward_direct::launch(
&x.client,
cube_count,
cube_dim,
address_type!(out_grad, output),
out_grad.as_tensor_arg(line_size),
view4d(&output, line_size),
shape_divmod(&output),
ScalarArg::new(working_units),
output.dtype.into(),
)
.expect("Kernel to never fail");
permute_nhwc_to_nchw(output)
}

View File

@@ -0,0 +1,166 @@
use super::pool2d::{
Pool2dDirectArgsLaunch, Pool2dDirectStrategy, Pool2dDirectStrategyFamily, pool2d_direct,
};
use crate::{
CubeRuntime,
kernel::{
into_contiguous_aligned,
pool::pool2d::{Position, view4d},
utils::{address_type, shape_divmod},
},
ops::{max_line_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
tensor::CubeTensor,
};
use burn_backend::{Shape, ops::conv::calculate_pool_output_size};
use cubecl::{CubeDim, calculate_cube_count_elemwise, prelude::ScalarArg};
use cubecl::{prelude::*, std::tensor::View};
struct AvgPoolStrategy;
impl Pool2dDirectStrategyFamily for AvgPoolStrategy {
type Indices = ();
type Config = AvgPoolStrategyConfig;
type Pool2d<N: Numeric> = Self;
}
#[derive(CubeType, Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub struct AvgPoolStrategyConfig {
count_include_pad: bool,
/// Total padded height (input_height + 2 * padding_0)
padded_h: u32,
/// Total padded width (input_width + 2 * padding_1)
padded_w: u32,
}
#[cube]
impl<N: Numeric> Pool2dDirectStrategy<N> for AvgPoolStrategy {
type Accumulator = (Line<N>, u32);
type Config = AvgPoolStrategyConfig;
type Indices = ();
fn initialize(
#[comptime] _config: &Self::Config,
#[comptime] line_size: LineSize,
) -> Self::Accumulator {
let sum = Line::empty(line_size).fill(N::from_int(0));
// Count will be set dynamically: either by accumulate (count_include_pad=false)
// or by set_padded_count (count_include_pad=true)
let count = 0u32;
(sum, count)
}
fn accumulate(
#[comptime] config: &Self::Config,
accumulator: &mut Self::Accumulator,
_index: usize,
result: Line<N>,
) {
let (sum, count) = accumulator;
// Only count valid positions when count_include_pad=false
if comptime![!config.count_include_pad] {
*count += 1;
}
*sum += result;
}
fn count_position(
#[comptime] config: &Self::Config,
accumulator: &mut Self::Accumulator,
ih: u32,
iw: u32,
) {
// When count_include_pad=true, count positions within padded bounds
// (excludes ceil_mode extensions beyond the padded input)
if comptime![config.count_include_pad] && ih < config.padded_h && iw < config.padded_w {
let (_sum, count) = accumulator;
*count += 1;
}
}
fn store(
#[comptime] _config: &Self::Config,
position: Position,
output: &mut View<Line<N>, Position, ReadWrite>,
_output_indices: &mut (),
accumulator: Self::Accumulator,
) {
let (sum, count) = accumulator;
output[position] = sum / Line::cast_from(count);
}
}
pub(crate) fn avg_pool2d<R: CubeRuntime>(
x: CubeTensor<R>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
ceil_mode: bool,
) -> CubeTensor<R> {
let [batch_size, channels, in_h, in_w] = x.meta.shape().dims();
let dilation = 1;
let size_0 = calculate_pool_output_size(
kernel_size[0],
stride[0],
padding[0],
dilation,
in_h,
ceil_mode,
);
let size_1 = calculate_pool_output_size(
kernel_size[1],
stride[1],
padding[1],
dilation,
in_w,
ceil_mode,
);
// Padded dimensions (for count_include_pad with ceil_mode)
let padded_0 = in_h + 2 * padding[0];
let padded_1 = in_w + 2 * padding[1];
let x = into_contiguous_aligned(permute_nchw_to_nhwc(x));
let line_size = max_line_size(&x);
let shape_out = Shape::new([batch_size, size_0, size_1, channels]);
let output = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, x.dtype);
let working_units = output.meta.num_elements() / line_size as usize;
let cube_dim = CubeDim::new(&x.client, working_units);
let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);
pool2d_direct::launch::<AvgPoolStrategy, R>(
&x.client,
cube_count,
cube_dim,
address_type!(x, output),
x.as_tensor_arg(line_size),
view4d(&output, line_size),
(),
shape_divmod(&output),
ScalarArg::new(working_units),
Pool2dDirectArgsLaunch::new(
ScalarArg::new(stride[0] as u32),
ScalarArg::new(stride[1] as u32),
ScalarArg::new(dilation as u32),
ScalarArg::new(dilation as u32),
ScalarArg::new(padding[0] as u32),
ScalarArg::new(padding[1] as u32),
),
(kernel_size[0] as u32, kernel_size[1] as u32),
AvgPoolStrategyConfig {
count_include_pad,
padded_h: padded_0 as u32,
padded_w: padded_1 as u32,
},
output.dtype.into(),
)
.expect("Kernel to never fail");
permute_nhwc_to_nchw(output)
}

View File

@@ -0,0 +1,183 @@
use crate::{
CubeRuntime,
kernel::{
pool::pool2d::{Position, view4d},
utils::{address_type, decompose_linear, shape_divmod},
},
ops::{max_line_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
tensor::CubeTensor,
};
use burn_backend::Shape;
use cubecl::{
calculate_cube_count_elemwise,
prelude::*,
std::{FastDivmod, tensor::View},
};
#[derive(CubeLaunch, CubeType)]
pub(crate) struct PoolBackwardArgs {
pub stride_0: i32,
pub stride_1: i32,
pub dilation_0: i32,
pub dilation_1: i32,
pub padding_0: i32,
pub padding_1: i32,
}
#[cube(launch_unchecked, address_type = "dynamic")]
fn avg_pool2d_backward_kernel<E: Numeric>(
grad: &Tensor<Line<E>>,
output: &mut View<Line<E>, Position, ReadWrite>,
out_shape: Sequence<FastDivmod<usize>>,
working_units: usize,
args: &PoolBackwardArgs,
#[comptime] kernel_size_0: i32,
#[comptime] kernel_size_1: i32,
#[comptime] count_include_pad: bool,
#[define(E)] _dtype: StorageType,
) {
if ABSOLUTE_POS >= working_units {
terminate!();
}
let line_size = grad.line_size();
let (_, pos) = decompose_linear(ABSOLUTE_POS * output.line_size(), &out_shape);
let [batch, ih, iw, channel] = *pos else {
unreachable!()
};
let mut grad_acc = Line::empty(grad.line_size()).fill(E::from_int(0));
let (oh_start, oh_end, ow_start, ow_end) = loop_ranges(
ih as i32,
iw as i32,
grad.shape(1) as u32,
grad.shape(2) as u32,
args,
kernel_size_0,
kernel_size_1,
);
let padding_0 = args.padding_0 as u32;
let padding_1 = args.padding_1 as u32;
let stride_0 = args.stride_0 as u32;
let stride_1 = args.stride_1 as u32;
let kernel_size_0 = comptime![kernel_size_0 as u32];
let kernel_size_1 = comptime![kernel_size_1 as u32];
let index_base = batch * grad.stride(0) + channel * grad.stride(3);
let border_bottom = output.shape().1 as u32 + padding_0;
let border_right = output.shape().2 as u32 + padding_1;
let begin_h = ih as u32 + padding_0;
let begin_w = iw as u32 + padding_1;
for oh in oh_start..oh_end {
let ih_start = oh * stride_0;
let ih_end = clamp_max(ih_start + kernel_size_0, border_bottom);
let ih_start = clamp_min(ih_start, padding_0);
if begin_h >= ih_start && (ih as u32) < ih_end {
for ow in ow_start..ow_end {
let index =
index_base + oh as usize * grad.stride(1) + ow as usize * grad.stride(2);
let iw_start = ow * stride_1;
let iw_end = clamp_max(iw_start + kernel_size_1, border_right);
let iw_start = clamp_min(iw_start, padding_1);
if begin_w >= iw_start && (iw as u32) < iw_end {
if count_include_pad {
grad_acc += grad[index / line_size]
/ Line::cast_from(kernel_size_0 * kernel_size_1);
} else {
let ih_diff = ih_end - ih_start;
let iw_diff = iw_end - iw_start;
let count = Line::cast_from(ih_diff * iw_diff);
grad_acc += grad[index / line_size] / count;
}
}
}
}
}
output[(batch, ih, iw, channel)] = grad_acc;
}
#[cube]
fn loop_ranges(
ih: i32,
iw: i32,
grad_h: u32,
grad_w: u32,
args: &PoolBackwardArgs,
#[comptime] kernel_size_0: i32,
#[comptime] kernel_size_1: i32,
) -> (u32, u32, u32, u32) {
let kms_0 = args.dilation_0 * kernel_size_0 - args.stride_0;
let kms_1 = args.dilation_1 * kernel_size_1 - args.stride_1;
let oh_start = clamp_min((ih + args.padding_0 - kms_0) / args.stride_0, 0) as u32;
let ow_start = clamp_min((iw + args.padding_1 - kms_1) / args.stride_1, 0) as u32;
let oh_end = clamp_max(clamp_min(kms_0, 0) as u32 + oh_start, grad_h - 1) + 1;
let ow_end = clamp_max(clamp_min(kms_1, 0) as u32 + ow_start, grad_w - 1) + 1;
(oh_start, oh_end, ow_start, ow_end)
}
pub(crate) fn avg_pool2d_backward<R: CubeRuntime>(
x: CubeTensor<R>,
grad: CubeTensor<R>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
_ceil_mode: bool,
) -> CubeTensor<R> {
let [batches, channels, height, width] = x.meta.shape().dims();
let grad = permute_nchw_to_nhwc(grad);
let line_size = if x.meta.strides()[3] == grad.meta.strides()[3] {
max_line_size(&x)
} else {
1
};
let dilation = 1;
let out_shape = Shape::new([batches, height, width, channels]);
let output = empty_device_dtype(x.client.clone(), x.device.clone(), out_shape, x.dtype);
let working_units = output.meta.num_elements() / line_size as usize;
let cube_dim = CubeDim::new(&x.client, working_units);
let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);
unsafe {
avg_pool2d_backward_kernel::launch_unchecked(
&grad.client,
cube_count,
cube_dim,
address_type!(grad, output),
grad.as_tensor_arg(line_size),
view4d(&output, line_size),
shape_divmod(&output),
ScalarArg::new(working_units),
PoolBackwardArgsLaunch::new(
ScalarArg::new(stride[0] as i32),
ScalarArg::new(stride[1] as i32),
ScalarArg::new(dilation),
ScalarArg::new(dilation),
ScalarArg::new(padding[0] as i32),
ScalarArg::new(padding[1] as i32),
),
kernel_size[0] as i32,
kernel_size[1] as i32,
count_include_pad,
output.dtype.into(),
)
}
.expect("Kernel to never fail");
permute_nhwc_to_nchw(output)
}

View File

@@ -0,0 +1,255 @@
use super::pool2d::{
Pool2dDirectArgsLaunch, Pool2dDirectStrategy, Pool2dDirectStrategyFamily, pool2d_direct,
};
use crate::{
CubeRuntime,
kernel::{
into_contiguous_aligned,
pool::pool2d::{Position, view4d},
utils::{address_type, shape_divmod},
},
ops::{max_line_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
tensor::CubeTensor,
};
use burn_backend::{DType, Shape, ops::conv::calculate_pool_output_size};
use cubecl::{CubeDim, calculate_cube_count_elemwise, prelude::*, std::tensor::View};
struct MaxPoolStrategy;
struct MaxPoolWithIndicesStrategy;
impl Pool2dDirectStrategyFamily for MaxPoolStrategy {
type Indices = ();
type Config = ();
type Pool2d<N: Numeric> = Self;
}
impl Pool2dDirectStrategyFamily for MaxPoolWithIndicesStrategy {
type Indices = View<Line<i32>, Position, ReadWrite>;
type Config = ();
type Pool2d<N: Numeric> = Self;
}
#[cube]
impl<N: Numeric> Pool2dDirectStrategy<N> for MaxPoolStrategy {
type Accumulator = Line<N>;
type Config = ();
type Indices = ();
fn initialize(
#[comptime] _config: &Self::Config,
#[comptime] line_size: LineSize,
) -> Self::Accumulator {
Line::empty(line_size).fill(N::min_value())
}
fn accumulate(
#[comptime] _config: &Self::Config,
accumulator: &mut Self::Accumulator,
_index: LineSize,
result: Line<N>,
) {
*accumulator = max(*accumulator, result);
}
fn count_position(
#[comptime] _config: &Self::Config,
_accumulator: &mut Self::Accumulator,
_ih: u32,
_iw: u32,
) {
}
fn store(
#[comptime] _config: &Self::Config,
position: Position,
output: &mut View<Line<N>, Position, ReadWrite>,
_output_indices: &mut (),
accumulator: Self::Accumulator,
) {
output[position] = accumulator;
}
}
#[cube]
impl<N: Numeric> Pool2dDirectStrategy<N> for MaxPoolWithIndicesStrategy {
type Accumulator = (Line<N>, Line<i32>);
type Config = ();
type Indices = View<Line<i32>, Position, ReadWrite>;
fn initialize(
#[comptime] _config: &Self::Config,
#[comptime] line_size: LineSize,
) -> Self::Accumulator {
let val = Line::empty(line_size).fill(N::min_value());
let idx = Line::empty(line_size).fill(0i32);
(val, idx)
}
fn accumulate(
#[comptime] _config: &Self::Config,
accumulator: &mut Self::Accumulator,
index: usize,
result: Line<N>,
) {
let indices = Line::cast_from(index);
accumulator.1 = select_many(result.greater_than(accumulator.0), indices, accumulator.1);
accumulator.0 = max(result, accumulator.0);
}
fn count_position(
#[comptime] _config: &Self::Config,
_accumulator: &mut Self::Accumulator,
_ih: u32,
_iw: u32,
) {
}
fn store(
#[comptime] _config: &Self::Config,
position: Position,
output: &mut View<Line<N>, Position, ReadWrite>,
output_indices: &mut View<Line<i32>, Position, ReadWrite>,
accumulator: Self::Accumulator,
) {
output[position] = accumulator.0;
output_indices[position] = accumulator.1;
}
}
pub(crate) fn max_pool2d<R: CubeRuntime>(
x: CubeTensor<R>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
) -> CubeTensor<R> {
let [batch_size, channels, height, width] = x.meta.shape().dims();
let size_0 = calculate_pool_output_size(
kernel_size[0],
stride[0],
padding[0],
dilation[0],
height,
ceil_mode,
);
let size_1 = calculate_pool_output_size(
kernel_size[1],
stride[1],
padding[1],
dilation[1],
width,
ceil_mode,
);
let x = into_contiguous_aligned(permute_nchw_to_nhwc(x));
let line_size = max_line_size(&x);
let shape_out = Shape::new([batch_size, size_0, size_1, channels]);
let output = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, x.dtype);
let working_units = output.meta.num_elements() / line_size as usize;
let cube_dim = CubeDim::new(&x.client, working_units);
let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);
pool2d_direct::launch::<MaxPoolStrategy, R>(
&x.client,
cube_count,
cube_dim,
address_type!(x, output),
x.as_tensor_arg(line_size),
view4d(&output, line_size),
(),
shape_divmod(&output),
ScalarArg::new(working_units),
Pool2dDirectArgsLaunch::new(
ScalarArg::new(stride[0] as u32),
ScalarArg::new(stride[1] as u32),
ScalarArg::new(dilation[0] as u32),
ScalarArg::new(dilation[1] as u32),
ScalarArg::new(padding[0] as u32),
ScalarArg::new(padding[1] as u32),
),
(kernel_size[0] as u32, kernel_size[1] as u32),
(),
output.dtype.into(),
)
.expect("Kernel to never fail");
permute_nhwc_to_nchw(output)
}
pub(crate) fn max_pool2d_with_indices<R: CubeRuntime>(
x: CubeTensor<R>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
dtype_indices: DType,
) -> (CubeTensor<R>, CubeTensor<R>) {
let [batch_size, channels, size_0, size_1] = x.meta.shape().dims();
let size_0 = calculate_pool_output_size(
kernel_size[0],
stride[0],
padding[0],
dilation[0],
size_0,
ceil_mode,
);
let size_1 = calculate_pool_output_size(
kernel_size[1],
stride[1],
padding[1],
dilation[1],
size_1,
ceil_mode,
);
let x = into_contiguous_aligned(permute_nchw_to_nhwc(x));
let line_size = max_line_size(&x);
let shape_out = Shape::new([batch_size, size_0, size_1, channels]);
let output = empty_device_dtype(
x.client.clone(),
x.device.clone(),
shape_out.clone(),
x.dtype,
);
let indices = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, dtype_indices);
let working_units = output.meta.num_elements() / line_size as usize;
let cube_dim = CubeDim::new(&x.client, working_units);
let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);
pool2d_direct::launch::<MaxPoolWithIndicesStrategy, R>(
&x.client,
cube_count,
cube_dim,
address_type!(x, output, indices),
x.as_tensor_arg(line_size),
view4d(&output, line_size),
view4d(&indices, line_size),
shape_divmod(&output),
ScalarArg::new(working_units),
Pool2dDirectArgsLaunch::new(
ScalarArg::new(stride[0] as u32),
ScalarArg::new(stride[1] as u32),
ScalarArg::new(dilation[0] as u32),
ScalarArg::new(dilation[1] as u32),
ScalarArg::new(padding[0] as u32),
ScalarArg::new(padding[1] as u32),
),
(kernel_size[0] as u32, kernel_size[1] as u32),
(),
output.dtype.into(),
)
.expect("Kernel to never fail");
let output = permute_nhwc_to_nchw(output);
let indices = permute_nhwc_to_nchw(indices);
(output, indices)
}

View File

@@ -0,0 +1,156 @@
use crate::{
CubeRuntime,
kernel::{
into_contiguous_aligned,
utils::{address_type, decompose_linear, shape_divmod},
},
ops::{max_line_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
tensor::CubeTensor,
};
use burn_backend::Shape;
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::FastDivmod};
use super::{PoolBackwardArgs, PoolBackwardArgsLaunch};
#[cube(launch_unchecked, address_type = "dynamic")]
fn max_pool2d_with_indices_backward_kernel<E: Numeric, I: Int>(
grad: &Tensor<Line<E>>,
indices: &Tensor<Line<I>>,
output: &mut Tensor<Line<E>>,
out_shape: Sequence<FastDivmod<usize>>,
working_units: usize,
args: &PoolBackwardArgs,
#[comptime] kernel_size_0: i32,
#[comptime] kernel_size_1: i32,
#[define(E, I)] _dtypes: [StorageType; 2],
) {
if ABSOLUTE_POS >= working_units {
terminate!();
}
let (_, pos) = decompose_linear(ABSOLUTE_POS * output.line_size(), &out_shape);
let [batch, ih, iw, channel] = *pos else {
unreachable!()
};
let line_size = grad.line_size();
let index_current = ih * output.shape(2) + iw;
let (oh_start, oh_end, ow_start, ow_end) = loop_ranges(
ih as i32,
iw as i32,
grad.shape(1) as u32,
grad.shape(2) as u32,
args,
kernel_size_0,
kernel_size_1,
);
let mut grad_acc = Line::empty(grad.line_size()).fill(E::from_int(0));
let grad_idx_base = batch * grad.stride(0) + channel * grad.stride(3);
let ind_idx_base = batch * indices.stride(0) + channel * indices.stride(3);
for oh in oh_start..oh_end {
for ow in ow_start..ow_end {
let grad_index =
grad_idx_base + oh as usize * grad.stride(1) + ow as usize * grad.stride(2);
let indices_index =
ind_idx_base + oh as usize * indices.stride(1) + ow as usize * indices.stride(2);
let index_max = Line::<u32>::cast_from(indices[indices_index / line_size]);
grad_acc += select_many(
index_max.equal(Line::cast_from(index_current)),
grad[grad_index / line_size],
Line::new(E::from_int(0)),
);
}
}
let index_output = batch * output.stride(0)
+ ih * output.stride(1)
+ iw * output.stride(2)
+ channel * output.stride(3);
output[index_output / output.line_size()] = grad_acc;
}
#[cube]
fn loop_ranges(
ih: i32,
iw: i32,
grad_h: u32,
grad_w: u32,
args: &PoolBackwardArgs,
#[comptime] kernel_size_0: i32,
#[comptime] kernel_size_1: i32,
) -> (u32, u32, u32, u32) {
let kms_0 = args.dilation_0 * kernel_size_0 - args.stride_0;
let kms_1 = args.dilation_1 * kernel_size_1 - args.stride_1;
let oh_start = clamp_min((ih + args.padding_0 - kms_0) / args.stride_0, 0) as u32;
let ow_start = clamp_min((iw + args.padding_1 - kms_1) / args.stride_1, 0) as u32;
let oh_end = clamp_max(clamp_min(kms_0, 0) as u32 + oh_start, grad_h - 1) + 1;
let ow_end = clamp_max(clamp_min(kms_1, 0) as u32 + ow_start, grad_w - 1) + 1;
(oh_start, oh_end, ow_start, ow_end)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn max_pool2d_with_indices_backward<R: CubeRuntime>(
x: CubeTensor<R>,
grad: CubeTensor<R>,
indices: CubeTensor<R>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
_ceil_mode: bool,
) -> CubeTensor<R> {
let [batches, channels, height, width] = x.meta.shape().dims();
let grad = into_contiguous_aligned(permute_nchw_to_nhwc(grad));
let indices = into_contiguous_aligned(permute_nchw_to_nhwc(indices));
let line_size = if grad.meta.strides()[3] == indices.meta.strides()[3] {
max_line_size(&grad)
} else {
1
};
let out_shape = Shape::new([batches, height, width, channels]);
let output = empty_device_dtype(x.client.clone(), x.device.clone(), out_shape, x.dtype);
let working_units = output.meta.num_elements() / line_size as usize;
let cube_dim = CubeDim::new(&x.client, working_units);
let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);
unsafe {
max_pool2d_with_indices_backward_kernel::launch_unchecked(
&x.client,
cube_count,
cube_dim,
address_type!(grad, indices, output),
grad.as_tensor_arg(line_size),
indices.as_tensor_arg(line_size),
output.as_tensor_arg(line_size),
shape_divmod(&output),
ScalarArg::new(working_units),
PoolBackwardArgsLaunch::new(
ScalarArg::new(stride[0] as i32),
ScalarArg::new(stride[1] as i32),
ScalarArg::new(dilation[0] as i32),
ScalarArg::new(dilation[1] as i32),
ScalarArg::new(padding[0] as i32),
ScalarArg::new(padding[1] as i32),
),
kernel_size[0] as i32,
kernel_size[1] as i32,
[x.dtype.into(), indices.dtype.into()],
)
.expect("Kernel to never fail")
};
permute_nhwc_to_nchw(output)
}

View File

@@ -0,0 +1,15 @@
mod adaptive_avg_pool2d;
mod adaptive_avg_pool2d_backward;
mod avg_pool2d;
mod avg_pool2d_backward;
mod max_pool2d;
mod max_pool2d_backward;
pub(super) mod pool2d;
pub(crate) use adaptive_avg_pool2d::*;
pub(crate) use adaptive_avg_pool2d_backward::*;
pub(crate) use avg_pool2d::*;
pub(crate) use avg_pool2d_backward::*;
pub(crate) use max_pool2d::*;
pub(crate) use max_pool2d_backward::*;

View File

@@ -0,0 +1,155 @@
use core::hash::Hash;
use cubecl::{
prelude::*,
std::{
FastDivmod,
tensor::{
View,
launch::ViewArg,
layout::fixed_dim::{FixedDimLayout, FixedDimLayoutLaunch},
},
},
};
use crate::{CubeRuntime, kernel::utils::decompose_linear, tensor::CubeTensor};
pub trait Pool2dDirectStrategyFamily: Send + Sync + 'static {
type Indices: LaunchArg;
type Config: CubeType + Clone + Send + Sync + core::fmt::Debug + Hash + core::cmp::Eq;
type Pool2d<N: Numeric>: Pool2dDirectStrategy<N, Config = Self::Config, Indices = Self::Indices>;
}
pub(super) type Position = (usize, usize, usize, usize);
#[cube]
pub(crate) trait Pool2dDirectStrategy<N: Numeric>: Send + Sync + 'static {
type Accumulator: CubeType;
type Config: CubeType + Clone + Send + Sync + core::fmt::Debug + Hash + core::cmp::Eq;
type Indices: LaunchArg;
fn initialize(
#[comptime] config: &Self::Config,
#[comptime] line_size: LineSize,
) -> Self::Accumulator;
fn accumulate(
#[comptime] config: &Self::Config,
accumulator: &mut Self::Accumulator,
index: usize,
result: Line<N>,
);
/// Count a position within the kernel window (for avg_pool count_include_pad).
/// Called for each position in the kernel window with the current ih/iw coordinates.
/// Only avg_pool uses this; max_pool implements as no-op.
fn count_position(
#[comptime] config: &Self::Config,
accumulator: &mut Self::Accumulator,
ih: u32,
iw: u32,
);
fn store(
#[comptime] config: &Self::Config,
position: Position,
output: &mut View<Line<N>, Position, ReadWrite>,
output_indices: &mut Self::Indices,
accumulator: Self::Accumulator,
);
}
#[derive(CubeLaunch, CubeType)]
pub struct Pool2dDirectArgs {
pub strides_0: u32,
pub strides_1: u32,
pub dilation_0: u32,
pub dilation_1: u32,
pub padding_0: u32,
pub padding_1: u32,
}
#[cube(launch, address_type = "dynamic")]
pub fn pool2d_direct<E: Numeric, S: Pool2dDirectStrategyFamily>(
input: &Tensor<Line<E>>,
output: &mut View<Line<E>, Position, ReadWrite>,
indices: &mut S::Indices,
out_shape: Sequence<FastDivmod<usize>>,
working_units: usize,
args: &Pool2dDirectArgs,
#[comptime] kernel_size: (u32, u32),
#[comptime] config: &S::Config,
#[define(E)] _dtype: StorageType,
) {
if ABSOLUTE_POS >= working_units {
terminate!();
}
let (_, pos) = decompose_linear(ABSOLUTE_POS * output.line_size(), &out_shape);
let [b, oh, ow, c] = *pos else { unreachable!() };
let (in_stride_h, in_stride_w) = (input.stride(1), input.stride(2));
let (in_h, in_w) = (input.shape(1) as u32, input.shape(2) as u32);
let mut accumulator = S::Pool2d::<E>::initialize(config, input.line_size());
let in_b_off = b * input.stride(0);
let in_c_off = c * input.stride(3);
let border_bottom = in_h + args.padding_0;
let border_right = in_w + args.padding_1;
for kh in 0..kernel_size.0 {
let ih = oh as u32 * args.strides_0 + kh * args.dilation_0;
let within_padding_h = ih >= args.padding_0 && ih < border_bottom;
for kw in 0..kernel_size.1 {
let iw = ow as u32 * args.strides_1 + kw * args.dilation_1;
let within_padding_w = iw >= args.padding_1 && iw < border_right;
// Let strategy handle position counting (only used by avg_pool)
S::Pool2d::<E>::count_position(config, &mut accumulator, ih, iw);
// Only accumulate values from valid input positions
if within_padding_h && within_padding_w {
let ih_pad = ih - args.padding_0;
let iw_pad = iw - args.padding_1;
let in_h_off = ih_pad as usize * in_stride_h;
let in_w_off = iw_pad as usize * in_stride_w;
let index_input = in_b_off + in_c_off + in_h_off + in_w_off;
S::Pool2d::<E>::accumulate(
config,
&mut accumulator,
ih_pad as usize * in_w as usize + iw_pad as usize,
input[index_input / input.line_size()],
);
}
}
}
S::Pool2d::<E>::store(config, (b, oh, ow, c), output, indices, accumulator);
}
pub(super) fn view4d<R: CubeRuntime>(
tensor: &CubeTensor<R>,
line_size: LineSize,
) -> ViewArg<'_, Position, R> {
let shape = tensor.meta.shape();
let shape = (
ScalarArg::new(shape[0]),
ScalarArg::new(shape[1]),
ScalarArg::new(shape[2]),
ScalarArg::new(shape[3]),
);
let handle = tensor.as_handle_ref();
let len = handle.shape.iter().product::<usize>();
let layout =
FixedDimLayoutLaunch::<Position, R>::from_shape_handle_unchecked(&handle, shape, line_size);
let buffer = unsafe {
ArrayArg::from_raw_parts_and_size(handle.handle, len, line_size, handle.elem_size)
};
ViewArg::new::<FixedDimLayout<Position>>(buffer, layout)
}

View File

@@ -0,0 +1,18 @@
use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
use burn_backend::{DType, Shape};
/// Pseudo-random generator with bernoulli distribution
pub fn random_bernoulli<R: CubeRuntime>(
shape: Shape,
device: &R::Device,
probability: f32,
dtype: DType,
) -> CubeTensor<R> {
let client = R::client(device);
let output = empty_device_dtype(client.clone(), device.clone(), shape, dtype);
cubek::random::random_bernoulli(&client, probability, output.as_handle_ref(), dtype.into())
.expect("Kernel to never fail");
output
}

View File

@@ -0,0 +1,7 @@
mod bernoulli;
mod normal;
mod uniform;
pub use bernoulli::*;
pub use normal::*;
pub use uniform::*;

View File

@@ -0,0 +1,20 @@
use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
use burn_backend::{DType, Shape};
/// Pseudo-random generator with uniform distribution
pub fn random_normal<R: CubeRuntime>(
shape: Shape,
device: &R::Device,
mean: f32,
std: f32,
dtype: DType,
) -> CubeTensor<R> {
let client = R::client(device);
let output = empty_device_dtype(client.clone(), device.clone(), shape, dtype);
let output_handle = output.as_handle_ref();
cubek::random::random_normal(&client, mean, std, output_handle, dtype.into())
.expect("Kernel to never fail");
output
}

View File

@@ -0,0 +1,43 @@
use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
use burn_backend::{DType, Shape, TensorMetadata};
/// Pseudo-random generator with uniform distribution
pub fn random_uniform<R: CubeRuntime>(
shape: Shape,
device: &R::Device,
lower_bound: f32,
upper_bound: f32,
dtype: DType,
) -> CubeTensor<R> {
let client = R::client(device);
let output = empty_device_dtype(client.clone(), device.clone(), shape, dtype);
let output_handle = output.as_handle_ref();
cubek::random::random_uniform(
&client,
lower_bound,
upper_bound,
output_handle,
dtype.into(),
)
.expect("Kernel to never fail");
output
}
/// Pseudo-random generator for uniform distribution, based on
/// another tensor.
pub fn random_like_uniform<R: CubeRuntime>(
tensor: &CubeTensor<R>,
lower_bound: f32,
upper_bound: f32,
dtype: DType,
) -> CubeTensor<R> {
random_uniform(
tensor.shape(),
&tensor.device,
lower_bound,
upper_bound,
dtype,
)
}

View File

@@ -0,0 +1,34 @@
use crate::tensor::CubeTensor;
use crate::{CubeRuntime, ops::numeric::empty_device_dtype};
use burn_backend::{DType, TensorMetadata};
/// Convert the tensor back to a higher precision data type.
pub fn dequantize<R>(tensor: CubeTensor<R>, dtype: DType) -> CubeTensor<R>
where
R: CubeRuntime,
{
let scheme = match tensor.dtype {
DType::QFloat(scheme) => scheme,
_ => return tensor,
};
let output = empty_device_dtype(
tensor.client.clone(),
tensor.device.clone(),
tensor.shape(),
dtype,
);
let (values, params) = tensor.quantized_handles().unwrap();
cubek::quantization::dequantize::launch_ref(
&values.client,
&values.as_handle_ref(),
&output.as_handle_ref(),
&params.as_handle_ref(),
&scheme,
dtype.into(),
)
.expect("Kernel to never fail");
output
}

View File

@@ -0,0 +1,5 @@
mod dequantize;
mod quantize;
pub use dequantize::*;
pub use quantize::*;

View File

@@ -0,0 +1,29 @@
use crate::CubeRuntime;
use crate::{ops::empty_qtensor_optimized, tensor::CubeTensor};
use burn_backend::{TensorMetadata, quantization::QuantScheme};
/// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
pub fn quantize<R>(
tensor: CubeTensor<R>,
scheme: &QuantScheme,
scale: CubeTensor<R>,
) -> CubeTensor<R>
where
R: CubeRuntime,
{
let output = empty_qtensor_optimized(tensor.shape(), *scheme, &tensor.device);
let (out_values, out_params) = output.clone().quantized_handles().unwrap();
cubek::quantization::quantize::launch_ref(
&tensor.client,
&tensor.as_handle_ref(),
&out_values.as_handle_ref(),
&scale.as_handle_ref(),
&out_params.as_handle_ref(),
scheme,
tensor.dtype.into(),
)
.expect("Kernel to never fail");
output
}

View File

@@ -0,0 +1,242 @@
#[cfg(feature = "autotune")]
use super::{autotune_reduce, autotune_sum};
use crate::{
CubeRuntime,
ops::numeric::{empty_device_contiguous_dtype, zeros_client},
tensor::CubeTensor,
};
use burn_backend::{DType, TensorMetadata};
use burn_std::Metadata;
use cubecl::{AutotuneKey, client::ComputeClient, features::TypeUsage, ir::StorageType};
use cubek::reduce::{
ReduceDtypes, ReduceError, ReduceStrategy,
components::instructions::ReduceOperationConfig,
launch::{LineSizeStrategy, RoutineStrategy},
routines::{BlueprintStrategy, unit::UnitStrategy},
shared_sum,
};
use serde::{Deserialize, Serialize};
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
/// Autotune key representative of sum versions
pub struct SumAutotuneKey {
/// The type of the tensor
dtype: burn_backend::DType,
/// The anchored length of the tensor
#[autotune(anchor)]
length: usize,
}
/// Check if the client supports atomic add for the given element type.
fn supports_atomic_add<R: CubeRuntime>(client: &ComputeClient<R>, dtype: DType) -> bool {
client
.properties()
.type_usage(StorageType::Atomic(dtype.into()))
.contains(TypeUsage::AtomicAdd)
}
/// [Sum](sum) with fallback when `client` doesn't support atomic add for the type `E`.
pub fn sum_fallback<R: CubeRuntime>(
tensor: CubeTensor<R>,
mut strategy: SumStrategy,
) -> Result<CubeTensor<R>, ReduceError> {
// Early check before creating output and fallback
if matches!(strategy, SumStrategy::OneShot(_))
&& !supports_atomic_add(&tensor.client, tensor.dtype)
{
strategy = SumStrategy::Chained(Default::default());
}
sum(tensor, strategy)
}
/// Specialize reduce function to compute the sum of all elements of the `input` tensor and return
/// the value into a single-element tensor of shape `1 x 1 x 1 x ...` with the same rank as `input`.
///
/// This is expected to be faster for larger tensors than calling [reduce] with the `Sum` instruction.
///
/// Return an error if the `client` doesn't support atomic add for the type `E`.
pub fn sum<Run: CubeRuntime>(
tensor: CubeTensor<Run>,
strategy: SumStrategy,
) -> Result<CubeTensor<Run>, ReduceError> {
let client = tensor.client.clone();
let device = tensor.device.clone();
match strategy {
SumStrategy::OneShot(cube_count) => {
let output = zeros_client(client.clone(), device, [1].into(), tensor.dtype);
shared_sum::<Run>(
&client,
tensor.as_handle_ref(),
output.as_handle_ref(),
cube_count,
tensor.dtype.into(),
)?;
Ok(output)
}
SumStrategy::Chained(strategy) => {
reduce::<Run>(tensor, None, strategy, ReduceOperationConfig::Sum)
}
#[cfg(feature = "autotune")]
SumStrategy::Autotune => Ok(autotune_sum::<Run>(&client, tensor)),
}
}
/// Select a strategy to perform a sum.
pub enum SumStrategy {
/// Run a single kernel with many cubes working in parallel to sum all elements.
/// The provided value is the number of elements summed per unit (up-to-rounding )
OneShot(u32),
/// Use multiple kernels
Chained(KernelReduceStrategy),
/// Use autotune to find the best cube count given the hardware and the input.
#[cfg(feature = "autotune")]
Autotune,
}
impl Default for SumStrategy {
fn default() -> Self {
#[cfg(feature = "autotune")]
return Self::Autotune;
#[cfg(not(feature = "autotune"))]
return Self::OneShot(4);
}
}
/// Reduce all elements of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).
///
/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.
///
/// If there is no error, the output is a tensor with decreasing strides
/// where the shape of reduced dim is set to 1 but all shape are similar to the input.
pub fn reduce<Run: CubeRuntime>(
mut tensor: CubeTensor<Run>,
output_dtype: Option<DType>,
strategy: KernelReduceStrategy,
config: ReduceOperationConfig,
) -> Result<CubeTensor<Run>, cubek::reduce::ReduceError> {
// In practice, it looks like starting by the axis with the smallest shape
// and going in increasing order lead to the fastest calculation.
let sorted_axis = argsort(tensor.meta.shape());
for axis in sorted_axis {
tensor = reduce_dim::<Run>(tensor, output_dtype, axis, strategy.clone(), config)?;
}
// reshape to scalar tensor
*tensor.meta = Metadata::new([1], [1]);
Ok(tensor)
}
fn argsort(shape: &[usize]) -> Vec<usize> {
let mut indices = (0..shape.len()).collect::<Vec<_>>();
indices.sort_by_key(|&i| &shape[i]);
indices
}
/// Reduce the given `axis` of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).
///
/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.
/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid.
///
/// If there is no error, the output is a tensor with decreasing strides
/// where the shape of reduced dim is set to 1 but all shape are similar to the input.
pub fn reduce_dim<Run: CubeRuntime>(
input: CubeTensor<Run>,
output_dtype: Option<DType>,
dim: usize,
strategy: KernelReduceStrategy,
config: ReduceOperationConfig,
) -> Result<CubeTensor<Run>, cubek::reduce::ReduceError> {
debug_assert!(
!matches!(
config,
ReduceOperationConfig::ArgMax | ReduceOperationConfig::ArgMin
) || output_dtype.is_some(),
"The `output_dtype` has to be `Some` only when the `config` is `ArgMax` or `ArgMin`.
"
);
let dtypes = config.precision(input.dtype.into(), output_dtype.map(Into::into));
let client = input.client.clone();
let output = init_reduce_output::<Run>(&input, dim, &dtypes).ok_or(
cubek::reduce::ReduceError::InvalidAxis {
axis: dim,
rank: input.meta.num_dims(),
},
)?;
let result = match strategy {
KernelReduceStrategy::Unspecified => cubek::reduce::reduce::<Run>(
&client,
input.as_handle_ref(),
output.as_handle_ref(),
dim,
ReduceStrategy {
routine: RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
line_size: LineSizeStrategy {
parallel_output_vectorization: false,
},
},
config,
dtypes,
),
KernelReduceStrategy::Specific(strategy) => cubek::reduce::reduce::<Run>(
&client,
input.as_handle_ref(),
output.as_handle_ref(),
dim,
strategy,
config,
dtypes,
),
#[cfg(feature = "autotune")]
KernelReduceStrategy::Autotune => {
autotune_reduce::<Run>(&client, input, output.clone(), dim, config, dtypes);
Ok(())
}
};
result.map(|_| output)
}
/// Creates an empty output tensor with the proper shape and decreasing strides to reduce the given `axis` of `input`
/// or return `None` if `axis` is out-of-bound.
pub fn init_reduce_output<Run: CubeRuntime>(
input: &CubeTensor<Run>,
dim: usize,
dtypes: &ReduceDtypes,
) -> Option<CubeTensor<Run>> {
(dim < input.meta.num_dims()).then(|| {
let mut shape_out = input.shape();
shape_out[dim] = 1;
empty_device_contiguous_dtype(
input.client.clone(),
input.device.clone(),
shape_out,
dtypes.output.elem_type().into(),
)
})
}
/// Select a strategy to perform a reduction.
#[derive(Clone, Debug)]
pub enum KernelReduceStrategy {
/// Use a best-effort strategy based on the hardware capacity.
/// This differs from Autotune as it doesn't try and compare many strategies to select the best.
Unspecified,
/// Fix the exact strategy for the reduction.
Specific(cubek::reduce::launch::ReduceStrategy),
/// Use autotune to find the best strategy given the hardware and the inputs.
#[cfg(feature = "autotune")]
Autotune,
}
impl Default for KernelReduceStrategy {
fn default() -> Self {
#[cfg(feature = "autotune")]
return Self::Autotune;
#[cfg(not(feature = "autotune"))]
return Self::Unspecified;
}
}

View File

@@ -0,0 +1,7 @@
mod base;
#[cfg(feature = "autotune")]
mod tune;
pub use base::*;
#[cfg(feature = "autotune")]
pub use tune::*;

View File

@@ -0,0 +1,286 @@
#![allow(missing_docs)]
use super::SumAutotuneKey;
use crate::{CubeAutotuneKey, CubeRuntime, CubeTuneId, tensor::CubeTensor};
use cubecl::{
client::ComputeClient,
tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner},
};
use cubek::reduce::{
ReduceDtypes, ReduceStrategy,
components::instructions::ReduceOperationConfig,
launch::{LineSizeStrategy, RoutineStrategy, tune_key::ReduceAutotuneKey},
routines::{BlueprintStrategy, cube::CubeStrategy, plane::PlaneStrategy, unit::UnitStrategy},
};
/// Executes autotune on reduce operations.
pub fn autotune_reduce<R: CubeRuntime>(
client: &ComputeClient<R>,
input: CubeTensor<R>,
output: CubeTensor<R>,
axis: usize,
config: ReduceOperationConfig,
dtypes: ReduceDtypes,
) {
use reduce_ops::*;
static TUNER: LocalTuner<ReduceAutotuneKey, CubeTuneId> = local_tuner!("reduce-dim");
let tunables = TUNER.init(|| {
const PRIORITY_MAX: i8 = 2;
const PRIORITY_MIN: i8 = 1;
const PRIORITY_SKIP: i8 = -1;
let mut set = TunableSet::new(create_key::<R>, reduce_input_gen::<R>);
let default_group =
TuneGroup::<ReduceAutotuneKey>::new("default_reduce", |_key| PRIORITY_MAX);
let vectorized_parallel_group =
TuneGroup::<ReduceAutotuneKey>::new("vectorized_parallel_reduce", |key| {
if key.axis_is_contiguous {
PRIORITY_MAX
} else {
// We disable the tunable with the setting [line_size.parallel_output_vectorization]
// when the reduce isn't parallel, since it would duplicate tunables.
PRIORITY_SKIP
}
});
enum ReduceProps {
GreatWithLowReduceCount,
GreatWithHighReduceCount,
Balanced,
}
for (line_size, line_size_ident) in [
(
LineSizeStrategy {
parallel_output_vectorization: true,
},
"_vectorized_parallel_reduce",
),
(
LineSizeStrategy {
parallel_output_vectorization: false,
},
"",
),
] {
for (name, routine, props) in [
(
"unit",
RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
ReduceProps::GreatWithHighReduceCount,
),
(
"plane",
RoutineStrategy::Plane(BlueprintStrategy::Inferred(PlaneStrategy {
independent: true,
})),
ReduceProps::Balanced,
),
(
"cube",
RoutineStrategy::Cube(BlueprintStrategy::Inferred(CubeStrategy {
use_planes: true,
})),
ReduceProps::GreatWithLowReduceCount,
),
] {
let name = format!("{name}{line_size_ident}");
let mut tunable = Tunable::new(
name,
move |(input, output, axis, config, dtypes): (
CubeTensor<R>,
CubeTensor<R>,
usize,
ReduceOperationConfig,
ReduceDtypes,
)| {
let strategy = ReduceStrategy {
routine: routine.clone(),
line_size,
};
cubek::reduce::reduce::<R>(
&input.client,
input.as_handle_ref(),
output.as_handle_ref(),
axis,
strategy,
config,
dtypes,
)
.map_err(|e| format!("{e}"))
},
);
if line_size.parallel_output_vectorization {
tunable = tunable.group(&vectorized_parallel_group, |_| PRIORITY_MAX);
}
tunable = tunable.group(&default_group, move |key| match props {
ReduceProps::GreatWithLowReduceCount => {
if key.vector_count < 128 {
PRIORITY_MAX
} else {
// When you have a high level of vector to reduce, it is normally
// better to use another routine.
PRIORITY_MIN
}
}
ReduceProps::GreatWithHighReduceCount => {
if key.vector_count > 64 {
PRIORITY_MAX
} else {
// Bellow 64 it is normally better to use another routine
PRIORITY_MIN
}
}
ReduceProps::Balanced => PRIORITY_MAX,
});
set = set.with(tunable);
}
}
set
});
TUNER.execute(
&CubeTuneId::new(&input.client, &input.device),
client,
tunables,
(input, output, axis, config, dtypes),
);
}
pub(crate) fn create_key<Run: CubeRuntime>(
input: &CubeTensor<Run>,
output: &CubeTensor<Run>,
axis: &usize,
_config: &ReduceOperationConfig,
dtypes: &ReduceDtypes,
) -> ReduceAutotuneKey {
let elem_input = input.dtype.into();
let elem_output = output.dtype.into();
let elem_acc = dtypes.accumulation.elem_type();
ReduceAutotuneKey::generate(
elem_input,
elem_output,
elem_acc,
input.meta.shape(),
input.meta.strides()[*axis] == 1,
*axis,
)
}
mod reduce_ops {
#![allow(missing_docs)]
use cubek::reduce::ReduceDtypes;
use super::*;
pub(crate) fn reduce_input_gen<Run: CubeRuntime>(
_key: &ReduceAutotuneKey,
input: &CubeTensor<Run>,
output: &CubeTensor<Run>,
dim: &usize,
config: &ReduceOperationConfig,
dtypes: &ReduceDtypes,
) -> (
CubeTensor<Run>,
CubeTensor<Run>,
usize,
ReduceOperationConfig,
ReduceDtypes,
) {
(input.clone(), output.copy(), *dim, *config, *dtypes)
}
}
/// Executes autotune on reduce operations.
#[cfg(feature = "autotune")]
pub fn autotune_sum<R: CubeRuntime>(
client: &ComputeClient<R>,
input: CubeTensor<R>,
) -> CubeTensor<R> {
use sum_ops::*;
static TUNER: LocalTuner<CubeAutotuneKey, CubeTuneId> = local_tuner!("autotune-sum");
let tunables = TUNER.init(|| {
TunableSet::new(create_key_sum::<R>, sum_input_gen::<R>)
.with(Tunable::new("sum_chained", sum_chained::<R>))
.with(Tunable::new("sum_one_shot", sum_one_shot::<R, 1>))
.with(Tunable::new("sum_one_shot", sum_one_shot::<R, 2>))
.with(Tunable::new("sum_one_shot", sum_one_shot::<R, 4>))
.with(Tunable::new("sum_one_shot", sum_one_shot::<R, 8>))
.with(Tunable::new("sum_one_shot", sum_one_shot::<R, 16>))
.with(Tunable::new("sum_one_shot", sum_one_shot::<R, 32>))
.with(Tunable::new("sum_one_shot", sum_one_shot::<R, 64>))
});
TUNER.execute(
&CubeTuneId::new(&input.client, &input.device),
client,
tunables,
input,
)
}
pub(crate) fn create_key_sum<Run: CubeRuntime>(input: &CubeTensor<Run>) -> CubeAutotuneKey {
CubeAutotuneKey::Sum(SumAutotuneKey::generate(input))
}
impl SumAutotuneKey {
#[allow(unused)]
pub(crate) fn generate<Run: CubeRuntime>(input: &CubeTensor<Run>) -> Self {
let dtype = input.dtype;
let length = input.meta.num_elements();
Self::new(dtype, length)
}
}
mod sum_ops {
#![allow(missing_docs)]
use crate::ops::numeric::zeros_client;
use super::*;
pub(crate) fn sum_input_gen<Run: CubeRuntime>(
_key: &CubeAutotuneKey,
input: &CubeTensor<Run>,
) -> CubeTensor<Run> {
input.clone()
}
pub(crate) fn sum_one_shot<Run: CubeRuntime, const C: u32>(
input: CubeTensor<Run>,
) -> Result<CubeTensor<Run>, String> {
let client = input.client.clone();
let device = input.device.clone();
let output = zeros_client(client.clone(), device, [1].into(), input.dtype);
cubek::reduce::shared_sum::<Run>(
&input.client,
input.as_handle_ref(),
output.as_handle_ref(),
C,
input.dtype.into(),
)
.map_err(|e| e.to_string())
.map(|_| output)
}
#[cfg(feature = "autotune")]
pub(crate) fn sum_chained<Run: CubeRuntime>(
input: CubeTensor<Run>,
) -> Result<CubeTensor<Run>, String> {
crate::kernel::reduce::reduce::<Run>(
input,
None,
crate::kernel::reduce::KernelReduceStrategy::Autotune,
cubek::reduce::components::instructions::ReduceOperationConfig::Sum,
)
.map_err(|e| e.to_string())
}
}

View File

@@ -0,0 +1,191 @@
use crate::{
CubeRuntime,
kernel::utils::{address_type, linear_view, linear_view_alias},
ops::{max_line_size, numeric::empty_device_dtype},
tensor::CubeTensor,
};
use burn_backend::TensorMetadata;
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
pub(crate) trait FloatUnaryOpFamily: 'static + Send + Sync {
type Options: LaunchArg;
type Unary<F: Float>: FloatUnaryOp<F, Options = Self::Options>;
}
#[cube]
pub(crate) trait FloatUnaryOp<F: Float>: 'static + Send + Sync {
type Options: LaunchArg;
fn execute(input: Line<F>, options: &Self::Options) -> Line<F>;
}
#[cube(launch_unchecked, address_type = "dynamic")]
pub(crate) fn unary_float<F: Float, O: FloatUnaryOpFamily>(
input: &LinearView<Line<F>>,
output: &mut LinearView<Line<F>, ReadWrite>,
options: &O::Options,
#[define(F)] _dtype: StorageType,
) {
if !output.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
output[ABSOLUTE_POS] = O::Unary::<F>::execute(input[ABSOLUTE_POS], options);
}
pub(crate) fn launch_unary_float<R, O, Args>(tensor: CubeTensor<R>, args: Args) -> CubeTensor<R>
where
// Magic fix for lifetime, the closure is supposed to capture everything required to create the
// argument.
for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>,
R: CubeRuntime,
O: FloatUnaryOpFamily,
{
let line_size = max_line_size(&tensor);
let client = tensor.client.clone();
let num_elems = tensor.meta.num_elements();
let working_units = num_elems / line_size as usize;
let cube_dim = CubeDim::new(&tensor.client, working_units);
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
unsafe {
if tensor.can_mut() && tensor.is_nonoverlapping() {
unary_float::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(tensor),
linear_view(&tensor, line_size),
linear_view_alias(&tensor, line_size, 0),
args(&()),
tensor.dtype.into(),
)
.expect("Kernel to never fail");
tensor
} else {
let output = empty_device_dtype(
tensor.client.clone(),
tensor.device.clone(),
tensor.shape(),
tensor.dtype,
);
unary_float::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(tensor, output),
linear_view(&tensor, line_size),
linear_view(&output, line_size),
args(&()),
tensor.dtype.into(),
)
.expect("Kernel to never fail");
output
}
}
}
/// Use comptime enum to implement all unary operations that don't have any input argument in the
/// kernel definition.
pub(crate) mod unary_basic {
use super::*;
pub(crate) fn launch<R, Args>(tensor: CubeTensor<R>, args: Args) -> CubeTensor<R>
where
R: CubeRuntime,
for<'a> Args: FnOnce(&'a ()) -> BasicFloatUnaryKind,
{
launch_unary_float::<R, BasicFloatUnary, _>(tensor, |input| {
BasicFloatUnaryOptionsLaunch::new(args(input))
})
}
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum BasicFloatUnaryKind {
Exp,
Log,
Log1p,
Sqrt,
Abs,
Sign,
ArcCos,
ArcCosh,
ArcSin,
ArcSinh,
ArcTan,
ArcTanh,
Cos,
Cosh,
Sin,
Sinh,
Tan,
Tanh,
Round,
Floor,
Ceil,
Trunc,
Erf,
Recip,
}
#[derive(CubeLaunch, CubeType)]
struct BasicFloatUnaryOptions {
#[cube(comptime)]
kind: BasicFloatUnaryKind,
}
struct BasicFloatUnary;
#[cube]
impl<F: Float> FloatUnaryOp<F> for BasicFloatUnary {
type Options = BasicFloatUnaryOptions;
fn execute(input: Line<F>, options: &Self::Options) -> Line<F> {
match comptime![options.kind] {
BasicFloatUnaryKind::Exp => Line::exp(input),
BasicFloatUnaryKind::Log => Line::ln(input),
BasicFloatUnaryKind::Log1p => Line::log1p(input),
BasicFloatUnaryKind::Sqrt => Line::sqrt(input),
BasicFloatUnaryKind::Abs => Line::abs(input),
BasicFloatUnaryKind::Sign => {
let zero = Line::new(F::new(0.0));
let one = Line::new(F::new(1.0));
let minus_one = Line::new(F::new(-1.0));
let is_positive = input.greater_than(zero);
let is_negative = input.less_than(zero);
let sign = select_many(is_negative, minus_one, zero);
select_many(is_positive, one, sign)
}
BasicFloatUnaryKind::Cos => Line::cos(input),
BasicFloatUnaryKind::Sin => Line::sin(input),
BasicFloatUnaryKind::Tan => Line::tan(input),
BasicFloatUnaryKind::Cosh => Line::cosh(input),
BasicFloatUnaryKind::Sinh => Line::sinh(input),
BasicFloatUnaryKind::Tanh => Line::tanh(input),
BasicFloatUnaryKind::Round => Line::round(input),
BasicFloatUnaryKind::Floor => Line::floor(input),
BasicFloatUnaryKind::Ceil => Line::ceil(input),
BasicFloatUnaryKind::Trunc => Line::trunc(input),
BasicFloatUnaryKind::Erf => Line::erf(input),
BasicFloatUnaryKind::Recip => Line::recip(input),
BasicFloatUnaryKind::ArcCos => Line::acos(input),
BasicFloatUnaryKind::ArcCosh => Line::acosh(input),
BasicFloatUnaryKind::ArcSin => Line::asin(input),
BasicFloatUnaryKind::ArcSinh => Line::asinh(input),
BasicFloatUnaryKind::ArcTan => Line::atan(input),
BasicFloatUnaryKind::ArcTanh => Line::atanh(input),
}
}
}
impl FloatUnaryOpFamily for BasicFloatUnary {
type Options = BasicFloatUnaryOptions;
type Unary<F: Float> = Self;
}
}

View File

@@ -0,0 +1,142 @@
use crate::{
CubeRuntime,
kernel::utils::{address_type, linear_view, linear_view_alias},
ops::{max_line_size, numeric::empty_device_dtype},
tensor::CubeTensor,
};
use burn_backend::TensorMetadata;
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
pub(crate) trait IntUnaryOpFamily: 'static + Send + Sync {
type Options: LaunchArg;
type Unary<I: Int>: IntUnaryOp<I, Options = Self::Options>;
}
#[cube]
pub(crate) trait IntUnaryOp<I: CubePrimitive>: 'static + Send + Sync {
type Options: LaunchArg;
fn execute(input: Line<I>, options: &Self::Options) -> Line<I>;
}
#[cube(launch_unchecked, address_type = "dynamic")]
pub(crate) fn unary_int<I: Int, O: IntUnaryOpFamily>(
input: &LinearView<Line<I>>,
output: &mut LinearView<Line<I>, ReadWrite>,
options: &O::Options,
#[define(I)] _dtype: StorageType,
) {
if !output.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
output[ABSOLUTE_POS] = O::Unary::<I>::execute(input[ABSOLUTE_POS], options);
}
pub(crate) fn launch_unary_int<R, O, Args>(tensor: CubeTensor<R>, args: Args) -> CubeTensor<R>
where
for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>,
R: CubeRuntime,
O: IntUnaryOpFamily,
{
let line_size = max_line_size(&tensor);
let client = tensor.client.clone();
let num_elems = tensor.meta.num_elements();
let working_units = num_elems / line_size as usize;
let cube_dim = CubeDim::new(&tensor.client, working_units);
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
unsafe {
if tensor.can_mut() && tensor.is_nonoverlapping() {
unary_int::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(tensor),
linear_view(&tensor, line_size),
linear_view_alias(&tensor, line_size, 0),
args(&()),
tensor.dtype.into(),
)
.expect("Kernel to never fail");
tensor
} else {
let output = empty_device_dtype(
tensor.client.clone(),
tensor.device.clone(),
tensor.shape(),
tensor.dtype,
);
unary_int::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(tensor, output),
linear_view(&tensor, line_size),
linear_view(&output, line_size),
args(&()),
tensor.dtype.into(),
)
.expect("Kernel to never fail");
output
}
}
}
pub(crate) mod unary_basic_int {
use super::*;
pub(crate) fn launch<R, Args>(tensor: CubeTensor<R>, args: Args) -> CubeTensor<R>
where
R: CubeRuntime,
for<'a> Args: FnOnce(&'a ()) -> BasicIntUnaryKind,
{
launch_unary_int::<R, BasicIntUnary, _>(tensor, |input| {
BasicIntUnaryOptionsLaunch::new(args(input))
})
}
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum BasicIntUnaryKind {
BitwiseNot,
Sign,
}
#[derive(CubeLaunch, CubeType)]
struct BasicIntUnaryOptions {
#[cube(comptime)]
kind: BasicIntUnaryKind,
}
struct BasicIntUnary;
#[cube]
impl<I: Int> IntUnaryOp<I> for BasicIntUnary {
type Options = BasicIntUnaryOptions;
fn execute(input: Line<I>, options: &Self::Options) -> Line<I> {
match comptime![options.kind] {
BasicIntUnaryKind::BitwiseNot => !input,
BasicIntUnaryKind::Sign => {
let zero = Line::new(I::new(0));
let one = Line::new(I::new(1));
let minus_one = Line::new(I::new(-1));
let is_positive = input.greater_than(zero);
let is_negative = input.less_than(zero);
let sign = select_many(is_negative, minus_one, zero);
select_many(is_positive, one, sign)
}
}
}
}
impl IntUnaryOpFamily for BasicIntUnary {
type Options = BasicIntUnaryOptions;
type Unary<I: Int> = Self;
}
}

View File

@@ -0,0 +1,90 @@
use crate::{
CubeRuntime,
kernel::utils::{address_type, linear_view, linear_view_alias},
ops::{max_line_size, numeric::empty_device_dtype},
tensor::CubeTensor,
};
use burn_backend::TensorMetadata;
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
pub(crate) trait NumericUnaryOpFamily: 'static + Send + Sync {
type Options: LaunchArg;
type Unary<N: Numeric>: NumericUnaryOp<N, Options = Self::Options>;
}
#[cube]
pub(crate) trait NumericUnaryOp<N: CubePrimitive>: 'static + Send + Sync {
type Options: LaunchArg;
fn execute(input: Line<N>, options: &Self::Options) -> Line<N>;
}
#[cube(launch_unchecked, address_type = "dynamic")]
pub(crate) fn unary_numeric<N: Numeric, O: NumericUnaryOpFamily>(
input: &LinearView<Line<N>>,
output: &mut LinearView<Line<N>, ReadWrite>,
options: &O::Options,
#[define(N)] _dtype: StorageType,
) {
if !output.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
output[ABSOLUTE_POS] = O::Unary::<N>::execute(input[ABSOLUTE_POS], options);
}
pub(crate) fn launch_unary_numeric<R, O, Args>(tensor: CubeTensor<R>, args: Args) -> CubeTensor<R>
where
// Magic fix for lifetime, the closure is supposed to capture everything required to create the
// argument.
for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>,
R: CubeRuntime,
O: NumericUnaryOpFamily,
{
let line_size = max_line_size(&tensor);
let client = tensor.client.clone();
let num_elems = tensor.meta.num_elements();
let working_units = num_elems / line_size as usize;
let cube_dim = CubeDim::new(&tensor.client, working_units);
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
unsafe {
if tensor.can_mut() && tensor.is_nonoverlapping() {
unary_numeric::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(tensor),
linear_view(&tensor, line_size),
linear_view_alias(&tensor, line_size, 0),
args(&()),
tensor.dtype.into(),
)
.expect("Kernel to never fail");
tensor
} else {
let output = empty_device_dtype(
tensor.client.clone(),
tensor.device.clone(),
tensor.shape(),
tensor.dtype,
);
unary_numeric::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(tensor, output),
linear_view(&tensor, line_size),
linear_view(&output, line_size),
args(&()),
tensor.dtype.into(),
)
.expect("Kernel to never fail");
output
}
}
}

View File

@@ -0,0 +1,198 @@
use burn_backend::Shape;
use cubecl::{
ir::LineSize,
prelude::*,
std::{
FastDivmod, FastDivmodArgs, FastDivmodInt,
tensor::layout::linear::{LinearLayoutArgs, LinearViewLaunch},
},
};
use cubecl::{prelude::SequenceArg, std::tensor::layout::linear::LinearLayout};
use crate::{CubeRuntime, tensor::CubeTensor};
pub fn shape_divmod<'a, R: CubeRuntime>(
tensor: &CubeTensor<R>,
) -> SequenceArg<'a, R, FastDivmod<usize>> {
let mut arg = SequenceArg::new();
for dim in tensor.meta.shape().iter() {
arg.push(FastDivmodArgs::<usize>::new(&tensor.client, *dim));
}
arg
}
pub fn linear_layout<'a, R: CubeRuntime>(
tensor: &'a CubeTensor<R>,
line_size: LineSize,
) -> LinearLayoutArgs<'a, R> {
LinearLayoutArgs::from_shape_strides(
&tensor.client,
tensor.meta.shape(),
tensor.meta.strides(),
line_size,
)
}
pub fn linear_layout_ref<'a, R: CubeRuntime>(
tensor: &'a CubeTensor<R>,
reference: &'a CubeTensor<R>,
line_size: LineSize,
) -> LinearLayoutArgs<'a, R> {
LinearLayoutArgs::from_shape_strides_with_reference(
&tensor.client,
tensor.meta.shape(),
reference.meta.shape(),
tensor.meta.strides(),
line_size,
)
}
pub fn linear_view<'a, R: CubeRuntime>(
tensor: &'a CubeTensor<R>,
line_size: LineSize,
) -> LinearViewLaunch<'a, R> {
let len = tensor.meta.num_elements();
let layout = linear_layout(tensor, line_size);
let buffer = unsafe {
ArrayArg::from_raw_parts_and_size(&tensor.handle, len, line_size, tensor.elem_size())
};
LinearViewLaunch::new::<LinearLayout>(buffer, layout)
}
pub fn linear_view_ref<'a, R: CubeRuntime>(
tensor: &'a CubeTensor<R>,
reference: &'a CubeTensor<R>,
line_size: LineSize,
) -> LinearViewLaunch<'a, R> {
let len = tensor.meta.num_elements();
let layout = linear_layout_ref(tensor, reference, line_size);
let buffer = unsafe {
ArrayArg::from_raw_parts_and_size(&tensor.handle, len, line_size, tensor.elem_size())
};
LinearViewLaunch::new::<LinearLayout>(buffer, layout)
}
pub fn linear_view_alias<'a, R: CubeRuntime>(
tensor: &'a CubeTensor<R>,
line_size: LineSize,
pos: usize,
) -> LinearViewLaunch<'a, R> {
let layout = linear_layout(tensor, line_size);
let buffer = ArrayArg::Alias { input_pos: pos };
LinearViewLaunch::new::<LinearLayout>(buffer, layout)
}
pub fn split_dim<R: CubeRuntime>(
mut tensor: CubeTensor<R>,
dim: usize,
shape: &[usize],
) -> CubeTensor<R> {
let mut stride = tensor.meta.strides()[dim];
tensor.meta.remove(dim);
for size in shape.iter().rev() {
tensor.meta.insert(dim, *size, stride);
stride *= size;
}
tensor
}
pub fn broadcast_shape<R: CubeRuntime>(tensors: &[&CubeTensor<R>]) -> Shape {
let rank = tensors[0].meta.num_dims();
debug_assert!(
tensors.iter().all(|it| it.meta.num_dims() == rank),
"Broadcast tensors must have the same rank"
);
let dims = (0..rank).map(|dim| {
let max = tensors.iter().map(|it| it.meta.shape()[dim]).max();
let max = max.unwrap_or(1);
debug_assert!(
tensors
.iter()
.all(|it| it.meta.shape()[dim] == max || it.meta.shape()[dim] == 1),
"Broadcast dims must be size 1"
);
max
});
Shape::from(dims)
}
pub fn broadcast_strides<'a, R: CubeRuntime>(
reference: &CubeTensor<R>,
tensor: &'a CubeTensor<R>,
) -> SequenceArg<'a, R, usize> {
if reference.meta.shape() != tensor.meta.shape() {
tensor
.meta
.strides()
.iter()
.zip(
tensor
.meta
.shape()
.iter()
.zip(reference.meta.shape().iter()),
)
.map(|(stride, (shape, ref_shape))| if *shape == *ref_shape { *stride } else { 0 })
.map(ScalarArg::new)
.collect()
} else {
tensor
.meta
.strides()
.iter()
.copied()
.map(ScalarArg::new)
.collect()
}
}
#[cube]
pub(crate) fn decompose_linear<I: FastDivmodInt>(
pos: I,
shape: &Sequence<FastDivmod<I>>,
) -> (I, Sequence<I>) {
let rank = comptime![shape.len()];
let mut offs = pos;
let mut out = Sequence::new();
#[unroll]
for i in 0..rank {
let dim = comptime![rank - i - 1];
let (rem, offs_local) = shape.index(dim).div_mod(offs);
out.push(offs_local);
offs = rem;
}
(offs, out.rev())
}
pub(crate) trait RequiredAddrType {
fn required_address_type(&self) -> AddressType;
}
impl<R: CubeRuntime> RequiredAddrType for CubeTensor<R> {
fn required_address_type(&self) -> AddressType {
self.required_address_type()
}
}
impl<R: CubeRuntime> RequiredAddrType for Option<CubeTensor<R>> {
fn required_address_type(&self) -> AddressType {
self.as_ref()
.map(|it| it.required_address_type())
.unwrap_or_default()
}
}
macro_rules! address_type {
($($tensor: tt),*) => {
[$($crate::kernel::utils::RequiredAddrType::required_address_type(&$tensor)),*]
.into_iter()
.max()
.unwrap_or_default()
};
}
pub(crate) use address_type;

View File

@@ -0,0 +1,50 @@
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
//! Burn JIT Backend
#[macro_use]
extern crate derive_new;
extern crate alloc;
/// Utilities for implementing JIT kernels
pub mod ops;
/// Kernel module
pub mod kernel;
/// Tensor module.
pub mod tensor;
/// Elements for JIT backend
pub mod element;
use cubecl::{CubeTask, Runtime};
pub use element::{BoolElement, CubeElement, FloatElement, IntElement};
mod backend;
pub use backend::*;
// Re-export cubecl.
pub use cubecl;
mod tune_key;
pub use tune_key::CubeAutotuneKey;
#[cfg(any(feature = "fusion", test))]
/// Module for interacting with fusion
pub mod fusion;
#[cfg(feature = "template")]
/// Module for compiling custom non-jit kernels
pub mod template;
/// Just-in-Time runtime extending the [cube runtime](Runtime).
pub trait CubeRuntime: Runtime<Device = Self::CubeDevice, Server = Self::CubeServer> {
/// The device that should also implement [burn_backend::backend::DeviceOps].
type CubeDevice: burn_backend::DeviceOps;
/// The cube server with the [CubeAutotuneKey].
type CubeServer: cubecl::server::ComputeServer<Kernel = Box<dyn CubeTask<Self::Compiler>>>;
}
pub use cubecl::CubeTuneId;

View File

@@ -0,0 +1,11 @@
use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement};
use burn_backend::ops::ActivationOps;
impl<R, F, I, BT> ActivationOps<Self> for CubeBackend<R, F, I, BT>
where
R: CubeRuntime,
F: FloatElement,
I: IntElement,
BT: BoolElement,
{
}

View File

@@ -0,0 +1,432 @@
use crate::{CubeRuntime, kernel, ops::numeric::empty_device_dtype, tensor::CubeTensor};
use burn_backend::{
DType, ExecutionError, QTensorPrimitive, Shape, TensorData,
quantization::{QuantLevel, QuantStore, params_shape},
};
use burn_backend::{TensorMetadata, ops::unfold::calculate_unfold_shape};
use burn_std::{
Metadata, strides,
tensor::{ReshapeAction, contiguous_strides, reshape_action},
};
use cubecl::{ir::LineSize, server::CopyDescriptor};
use cubecl::{quant::scheme::BlockSize, tensor_line_size_parallel};
pub(crate) fn from_data<R: CubeRuntime>(data: TensorData, device: &R::Device) -> CubeTensor<R> {
let client = R::client(device);
let alloc = client.create_tensor(data.bytes, &data.shape, data.dtype.size());
let shape: Shape = (&data.shape).into();
CubeTensor::new(
client,
alloc.handle,
Metadata::new(shape, alloc.strides),
device.clone(),
data.dtype,
)
}
pub(crate) async fn into_data<R: CubeRuntime>(
tensor: CubeTensor<R>,
) -> Result<TensorData, ExecutionError> {
let tensor = kernel::into_contiguous_aligned(tensor);
let elem_size = tensor.elem_size();
let shape = tensor.meta.shape();
let strides = tensor.meta.strides();
let binding = CopyDescriptor::new(tensor.handle.binding(), shape, strides, elem_size);
let bytes = tensor
.client
.read_one_tensor_async(binding)
.await
.map_err(|err| ExecutionError::WithContext {
reason: format!("{err}"),
})?;
Ok(TensorData::from_bytes(
bytes,
tensor.meta.shape,
tensor.dtype,
))
}
/// Read data from a `CubeTensor` synchronously
#[allow(unused, reason = "useful for debugging kernels")]
pub fn into_data_sync<R: CubeRuntime>(tensor: CubeTensor<R>) -> TensorData {
burn_std::future::block_on(into_data(tensor)).unwrap()
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(tensor, device))
)]
pub(crate) fn to_device<R: CubeRuntime>(
tensor: CubeTensor<R>,
device: &R::Device,
) -> CubeTensor<R> {
if &tensor.device == device {
return tensor;
}
let tensor = kernel::into_contiguous_aligned(tensor);
let client = R::client(device);
tensor.to_client(client, device.clone())
}
pub(crate) fn empty<R: CubeRuntime>(
shape: Shape,
device: &R::Device,
dtype: DType,
) -> CubeTensor<R> {
let client = R::client(device);
let alloc = client.empty_tensor(&shape, dtype.size());
CubeTensor::new(
client,
alloc.handle,
Metadata::new(shape, alloc.strides),
device.clone(),
dtype,
)
}
pub(crate) fn swap_dims<R: CubeRuntime>(
mut tensor: CubeTensor<R>,
dim1: usize,
dim2: usize,
) -> CubeTensor<R> {
tensor.meta.swap(dim1, dim2);
if let DType::QFloat(scheme) = tensor.dtype
&& let QuantLevel::Block(block_size) = scheme.level
{
let rank = tensor.rank();
let qparams = tensor.qparams.as_mut().unwrap();
let mut block_size = block_size.to_dim_vec(rank);
block_size.swap(dim1, dim2);
// Truncate unit dims from the start
let block_size = BlockSize::new_trim(block_size);
if block_size.len() > BlockSize::MAX_DIMS {
panic!("Swapped block size would exceed max dims");
}
qparams.scales.metadata.swap(dim1, dim2);
tensor.dtype = DType::QFloat(scheme.with_level(QuantLevel::Block(block_size)))
}
if let DType::QFloat(scheme) = &mut tensor.dtype
&& let QuantStore::PackedU32(packed_dim) | QuantStore::PackedNative(packed_dim) =
&mut scheme.store
{
let rank = tensor.meta.num_dims();
if *packed_dim == rank - dim1 - 1 {
*packed_dim = rank - dim2 - 1;
} else if *packed_dim == rank - dim2 - 1 {
*packed_dim = rank - dim1 - 1;
}
}
tensor
}
/// Permute a tensor's dimensions
pub fn permute<R: CubeRuntime>(mut tensor: CubeTensor<R>, axes: &[usize]) -> CubeTensor<R> {
tensor.meta.permute(axes).unwrap();
if let DType::QFloat(scheme) = tensor.dtype
&& let QuantLevel::Block(block_size) = scheme.level
{
let rank = tensor.rank();
let qparams = tensor.qparams.as_mut().unwrap();
let mut block_size = block_size.to_dim_vec(rank);
block_size = axes.iter().map(|i| block_size[*i]).collect();
// Truncate unit dims from the start
let block_size = block_size
.into_iter()
.skip_while(|it| *it == 1)
.collect::<Vec<_>>();
if block_size.len() > BlockSize::MAX_DIMS {
panic!("Swapped block size would exceed max dims");
}
qparams.scales.metadata.permute(axes).unwrap();
tensor.dtype = DType::QFloat(scheme.with_level(QuantLevel::block(&block_size)))
}
if let DType::QFloat(scheme) = &mut tensor.dtype
&& let QuantStore::PackedU32(packed_dim) = &mut scheme.store
{
let rank = tensor.meta.num_dims();
let new_pos = axes
.iter()
.position(|axis| *axis == rank - *packed_dim - 1)
.unwrap_or(0);
*packed_dim = rank - new_pos - 1;
}
tensor
}
/// Permute a tensor's dimensions from NCHW to NHWC, or the N-dimensional equivalent
pub fn permute_nchw_to_nhwc<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
let rank = tensor.meta.num_dims();
let c_dim = 1;
let mut dims = vec![0];
dims.extend(2..rank);
dims.push(c_dim);
permute(tensor, &dims)
}
/// Permute a shape's dimensions from NCHW to NHWC, or the N-dimensional equivalent
pub fn permute_nchw_to_nhwc_shape(shape: Shape) -> Shape {
let rank = shape.num_dims();
let c_dim = 1;
let mut dims = vec![0];
dims.extend(2..rank);
dims.push(c_dim);
shape.permuted(&dims).expect("Shape permute should succeed")
}
/// Permute a tensor's dimensions from NHWC to NCHW, or the N-dimensional equivalent
pub fn permute_nhwc_to_nchw<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
let rank = tensor.meta.num_dims();
let c_dim = rank - 1;
let mut dims = vec![0];
dims.push(c_dim);
dims.extend(1..c_dim);
permute(tensor, &dims)
}
/// Permute a shape's dimensions from NHWC to NCHW, or the N-dimensional equivalent
pub fn permute_nhwc_to_nchw_shape(shape: Shape) -> Shape {
let rank = shape.num_dims();
let c_dim = rank - 1;
let mut dims = vec![0];
dims.push(c_dim);
dims.extend(1..c_dim);
shape.permuted(&dims).expect("Shape permute should succeed")
}
pub(crate) fn expand<R: CubeRuntime>(tensor: CubeTensor<R>, target_shape: Shape) -> CubeTensor<R> {
let ndims_in = tensor.meta.shape().num_dims();
let ndims_out = target_shape.num_dims();
// Initialize new strides with zeros
let mut new_strides = strides![0usize; ndims_out];
// Calculate the difference in dimensions
let dim_diff = ndims_out.saturating_sub(ndims_in);
// Compare dimensions from the end, setting strides for matching dimensions or broadcasted ones
let mut tensor_dim_iter = tensor.meta.shape().iter().rev();
for i in (0..ndims_out).rev() {
if i >= dim_diff {
if let Some(&tensor_dim) = tensor_dim_iter.next() {
if tensor_dim == target_shape[i] || tensor_dim == 1 {
// Copy stride for non-broadcast dimensions or set to 0 for broadcast ones
new_strides[i] = if tensor_dim == target_shape[i] {
tensor.meta.strides()[i - dim_diff]
} else {
0
};
} else {
// Error handling: Dimension mismatch for broadcasting
panic!(
"Dimension mismatch: cannot broadcast dimension {tensor_dim} of tensor to target shape"
);
}
} else {
// If the input tensor has fewer dimensions, treat missing dimensions as 1
// and set stride to 0 (broadcasting)
new_strides[i] = 0;
}
} else {
// For extra dimensions in the target shape, set stride to 0 (broadcasting)
new_strides[i] = 0;
}
}
// Extra check to ensure block scales must be properly handled once they're added
if tensor.qparams.is_some() {
match tensor.scheme().level {
QuantLevel::Tensor => {}
QuantLevel::Block(_) => todo!(),
}
}
CubeTensor {
client: tensor.client,
device: tensor.device,
meta: Box::new(Metadata::new(target_shape, new_strides)),
handle: tensor.handle,
dtype: tensor.dtype,
qparams: tensor.qparams,
}
}
/// Reshape a jit tensor to a new shape
pub fn reshape<R: CubeRuntime>(mut tensor: CubeTensor<R>, shape: Shape) -> CubeTensor<R> {
let analysis = reshape_action(tensor.meta.shape(), tensor.meta.strides(), &shape);
match analysis {
ReshapeAction::UpdateStrides { strides } => {
*tensor.meta = Metadata::new(shape, strides);
return tensor;
}
ReshapeAction::NoChange => return tensor,
ReshapeAction::Recompute => (),
}
let out = empty_device_dtype(
tensor.client.clone(),
tensor.device.clone(),
shape,
tensor.dtype,
);
cubecl::std::tensor::copy_into(
&tensor.client,
&tensor.as_handle_ref(),
&out.as_handle_ref(),
tensor.dtype.into(),
)
.expect("Kernel should not fail");
out
}
/// Reshape a jit tensor to a new shape
pub fn q_reshape<R: CubeRuntime>(mut tensor: CubeTensor<R>, shape: Shape) -> CubeTensor<R> {
let scheme = *tensor.scheme();
let shape_values = {
let rank = shape.num_dims();
let mut shape = shape.clone();
shape[rank - 1] = shape[rank - 1].div_ceil(scheme.num_quants());
shape
};
let shape_scales = params_shape(&shape, scheme.level);
let (values, scales) = tensor.quantized_handles().unwrap();
let analysis_values = reshape_action(values.meta.shape(), values.meta.strides(), &shape_values);
let analysis_scales = reshape_action(scales.meta.shape(), scales.meta.strides(), &shape_scales);
match (analysis_values, analysis_scales) {
(
ReshapeAction::UpdateStrides { strides },
ReshapeAction::UpdateStrides {
strides: scales_strides,
},
) => {
let qparams = tensor.qparams.as_mut().unwrap();
*tensor.meta = Metadata::new(shape, strides);
qparams.scales.metadata = Metadata::new(shape_scales, scales_strides);
}
(ReshapeAction::UpdateStrides { strides }, ReshapeAction::NoChange) => {
*tensor.meta = Metadata::new(shape, strides);
}
(
ReshapeAction::NoChange,
ReshapeAction::UpdateStrides {
strides: scales_strides,
},
) => {
let qparams = tensor.qparams.as_mut().unwrap();
qparams.scales.metadata = Metadata::new(shape_scales, scales_strides);
}
(ReshapeAction::NoChange, ReshapeAction::NoChange) => {}
_ => {
tensor = kernel::into_contiguous(tensor);
*tensor.meta = Metadata::new(shape, contiguous_strides(&shape_values));
let qparams = tensor.qparams.as_mut().unwrap();
let strides = contiguous_strides(&shape_scales);
qparams.scales.metadata = Metadata::new(shape_scales, strides);
}
}
tensor
}
pub(crate) fn max_line_size<R: CubeRuntime>(tensor: &CubeTensor<R>) -> LineSize {
tensor_line_size_parallel(
tensor.client.io_optimized_line_sizes(tensor.dtype.size()),
tensor.meta.shape(),
tensor.meta.strides(),
tensor.meta.num_dims() - 1,
)
}
pub(crate) fn max_line_size_many<R: CubeRuntime>(
tensors: &[&CubeTensor<R>],
axis: usize,
) -> LineSize {
let vec = tensors
.iter()
.map(|tensor| {
tensor_line_size_parallel(
tensor.client.io_optimized_line_sizes(tensor.dtype.size()),
tensor.meta.shape(),
tensor.meta.strides(),
axis,
)
})
.min();
vec.unwrap_or(0)
}
/// Unfold windows along a dimension.
///
/// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
/// where windows are advanced by `step` at each index.
///
/// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
///
/// The new view will have the unfolded dimension replaced by two dimensions;
/// one in the position of the original dimension, with size equal to the number of windows,
/// and one appended to the right-most position, with size equal to `size`.
///
/// # Arguments
///
/// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
/// * `dim` - the dimension to unfold.
/// * `size` - the size of each unfolded window.
/// * `step` - the step between each window.
///
/// # Returns
///
/// A tensor view with the shape ``[pre=..., windows, post=..., size]``.
pub fn unfold<R: CubeRuntime>(
tensor: CubeTensor<R>,
dim: usize,
size: usize,
step: usize,
) -> CubeTensor<R> {
let shape = calculate_unfold_shape(tensor.shape(), dim, size, step);
let d_stride = tensor.meta.strides()[dim];
let mut strides = tensor.meta.strides.clone();
strides[dim] = step * d_stride;
strides.push(d_stride);
CubeTensor {
meta: Box::new(Metadata::new(shape, strides)),
..tensor
}
}

Some files were not shown because too many files have changed in this diff Show More