feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
@@ -0,0 +1,88 @@
|
||||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science"]
|
||||
description = "Generic backend that can be compiled just-in-time to any shader language target"
|
||||
documentation = "https://docs.rs/burn-cubecl"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "gpu"]
|
||||
license.workspace = true
|
||||
name = "burn-cubecl"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cubecl"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = [
|
||||
"autotune",
|
||||
"std",
|
||||
"fusion",
|
||||
"cubecl/default",
|
||||
"burn-fusion?/default",
|
||||
"burn-cubecl-fusion?/default",
|
||||
]
|
||||
std = [
|
||||
"cubecl/std",
|
||||
"burn-backend/std",
|
||||
"burn-fusion?/std",
|
||||
"burn-cubecl-fusion?/std",
|
||||
]
|
||||
doc = ["default"]
|
||||
memory-checks = ["burn-fusion?/memory-checks"]
|
||||
tracing = [
|
||||
"dep:tracing",
|
||||
"cubecl/tracing",
|
||||
"burn-std/tracing",
|
||||
"burn-backend/tracing",
|
||||
"burn-fusion?/tracing",
|
||||
"burn-cubecl-fusion?/tracing",
|
||||
]
|
||||
|
||||
autotune = ["burn-cubecl-fusion?/autotune"]
|
||||
autotune-checks = [
|
||||
"autotune",
|
||||
"cubecl/autotune-checks",
|
||||
"burn-cubecl-fusion?/autotune-checks",
|
||||
]
|
||||
|
||||
fusion = ["burn-fusion", "burn-cubecl-fusion"]
|
||||
fusion-experimental = ["fusion"]
|
||||
|
||||
template = []
|
||||
|
||||
[dependencies]
|
||||
burn-cubecl-fusion = { path = "../burn-cubecl-fusion", version = "=0.21.0-pre.2", default-features = false, optional = true }
|
||||
burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", default-features = false, optional = true }
|
||||
burn-ir = { path = "../burn-ir", version = "=0.21.0-pre.2", default-features = false }
|
||||
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false, features = [
|
||||
"cubecl",
|
||||
] }
|
||||
burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false, features = [
|
||||
"cubecl",
|
||||
] }
|
||||
cubecl = { workspace = true, features = ["stdlib"] }
|
||||
cubek = { workspace = true, features = [
|
||||
"attention",
|
||||
"matmul",
|
||||
"convolution",
|
||||
"reduce",
|
||||
"random",
|
||||
"quantization",
|
||||
] }
|
||||
tracing = { workspace = true, features = ["attributes"], optional = true }
|
||||
|
||||
derive-new = { workspace = true }
|
||||
log = { workspace = true }
|
||||
|
||||
# Async
|
||||
futures-lite = { workspace = true, features = ["std"] }
|
||||
|
||||
# Template
|
||||
serde = { workspace = true }
|
||||
text_placeholder = { workspace = true, features = ["struct_context"] }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
@@ -0,0 +1 @@
|
||||
../../LICENSE-APACHE
|
||||
1
crates/stable-diffusion-burn/burn-crates/burn-cubecl/LICENSE-MIT
Symbolic link
1
crates/stable-diffusion-burn/burn-crates/burn-cubecl/LICENSE-MIT
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-MIT
|
||||
@@ -0,0 +1,3 @@
|
||||
# Burn CubeCL Backend
|
||||
|
||||
Generic backend that can be compiled just-in-time (JIT) to any shader language target.
|
||||
@@ -0,0 +1,196 @@
|
||||
use crate::{CubeRuntime, FloatElement, IntElement, element::BoolElement, tensor::CubeTensor};
|
||||
use burn_backend::{Backend, DTypeUsage, DTypeUsageSet, DeviceOps, ExecutionError, TensorData};
|
||||
use burn_std::DType;
|
||||
use cubecl::{
|
||||
features::{MmaConfig, TypeUsage},
|
||||
server::ComputeServer,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(not(feature = "fusion"))]
|
||||
use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
|
||||
#[cfg(not(feature = "fusion"))]
|
||||
use burn_ir::{BackendIr, TensorHandle};
|
||||
|
||||
/// Generic tensor backend that can be compiled just-in-time to any shader runtime
|
||||
#[derive(new)]
|
||||
pub struct CubeBackend<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> {
|
||||
_runtime: PhantomData<R>,
|
||||
_float_elem: PhantomData<F>,
|
||||
_int_elem: PhantomData<I>,
|
||||
_bool_elem: PhantomData<BT>,
|
||||
}
|
||||
|
||||
impl<R, F, I, BT> Backend for CubeBackend<R, F, I, BT>
|
||||
where
|
||||
R: CubeRuntime,
|
||||
R::Server: ComputeServer,
|
||||
R::Device: DeviceOps,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
BT: BoolElement,
|
||||
{
|
||||
type Device = R::Device;
|
||||
|
||||
type FloatElem = F;
|
||||
type IntElem = I;
|
||||
type BoolElem = BT;
|
||||
|
||||
type FloatTensorPrimitive = CubeTensor<R>;
|
||||
type IntTensorPrimitive = CubeTensor<R>;
|
||||
type BoolTensorPrimitive = CubeTensor<R>;
|
||||
type QuantizedTensorPrimitive = CubeTensor<R>;
|
||||
|
||||
fn name(device: &Self::Device) -> String {
|
||||
let client = R::client(device);
|
||||
format!("cubecl<{}>", R::name(&client))
|
||||
}
|
||||
|
||||
fn seed(_device: &Self::Device, seed: u64) {
|
||||
cubek::random::seed(seed);
|
||||
}
|
||||
|
||||
fn ad_enabled(_device: &Self::Device) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
|
||||
let client = R::client(device);
|
||||
futures_lite::future::block_on(client.sync()).map_err(|err| ExecutionError::WithContext {
|
||||
reason: format!("{err}"),
|
||||
})
|
||||
}
|
||||
|
||||
fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
|
||||
device: &Self::Device,
|
||||
input: Input,
|
||||
func: Func,
|
||||
) -> Output {
|
||||
let client = R::client(device);
|
||||
client.memory_persistent_allocation(input, func)
|
||||
}
|
||||
|
||||
fn memory_cleanup(device: &Self::Device) {
|
||||
let client = R::client(device);
|
||||
client.memory_cleanup();
|
||||
}
|
||||
|
||||
fn staging<'a, Iter>(data: Iter, device: &Self::Device)
|
||||
where
|
||||
Iter: Iterator<Item = &'a mut TensorData>,
|
||||
{
|
||||
let client = R::client(device);
|
||||
client.staging(data.map(|td| &mut td.bytes), false);
|
||||
}
|
||||
|
||||
fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
|
||||
let client = R::client(device);
|
||||
|
||||
let type_usage = client.properties().type_usage(dtype.into());
|
||||
// Same as `TypeUsage::all_scalar()`, but we make the usage explicit here
|
||||
type_usage.is_superset(
|
||||
TypeUsage::Buffer
|
||||
| TypeUsage::Conversion
|
||||
| TypeUsage::Arithmetic
|
||||
| TypeUsage::DotProduct,
|
||||
)
|
||||
}
|
||||
|
||||
fn dtype_usage(device: &Self::Device, dtype: DType) -> DTypeUsageSet {
|
||||
let client = R::client(device);
|
||||
|
||||
let props = client.properties();
|
||||
let storage = dtype.into();
|
||||
let usage = props.type_usage(storage);
|
||||
|
||||
let mut out = DTypeUsageSet::new();
|
||||
|
||||
if usage.is_superset(TypeUsage::Buffer | TypeUsage::Conversion) {
|
||||
out |= DTypeUsage::Storage;
|
||||
}
|
||||
|
||||
if usage.contains(TypeUsage::Arithmetic) {
|
||||
out |= DTypeUsage::Arithmetic;
|
||||
}
|
||||
|
||||
let has_mma = |cfg: &MmaConfig| {
|
||||
cfg.a_type == storage || cfg.b_type == storage || cfg.cd_type == storage
|
||||
};
|
||||
if props.features.cmma.iter().any(has_mma) || props.features.mma.iter().any(has_mma) {
|
||||
out |= DTypeUsage::Accelerated;
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> core::fmt::Debug
|
||||
for CubeBackend<R, F, I, BT>
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str("CubeCLBackend")
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Clone
|
||||
for CubeBackend<R, F, I, BT>
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Default
|
||||
for CubeBackend<R, F, I, BT>
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: cubecl::Runtime> CubeRuntime for R
|
||||
where
|
||||
R::Device: DeviceOps,
|
||||
{
|
||||
type CubeDevice = R::Device;
|
||||
type CubeServer = R::Server;
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "fusion"))]
|
||||
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
|
||||
for CubeBackend<R, F, I, BT>
|
||||
{
|
||||
type Handle = CubeTensor<R>;
|
||||
|
||||
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
|
||||
handle.handle
|
||||
}
|
||||
|
||||
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
|
||||
handle.handle
|
||||
}
|
||||
|
||||
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
|
||||
handle.handle
|
||||
}
|
||||
|
||||
fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
|
||||
handle.handle
|
||||
}
|
||||
|
||||
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
|
||||
tensor
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
use burn_backend::{Element, bf16, f16};
|
||||
use cubecl::{
|
||||
CubeElement as CubeElem, flex32,
|
||||
prelude::{Float, Int, Numeric},
|
||||
};
|
||||
use cubek::{
|
||||
matmul::definition::{MatmulPrecision, MatrixPrecision},
|
||||
reduce::ReducePrecision,
|
||||
};
|
||||
|
||||
/// The base element trait for the jit backend.
|
||||
pub trait CubeElement: Element + CubeElem + PartialEq + Numeric {}
|
||||
|
||||
/// Element that can be used for matrix multiplication. Includes ints and floats.
|
||||
pub trait MatmulElement:
|
||||
CubeElement + MatmulPrecision<Acc: MatrixPrecision<Global: CubeElement>>
|
||||
{
|
||||
}
|
||||
|
||||
/// The float element type for the jit backend.
|
||||
pub trait FloatElement: MatmulElement + Float {}
|
||||
|
||||
/// The int element type for the jit backend.
|
||||
pub trait IntElement:
|
||||
MatmulElement + Int + ReducePrecision<EI: CubeElement, EA: CubeElement>
|
||||
{
|
||||
}
|
||||
|
||||
/// The element type for booleans for the jit backend.
|
||||
pub trait BoolElement: CubeElement + Int {
|
||||
/// The true value for the boolean element.
|
||||
fn true_val() -> Self {
|
||||
Self::from_int(1)
|
||||
}
|
||||
|
||||
/// The false value for the boolean element.
|
||||
fn false_val() -> Self {
|
||||
Self::from_int(0)
|
||||
}
|
||||
|
||||
/// New bool element from Rust bool.
|
||||
fn new_bool(val: bool) -> Self {
|
||||
match val {
|
||||
true => Self::true_val(),
|
||||
false => Self::false_val(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CubeElement for u64 {}
|
||||
impl CubeElement for u32 {}
|
||||
impl CubeElement for u16 {}
|
||||
impl CubeElement for u8 {}
|
||||
impl CubeElement for i64 {}
|
||||
impl CubeElement for i32 {}
|
||||
impl CubeElement for i16 {}
|
||||
impl CubeElement for i8 {}
|
||||
impl CubeElement for f64 {}
|
||||
impl CubeElement for f32 {}
|
||||
impl CubeElement for flex32 {}
|
||||
impl CubeElement for f16 {}
|
||||
impl CubeElement for bf16 {}
|
||||
|
||||
impl FloatElement for f64 {}
|
||||
impl FloatElement for f32 {}
|
||||
impl FloatElement for flex32 {}
|
||||
impl FloatElement for bf16 {}
|
||||
impl FloatElement for f16 {}
|
||||
impl IntElement for i64 {}
|
||||
impl IntElement for i32 {}
|
||||
impl IntElement for i16 {}
|
||||
impl IntElement for i8 {}
|
||||
impl IntElement for u64 {}
|
||||
impl IntElement for u32 {}
|
||||
impl IntElement for u16 {}
|
||||
impl IntElement for u8 {}
|
||||
|
||||
impl BoolElement for u8 {}
|
||||
impl BoolElement for u32 {}
|
||||
|
||||
impl MatmulElement for f64 {}
|
||||
impl MatmulElement for f32 {}
|
||||
impl MatmulElement for flex32 {}
|
||||
impl MatmulElement for bf16 {}
|
||||
impl MatmulElement for f16 {}
|
||||
|
||||
impl MatmulElement for i64 {}
|
||||
impl MatmulElement for i32 {}
|
||||
impl MatmulElement for i16 {}
|
||||
impl MatmulElement for i8 {}
|
||||
impl MatmulElement for u64 {}
|
||||
impl MatmulElement for u32 {}
|
||||
impl MatmulElement for u16 {}
|
||||
impl MatmulElement for u8 {}
|
||||
@@ -0,0 +1,205 @@
|
||||
use crate::BoolElement;
|
||||
use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, kernel, tensor::CubeTensor};
|
||||
use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
|
||||
use burn_backend::{DType, Shape};
|
||||
use burn_cubecl_fusion::optim::reduce::ReduceSettings;
|
||||
use burn_cubecl_fusion::optim::reduce_broadcasted::ReduceBroadcastedFuser;
|
||||
use burn_cubecl_fusion::{
|
||||
CubeFusionHandle, FallbackOperation,
|
||||
optim::{
|
||||
CubeOptimization, CubeOptimizationState,
|
||||
elemwise::{ElementWiseFuser, ElemwiseOptimization},
|
||||
matmul::{MatmulFuser, MatmulOptimization},
|
||||
reduce::{ReduceFuser, ReduceOptimization},
|
||||
reduce_broadcasted::ReduceBroadcastedOptimization,
|
||||
},
|
||||
};
|
||||
use burn_fusion::{
|
||||
FusionBackend, FusionRuntime,
|
||||
stream::{Operation, OrderedExecution},
|
||||
};
|
||||
use burn_ir::{BackendIr, TensorHandle};
|
||||
use burn_std::Metadata;
|
||||
use core::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
impl<R, BT> burn_fusion::Optimization<FusionCubeRuntime<R, BT>> for CubeOptimization<R>
|
||||
where
|
||||
R: CubeRuntime,
|
||||
BT: BoolElement,
|
||||
{
|
||||
fn execute(
|
||||
&mut self,
|
||||
context: &mut burn_fusion::stream::Context<
|
||||
'_,
|
||||
<FusionCubeRuntime<R, BT> as FusionRuntime>::FusionHandle,
|
||||
>,
|
||||
execution: &OrderedExecution<FusionCubeRuntime<R, BT>>,
|
||||
) {
|
||||
match self {
|
||||
Self::ElementWise(op) => op.execute::<BT>(context),
|
||||
Self::Matmul(op) => op.execute::<BT>(context, |index| {
|
||||
let operation = execution.operation_within_optimization(index);
|
||||
Box::new(FallbackOperationWrapper::new(operation))
|
||||
}),
|
||||
Self::Reduce(op) => op.execute::<BT>(context, |index| {
|
||||
let operation = execution.operation_within_optimization(index);
|
||||
Box::new(FallbackOperationWrapper::new(operation))
|
||||
}),
|
||||
Self::ReduceBroadcasted(op) => op.execute::<BT>(context, |index| {
|
||||
let operation = execution.operation_within_optimization(index);
|
||||
Box::new(FallbackOperationWrapper::new(operation))
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_state(&self) -> CubeOptimizationState {
|
||||
self.to_opt_state()
|
||||
}
|
||||
|
||||
fn from_state(device: &R::Device, state: CubeOptimizationState) -> Self {
|
||||
match state {
|
||||
CubeOptimizationState::ElementWise(state) => {
|
||||
Self::ElementWise(ElemwiseOptimization::from_state(device, state))
|
||||
}
|
||||
CubeOptimizationState::Matmul(state) => {
|
||||
Self::Matmul(MatmulOptimization::from_state(device, state))
|
||||
}
|
||||
CubeOptimizationState::Reduce(state) => {
|
||||
Self::Reduce(ReduceOptimization::from_state(device, state))
|
||||
}
|
||||
CubeOptimizationState::ReduceBroadcasted(state) => {
|
||||
Self::ReduceBroadcasted(ReduceBroadcastedOptimization::from_state(device, state))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct FallbackOperationWrapper<O: Clone> {
|
||||
operation: O,
|
||||
}
|
||||
|
||||
impl<O: Clone> FallbackOperationWrapper<O> {
|
||||
fn new(op: O) -> Self {
|
||||
Self { operation: op }
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: CubeRuntime, BT: BoolElement> FallbackOperation<R>
|
||||
for FallbackOperationWrapper<Arc<dyn Operation<FusionCubeRuntime<R, BT>>>>
|
||||
{
|
||||
fn run(&self, context: &mut burn_fusion::stream::Context<'_, CubeFusionHandle<R>>) {
|
||||
self.operation.as_ref().execute(context.handles);
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
|
||||
for CubeBackend<R, F, I, BT>
|
||||
{
|
||||
type Handle = CubeFusionHandle<R>;
|
||||
|
||||
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
|
||||
into_tensor(handle.handle, handle.shape)
|
||||
}
|
||||
|
||||
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
|
||||
into_tensor(handle.handle, handle.shape)
|
||||
}
|
||||
|
||||
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
|
||||
into_tensor(handle.handle, handle.shape)
|
||||
}
|
||||
|
||||
fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
|
||||
into_tensor(handle.handle, handle.shape)
|
||||
}
|
||||
|
||||
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
|
||||
tensor.into()
|
||||
}
|
||||
|
||||
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
|
||||
tensor.into()
|
||||
}
|
||||
|
||||
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
|
||||
tensor.into()
|
||||
}
|
||||
|
||||
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
|
||||
tensor.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: CubeRuntime, BT: BoolElement> FusionRuntime for FusionCubeRuntime<R, BT> {
|
||||
type OptimizationState = CubeOptimizationState;
|
||||
type Optimization = CubeOptimization<R>;
|
||||
type FusionHandle = CubeFusionHandle<R>;
|
||||
type FusionDevice = R::CubeDevice;
|
||||
type BoolRepr = BT;
|
||||
|
||||
fn fusers(device: R::Device) -> Vec<Box<dyn burn_fusion::OperationFuser<Self::Optimization>>> {
|
||||
vec![
|
||||
Box::new(ElementWiseFuser::new(
|
||||
device.clone(),
|
||||
BT::as_type_native_unchecked().into(),
|
||||
)),
|
||||
Box::new(MatmulFuser::new(
|
||||
device.clone(),
|
||||
BT::as_type_native_unchecked().into(),
|
||||
)),
|
||||
Box::new(ReduceFuser::new(
|
||||
device.clone(),
|
||||
BT::as_type_native_unchecked().into(),
|
||||
ReduceSettings::Always,
|
||||
)),
|
||||
Box::new(ReduceBroadcastedFuser::new(
|
||||
device.clone(),
|
||||
BT::as_type_native_unchecked().into(),
|
||||
)),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/// Fusion runtime for JIT runtimes.
|
||||
#[derive(Debug)]
|
||||
pub struct FusionCubeRuntime<R: CubeRuntime, BT: BoolElement> {
|
||||
_b: PhantomData<R>,
|
||||
_bool: PhantomData<BT>,
|
||||
}
|
||||
|
||||
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> FusionBackend
|
||||
for CubeBackend<R, F, I, BT>
|
||||
{
|
||||
type FusionRuntime = FusionCubeRuntime<R, BT>;
|
||||
|
||||
type FullPrecisionBackend = CubeBackend<R, f32, i32, BT>;
|
||||
|
||||
fn cast_float(tensor: FloatTensor<Self>, dtype: DType) -> Self::Handle {
|
||||
kernel::cast(tensor, dtype).into()
|
||||
}
|
||||
}
|
||||
|
||||
fn into_tensor<R: CubeRuntime>(handle: CubeFusionHandle<R>, shape: Shape) -> CubeTensor<R> {
|
||||
CubeTensor {
|
||||
client: handle.client,
|
||||
handle: handle.handle,
|
||||
device: handle.device,
|
||||
meta: Box::new(Metadata::new(shape, handle.strides)),
|
||||
dtype: handle.dtype,
|
||||
qparams: handle.qparams,
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: CubeRuntime> From<CubeTensor<R>> for CubeFusionHandle<R> {
|
||||
fn from(value: CubeTensor<R>) -> Self {
|
||||
Self {
|
||||
client: value.client,
|
||||
handle: value.handle,
|
||||
device: value.device,
|
||||
strides: value.meta.strides,
|
||||
dtype: value.dtype,
|
||||
qparams: value.qparams,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
use crate::{
|
||||
CubeBackend, CubeRuntime, kernel::attention::attention_autotune,
|
||||
ops::numeric::empty_device_dtype, tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::{
|
||||
DType, Shape,
|
||||
ops::{AttentionModuleOptions, attention::attention_fallback},
|
||||
};
|
||||
use cubek::attention::definition::{
|
||||
AccumulatorPrecision, AttentionGlobalTypes, AttentionOptions, AttentionSetupError,
|
||||
};
|
||||
use cubek::attention::launch;
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Strategy used to select which attention implementation to run.
|
||||
pub enum AttentionStrategy {
|
||||
/// Flash Attention using accelerated inner matmuls.
|
||||
FlashBlackboxAccelerated,
|
||||
|
||||
/// Flash Attention using unit inner matmuls.
|
||||
FlashUnit,
|
||||
|
||||
/// Fallback implementation using multiple separate kernels.
|
||||
Fallback,
|
||||
|
||||
/// Automatically benchmark and select the best strategy at runtime.
|
||||
#[cfg(feature = "autotune")]
|
||||
Autotune,
|
||||
}
|
||||
|
||||
impl Default for AttentionStrategy {
|
||||
fn default() -> Self {
|
||||
// if autotune is enabled, default to autotune
|
||||
#[cfg(feature = "autotune")]
|
||||
return AttentionStrategy::Autotune;
|
||||
|
||||
// if autotune is disabled, default to fallback to make sure it runs
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
AttentionStrategy::Fallback
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Launch an attention kernel with given strategy
|
||||
pub fn attention<R: CubeRuntime>(
|
||||
query: CubeTensor<R>,
|
||||
key: CubeTensor<R>,
|
||||
value: CubeTensor<R>,
|
||||
mask: Option<CubeTensor<R>>,
|
||||
attn_bias: Option<CubeTensor<R>>,
|
||||
options: AttentionModuleOptions,
|
||||
strategy: &AttentionStrategy,
|
||||
out: Option<CubeTensor<R>>,
|
||||
) -> Result<CubeTensor<R>, AttentionSetupError> {
|
||||
let mut out = out.unwrap_or_else(|| init_attention_output(&query, &value));
|
||||
match strategy {
|
||||
AttentionStrategy::FlashBlackboxAccelerated => flash_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
mask,
|
||||
attn_bias,
|
||||
options,
|
||||
out,
|
||||
launch::Strategy::BlackboxAccelerated(
|
||||
cubek::attention::launch::BlueprintStrategy::Inferred(()),
|
||||
),
|
||||
),
|
||||
AttentionStrategy::FlashUnit => flash_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
mask,
|
||||
attn_bias,
|
||||
options,
|
||||
out,
|
||||
launch::Strategy::Unit(cubek::attention::launch::BlueprintStrategy::Inferred(())),
|
||||
),
|
||||
AttentionStrategy::Fallback => {
|
||||
out = attention_fallback::<CubeBackend<R, f32, i32, u8>>(
|
||||
query, key, value, mask, attn_bias, options,
|
||||
);
|
||||
Ok(out)
|
||||
}
|
||||
#[cfg(feature = "autotune")]
|
||||
AttentionStrategy::Autotune => {
|
||||
attention_autotune(query, key, value, mask, attn_bias, options, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Launch a flash attention kernel
|
||||
pub fn flash_attention<R: CubeRuntime>(
|
||||
query: CubeTensor<R>,
|
||||
key: CubeTensor<R>,
|
||||
value: CubeTensor<R>,
|
||||
mask: Option<CubeTensor<R>>,
|
||||
_attn_bias: Option<CubeTensor<R>>,
|
||||
options: AttentionModuleOptions,
|
||||
out: CubeTensor<R>,
|
||||
strategy: launch::Strategy,
|
||||
) -> Result<CubeTensor<R>, AttentionSetupError> {
|
||||
let client = &query.client;
|
||||
|
||||
let dtypes = AttentionGlobalTypes {
|
||||
query: query.dtype.into(),
|
||||
key: key.dtype.into(),
|
||||
value: value.dtype.into(),
|
||||
mask: mask.as_ref().map(|m| m.dtype).unwrap_or(DType::U8).into(),
|
||||
out: out.dtype.into(),
|
||||
};
|
||||
|
||||
cubek::attention::launch::launch_ref::<R>(
|
||||
strategy,
|
||||
client,
|
||||
&query.as_handle_ref(),
|
||||
&key.as_handle_ref(),
|
||||
&value.as_handle_ref(),
|
||||
&mask.as_ref().map(|mask| mask.as_handle_ref()),
|
||||
&out.as_handle_ref(),
|
||||
&dtypes,
|
||||
AttentionOptions {
|
||||
causal: options.is_causal,
|
||||
accumulator_precision: AccumulatorPrecision::Strict(cubecl::ir::StorageType::Scalar(
|
||||
cubecl::ir::ElemType::Float(cubecl::ir::FloatKind::F32),
|
||||
)),
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
pub(crate) fn init_attention_output<R: CubeRuntime>(
|
||||
query: &CubeTensor<R>,
|
||||
value: &CubeTensor<R>,
|
||||
) -> CubeTensor<R> {
|
||||
let num_batches = query.meta.shape[0];
|
||||
let num_heads = query.meta.shape[1];
|
||||
let seq_q = query.meta.shape[2];
|
||||
let val_dim = value.meta.shape[3];
|
||||
let out_shape = Shape::new([num_batches, num_heads, seq_q, val_dim]);
|
||||
|
||||
empty_device_dtype::<R>(
|
||||
query.client.clone(),
|
||||
query.device.clone(),
|
||||
out_shape,
|
||||
query.dtype,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
mod base;
|
||||
mod tune;
|
||||
|
||||
pub use base::*;
|
||||
pub use tune::*;
|
||||
@@ -0,0 +1,166 @@
|
||||
use crate::{
|
||||
CubeRuntime, CubeTuneId,
|
||||
kernel::attention::{AttentionStrategy, attention},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::ops::AttentionModuleOptions;
|
||||
use cubecl::tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner};
|
||||
use cubek::attention::{definition::AttentionSetupError, launch::AttentionAutotuneKey};
|
||||
|
||||
/// Executes autotune on attention operations
|
||||
pub fn attention_autotune<R: CubeRuntime>(
|
||||
query: CubeTensor<R>,
|
||||
key: CubeTensor<R>,
|
||||
value: CubeTensor<R>,
|
||||
mask: Option<CubeTensor<R>>,
|
||||
attn_bias: Option<CubeTensor<R>>,
|
||||
options: AttentionModuleOptions,
|
||||
out: CubeTensor<R>,
|
||||
) -> Result<CubeTensor<R>, AttentionSetupError> {
|
||||
let client = query.client.clone();
|
||||
|
||||
static TUNER: LocalTuner<AttentionAutotuneKey, CubeTuneId> = local_tuner!();
|
||||
|
||||
let tunables = TUNER.init(|| {
|
||||
const PRIORITY_MAX: i8 = 3;
|
||||
const PRIORITY_MIN: i8 = 0;
|
||||
|
||||
let flash_attention =
|
||||
TuneGroup::<AttentionAutotuneKey>::new("flash_attention", |_key| PRIORITY_MAX);
|
||||
let fallback = TuneGroup::<AttentionAutotuneKey>::new("fallback", |_key| PRIORITY_MIN);
|
||||
|
||||
let mut set = TunableSet::new(create_key::<R>, input_gen::<R>);
|
||||
|
||||
// First entry should always work, since it is considered the fallback.
|
||||
set = set.with(
|
||||
Tunable::new(
|
||||
"fallback",
|
||||
|query, key, value, mask, attn_bias, out, options| {
|
||||
attention::<R>(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
mask,
|
||||
attn_bias,
|
||||
options,
|
||||
&AttentionStrategy::Fallback,
|
||||
Some(out),
|
||||
)
|
||||
.map_err(|err| std::format!("{err:?}"))
|
||||
},
|
||||
)
|
||||
.group(&fallback, |_key| PRIORITY_MAX),
|
||||
);
|
||||
|
||||
set = set.with(
|
||||
Tunable::new(
|
||||
"blackbox_accelerated",
|
||||
|query, key, value, mask, attn_bias, out, options| {
|
||||
attention::<R>(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
mask,
|
||||
attn_bias,
|
||||
options,
|
||||
&AttentionStrategy::FlashBlackboxAccelerated,
|
||||
Some(out),
|
||||
)
|
||||
.map_err(|err| std::format!("{err:?}"))
|
||||
},
|
||||
)
|
||||
.group(&flash_attention, |_key| PRIORITY_MAX),
|
||||
);
|
||||
|
||||
set = set.with(
|
||||
Tunable::new(
|
||||
"unit",
|
||||
|query, key, value, mask, attn_bias, out, options| {
|
||||
attention::<R>(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
mask,
|
||||
attn_bias,
|
||||
options,
|
||||
&AttentionStrategy::FlashUnit,
|
||||
Some(out),
|
||||
)
|
||||
.map_err(|err| std::format!("{err:?}"))
|
||||
},
|
||||
)
|
||||
.group(&flash_attention, |_key| PRIORITY_MIN),
|
||||
);
|
||||
|
||||
set
|
||||
});
|
||||
|
||||
TUNER.execute(
|
||||
&CubeTuneId::new(&client, &query.device),
|
||||
&client,
|
||||
tunables,
|
||||
(query, key, value, mask, attn_bias, out.clone(), options),
|
||||
);
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn create_key<R: CubeRuntime>(
|
||||
query: &CubeTensor<R>,
|
||||
key: &CubeTensor<R>,
|
||||
value: &CubeTensor<R>,
|
||||
mask: &Option<CubeTensor<R>>,
|
||||
_attn_bias: &Option<CubeTensor<R>>,
|
||||
out: &CubeTensor<R>,
|
||||
_options: &AttentionModuleOptions,
|
||||
) -> AttentionAutotuneKey {
|
||||
let total_batches = query.meta.shape[0] * query.meta.shape[1];
|
||||
let seq_q = query.meta.shape[2];
|
||||
let head_dim = query.meta.shape[3];
|
||||
let seq_kv = value.meta.shape[2];
|
||||
let val_dim = value.meta.shape[3];
|
||||
|
||||
AttentionAutotuneKey::generate(
|
||||
query.dtype.into(),
|
||||
key.dtype.into(),
|
||||
value.dtype.into(),
|
||||
out.dtype.into(),
|
||||
total_batches,
|
||||
seq_q,
|
||||
head_dim,
|
||||
seq_kv,
|
||||
val_dim,
|
||||
mask.is_some(),
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn input_gen<R: CubeRuntime>(
|
||||
_key: &AttentionAutotuneKey,
|
||||
query: &CubeTensor<R>,
|
||||
key: &CubeTensor<R>,
|
||||
value: &CubeTensor<R>,
|
||||
mask: &Option<CubeTensor<R>>,
|
||||
attn_bias: &Option<CubeTensor<R>>,
|
||||
out: &CubeTensor<R>,
|
||||
options: &AttentionModuleOptions,
|
||||
) -> (
|
||||
CubeTensor<R>,
|
||||
CubeTensor<R>,
|
||||
CubeTensor<R>,
|
||||
Option<CubeTensor<R>>,
|
||||
Option<CubeTensor<R>>,
|
||||
CubeTensor<R>,
|
||||
AttentionModuleOptions,
|
||||
) {
|
||||
(
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
value.clone(),
|
||||
mask.clone(),
|
||||
attn_bias.clone(),
|
||||
out.copy(),
|
||||
*options,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{
|
||||
address_type, broadcast_shape, linear_view, linear_view_alias, linear_view_ref,
|
||||
},
|
||||
ops::{max_line_size, numeric::empty_device_dtype},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::{TensorMetadata, bf16, f16};
|
||||
use cubecl::{
|
||||
calculate_cube_count_elemwise, intrinsic, prelude::*, std::tensor::layout::linear::LinearView,
|
||||
};
|
||||
|
||||
pub(crate) trait BinaryOpFamily: Send + Sync + 'static {
|
||||
type BinaryOp<C: Numeric>: BinaryOp<C>;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub(crate) trait BinaryOp<C: Numeric>: 'static + Send + Sync {
|
||||
/// Execute a binary operation.
|
||||
fn execute(lhs: Line<C>, rhs: Line<C>) -> Line<C>;
|
||||
}
|
||||
|
||||
pub(crate) struct AddOp;
|
||||
pub(crate) struct SubOp;
|
||||
pub(crate) struct MulOp;
|
||||
pub(crate) struct DivOp;
|
||||
pub(crate) struct RemainderOp;
|
||||
pub(crate) struct AndOp;
|
||||
pub(crate) struct OrOp;
|
||||
pub(crate) struct PowOp;
|
||||
|
||||
impl BinaryOpFamily for AddOp {
|
||||
type BinaryOp<C: Numeric> = Self;
|
||||
}
|
||||
|
||||
impl BinaryOpFamily for SubOp {
|
||||
type BinaryOp<C: Numeric> = Self;
|
||||
}
|
||||
|
||||
impl BinaryOpFamily for MulOp {
|
||||
type BinaryOp<C: Numeric> = Self;
|
||||
}
|
||||
|
||||
impl BinaryOpFamily for DivOp {
|
||||
type BinaryOp<C: Numeric> = Self;
|
||||
}
|
||||
|
||||
impl BinaryOpFamily for RemainderOp {
|
||||
type BinaryOp<C: Numeric> = Self;
|
||||
}
|
||||
|
||||
impl BinaryOpFamily for PowOp {
|
||||
type BinaryOp<C: Numeric> = Self;
|
||||
}
|
||||
|
||||
impl BinaryOpFamily for AndOp {
|
||||
type BinaryOp<C: Numeric> = Self;
|
||||
}
|
||||
|
||||
impl BinaryOpFamily for OrOp {
|
||||
type BinaryOp<C: Numeric> = Self;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> BinaryOp<N> for AddOp {
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
|
||||
lhs + rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> BinaryOp<N> for SubOp {
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
|
||||
lhs - rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> BinaryOp<N> for MulOp {
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
|
||||
lhs * rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> BinaryOp<N> for DivOp {
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
|
||||
lhs / rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> BinaryOp<N> for RemainderOp {
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
|
||||
Line::rem(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> BinaryOp<N> for PowOp {
|
||||
#[allow(unused)]
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
|
||||
intrinsic!(|scope| {
|
||||
let elem = N::as_type(scope).elem_type();
|
||||
|
||||
if let cubecl::ir::ElemType::Float(kind) = elem {
|
||||
match kind {
|
||||
cubecl::ir::FloatKind::F16 => {
|
||||
let lhs = <Line<f16> as Cast>::__expand_cast_from(scope, lhs);
|
||||
let rhs = <Line<f16> as Cast>::__expand_cast_from(scope, rhs);
|
||||
let out = Line::__expand_powf(scope, lhs, rhs);
|
||||
return <Line<N> as Cast>::__expand_cast_from(scope, out);
|
||||
}
|
||||
cubecl::ir::FloatKind::BF16 => {
|
||||
let lhs = <Line<bf16> as Cast>::__expand_cast_from(scope, lhs);
|
||||
let rhs = <Line<bf16> as Cast>::__expand_cast_from(scope, rhs);
|
||||
let out = Line::__expand_powf(scope, lhs, rhs);
|
||||
return <Line<N> as Cast>::__expand_cast_from(scope, out);
|
||||
}
|
||||
cubecl::ir::FloatKind::F64 => {
|
||||
let lhs = <Line<f64> as Cast>::__expand_cast_from(scope, lhs);
|
||||
let rhs = <Line<f64> as Cast>::__expand_cast_from(scope, rhs);
|
||||
let out = Line::__expand_powf(scope, lhs, rhs);
|
||||
return <Line<N> as Cast>::__expand_cast_from(scope, out);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
};
|
||||
|
||||
let lhs = <Line<f32> as Cast>::__expand_cast_from(scope, lhs);
|
||||
let rhs = <Line<f32> as Cast>::__expand_cast_from(scope, rhs);
|
||||
let out = Line::__expand_powf(scope, lhs, rhs);
|
||||
return <Line<N> as Cast>::__expand_cast_from(scope, out);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> BinaryOp<N> for AndOp {
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
|
||||
Line::cast_from(Line::<bool>::cast_from(lhs).and(Line::<bool>::cast_from(rhs)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> BinaryOp<N> for OrOp {
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
|
||||
Line::cast_from(Line::<bool>::cast_from(lhs).or(Line::<bool>::cast_from(rhs)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
pub(crate) fn kernel_scalar_binop<C: Numeric, O: BinaryOpFamily>(
|
||||
input: &LinearView<Line<C>>,
|
||||
scalar: InputScalar,
|
||||
output: &mut LinearView<Line<C>, ReadWrite>,
|
||||
#[define(C)] _dtype: StorageType,
|
||||
) {
|
||||
if !output.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
output[ABSOLUTE_POS] =
|
||||
O::BinaryOp::<C>::execute(input[ABSOLUTE_POS], Line::new(scalar.get::<C>()));
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
pub(crate) fn kernel_binop<C: Numeric, O: BinaryOpFamily>(
|
||||
lhs: &LinearView<Line<C>>,
|
||||
rhs: &LinearView<Line<C>>,
|
||||
out: &mut LinearView<Line<C>, ReadWrite>,
|
||||
#[define(C)] _dtype: StorageType,
|
||||
) {
|
||||
if !out.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
out[ABSOLUTE_POS] = O::BinaryOp::<C>::execute(lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS]);
|
||||
}
|
||||
|
||||
pub(crate) fn launch_binop<R: CubeRuntime, O: BinaryOpFamily>(
|
||||
lhs: CubeTensor<R>,
|
||||
rhs: CubeTensor<R>,
|
||||
) -> CubeTensor<R> {
|
||||
let line_size_lhs = max_line_size(&lhs);
|
||||
let line_size_rhs = max_line_size(&rhs);
|
||||
let line_size = Ord::min(line_size_lhs, line_size_rhs);
|
||||
|
||||
let shape_out = broadcast_shape(&[&lhs, &rhs]);
|
||||
let dtype = lhs.dtype;
|
||||
|
||||
let client = lhs.client.clone();
|
||||
let num_elems = shape_out.num_elements();
|
||||
let working_units = num_elems / line_size as usize;
|
||||
|
||||
let cube_dim = CubeDim::new(&lhs.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&lhs.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
if lhs.can_mut_broadcast(&rhs) {
|
||||
kernel_binop::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(lhs, rhs),
|
||||
linear_view(&lhs, line_size),
|
||||
linear_view_ref(&rhs, &lhs, line_size),
|
||||
linear_view_alias(&lhs, line_size, 0),
|
||||
dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
lhs
|
||||
} else if rhs.can_mut_broadcast(&lhs) {
|
||||
kernel_binop::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(lhs, rhs),
|
||||
linear_view_ref(&lhs, &rhs, line_size),
|
||||
linear_view(&rhs, line_size),
|
||||
linear_view_alias(&rhs, line_size, 1),
|
||||
dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
rhs
|
||||
} else {
|
||||
let output =
|
||||
empty_device_dtype(lhs.client.clone(), lhs.device.clone(), shape_out, dtype);
|
||||
|
||||
kernel_binop::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(lhs, rhs, output),
|
||||
linear_view_ref(&lhs, &output, line_size),
|
||||
linear_view_ref(&rhs, &output, line_size),
|
||||
linear_view(&output, line_size),
|
||||
dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn launch_scalar_binop<R: CubeRuntime, O: BinaryOpFamily>(
|
||||
tensor: CubeTensor<R>,
|
||||
scalar: InputScalar,
|
||||
) -> CubeTensor<R> {
|
||||
// Vectorization is only enabled when the last dimension is contiguous.
|
||||
let line_size = max_line_size(&tensor);
|
||||
let client = tensor.client.clone();
|
||||
let num_elems = tensor.meta.num_elements();
|
||||
let dtype = tensor.dtype;
|
||||
|
||||
let working_units = num_elems / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&tensor.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
if tensor.can_mut() && tensor.is_nonoverlapping() {
|
||||
kernel_scalar_binop::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor),
|
||||
linear_view(&tensor, line_size),
|
||||
scalar,
|
||||
linear_view_alias(&tensor, line_size, 0),
|
||||
dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
tensor
|
||||
} else {
|
||||
let output = empty_device_dtype(
|
||||
tensor.client.clone(),
|
||||
tensor.device.clone(),
|
||||
tensor.shape(),
|
||||
dtype,
|
||||
);
|
||||
|
||||
kernel_scalar_binop::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor, output),
|
||||
linear_view(&tensor, line_size),
|
||||
scalar,
|
||||
linear_view(&output, line_size),
|
||||
dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{
|
||||
address_type, broadcast_shape, linear_view, linear_view_alias, linear_view_ref,
|
||||
},
|
||||
ops::{max_line_size, numeric::empty_device_dtype},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
|
||||
|
||||
pub(crate) trait BinaryOpFloatFamily: Send + Sync + 'static {
|
||||
type BinaryOp<C: Float>: BinaryOpFloat<C>;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub(crate) trait BinaryOpFloat<C: Float>: 'static + Send + Sync {
|
||||
/// Execute a binary operation.
|
||||
fn execute(lhs: Line<C>, rhs: Line<C>) -> Line<C>;
|
||||
}
|
||||
|
||||
pub(crate) struct ArcTan2Op;
|
||||
|
||||
impl BinaryOpFloatFamily for ArcTan2Op {
|
||||
type BinaryOp<C: Float> = Self;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Float> BinaryOpFloat<N> for ArcTan2Op {
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
|
||||
Line::atan2(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
pub(crate) fn kernel_binop<C: Float, O: BinaryOpFloatFamily>(
|
||||
lhs: &LinearView<Line<C>>,
|
||||
rhs: &LinearView<Line<C>>,
|
||||
out: &mut LinearView<Line<C>, ReadWrite>,
|
||||
#[define(C)] _dtype: StorageType,
|
||||
) {
|
||||
if !out.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
out[ABSOLUTE_POS] = O::BinaryOp::<C>::execute(lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS]);
|
||||
}
|
||||
|
||||
pub(crate) fn launch_binop_float<R: CubeRuntime, O: BinaryOpFloatFamily>(
|
||||
lhs: CubeTensor<R>,
|
||||
rhs: CubeTensor<R>,
|
||||
) -> CubeTensor<R> {
|
||||
let line_size_lhs = max_line_size(&lhs);
|
||||
let line_size_rhs = max_line_size(&rhs);
|
||||
let line_size = Ord::min(line_size_lhs, line_size_rhs);
|
||||
|
||||
let shape_out = broadcast_shape(&[&lhs, &rhs]);
|
||||
let dtype = lhs.dtype;
|
||||
|
||||
let client = lhs.client.clone();
|
||||
let num_elems = shape_out.num_elements();
|
||||
let working_units = num_elems / line_size as usize;
|
||||
|
||||
let cube_dim = CubeDim::new(&lhs.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&lhs.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
if lhs.can_mut_broadcast(&rhs) {
|
||||
kernel_binop::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(lhs, rhs),
|
||||
linear_view(&lhs, line_size),
|
||||
linear_view_ref(&rhs, &lhs, line_size),
|
||||
linear_view_alias(&lhs, line_size, 0),
|
||||
dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
lhs
|
||||
} else if rhs.can_mut_broadcast(&lhs) {
|
||||
kernel_binop::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(lhs, rhs),
|
||||
linear_view_ref(&lhs, &rhs, line_size),
|
||||
linear_view(&rhs, line_size),
|
||||
linear_view_alias(&rhs, line_size, 1),
|
||||
dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
rhs
|
||||
} else {
|
||||
let output =
|
||||
empty_device_dtype(lhs.client.clone(), lhs.device.clone(), shape_out, dtype);
|
||||
|
||||
kernel_binop::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(lhs, rhs, output),
|
||||
linear_view_ref(&lhs, &output, line_size),
|
||||
linear_view_ref(&rhs, &output, line_size),
|
||||
linear_view(&output, line_size),
|
||||
dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the four-quadrant inverse tangent of `lhs / rhs`.
|
||||
pub fn atan2<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
|
||||
launch_binop_float::<R, ArcTan2Op>(lhs, rhs)
|
||||
}
|
||||
@@ -0,0 +1,229 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{
|
||||
address_type, broadcast_shape, linear_view, linear_view_alias, linear_view_ref,
|
||||
},
|
||||
ops::{max_line_size, numeric::empty_device_dtype},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::TensorMetadata;
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
|
||||
|
||||
pub(crate) trait BinaryOpIntFamily: Send + Sync + 'static {
|
||||
type BinaryOp<C: Int>: BinaryOpInt<C>;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub(crate) trait BinaryOpInt<C: Int>: 'static + Send + Sync {
|
||||
/// Execute a binary operation.
|
||||
fn execute(lhs: Line<C>, rhs: Line<C>) -> Line<C>;
|
||||
}
|
||||
|
||||
pub(crate) struct BitwiseAndOp;
|
||||
pub(crate) struct BitwiseOrOp;
|
||||
pub(crate) struct BitwiseXorOp;
|
||||
pub(crate) struct BitwiseShrOp;
|
||||
pub(crate) struct BitwiseShlOp;
|
||||
|
||||
impl BinaryOpIntFamily for BitwiseAndOp {
|
||||
type BinaryOp<C: Int> = Self;
|
||||
}
|
||||
|
||||
impl BinaryOpIntFamily for BitwiseOrOp {
|
||||
type BinaryOp<C: Int> = Self;
|
||||
}
|
||||
|
||||
impl BinaryOpIntFamily for BitwiseXorOp {
|
||||
type BinaryOp<C: Int> = Self;
|
||||
}
|
||||
|
||||
impl BinaryOpIntFamily for BitwiseShrOp {
|
||||
type BinaryOp<C: Int> = Self;
|
||||
}
|
||||
|
||||
impl BinaryOpIntFamily for BitwiseShlOp {
|
||||
type BinaryOp<C: Int> = Self;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Int> BinaryOpInt<N> for BitwiseAndOp {
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
|
||||
lhs & rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Int> BinaryOpInt<N> for BitwiseOrOp {
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
|
||||
lhs | rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Int> BinaryOpInt<N> for BitwiseXorOp {
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
|
||||
lhs ^ rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Int> BinaryOpInt<N> for BitwiseShrOp {
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
|
||||
lhs >> rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Int> BinaryOpInt<N> for BitwiseShlOp {
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
|
||||
lhs << rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
pub(crate) fn kernel_scalar_binop_int<C: Int, O: BinaryOpIntFamily>(
|
||||
input: &LinearView<Line<C>>,
|
||||
scalar: InputScalar,
|
||||
output: &mut LinearView<Line<C>, ReadWrite>,
|
||||
#[define(C)] _dtype: StorageType,
|
||||
) {
|
||||
if !output.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
output[ABSOLUTE_POS] =
|
||||
O::BinaryOp::<C>::execute(input[ABSOLUTE_POS], Line::new(scalar.get::<C>()));
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
pub(crate) fn kernel_binop_int<C: Int, O: BinaryOpIntFamily>(
|
||||
lhs: &LinearView<Line<C>>,
|
||||
rhs: &LinearView<Line<C>>,
|
||||
out: &mut LinearView<Line<C>, ReadWrite>,
|
||||
#[define(C)] _dtype: StorageType,
|
||||
) {
|
||||
if !out.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
out[ABSOLUTE_POS] = O::BinaryOp::<C>::execute(lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS]);
|
||||
}
|
||||
|
||||
pub(crate) fn launch_binop_int<R: CubeRuntime, O: BinaryOpIntFamily>(
|
||||
lhs: CubeTensor<R>,
|
||||
rhs: CubeTensor<R>,
|
||||
) -> CubeTensor<R> {
|
||||
let line_size_lhs = max_line_size(&lhs);
|
||||
let line_size_rhs = max_line_size(&rhs);
|
||||
let line_size = Ord::min(line_size_lhs, line_size_rhs);
|
||||
|
||||
let shape_out = broadcast_shape(&[&lhs, &rhs]);
|
||||
|
||||
let client = lhs.client.clone();
|
||||
let num_elems = shape_out.num_elements();
|
||||
|
||||
let working_units = num_elems / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&lhs.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&lhs.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
if lhs.can_mut_broadcast(&rhs) {
|
||||
kernel_binop_int::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(lhs, rhs),
|
||||
linear_view(&lhs, line_size),
|
||||
linear_view_ref(&rhs, &lhs, line_size),
|
||||
linear_view_alias(&lhs, line_size, 0),
|
||||
lhs.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
lhs
|
||||
} else if rhs.can_mut_broadcast(&lhs) {
|
||||
kernel_binop_int::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(lhs, rhs),
|
||||
linear_view_ref(&lhs, &rhs, line_size),
|
||||
linear_view(&rhs, line_size),
|
||||
linear_view_alias(&rhs, line_size, 1),
|
||||
lhs.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
rhs
|
||||
} else {
|
||||
let output =
|
||||
empty_device_dtype(lhs.client.clone(), lhs.device.clone(), shape_out, lhs.dtype);
|
||||
|
||||
kernel_binop_int::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(lhs, rhs, output),
|
||||
linear_view_ref(&lhs, &output, line_size),
|
||||
linear_view_ref(&rhs, &output, line_size),
|
||||
linear_view(&output, line_size),
|
||||
lhs.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn launch_scalar_binop_int<R: CubeRuntime, O: BinaryOpIntFamily>(
|
||||
tensor: CubeTensor<R>,
|
||||
scalar: InputScalar,
|
||||
) -> CubeTensor<R> {
|
||||
let line_size = max_line_size(&tensor);
|
||||
let client = tensor.client.clone();
|
||||
let num_elems = tensor.meta.shape.num_elements();
|
||||
|
||||
let working_units = num_elems / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&tensor.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
if tensor.can_mut() && tensor.is_nonoverlapping() {
|
||||
kernel_scalar_binop_int::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor),
|
||||
linear_view(&tensor, line_size),
|
||||
scalar,
|
||||
linear_view_alias(&tensor, line_size, 0),
|
||||
tensor.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
tensor
|
||||
} else {
|
||||
let output = empty_device_dtype(
|
||||
tensor.client.clone(),
|
||||
tensor.device.clone(),
|
||||
tensor.shape(),
|
||||
tensor.dtype,
|
||||
);
|
||||
|
||||
kernel_scalar_binop_int::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor, output),
|
||||
linear_view(&tensor, line_size),
|
||||
scalar,
|
||||
linear_view(&output, line_size),
|
||||
tensor.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{address_type, linear_view},
|
||||
ops::{max_line_size, numeric::empty_device_dtype},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::{DType, TensorMetadata};
|
||||
use cubecl::std::tensor::layout::linear::LinearView;
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*};
|
||||
|
||||
#[cube(launch, address_type = "dynamic")]
|
||||
pub(crate) fn cast_element<I: Numeric, O: Numeric>(
|
||||
input: &LinearView<Line<I>>,
|
||||
output: &mut LinearView<Line<O>, ReadWrite>,
|
||||
#[define(I, O)] _dtypes: [StorageType; 2],
|
||||
) {
|
||||
if !output.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
output[ABSOLUTE_POS] = Line::cast_from(input[ABSOLUTE_POS]);
|
||||
}
|
||||
|
||||
/// Cast a tensor to the given element type.
|
||||
///
|
||||
/// Note: When input element is semantically a boolean, prefer bool_cast function.
|
||||
pub fn cast<R: CubeRuntime>(input: CubeTensor<R>, dtype: DType) -> CubeTensor<R> {
|
||||
let dtype_output = match dtype {
|
||||
DType::Flex32 => DType::F32,
|
||||
_ => dtype,
|
||||
};
|
||||
let dtype_input = match input.dtype {
|
||||
DType::Flex32 => DType::F32,
|
||||
_ => input.dtype,
|
||||
};
|
||||
|
||||
if dtype_input == dtype_output {
|
||||
return input;
|
||||
}
|
||||
|
||||
let client = input.client.clone();
|
||||
|
||||
let line_size = max_line_size(&input);
|
||||
|
||||
let num_elems: usize = input.meta.num_elements();
|
||||
|
||||
let working_units = num_elems / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&client, working_units, cube_dim);
|
||||
|
||||
let output = empty_device_dtype(
|
||||
client.clone(),
|
||||
input.device.clone(),
|
||||
input.shape(),
|
||||
dtype, // We take the same dtype as passed as input (Flex32 not F32)
|
||||
);
|
||||
|
||||
cast_element::launch(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(input, output),
|
||||
linear_view(&input, line_size),
|
||||
linear_view(&output, line_size),
|
||||
[dtype_input.into(), dtype_output.into()],
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
use crate::{
|
||||
CubeElement, CubeRuntime,
|
||||
kernel::utils::{address_type, linear_view},
|
||||
ops::{max_line_size, numeric::empty_device},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::TensorMetadata;
|
||||
use cubecl::{
|
||||
CubeDim, calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView,
|
||||
};
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn bool_cast_kernel<B: Int, T: Numeric>(
|
||||
input: &LinearView<Line<B>>,
|
||||
output: &mut LinearView<Line<T>, ReadWrite>,
|
||||
#[define(B)] _input_ty: StorageType,
|
||||
) {
|
||||
if !output.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
output[ABSOLUTE_POS] = Line::cast_from(input[ABSOLUTE_POS] & Line::cast_from(1u32));
|
||||
}
|
||||
|
||||
/// Cast a bool tensor to the given element type.
|
||||
///
|
||||
/// This alternative to cast is necessary because bool are represented as u32 or u8
|
||||
/// where any non-zero value means true. Depending how it was created
|
||||
/// it may hold an uncanny bit combination. Naively casting it would not
|
||||
/// necessarily yield 0 or 1.
|
||||
pub fn bool_cast<R: CubeRuntime, EO: CubeElement>(tensor: CubeTensor<R>) -> CubeTensor<R> {
|
||||
let output =
|
||||
empty_device::<R, EO>(tensor.client.clone(), tensor.device.clone(), tensor.shape());
|
||||
|
||||
let line_size = max_line_size(&tensor);
|
||||
let num_elems = tensor.meta.num_elements();
|
||||
let working_units = num_elems / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&tensor.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
bool_cast_kernel::launch_unchecked::<EO, R>(
|
||||
&tensor.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor, output),
|
||||
linear_view(&tensor, line_size),
|
||||
linear_view(&output, line_size),
|
||||
tensor.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
mod base;
|
||||
mod bool_cast;
|
||||
|
||||
pub use base::*;
|
||||
pub use bool_cast::*;
|
||||
@@ -0,0 +1,42 @@
|
||||
use cubecl::prelude::*;
|
||||
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
|
||||
#[derive(CubeLaunch, CubeType)]
|
||||
struct Options {
|
||||
min_value: InputScalar,
|
||||
max_value: InputScalar,
|
||||
}
|
||||
|
||||
pub(crate) fn clamp<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
min_value: InputScalar,
|
||||
max_value: InputScalar,
|
||||
) -> CubeTensor<R> {
|
||||
struct ClampOp;
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> NumericUnaryOp<N> for ClampOp {
|
||||
type Options = Options;
|
||||
|
||||
fn execute(input: Line<N>, options: &Self::Options) -> Line<N> {
|
||||
let line_size = input.size();
|
||||
cubecl::prelude::clamp(
|
||||
input,
|
||||
Line::empty(line_size).fill(options.min_value.get::<N>()),
|
||||
Line::empty(line_size).fill(options.max_value.get::<N>()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NumericUnaryOpFamily for ClampOp {
|
||||
type Options = Options;
|
||||
type Unary<N: Numeric> = Self;
|
||||
}
|
||||
|
||||
launch_unary_numeric::<R, ClampOp, _>(input, |_| OptionsLaunch::new(min_value, max_value))
|
||||
}
|
||||
@@ -0,0 +1,432 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{
|
||||
address_type, broadcast_shape, linear_view, linear_view_alias, linear_view_ref,
|
||||
},
|
||||
ops::{max_line_size, numeric::empty_device_dtype},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::{DType, TensorMetadata};
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
|
||||
|
||||
#[cube]
|
||||
pub(crate) trait ComparisonOpFamily: 'static + Send + Sync {
|
||||
type Operation<N: Numeric>: ComparisonOp<N>;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub(crate) trait ComparisonOp<C: Numeric>: 'static + Send + Sync {
|
||||
/// Execute a comparison operation.
|
||||
fn execute(lhs: Line<C>, rhs: Line<C>) -> bool;
|
||||
}
|
||||
|
||||
struct EqualOp;
|
||||
struct GreaterEqualOp;
|
||||
struct LowerEqualOp;
|
||||
struct GreaterOp;
|
||||
struct LowerOp;
|
||||
|
||||
impl ComparisonOpFamily for EqualOp {
|
||||
type Operation<N: Numeric> = Self;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> ComparisonOp<N> for EqualOp {
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> bool {
|
||||
lhs == rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl ComparisonOpFamily for GreaterEqualOp {
|
||||
type Operation<N: Numeric> = Self;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> ComparisonOp<N> for GreaterEqualOp {
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> bool {
|
||||
lhs >= rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl ComparisonOpFamily for LowerEqualOp {
|
||||
type Operation<N: Numeric> = Self;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> ComparisonOp<N> for LowerEqualOp {
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> bool {
|
||||
lhs <= rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl ComparisonOpFamily for GreaterOp {
|
||||
type Operation<N: Numeric> = Self;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> ComparisonOp<N> for GreaterOp {
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> bool {
|
||||
lhs > rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl ComparisonOpFamily for LowerOp {
|
||||
type Operation<N: Numeric> = Self;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> ComparisonOp<N> for LowerOp {
|
||||
fn execute(lhs: Line<N>, rhs: Line<N>) -> bool {
|
||||
lhs < rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
pub(crate) fn kernel_scalar_cmp<N: Numeric, Bool: Numeric, O: ComparisonOpFamily>(
|
||||
input: &LinearView<Line<N>>,
|
||||
scalar: InputScalar,
|
||||
output: &mut LinearView<Line<Bool>, ReadWrite>,
|
||||
#[define(N, Bool)] _dtypes: [StorageType; 2],
|
||||
) {
|
||||
if !output.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
output[ABSOLUTE_POS] = Line::cast_from(O::Operation::<N>::execute(
|
||||
input[ABSOLUTE_POS],
|
||||
Line::new(scalar.get::<N>()),
|
||||
));
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
pub(crate) fn kernel_cmp<N: Numeric, Bool: Numeric, O: ComparisonOpFamily>(
|
||||
lhs: &LinearView<Line<N>>,
|
||||
rhs: &LinearView<Line<N>>,
|
||||
out: &mut LinearView<Line<Bool>, ReadWrite>,
|
||||
#[define(N, Bool)] _dtype: [StorageType; 2],
|
||||
) {
|
||||
if !out.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
out[ABSOLUTE_POS] = Line::cast_from(O::Operation::<N>::execute(
|
||||
lhs[ABSOLUTE_POS],
|
||||
rhs[ABSOLUTE_POS],
|
||||
));
|
||||
}
|
||||
|
||||
pub(crate) fn launch_cmp<R: CubeRuntime, O: ComparisonOpFamily>(
|
||||
lhs: CubeTensor<R>,
|
||||
rhs: CubeTensor<R>,
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
let line_size_lhs = max_line_size(&lhs);
|
||||
let line_size_rhs = max_line_size(&rhs);
|
||||
|
||||
let line_size = Ord::min(line_size_lhs, line_size_rhs);
|
||||
|
||||
let shape_out = broadcast_shape(&[&lhs, &rhs]);
|
||||
let client = lhs.client.clone();
|
||||
let num_elems = shape_out.num_elements();
|
||||
|
||||
let working_units = num_elems / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&lhs.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&lhs.client, working_units, cube_dim);
|
||||
|
||||
let dtypes = [lhs.dtype.into(), dtype_bool.into()];
|
||||
let same_tensor_type = dtypes[0] == dtypes[1];
|
||||
if same_tensor_type && lhs.can_mut_broadcast(&rhs) {
|
||||
unsafe {
|
||||
kernel_cmp::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(lhs, rhs),
|
||||
linear_view(&lhs, line_size),
|
||||
linear_view_ref(&rhs, &lhs, line_size),
|
||||
linear_view_alias(&lhs, line_size, 0),
|
||||
dtypes,
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
}
|
||||
|
||||
CubeTensor::new(lhs.client, lhs.handle, *lhs.meta, lhs.device, dtype_bool)
|
||||
} else if same_tensor_type && rhs.can_mut_broadcast(&lhs) {
|
||||
unsafe {
|
||||
kernel_cmp::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(lhs, rhs),
|
||||
linear_view_ref(&lhs, &rhs, line_size),
|
||||
linear_view(&rhs, line_size),
|
||||
linear_view_alias(&rhs, line_size, 1),
|
||||
dtypes,
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
};
|
||||
|
||||
CubeTensor::new(rhs.client, rhs.handle, *rhs.meta, rhs.device, dtype_bool)
|
||||
} else {
|
||||
let output = empty_device_dtype(
|
||||
lhs.client.clone(),
|
||||
lhs.device.clone(),
|
||||
shape_out,
|
||||
dtype_bool,
|
||||
);
|
||||
|
||||
unsafe {
|
||||
kernel_cmp::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(lhs, rhs, output),
|
||||
linear_view_ref(&lhs, &output, line_size),
|
||||
linear_view_ref(&rhs, &output, line_size),
|
||||
linear_view(&output, line_size),
|
||||
dtypes,
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
};
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn launch_scalar_cmp<R: CubeRuntime, O: ComparisonOpFamily>(
|
||||
tensor: CubeTensor<R>,
|
||||
scalar: InputScalar,
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
let line_size = max_line_size(&tensor);
|
||||
let client = tensor.client.clone();
|
||||
let num_elems = tensor.meta.num_elements();
|
||||
|
||||
let working_units = num_elems / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&tensor.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
|
||||
|
||||
let dtypes = [tensor.dtype.into(), dtype_bool.into()];
|
||||
let same_tensor_type = dtypes[0] == dtypes[1];
|
||||
|
||||
if same_tensor_type && tensor.can_mut() && tensor.is_nonoverlapping() {
|
||||
unsafe {
|
||||
kernel_scalar_cmp::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor),
|
||||
linear_view(&tensor, line_size),
|
||||
scalar,
|
||||
linear_view_alias(&tensor, line_size, 0),
|
||||
dtypes,
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
}
|
||||
|
||||
CubeTensor::new(
|
||||
tensor.client,
|
||||
tensor.handle,
|
||||
*tensor.meta,
|
||||
tensor.device,
|
||||
dtype_bool,
|
||||
)
|
||||
} else {
|
||||
let output = empty_device_dtype(
|
||||
tensor.client.clone(),
|
||||
tensor.device.clone(),
|
||||
tensor.shape(),
|
||||
dtype_bool,
|
||||
);
|
||||
|
||||
unsafe {
|
||||
kernel_scalar_cmp::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor, output),
|
||||
linear_view(&tensor, line_size),
|
||||
scalar,
|
||||
linear_view(&output, line_size),
|
||||
dtypes,
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
pub fn equal<R: CubeRuntime>(
|
||||
lhs: CubeTensor<R>,
|
||||
rhs: CubeTensor<R>,
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
launch_cmp::<R, EqualOp>(lhs, rhs, dtype_bool)
|
||||
}
|
||||
|
||||
pub fn greater<R: CubeRuntime>(
|
||||
lhs: CubeTensor<R>,
|
||||
rhs: CubeTensor<R>,
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
launch_cmp::<R, GreaterOp>(lhs, rhs, dtype_bool)
|
||||
}
|
||||
|
||||
pub fn greater_equal<R: CubeRuntime>(
|
||||
lhs: CubeTensor<R>,
|
||||
rhs: CubeTensor<R>,
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
launch_cmp::<R, GreaterEqualOp>(lhs, rhs, dtype_bool)
|
||||
}
|
||||
|
||||
pub fn lower<R: CubeRuntime>(
|
||||
lhs: CubeTensor<R>,
|
||||
rhs: CubeTensor<R>,
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
launch_cmp::<R, LowerOp>(lhs, rhs, dtype_bool)
|
||||
}
|
||||
|
||||
pub fn lower_equal<R: CubeRuntime>(
|
||||
lhs: CubeTensor<R>,
|
||||
rhs: CubeTensor<R>,
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
launch_cmp::<R, LowerEqualOp>(lhs, rhs, dtype_bool)
|
||||
}
|
||||
|
||||
pub fn equal_elem<R: CubeRuntime>(
|
||||
lhs: CubeTensor<R>,
|
||||
rhs: InputScalar,
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
launch_scalar_cmp::<R, EqualOp>(lhs, rhs, dtype_bool)
|
||||
}
|
||||
|
||||
pub fn greater_elem<R: CubeRuntime>(
|
||||
lhs: CubeTensor<R>,
|
||||
rhs: InputScalar,
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
launch_scalar_cmp::<R, GreaterOp>(lhs, rhs, dtype_bool)
|
||||
}
|
||||
|
||||
pub fn lower_elem<R: CubeRuntime>(
|
||||
lhs: CubeTensor<R>,
|
||||
rhs: InputScalar,
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
launch_scalar_cmp::<R, LowerOp>(lhs, rhs, dtype_bool)
|
||||
}
|
||||
|
||||
pub fn greater_equal_elem<R: CubeRuntime>(
|
||||
lhs: CubeTensor<R>,
|
||||
rhs: InputScalar,
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
launch_scalar_cmp::<R, GreaterEqualOp>(lhs, rhs, dtype_bool)
|
||||
}
|
||||
|
||||
pub fn lower_equal_elem<R: CubeRuntime>(
|
||||
lhs: CubeTensor<R>,
|
||||
rhs: InputScalar,
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
launch_scalar_cmp::<R, LowerEqualOp>(lhs, rhs, dtype_bool)
|
||||
}
|
||||
|
||||
// Unary comparison / predicate / relational ops
|
||||
|
||||
#[cube]
|
||||
pub(crate) trait PredicateOp<F: Float>: 'static + Send + Sync {
|
||||
/// Execute a predicate operation.
|
||||
fn execute(input: Line<F>) -> bool;
|
||||
}
|
||||
|
||||
pub(crate) trait PredicateOpFamily: 'static + Send + Sync {
|
||||
type Operation<F: Float>: PredicateOp<F>;
|
||||
}
|
||||
|
||||
struct IsNanOp;
|
||||
struct IsInfOp;
|
||||
|
||||
impl PredicateOpFamily for IsNanOp {
|
||||
type Operation<F: Float> = Self;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<F: Float> PredicateOp<F> for IsNanOp {
|
||||
fn execute(input: Line<F>) -> bool {
|
||||
Line::is_nan(input)
|
||||
}
|
||||
}
|
||||
|
||||
impl PredicateOpFamily for IsInfOp {
|
||||
type Operation<F: Float> = Self;
|
||||
}
|
||||
#[cube]
|
||||
impl<F: Float> PredicateOp<F> for IsInfOp {
|
||||
fn execute(input: Line<F>) -> bool {
|
||||
Line::is_inf(input)
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
pub(crate) fn kernel_predicate<F: Float, Bool: Numeric, O: PredicateOpFamily>(
|
||||
input: &LinearView<Line<F>>,
|
||||
output: &mut LinearView<Line<Bool>, ReadWrite>,
|
||||
#[define(F, Bool)] _dtypes: [StorageType; 2],
|
||||
) {
|
||||
if !output.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
output[ABSOLUTE_POS] = Line::cast_from(O::Operation::<F>::execute(input[ABSOLUTE_POS]));
|
||||
}
|
||||
|
||||
pub(crate) fn launch_predicate<R: CubeRuntime, O: PredicateOpFamily>(
|
||||
tensor: CubeTensor<R>,
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
let line_size = max_line_size(&tensor);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let num_elems = tensor.meta.num_elements();
|
||||
|
||||
let dtypes = [tensor.dtype.into(), dtype_bool.into()];
|
||||
let working_units = num_elems / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&tensor.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
|
||||
|
||||
let output = empty_device_dtype(
|
||||
tensor.client.clone(),
|
||||
tensor.device.clone(),
|
||||
tensor.shape(),
|
||||
dtype_bool,
|
||||
);
|
||||
|
||||
unsafe {
|
||||
kernel_predicate::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor, output),
|
||||
linear_view_ref(&tensor, &output, line_size),
|
||||
linear_view(&output, line_size),
|
||||
dtypes,
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub fn is_nan<R: CubeRuntime>(tensor: CubeTensor<R>, dtype_bool: DType) -> CubeTensor<R> {
|
||||
launch_predicate::<R, IsNanOp>(tensor, dtype_bool)
|
||||
}
|
||||
|
||||
pub fn is_inf<R: CubeRuntime>(tensor: CubeTensor<R>, dtype_bool: DType) -> CubeTensor<R> {
|
||||
launch_predicate::<R, IsInfOp>(tensor, dtype_bool)
|
||||
}
|
||||
@@ -0,0 +1,124 @@
|
||||
use burn_backend::{DType, QTensorPrimitive, TensorMetadata};
|
||||
use cubecl::quant::scheme::{QuantStore, QuantValue};
|
||||
use cubecl::server::AllocationKind;
|
||||
|
||||
use crate::{CubeRuntime, ops::empty_qtensor, tensor::CubeTensor};
|
||||
|
||||
/// Make a jit tensor contiguous.
|
||||
pub fn into_contiguous<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
|
||||
if tensor.is_contiguous() {
|
||||
return tensor;
|
||||
}
|
||||
|
||||
if tensor.qparams.is_some() {
|
||||
return into_contiguous_quantized(tensor, AllocationKind::Contiguous);
|
||||
}
|
||||
|
||||
let output = cubecl::std::tensor::into_contiguous_ref(
|
||||
&tensor.client,
|
||||
&tensor.as_handle_ref(),
|
||||
tensor.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
CubeTensor::new(
|
||||
tensor.client,
|
||||
output.handle,
|
||||
*output.metadata,
|
||||
tensor.device,
|
||||
tensor.dtype,
|
||||
)
|
||||
}
|
||||
|
||||
/// Make a jit tensor contiguous with an aligned last stride. Tensor is considered already contiguous
|
||||
/// if runtime can read it as is. This is equivalent in practice.
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip(tensor))
|
||||
)]
|
||||
pub fn into_contiguous_aligned<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
|
||||
if R::can_read_tensor(tensor.meta.shape(), tensor.meta.strides()) {
|
||||
return tensor;
|
||||
}
|
||||
|
||||
if tensor.qparams.is_some() {
|
||||
return into_contiguous_quantized(tensor, AllocationKind::Optimized);
|
||||
}
|
||||
|
||||
let output = cubecl::std::tensor::into_contiguous_pitched_ref(
|
||||
&tensor.client,
|
||||
&tensor.as_handle_ref(),
|
||||
tensor.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
CubeTensor::new(
|
||||
tensor.client,
|
||||
output.handle,
|
||||
*output.metadata,
|
||||
tensor.device,
|
||||
tensor.dtype,
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip(tensor))
|
||||
)]
|
||||
fn into_contiguous_quantized<R: CubeRuntime>(
|
||||
tensor: CubeTensor<R>,
|
||||
kind: AllocationKind,
|
||||
) -> CubeTensor<R> {
|
||||
let scheme = tensor.scheme();
|
||||
let output = empty_qtensor(tensor.shape(), *tensor.scheme(), &tensor.device, kind);
|
||||
let (values, scales) = tensor.quantized_handles().unwrap();
|
||||
let (out_values, out_scales) = output.quantized_handles().unwrap();
|
||||
|
||||
match scheme.store {
|
||||
QuantStore::PackedU32(packed_dim) => {
|
||||
cubecl::std::tensor::into_contiguous_packed_ref(
|
||||
&values.client,
|
||||
&values.as_handle_ref(),
|
||||
&out_values.as_handle_ref(),
|
||||
packed_dim,
|
||||
tensor.meta.shape(),
|
||||
scheme.num_quants(),
|
||||
DType::U32.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
}
|
||||
// e2m1 is special because it has a native packed representation, `e2m1x2`.
|
||||
// It's internally stored as `u8` with a packing factor of 2.
|
||||
QuantStore::PackedNative(packed_dim) if scheme.value == QuantValue::E2M1 => {
|
||||
cubecl::std::tensor::into_contiguous_packed_ref(
|
||||
&values.client,
|
||||
&values.as_handle_ref(),
|
||||
&out_values.as_handle_ref(),
|
||||
packed_dim,
|
||||
tensor.meta.shape(),
|
||||
scheme.num_quants(),
|
||||
DType::U8.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
}
|
||||
_ => {
|
||||
cubecl::std::tensor::copy_into(
|
||||
&values.client,
|
||||
&values.as_handle_ref(),
|
||||
&out_values.as_handle_ref(),
|
||||
values.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
}
|
||||
}
|
||||
|
||||
cubecl::std::tensor::copy_into(
|
||||
&scales.client,
|
||||
&scales.as_handle_ref(),
|
||||
&out_scales.as_handle_ref(),
|
||||
scales.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
use burn_backend::{
|
||||
TensorMetadata,
|
||||
ops::{ConvOptions, ConvTransposeOptions},
|
||||
};
|
||||
use burn_std::Shape;
|
||||
use cubek::convolution::components::ConvSetupError;
|
||||
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::conv::{conv_transpose2d, conv_transpose3d},
|
||||
ops::{permute_nchw_to_nhwc, permute_nhwc_to_nchw, reshape},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
|
||||
pub(crate) fn conv_data_backward_fallback<R: CubeRuntime, const N_DIM: usize>(
|
||||
out_grad: CubeTensor<R>,
|
||||
weights: CubeTensor<R>,
|
||||
in_shape: Shape,
|
||||
options: ConvOptions<N_DIM>,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
let dim_c = out_grad.rank();
|
||||
|
||||
let kernel_size = &weights.meta.shape()[1..dim_c];
|
||||
let in_shape = &in_shape[1..dim_c];
|
||||
let out_shape = &out_grad.meta.shape()[1..dim_c];
|
||||
|
||||
let mut padding_out = [0; N_DIM];
|
||||
|
||||
for i in 0..N_DIM {
|
||||
padding_out[i] = calculate_padding_out(
|
||||
kernel_size[i],
|
||||
options.stride[i],
|
||||
options.padding[i],
|
||||
options.dilation[i],
|
||||
in_shape[i],
|
||||
out_shape[i],
|
||||
);
|
||||
}
|
||||
|
||||
// We don't yet have NHWC kernels for conv_transpose so need to do this.
|
||||
// Should eventually use NHWC kernels instead
|
||||
let out_grad = permute_nhwc_to_nchw(out_grad);
|
||||
let weights = permute_nhwc_to_nchw(weights);
|
||||
|
||||
let in_grad = match N_DIM {
|
||||
1 => conv_transpose1d_from_conv_transpose2d(
|
||||
out_grad,
|
||||
weights,
|
||||
ConvTransposeOptions::new(
|
||||
[options.stride[0]],
|
||||
[options.padding[0]],
|
||||
[padding_out[0]],
|
||||
[options.dilation[0]],
|
||||
options.groups,
|
||||
),
|
||||
),
|
||||
2 => conv_transpose2d(
|
||||
out_grad,
|
||||
weights,
|
||||
None,
|
||||
ConvTransposeOptions::new(
|
||||
[options.stride[0], options.stride[1]],
|
||||
[options.padding[0], options.padding[1]],
|
||||
[padding_out[0], padding_out[1]],
|
||||
[options.dilation[0], options.dilation[1]],
|
||||
options.groups,
|
||||
),
|
||||
Default::default(),
|
||||
),
|
||||
3 => Ok(conv_transpose3d(
|
||||
out_grad,
|
||||
weights,
|
||||
None,
|
||||
ConvTransposeOptions::new(
|
||||
[options.stride[0], options.stride[1], options.stride[2]],
|
||||
[options.padding[0], options.padding[1], options.padding[2]],
|
||||
[padding_out[0], padding_out[1], padding_out[2]],
|
||||
[
|
||||
options.dilation[0],
|
||||
options.dilation[1],
|
||||
options.dilation[2],
|
||||
],
|
||||
options.groups,
|
||||
),
|
||||
)
|
||||
.unwrap()),
|
||||
_ => unimplemented!("Invalid dimensionality"),
|
||||
}?;
|
||||
Ok(permute_nchw_to_nhwc(in_grad))
|
||||
}
|
||||
|
||||
fn calculate_padding_out(
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
size_in: usize,
|
||||
size_out: usize,
|
||||
) -> usize {
|
||||
if stride <= 1 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let out = 1
|
||||
+ ((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) as f64 / stride as f64).ceil()
|
||||
as usize;
|
||||
i64::max(0, out as i64 - size_out as i64) as usize
|
||||
}
|
||||
|
||||
fn conv_transpose1d_from_conv_transpose2d<R: CubeRuntime>(
|
||||
x: CubeTensor<R>,
|
||||
weight: CubeTensor<R>,
|
||||
options: ConvTransposeOptions<1>,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
let [channels_in, channels_out, kernel_size] = weight.shape().dims();
|
||||
let [batch_size, _channels_in, length_in] = x.shape().dims();
|
||||
|
||||
let weight = reshape(
|
||||
weight,
|
||||
Shape::new([channels_in, channels_out, kernel_size, 1]),
|
||||
);
|
||||
let x = reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));
|
||||
|
||||
let tensor = conv_transpose2d(
|
||||
x,
|
||||
weight,
|
||||
None,
|
||||
ConvTransposeOptions::new(
|
||||
[options.stride[0], 1],
|
||||
[options.padding[0], 0],
|
||||
[options.padding_out[0], 0],
|
||||
[options.dilation[0], 1],
|
||||
options.groups,
|
||||
),
|
||||
Default::default(),
|
||||
)?;
|
||||
let [batch_size, channels_out, height_out, _weight_out] = tensor.shape().dims();
|
||||
Ok(reshape(
|
||||
tensor,
|
||||
Shape::from([batch_size, channels_out, height_out]),
|
||||
))
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
use burn_backend::ops::ConvOptions;
|
||||
use burn_std::Shape;
|
||||
use cubek::{
|
||||
convolution::{
|
||||
AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy, backward_data,
|
||||
components::ConvSetupError,
|
||||
},
|
||||
matmul::{
|
||||
definition::{MatmulElems, MatmulGlobalElems},
|
||||
launch::MatmulInputHandleRef,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
|
||||
|
||||
pub fn dgrad_gemm_simple_sync<R: CubeRuntime, const N: usize>(
|
||||
out_grad: CubeTensor<R>,
|
||||
weights: CubeTensor<R>,
|
||||
input_shape: Shape,
|
||||
options: ConvOptions<N>,
|
||||
tile_kind: AcceleratedTileKind,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
let read_strategy = match tile_kind {
|
||||
AcceleratedTileKind::Cmma => ReadingStrategy::Cyclic,
|
||||
AcceleratedTileKind::Mma => ReadingStrategy::Strided,
|
||||
};
|
||||
launch_backwards_data::<R, N>(
|
||||
&Strategy::Simple {
|
||||
read_strategy,
|
||||
tile_kind,
|
||||
},
|
||||
out_grad,
|
||||
weights,
|
||||
input_shape,
|
||||
options,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn dgrad_gemm_simple_async<R: CubeRuntime, const N: usize>(
|
||||
out_grad: CubeTensor<R>,
|
||||
weights: CubeTensor<R>,
|
||||
input_shape: Shape,
|
||||
options: ConvOptions<N>,
|
||||
tile_kind: AcceleratedTileKind,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
let read_strategy = match tile_kind {
|
||||
AcceleratedTileKind::Cmma => ReadingStrategy::AsyncCyclic,
|
||||
AcceleratedTileKind::Mma => ReadingStrategy::AsyncStrided,
|
||||
};
|
||||
launch_backwards_data::<R, N>(
|
||||
&Strategy::Simple {
|
||||
read_strategy,
|
||||
tile_kind,
|
||||
},
|
||||
out_grad,
|
||||
weights,
|
||||
input_shape,
|
||||
options,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn dgrad_gemm_simple_tma<R: CubeRuntime, const N: usize>(
|
||||
out_grad: CubeTensor<R>,
|
||||
weights: CubeTensor<R>,
|
||||
input_shape: Shape,
|
||||
options: ConvOptions<N>,
|
||||
tile_kind: AcceleratedTileKind,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
launch_backwards_data::<R, N>(
|
||||
&Strategy::Simple {
|
||||
read_strategy: ReadingStrategy::Tma,
|
||||
tile_kind,
|
||||
},
|
||||
out_grad,
|
||||
weights,
|
||||
input_shape,
|
||||
options,
|
||||
)
|
||||
}
|
||||
|
||||
/// Perform a convolution backwards data pass using the implicit GEMM (im2col) algorithm, using
|
||||
/// cubecl tiling matmul components.
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `out_grad` - The output gradients
|
||||
/// * `weight_shape` - The shape of the weights/weight gradients
|
||||
/// * `options` - The options to use for the convolution
|
||||
pub fn launch_backwards_data<R: CubeRuntime, const N: usize>(
|
||||
strategy: &Strategy,
|
||||
out_grad: CubeTensor<R>,
|
||||
weights: CubeTensor<R>,
|
||||
input_shape: Shape,
|
||||
options: ConvOptions<N>,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
if options.groups != 1 || options.stride.iter().any(|&s| s != 1) {
|
||||
return Err(ConvSetupError::Groups(options.groups));
|
||||
}
|
||||
|
||||
let out_dtype = out_grad.dtype;
|
||||
|
||||
let in_grad = empty_device_dtype(
|
||||
out_grad.client.clone(),
|
||||
out_grad.device.clone(),
|
||||
input_shape,
|
||||
out_dtype,
|
||||
);
|
||||
|
||||
let client = out_grad.client.clone();
|
||||
let dtypes = MatmulElems::from_globals(&MatmulGlobalElems {
|
||||
lhs: out_grad.dtype.into(),
|
||||
rhs: weights.dtype.into(),
|
||||
out: out_dtype.into(),
|
||||
});
|
||||
let out_grad = MatmulInputHandleRef::new(out_grad.as_handle_ref(), out_grad.dtype.into());
|
||||
let weights = MatmulInputHandleRef::new(weights.as_handle_ref(), weights.dtype.into());
|
||||
|
||||
backward_data::launch_ref::<R, N>(
|
||||
strategy,
|
||||
&client,
|
||||
&out_grad,
|
||||
&weights,
|
||||
&in_grad.as_handle_ref(),
|
||||
ConvolutionArgs {
|
||||
stride: options.stride,
|
||||
padding: options.padding,
|
||||
dilation: options.dilation,
|
||||
},
|
||||
dtypes,
|
||||
)?;
|
||||
|
||||
Ok(in_grad)
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
pub mod launch;
|
||||
pub use launch::*;
|
||||
@@ -0,0 +1,8 @@
|
||||
pub mod fallback;
|
||||
pub mod implicit_gemm;
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
pub mod tune;
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
pub(crate) use tune::*;
|
||||
@@ -0,0 +1,172 @@
|
||||
use burn_backend::ops::ConvOptions;
|
||||
use burn_std::Shape;
|
||||
use cubecl::{
|
||||
ir::StorageType,
|
||||
tune::{LocalTuner, Tunable, TunableSet, anchor, local_tuner},
|
||||
};
|
||||
use cubek::convolution::AcceleratedTileKind;
|
||||
|
||||
use crate::{
|
||||
CubeAutotuneKey, CubeRuntime, CubeTuneId,
|
||||
kernel::conv::{
|
||||
ConvAutotuneKey,
|
||||
backward_data::{fallback::conv_data_backward_fallback, implicit_gemm::*},
|
||||
},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
|
||||
/// Executes autotune on conv2d operations
|
||||
pub fn dgrad_autotune<R: CubeRuntime, const N: usize>(
|
||||
out_grad: CubeTensor<R>,
|
||||
weights: CubeTensor<R>,
|
||||
input_shape: Shape,
|
||||
options: ConvOptions<N>,
|
||||
) -> CubeTensor<R> {
|
||||
let client = out_grad.client.clone();
|
||||
|
||||
static TUNER: LocalTuner<CubeAutotuneKey, CubeTuneId> = local_tuner!();
|
||||
|
||||
// Note: TMA isn't currently implemented properly, and will always error.
|
||||
// It's kept here so it gets automatically enabled as soon as cubek updates.
|
||||
// No CMMA for TMA because swizzling will be mandatory for good performance on dgrad.
|
||||
let tunables = TUNER.init(|| {
|
||||
TunableSet::new(create_key::<R, N>, create_wgrad_input::<R, N>)
|
||||
.with(Tunable::new(
|
||||
"wgrad_fallback",
|
||||
conv_data_backward_fallback::<R, N>,
|
||||
))
|
||||
.with(Tunable::new(
|
||||
"simple_sync_cmma",
|
||||
|input, grad, shape, options| {
|
||||
dgrad_gemm_simple_sync(input, grad, shape, options, AcceleratedTileKind::Cmma)
|
||||
},
|
||||
))
|
||||
.with(Tunable::new(
|
||||
"simple_sync_mma",
|
||||
|input, grad, shape, options| {
|
||||
dgrad_gemm_simple_sync(input, grad, shape, options, AcceleratedTileKind::Mma)
|
||||
},
|
||||
))
|
||||
.with(Tunable::new(
|
||||
"simple_async_cmma",
|
||||
|input, grad, shape, options| {
|
||||
dgrad_gemm_simple_async(input, grad, shape, options, AcceleratedTileKind::Cmma)
|
||||
},
|
||||
))
|
||||
.with(Tunable::new(
|
||||
"simple_async_mma",
|
||||
|input, grad, shape, options| {
|
||||
dgrad_gemm_simple_async(input, grad, shape, options, AcceleratedTileKind::Mma)
|
||||
},
|
||||
))
|
||||
.with(Tunable::new(
|
||||
"simple_tma_mma",
|
||||
|input, grad, shape, options| {
|
||||
dgrad_gemm_simple_tma(input, grad, shape, options, AcceleratedTileKind::Mma)
|
||||
},
|
||||
))
|
||||
});
|
||||
|
||||
TUNER.execute(
|
||||
&CubeTuneId::new(&out_grad.client, &out_grad.device),
|
||||
&client,
|
||||
tunables,
|
||||
(out_grad, weights, input_shape, options),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn create_wgrad_input<R: CubeRuntime, const N: usize>(
|
||||
_key: &CubeAutotuneKey,
|
||||
out_grad: &CubeTensor<R>,
|
||||
weights: &CubeTensor<R>,
|
||||
input_shape: &Shape,
|
||||
options: &ConvOptions<N>,
|
||||
) -> (CubeTensor<R>, CubeTensor<R>, Shape, ConvOptions<N>) {
|
||||
(
|
||||
out_grad.clone(),
|
||||
weights.clone(),
|
||||
input_shape.clone(),
|
||||
options.clone(),
|
||||
)
|
||||
}
|
||||
|
||||
fn create_key<R: CubeRuntime, const N: usize>(
|
||||
out_grad: &CubeTensor<R>,
|
||||
weights: &CubeTensor<R>,
|
||||
input_shape: &Shape,
|
||||
options: &ConvOptions<N>,
|
||||
) -> CubeAutotuneKey {
|
||||
let dtype = out_grad.dtype;
|
||||
let rank = out_grad.meta.num_dims();
|
||||
let dim_c = rank - 1;
|
||||
|
||||
let batch_size = out_grad.meta.shape()[0];
|
||||
let in_channels = input_shape[dim_c];
|
||||
let out_channels = out_grad.meta.shape()[dim_c];
|
||||
|
||||
let kernel_size = weights.meta.shape()[1..dim_c].to_vec();
|
||||
let in_shape = input_shape[1..dim_c]
|
||||
.iter()
|
||||
.map(|shape| anchor(*shape, None, None, None))
|
||||
.collect();
|
||||
|
||||
let ConvOptions {
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
} = options.clone();
|
||||
|
||||
let lhs_stride_align = if out_grad.meta.strides()[dim_c] == 1 {
|
||||
stride_align(out_grad.meta.strides(), out_grad.dtype.into())
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let lhs_shape_align = pow2_factor(out_channels).min(lhs_stride_align);
|
||||
let rhs_stride_align = if weights.meta.strides()[dim_c] == 1 {
|
||||
stride_align(weights.meta.strides(), weights.dtype.into())
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let rhs_shape_align = pow2_factor(in_channels).min(rhs_stride_align);
|
||||
|
||||
CubeAutotuneKey::Conv(ConvAutotuneKey::new(
|
||||
kernel_size,
|
||||
stride.to_vec(),
|
||||
padding.to_vec(),
|
||||
dilation.to_vec(),
|
||||
groups,
|
||||
in_channels,
|
||||
out_channels,
|
||||
in_shape,
|
||||
batch_size,
|
||||
false,
|
||||
dtype,
|
||||
lhs_shape_align,
|
||||
lhs_stride_align,
|
||||
rhs_shape_align,
|
||||
rhs_stride_align,
|
||||
))
|
||||
}
|
||||
|
||||
/// Maximum factor relevant for strides. Currently set to 2^10 because that's 128-byte swizzle's
|
||||
/// repeat number, so it's the largest align that can have performance impacts.
|
||||
const MAX_STRIDE_FACTOR: u32 = 10;
|
||||
|
||||
/// Defines the non-contiguous stride alignment in terms of powers of two
|
||||
fn stride_align(strides: &[usize], elem: StorageType) -> u8 {
|
||||
let max = MAX_STRIDE_FACTOR;
|
||||
let dim_c = strides.len() - 1;
|
||||
let factor = strides[..dim_c]
|
||||
.iter()
|
||||
.map(|it| (*it * elem.size_bits()) / 8)
|
||||
.map(|it| it.trailing_zeros())
|
||||
.min()
|
||||
.unwrap_or(max);
|
||||
factor.min(max) as u8
|
||||
}
|
||||
|
||||
/// Defines the potential vectorization.
|
||||
fn pow2_factor(axis: usize) -> u8 {
|
||||
axis.trailing_zeros().min(4) as u8
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
use burn_backend::{TensorMetadata, ops::ConvOptions};
|
||||
use burn_std::{Shape, Slice};
|
||||
use cubek::convolution::components::ConvSetupError;
|
||||
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::{conv::base::conv_forward_nhwc, slice, slice_assign},
|
||||
ops::{numeric::empty_device_dtype, swap_dims},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
|
||||
/// Calculate the convolution backward pass with regard to the weight gradients.
|
||||
pub fn conv_weight_backward_fallback<R: CubeRuntime, const N_DIM: usize>(
|
||||
input: CubeTensor<R>,
|
||||
output_grad: CubeTensor<R>,
|
||||
weight_shape: Shape,
|
||||
options: ConvOptions<N_DIM>,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
match options.groups == 1 {
|
||||
true => conv_weight_grad_no_groups::<R, N_DIM>(input, output_grad, weight_shape, options),
|
||||
false => conv_weight_grad_groups::<R, N_DIM>(input, output_grad, weight_shape, options),
|
||||
}
|
||||
}
|
||||
|
||||
fn conv_weight_grad_no_groups<R: CubeRuntime, const N_DIM: usize>(
|
||||
input: CubeTensor<R>,
|
||||
output_grad: CubeTensor<R>,
|
||||
weight_shape: Shape,
|
||||
options: ConvOptions<N_DIM>,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
let dim_c = input.rank() - 1;
|
||||
|
||||
let input_swapped = swap_dims(input, 0, dim_c);
|
||||
let out_grad_swapped = swap_dims(output_grad, 0, dim_c);
|
||||
let weight_grad_swapped = conv_forward_nhwc(
|
||||
input_swapped,
|
||||
out_grad_swapped,
|
||||
None,
|
||||
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
|
||||
Default::default(),
|
||||
)?;
|
||||
let mut weight_grad = swap_dims(weight_grad_swapped, 0, dim_c);
|
||||
if weight_grad.shape() != weight_shape {
|
||||
let ranges = weight_shape.iter().map(|&s| 0..s).collect::<Vec<_>>();
|
||||
weight_grad = slice(weight_grad, &ranges);
|
||||
}
|
||||
|
||||
Ok(weight_grad)
|
||||
}
|
||||
|
||||
#[allow(clippy::single_range_in_vec_init, reason = "False positive")]
|
||||
fn conv_weight_grad_groups<R: CubeRuntime, const N_DIM: usize>(
|
||||
input: CubeTensor<R>,
|
||||
output_grad: CubeTensor<R>,
|
||||
weight_shape: Shape,
|
||||
options: ConvOptions<N_DIM>,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
let mut weight_grad = empty_device_dtype(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
weight_shape.clone(),
|
||||
input.dtype,
|
||||
);
|
||||
|
||||
let dim_c = input.rank() - 1;
|
||||
|
||||
let channels_out = weight_shape[0];
|
||||
let increment_co = channels_out / options.groups;
|
||||
|
||||
let input_swapped = swap_dims(input, 0, dim_c);
|
||||
let output_grad_swapped = swap_dims(output_grad, 0, dim_c);
|
||||
|
||||
let kernel_size = &weight_shape[1..dim_c];
|
||||
let kernel_size_slice = kernel_size.iter().map(|&s| 0..s).collect::<Vec<_>>();
|
||||
let increment_ci = weight_grad.meta.shape()[dim_c];
|
||||
|
||||
for g in 0..options.groups {
|
||||
let start_idx_ci = g * increment_ci;
|
||||
let end_idx_ci = (g + 1) * increment_ci;
|
||||
let start_idx_co = g * increment_co;
|
||||
let end_idx_co = (g + 1) * increment_co;
|
||||
|
||||
let input = slice(input_swapped.clone(), &[start_idx_ci..end_idx_ci]);
|
||||
let grad = slice(output_grad_swapped.clone(), &[start_idx_co..end_idx_co]);
|
||||
|
||||
let weight_grad_tmp = conv_forward_nhwc(
|
||||
input,
|
||||
grad,
|
||||
None,
|
||||
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
|
||||
Default::default(),
|
||||
)?;
|
||||
let mut weight_grad_tmp = swap_dims(weight_grad_tmp, 0, dim_c);
|
||||
let kernel_size_tmp = &weight_grad_tmp.meta.shape()[1..dim_c];
|
||||
|
||||
if kernel_size != kernel_size_tmp {
|
||||
let mut slices = vec![0..increment_co];
|
||||
slices.extend(kernel_size_slice.clone());
|
||||
slices.push(0..increment_ci);
|
||||
weight_grad_tmp = slice(weight_grad_tmp, &slices);
|
||||
}
|
||||
|
||||
let mut slices = vec![start_idx_co..end_idx_co];
|
||||
slices.extend(kernel_size_slice.clone());
|
||||
slices.push(0..increment_ci);
|
||||
let slices = slices.into_iter().map(Slice::from).collect::<Vec<_>>();
|
||||
|
||||
weight_grad = slice_assign(weight_grad, &slices, weight_grad_tmp);
|
||||
}
|
||||
|
||||
Ok(weight_grad)
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
use burn_backend::ops::ConvOptions;
|
||||
use burn_std::Shape;
|
||||
use cubek::{
|
||||
convolution::{
|
||||
AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy, backward_weight,
|
||||
components::ConvSetupError,
|
||||
},
|
||||
matmul::{
|
||||
definition::{MatmulElems, MatmulGlobalElems},
|
||||
launch::MatmulInputHandleRef,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
|
||||
|
||||
pub(crate) fn wgrad_gemm_simple_sync<R: CubeRuntime, const N: usize>(
|
||||
input: CubeTensor<R>,
|
||||
out_grad: CubeTensor<R>,
|
||||
weight_shape: Shape,
|
||||
options: ConvOptions<N>,
|
||||
tile_kind: AcceleratedTileKind,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
let read_strategy = match tile_kind {
|
||||
AcceleratedTileKind::Cmma => ReadingStrategy::Cyclic,
|
||||
AcceleratedTileKind::Mma => ReadingStrategy::Strided,
|
||||
};
|
||||
launch_backwards_weight::<R, N>(
|
||||
&Strategy::Simple {
|
||||
read_strategy,
|
||||
tile_kind,
|
||||
},
|
||||
input,
|
||||
out_grad,
|
||||
weight_shape,
|
||||
options,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn wgrad_gemm_simple_async<R: CubeRuntime, const N: usize>(
|
||||
input: CubeTensor<R>,
|
||||
out_grad: CubeTensor<R>,
|
||||
weight_shape: Shape,
|
||||
options: ConvOptions<N>,
|
||||
tile_kind: AcceleratedTileKind,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
let read_strategy = match tile_kind {
|
||||
AcceleratedTileKind::Cmma => ReadingStrategy::AsyncCyclic,
|
||||
AcceleratedTileKind::Mma => ReadingStrategy::AsyncStrided,
|
||||
};
|
||||
launch_backwards_weight::<R, N>(
|
||||
&Strategy::Simple {
|
||||
read_strategy,
|
||||
tile_kind,
|
||||
},
|
||||
input,
|
||||
out_grad,
|
||||
weight_shape,
|
||||
options,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn wgrad_gemm_simple_tma<R: CubeRuntime, const N: usize>(
|
||||
input: CubeTensor<R>,
|
||||
out_grad: CubeTensor<R>,
|
||||
weight_shape: Shape,
|
||||
options: ConvOptions<N>,
|
||||
tile_kind: AcceleratedTileKind,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
launch_backwards_weight::<R, N>(
|
||||
&Strategy::Simple {
|
||||
read_strategy: ReadingStrategy::Tma,
|
||||
tile_kind,
|
||||
},
|
||||
input,
|
||||
out_grad,
|
||||
weight_shape,
|
||||
options,
|
||||
)
|
||||
}
|
||||
|
||||
/// Perform a convolution backwards weight pass using the implicit GEMM (im2col) algorithm, using
|
||||
/// cubecl tiling matmul components.
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `out_grad` - The output gradients
|
||||
/// * `weight_shape` - The shape of the weights/weight gradients
|
||||
/// * `options` - The options to use for the convolution
|
||||
pub fn launch_backwards_weight<R: CubeRuntime, const N: usize>(
|
||||
strategy: &Strategy,
|
||||
input: CubeTensor<R>,
|
||||
out_grad: CubeTensor<R>,
|
||||
weight_shape: Shape,
|
||||
options: ConvOptions<N>,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
if options.groups != 1 {
|
||||
return Err(ConvSetupError::Groups(options.groups));
|
||||
}
|
||||
|
||||
let out_dtype = out_grad.dtype;
|
||||
|
||||
let weight_grad = empty_device_dtype(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
weight_shape,
|
||||
out_dtype,
|
||||
);
|
||||
|
||||
let client = input.client.clone();
|
||||
let dtypes = MatmulElems::from_globals(&MatmulGlobalElems {
|
||||
lhs: input.dtype.into(),
|
||||
rhs: out_grad.dtype.into(),
|
||||
out: out_dtype.into(),
|
||||
});
|
||||
let input = MatmulInputHandleRef::new(input.as_handle_ref(), input.dtype.into());
|
||||
let out_grad = MatmulInputHandleRef::new(out_grad.as_handle_ref(), out_grad.dtype.into());
|
||||
|
||||
backward_weight::launch_ref::<R, N>(
|
||||
strategy,
|
||||
&client,
|
||||
&input,
|
||||
&out_grad,
|
||||
&weight_grad.as_handle_ref(),
|
||||
ConvolutionArgs {
|
||||
stride: options.stride,
|
||||
padding: options.padding,
|
||||
dilation: options.dilation,
|
||||
},
|
||||
dtypes,
|
||||
)?;
|
||||
|
||||
Ok(weight_grad)
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
pub mod launch;
|
||||
pub use launch::*;
|
||||
@@ -0,0 +1,8 @@
|
||||
pub mod fallback;
|
||||
pub mod implicit_gemm;
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
pub mod tune;
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
pub(crate) use tune::*;
|
||||
@@ -0,0 +1,175 @@
|
||||
use burn_backend::ops::ConvOptions;
|
||||
use burn_std::Shape;
|
||||
use cubecl::{
|
||||
ir::StorageType,
|
||||
tune::{LocalTuner, Tunable, TunableSet, anchor, local_tuner},
|
||||
};
|
||||
use cubek::convolution::AcceleratedTileKind;
|
||||
|
||||
use crate::{
|
||||
CubeAutotuneKey, CubeRuntime, CubeTuneId,
|
||||
kernel::conv::{
|
||||
ConvAutotuneKey,
|
||||
backward_weight::{fallback::conv_weight_backward_fallback, implicit_gemm::*},
|
||||
},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
|
||||
/// Executes autotune on the weight gradients pass for convolution
|
||||
pub fn wgrad_autotune<R: CubeRuntime, const N: usize>(
|
||||
input: CubeTensor<R>,
|
||||
out_grad: CubeTensor<R>,
|
||||
weight_shape: Shape,
|
||||
options: ConvOptions<N>,
|
||||
) -> CubeTensor<R> {
|
||||
let client = input.client.clone();
|
||||
|
||||
static TUNER: LocalTuner<CubeAutotuneKey, CubeTuneId> = local_tuner!();
|
||||
|
||||
let tunables = TUNER.init(|| {
|
||||
TunableSet::new(create_key::<R, N>, create_wgrad_input::<R, N>)
|
||||
.with(Tunable::new(
|
||||
"wgrad_fallback",
|
||||
conv_weight_backward_fallback::<R, N>,
|
||||
))
|
||||
.with(Tunable::new(
|
||||
"simple_sync_cmma",
|
||||
|input, grad, shape, options| {
|
||||
wgrad_gemm_simple_sync(input, grad, shape, options, AcceleratedTileKind::Cmma)
|
||||
},
|
||||
))
|
||||
.with(Tunable::new(
|
||||
"simple_sync_mma",
|
||||
|input, grad, shape, options| {
|
||||
wgrad_gemm_simple_sync(input, grad, shape, options, AcceleratedTileKind::Mma)
|
||||
},
|
||||
))
|
||||
.with(Tunable::new(
|
||||
"simple_async_cmma",
|
||||
|input, grad, shape, options| {
|
||||
wgrad_gemm_simple_async(input, grad, shape, options, AcceleratedTileKind::Cmma)
|
||||
},
|
||||
))
|
||||
.with(Tunable::new(
|
||||
"simple_async_mma",
|
||||
|input, grad, shape, options| {
|
||||
wgrad_gemm_simple_async(input, grad, shape, options, AcceleratedTileKind::Mma)
|
||||
},
|
||||
))
|
||||
.with(Tunable::new(
|
||||
"simple_tma_cmma",
|
||||
|input, grad, shape, options| {
|
||||
wgrad_gemm_simple_tma(input, grad, shape, options, AcceleratedTileKind::Cmma)
|
||||
},
|
||||
))
|
||||
.with(Tunable::new(
|
||||
"simple_tma_mma",
|
||||
|input, grad, shape, options| {
|
||||
wgrad_gemm_simple_tma(input, grad, shape, options, AcceleratedTileKind::Mma)
|
||||
},
|
||||
))
|
||||
});
|
||||
|
||||
TUNER.execute(
|
||||
&CubeTuneId::new(&input.client, &input.device),
|
||||
&client,
|
||||
tunables,
|
||||
(input, out_grad, weight_shape, options),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn create_wgrad_input<R: CubeRuntime, const N: usize>(
|
||||
_key: &CubeAutotuneKey,
|
||||
input: &CubeTensor<R>,
|
||||
out_grad: &CubeTensor<R>,
|
||||
weight_shape: &Shape,
|
||||
options: &ConvOptions<N>,
|
||||
) -> (CubeTensor<R>, CubeTensor<R>, Shape, ConvOptions<N>) {
|
||||
(
|
||||
input.clone(),
|
||||
out_grad.clone(),
|
||||
weight_shape.clone(),
|
||||
options.clone(),
|
||||
)
|
||||
}
|
||||
|
||||
fn create_key<R: CubeRuntime, const N: usize>(
|
||||
input: &CubeTensor<R>,
|
||||
out_grad: &CubeTensor<R>,
|
||||
weight_shape: &Shape,
|
||||
options: &ConvOptions<N>,
|
||||
) -> CubeAutotuneKey {
|
||||
let dtype = input.dtype;
|
||||
let rank = input.meta.num_dims();
|
||||
let dim_c = rank - 1;
|
||||
|
||||
let batch_size = input.meta.shape()[0];
|
||||
let in_channels = input.meta.shape()[dim_c];
|
||||
let out_channels = weight_shape[0];
|
||||
|
||||
let kernel_size = weight_shape[1..dim_c].to_vec();
|
||||
let in_shape = input.meta.shape()[1..dim_c]
|
||||
.iter()
|
||||
.map(|shape| anchor(*shape, None, None, None))
|
||||
.collect();
|
||||
|
||||
let ConvOptions {
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
} = options.clone();
|
||||
|
||||
let lhs_stride_align = if out_grad.meta.strides()[dim_c] == 1 {
|
||||
stride_align(out_grad.meta.strides(), out_grad.dtype.into())
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let lhs_shape_align = pow2_factor(out_channels).min(lhs_stride_align);
|
||||
let rhs_stride_align = if input.meta.strides()[dim_c] == 1 {
|
||||
stride_align(input.meta.strides(), input.dtype.into())
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let rhs_shape_align = pow2_factor(in_channels).min(rhs_stride_align);
|
||||
|
||||
CubeAutotuneKey::Conv(ConvAutotuneKey::new(
|
||||
kernel_size,
|
||||
stride.to_vec(),
|
||||
padding.to_vec(),
|
||||
dilation.to_vec(),
|
||||
groups,
|
||||
in_channels,
|
||||
out_channels,
|
||||
in_shape,
|
||||
batch_size,
|
||||
false,
|
||||
dtype,
|
||||
lhs_shape_align,
|
||||
lhs_stride_align,
|
||||
rhs_shape_align,
|
||||
rhs_stride_align,
|
||||
))
|
||||
}
|
||||
|
||||
/// Maximum factor relevant for strides. Currently set to 2^10 because that's 128-byte swizzle's
|
||||
/// repeat number, so it's the largest align that can have performance impacts.
|
||||
const MAX_STRIDE_FACTOR: u32 = 10;
|
||||
|
||||
/// Defines the non-contiguous stride alignment in terms of powers of two
|
||||
fn stride_align(strides: &[usize], elem: StorageType) -> u8 {
|
||||
let max = MAX_STRIDE_FACTOR;
|
||||
let dim_c = strides.len() - 1;
|
||||
let factor = strides[..dim_c]
|
||||
.iter()
|
||||
.map(|it| (*it * elem.size_bits()) / 8)
|
||||
.map(|it| it.trailing_zeros())
|
||||
.min()
|
||||
.unwrap_or(max);
|
||||
factor.min(max) as u8
|
||||
}
|
||||
|
||||
/// Defines the potential vectorization.
|
||||
fn pow2_factor(axis: usize) -> u8 {
|
||||
axis.trailing_zeros().min(4) as u8
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
use burn_backend::ops::ConvOptions;
|
||||
use burn_std::Shape;
|
||||
use cubek::convolution::{AcceleratedTileKind, components::ConvSetupError};
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
use crate::kernel::conv::{backward_weight::wgrad_autotune, dgrad_autotune};
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::conv::{
|
||||
backward_data::{fallback::conv_data_backward_fallback, implicit_gemm::*},
|
||||
backward_weight::{fallback::conv_weight_backward_fallback, implicit_gemm::*},
|
||||
forward::implicit_gemm::conv_gemm_simple_sync,
|
||||
},
|
||||
ops::{permute_nchw_to_nhwc, permute_nchw_to_nhwc_shape, permute_nhwc_to_nchw},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
|
||||
use super::conv_direct;
|
||||
#[cfg(feature = "autotune")]
|
||||
use super::forward::conv_autotune;
|
||||
|
||||
/// The strategy to be used when launching a convolution kernel.
|
||||
pub enum ConvStrategy {
|
||||
/// A simple direct convolution.
|
||||
Direct,
|
||||
#[cfg(feature = "autotune")]
|
||||
/// Using autotune to choose the best kernel based on runtime information.
|
||||
Autotune,
|
||||
/// Implicit GEMM implementation of convolution. Lower memory usage but requires CMMA and
|
||||
/// has constraints on tensor shape.
|
||||
ImplicitGemm,
|
||||
}
|
||||
|
||||
impl Default for ConvStrategy {
|
||||
fn default() -> Self {
|
||||
// if autotune is enabled, default to autotune
|
||||
#[cfg(feature = "autotune")]
|
||||
return ConvStrategy::Autotune;
|
||||
|
||||
// if autotune is disabled, default to the more memory-conservative algorithm
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
ConvStrategy::Direct
|
||||
}
|
||||
}
|
||||
|
||||
/// Performs an N-dimensional convolution with the given strategy
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `weight` - The weights (filter) applied to each kernel
|
||||
/// * `bias` - The bias added to each channel
|
||||
/// * `options` - The options to use for the convolution
|
||||
/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
|
||||
pub fn conv_forward<R: CubeRuntime, const N: usize>(
|
||||
input: CubeTensor<R>,
|
||||
weight: CubeTensor<R>,
|
||||
bias: Option<CubeTensor<R>>,
|
||||
options: ConvOptions<N>,
|
||||
strategy: ConvStrategy,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
let input = permute_nchw_to_nhwc(input);
|
||||
let weight = permute_nchw_to_nhwc(weight);
|
||||
|
||||
let out = conv_forward_nhwc(input, weight, bias, options, strategy)?;
|
||||
|
||||
Ok(permute_nhwc_to_nchw(out))
|
||||
}
|
||||
|
||||
/// Performs an N-dimensional convolution with the given strategy on NHWC inputs/outputs
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `weight` - The weights (filter) applied to each kernel
|
||||
/// * `bias` - The bias added to each channel
|
||||
/// * `options` - The options to use for the convolution
|
||||
/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
|
||||
pub fn conv_forward_nhwc<R: CubeRuntime, const N: usize>(
|
||||
input: CubeTensor<R>,
|
||||
weight: CubeTensor<R>,
|
||||
bias: Option<CubeTensor<R>>,
|
||||
options: ConvOptions<N>,
|
||||
strategy: ConvStrategy,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
match strategy {
|
||||
ConvStrategy::Direct => conv_direct::<R, N>(input, weight, bias, options),
|
||||
#[cfg(feature = "autotune")]
|
||||
ConvStrategy::Autotune => Ok(conv_autotune::<R, N>(input, weight, bias, options)),
|
||||
ConvStrategy::ImplicitGemm => {
|
||||
if options.groups != 1 {
|
||||
conv_direct::<R, N>(input, weight, bias, options)
|
||||
} else {
|
||||
conv_gemm_simple_sync::<R, N>(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
options,
|
||||
AcceleratedTileKind::Cmma,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Performs an N-dimensional convolution backwards pass with regard to weight, with the given strategy
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `out_grad` - The output gradients
|
||||
/// * `weight_shape` - The shape of the weights/weight gradients
|
||||
/// * `options` - The options used for the convolution
|
||||
/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
|
||||
pub fn conv_weight_backward<R: CubeRuntime, const N: usize>(
|
||||
input: CubeTensor<R>,
|
||||
out_grad: CubeTensor<R>,
|
||||
weight_shape: Shape,
|
||||
options: ConvOptions<N>,
|
||||
strategy: ConvStrategy,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
let input = permute_nchw_to_nhwc(input);
|
||||
let out_grad = permute_nchw_to_nhwc(out_grad);
|
||||
let weight_shape = permute_nchw_to_nhwc_shape(weight_shape);
|
||||
|
||||
let weight_grad = match strategy {
|
||||
ConvStrategy::Direct => {
|
||||
conv_weight_backward_fallback::<R, N>(input, out_grad, weight_shape, options)
|
||||
}
|
||||
#[cfg(feature = "autotune")]
|
||||
ConvStrategy::Autotune => Ok(wgrad_autotune::<R, N>(
|
||||
input,
|
||||
out_grad,
|
||||
weight_shape,
|
||||
options,
|
||||
)),
|
||||
ConvStrategy::ImplicitGemm => {
|
||||
if options.groups != 1 {
|
||||
conv_weight_backward_fallback::<R, N>(input, out_grad, weight_shape, options)
|
||||
} else {
|
||||
wgrad_gemm_simple_sync::<R, N>(
|
||||
input,
|
||||
out_grad,
|
||||
weight_shape,
|
||||
options,
|
||||
AcceleratedTileKind::Cmma,
|
||||
)
|
||||
}
|
||||
}
|
||||
}?;
|
||||
|
||||
Ok(permute_nhwc_to_nchw(weight_grad))
|
||||
}
|
||||
|
||||
/// Performs an N-dimensional convolution backwards data pass with the given strategy
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `weight` - The weights (filter) applied to each kernel
|
||||
/// * `in_shape` - The shape of the input to the layer
|
||||
/// * `options` - The options to use for the convolution
|
||||
/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
|
||||
pub fn conv_data_backward<R: CubeRuntime, const N: usize>(
|
||||
out_grad: CubeTensor<R>,
|
||||
weights: CubeTensor<R>,
|
||||
in_shape: Shape,
|
||||
options: ConvOptions<N>,
|
||||
strategy: ConvStrategy,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
let out_grad = permute_nchw_to_nhwc(out_grad);
|
||||
let weights = permute_nchw_to_nhwc(weights);
|
||||
let in_shape = permute_nchw_to_nhwc_shape(in_shape);
|
||||
|
||||
let weight_grad = match strategy {
|
||||
ConvStrategy::Direct => {
|
||||
conv_data_backward_fallback::<R, N>(out_grad, weights, in_shape, options)?
|
||||
}
|
||||
#[cfg(feature = "autotune")]
|
||||
ConvStrategy::Autotune => dgrad_autotune::<R, N>(out_grad, weights, in_shape, options),
|
||||
ConvStrategy::ImplicitGemm => {
|
||||
if options.groups != 1 || options.stride.iter().any(|&s| s != 1) {
|
||||
conv_data_backward_fallback::<R, N>(out_grad, weights, in_shape, options)?
|
||||
} else {
|
||||
dgrad_gemm_simple_sync::<R, N>(
|
||||
out_grad,
|
||||
weights,
|
||||
in_shape,
|
||||
options,
|
||||
AcceleratedTileKind::Cmma,
|
||||
)?
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(permute_nhwc_to_nchw(weight_grad))
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
use crate::{CubeRuntime, tensor::CubeTensor};
|
||||
use burn_backend::ops::ConvTransposeOptions;
|
||||
use cubek::convolution::components::ConvSetupError;
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
use super::conv_transpose2d_autotune;
|
||||
use super::{conv_transpose2d_col2im, conv_transpose2d_direct};
|
||||
|
||||
/// The strategy to be used when launching a conv_transpose kernel.
|
||||
pub enum ConvTranspose2dStrategy {
|
||||
/// A simple direct convolution.
|
||||
Direct,
|
||||
#[cfg(feature = "autotune")]
|
||||
/// Using autotune to choose the best kernel based on runtime information.
|
||||
Autotune,
|
||||
/// GEMM (im2col) based implementation of convolution. Significantly increased memory usage.
|
||||
Gemm,
|
||||
}
|
||||
|
||||
impl Default for ConvTranspose2dStrategy {
|
||||
fn default() -> Self {
|
||||
// if autotune is enabled, default to autotune
|
||||
#[cfg(feature = "autotune")]
|
||||
return ConvTranspose2dStrategy::Autotune;
|
||||
|
||||
// if autotune is disabled, default to the more memory-conservative algorithm
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
ConvTranspose2dStrategy::Direct
|
||||
}
|
||||
}
|
||||
|
||||
/// Performs a 2D convolution with the given strategy
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `weight` - The weights (filter) applied to each kernel
|
||||
/// * `bias` - The bias added to each channel
|
||||
/// * `options` - The options to use for the convolution
|
||||
/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
|
||||
pub fn conv_transpose2d<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
weight: CubeTensor<R>,
|
||||
bias: Option<CubeTensor<R>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
strategy: ConvTranspose2dStrategy,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
match strategy {
|
||||
ConvTranspose2dStrategy::Direct => conv_transpose2d_direct(input, weight, bias, options),
|
||||
#[cfg(feature = "autotune")]
|
||||
ConvTranspose2dStrategy::Autotune => {
|
||||
Ok(conv_transpose2d_autotune(input, weight, bias, options))
|
||||
}
|
||||
ConvTranspose2dStrategy::Gemm => conv_transpose2d_col2im(input, weight, bias, options),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::{
|
||||
conv::batches_per_run,
|
||||
into_contiguous_aligned,
|
||||
matmul::{MatmulStrategy, matmul},
|
||||
slice,
|
||||
utils::{address_type, decompose_linear, linear_view, shape_divmod},
|
||||
},
|
||||
ops::{numeric::empty_device_dtype, reshape, swap_dims},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::{
|
||||
Shape,
|
||||
ops::{ConvTransposeOptions, conv::calculate_conv_transpose_output_size},
|
||||
};
|
||||
use cubecl::{
|
||||
calculate_cube_count_elemwise,
|
||||
prelude::*,
|
||||
std::{FastDivmod, tensor::layout::linear::LinearView},
|
||||
};
|
||||
use cubek::convolution::components::ConvSetupError;
|
||||
|
||||
/// Perform a 2D convolution transposition using the GEMM (col2im) algorithm.
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `weight` - The weights (filter) applied to each kernel
|
||||
/// * `bias` - The bias added to each channel
|
||||
/// * `options` - The options to use for the convolution
|
||||
pub fn conv_transpose2d_col2im<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
weight: CubeTensor<R>,
|
||||
bias: Option<CubeTensor<R>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
let [input_channels, im_ch_per_group, kernel_h, kernel_w] = weight.meta.shape().dims();
|
||||
let [batch_size, _, input_h, input_w] = input.meta.shape().dims();
|
||||
let groups = options.groups;
|
||||
let input_ch_per_group = input_channels / groups;
|
||||
let ConvTransposeOptions {
|
||||
padding: [padding_h, padding_w],
|
||||
padding_out: [padding_out_h, padding_out_w],
|
||||
dilation: [dilation_h, dilation_w],
|
||||
stride: [stride_h, stride_w],
|
||||
..
|
||||
} = options.clone();
|
||||
|
||||
let im_h = calculate_conv_transpose_output_size(
|
||||
kernel_h,
|
||||
stride_h,
|
||||
padding_h,
|
||||
padding_out_h,
|
||||
dilation_h,
|
||||
input_h,
|
||||
);
|
||||
let im_w = calculate_conv_transpose_output_size(
|
||||
kernel_w,
|
||||
stride_w,
|
||||
padding_w,
|
||||
padding_out_w,
|
||||
dilation_w,
|
||||
input_w,
|
||||
);
|
||||
let im_channels = im_ch_per_group * groups;
|
||||
|
||||
let batches_per_run = batches_per_run(
|
||||
batch_size,
|
||||
input_h * input_w,
|
||||
input.client.properties().hardware.plane_size_max as usize,
|
||||
)?;
|
||||
let col_shape_0 = im_ch_per_group * kernel_h * kernel_w;
|
||||
|
||||
let weight = reshape(
|
||||
weight.clone(),
|
||||
Shape::new([groups, input_ch_per_group, col_shape_0]),
|
||||
);
|
||||
let weight = into_contiguous_aligned(swap_dims(weight, 1, 2));
|
||||
|
||||
if batches_per_run != batch_size {
|
||||
let runs = batch_size / batches_per_run;
|
||||
|
||||
let im_shape = Shape::new([runs, batches_per_run, im_channels, im_h, im_w]);
|
||||
let image = empty_device_dtype(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
im_shape,
|
||||
input.dtype,
|
||||
);
|
||||
|
||||
let input_shape = Shape::new([runs, batches_per_run, input_channels, input_h, input_w]);
|
||||
let input = reshape(input, input_shape);
|
||||
let input_shape_run = Shape::new([batches_per_run, input_channels, input_h, input_w]);
|
||||
|
||||
for run in 0..runs {
|
||||
let input = index(input.clone(), run);
|
||||
let input = reshape(input, input_shape_run.clone());
|
||||
let im_shape = Shape::new([batches_per_run, im_channels, im_h, im_w]);
|
||||
let image_slice = index(image.clone(), run);
|
||||
let image_slice = reshape(image_slice, im_shape);
|
||||
execute(
|
||||
input,
|
||||
weight.clone(),
|
||||
bias.clone(),
|
||||
image_slice,
|
||||
options.clone(),
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
)?;
|
||||
}
|
||||
Ok(reshape(
|
||||
image,
|
||||
Shape::new([batch_size, im_channels, im_h, im_w]),
|
||||
))
|
||||
} else {
|
||||
let im_shape = Shape::new([batches_per_run, im_channels, im_h, im_w]);
|
||||
let image = empty_device_dtype(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
im_shape,
|
||||
input.dtype,
|
||||
);
|
||||
execute(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
image.clone(),
|
||||
options,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
)?;
|
||||
Ok(image)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn index<R: CubeRuntime>(tensor: CubeTensor<R>, i: usize) -> CubeTensor<R> {
|
||||
#[allow(clippy::single_range_in_vec_init)]
|
||||
let mut indices = vec![i..i + 1];
|
||||
for dim in tensor.meta.shape()[1..].iter() {
|
||||
indices.push(0..*dim);
|
||||
}
|
||||
let mut tensor = slice(tensor, &indices);
|
||||
tensor.meta.remove(0);
|
||||
tensor
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn execute<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
weight: CubeTensor<R>,
|
||||
bias: Option<CubeTensor<R>>,
|
||||
image: CubeTensor<R>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
kernel_h: usize,
|
||||
kernel_w: usize,
|
||||
) -> Result<(), ConvSetupError> {
|
||||
let [batch_size, _, input_h, input_w] = input.meta.shape().dims();
|
||||
let [groups, col_shape_0, input_ch_per_group] = weight.meta.shape().dims();
|
||||
|
||||
let col_shape_1 = batch_size * input_h * input_w;
|
||||
|
||||
let input = swap_dims(input, 0, 1);
|
||||
let input_shape = Shape::new([groups, input_ch_per_group, col_shape_1]);
|
||||
let input = reshape(input, input_shape);
|
||||
|
||||
let dtype = input.dtype;
|
||||
let columns = matmul(weight, input, None, MatmulStrategy::default(), dtype)?;
|
||||
let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1]));
|
||||
|
||||
col2im(
|
||||
columns, bias, image, kernel_h, kernel_w, input_h, input_w, options,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn col2im<R: CubeRuntime>(
|
||||
columns: CubeTensor<R>,
|
||||
bias: Option<CubeTensor<R>>,
|
||||
out: CubeTensor<R>,
|
||||
kernel_h: usize,
|
||||
kernel_w: usize,
|
||||
out_h: usize,
|
||||
out_w: usize,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> Result<(), LaunchError> {
|
||||
let dtype = columns.dtype;
|
||||
|
||||
let columns = into_contiguous_aligned(columns);
|
||||
let bias = bias.map(into_contiguous_aligned);
|
||||
|
||||
let num_elems = out.meta.num_elements();
|
||||
|
||||
let cube_dim = CubeDim::new(&columns.client, num_elems);
|
||||
let cube_count = calculate_cube_count_elemwise(&columns.client, num_elems, cube_dim);
|
||||
|
||||
unsafe {
|
||||
col2im_kernel::launch_unchecked(
|
||||
&columns.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(columns, bias, out),
|
||||
columns.as_tensor_arg(1),
|
||||
bias.as_ref().map(|bias| bias.as_tensor_arg(1)).into(),
|
||||
linear_view(&out, 1),
|
||||
shape_divmod(&out),
|
||||
Col2ImArgsLaunch::new(
|
||||
ScalarArg::new(out_h),
|
||||
ScalarArg::new(out_w),
|
||||
ScalarArg::new(kernel_h),
|
||||
ScalarArg::new(kernel_w),
|
||||
ScalarArg::new(options.padding[0]),
|
||||
ScalarArg::new(options.padding[1]),
|
||||
ScalarArg::new(options.dilation[0]),
|
||||
ScalarArg::new(options.dilation[1]),
|
||||
ScalarArg::new(options.stride[0]),
|
||||
ScalarArg::new(options.stride[1]),
|
||||
),
|
||||
dtype.into(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(CubeLaunch, CubeType)]
|
||||
struct Col2ImArgs {
|
||||
out_h: usize,
|
||||
out_w: usize,
|
||||
|
||||
kernel_h: usize,
|
||||
kernel_w: usize,
|
||||
|
||||
pad_h: usize,
|
||||
pad_w: usize,
|
||||
dilation_h: usize,
|
||||
dilation_w: usize,
|
||||
stride_h: usize,
|
||||
stride_w: usize,
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn col2im_kernel<E: Numeric>(
|
||||
columns: &Tensor<E>,
|
||||
bias: &Option<Tensor<E>>,
|
||||
image: &mut LinearView<E, ReadWrite>,
|
||||
image_shape: Sequence<FastDivmod<usize>>,
|
||||
args: &Col2ImArgs,
|
||||
#[define(E)] _dtype: StorageType,
|
||||
) {
|
||||
if ABSOLUTE_POS >= image.shape() {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let (_, pos) = decompose_linear(ABSOLUTE_POS, &image_shape);
|
||||
let [batch, ch_im, im_y, im_x] = *pos else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let im_x = im_x + args.pad_w;
|
||||
let im_y = im_y + args.pad_h;
|
||||
|
||||
let kernel_extent_w = (args.kernel_w - 1) * args.dilation_w + 1;
|
||||
let kernel_extent_h = (args.kernel_h - 1) * args.dilation_h + 1;
|
||||
|
||||
let mut val = E::from_int(0);
|
||||
|
||||
let x_col_start = if im_x >= kernel_extent_w {
|
||||
(im_x - kernel_extent_w) / args.stride_w + 1
|
||||
} else {
|
||||
0usize.runtime()
|
||||
};
|
||||
let x_col_end = clamp_max(im_x / args.stride_w + 1, args.out_w);
|
||||
let y_col_start = if im_y >= kernel_extent_h {
|
||||
(im_y - kernel_extent_h) / args.stride_h + 1
|
||||
} else {
|
||||
0usize.runtime()
|
||||
};
|
||||
let y_col_end = clamp_max(im_y / args.stride_h + 1, args.out_h);
|
||||
|
||||
for col_y in y_col_start..y_col_end {
|
||||
let kernel_y = im_y - col_y * args.stride_h;
|
||||
for col_x in x_col_start..x_col_end {
|
||||
let kernel_x = im_x - col_x * args.stride_w;
|
||||
|
||||
if kernel_y.is_multiple_of(args.dilation_h) && kernel_x.is_multiple_of(args.dilation_w)
|
||||
{
|
||||
let kernel_y = kernel_y / args.dilation_h;
|
||||
let kernel_x = kernel_x / args.dilation_w;
|
||||
|
||||
let col_k =
|
||||
ch_im * args.kernel_h * args.kernel_w + kernel_y * args.kernel_w + kernel_x;
|
||||
let col_n = batch * args.out_h * args.out_w + col_y * args.out_w + col_x;
|
||||
let col_pos = col_k * columns.stride(0) + col_n * columns.stride(1);
|
||||
val += columns[col_pos];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match bias {
|
||||
Some(bias) => image[ABSOLUTE_POS] = val + bias[ch_im],
|
||||
None => image[ABSOLUTE_POS] = val,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
mod base;
|
||||
mod col2im;
|
||||
|
||||
mod transpose_direct;
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
mod tune;
|
||||
|
||||
pub use base::*;
|
||||
pub use col2im::*;
|
||||
|
||||
pub use transpose_direct::*;
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
pub use tune::*;
|
||||
@@ -0,0 +1,185 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{address_type, decompose_linear, linear_view, shape_divmod},
|
||||
ops::numeric::empty_device_dtype,
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::{Shape, ops::ConvTransposeOptions};
|
||||
use cubecl::{
|
||||
calculate_cube_count_elemwise,
|
||||
prelude::*,
|
||||
std::{FastDivmod, tensor::layout::linear::LinearView},
|
||||
};
|
||||
use cubek::convolution::components::ConvSetupError;
|
||||
|
||||
#[derive(CubeLaunch, CubeType)]
|
||||
struct ConvArgs {
|
||||
conv_stride_0: usize,
|
||||
conv_stride_1: usize,
|
||||
dilation_0: usize,
|
||||
dilation_1: usize,
|
||||
padding_0: usize,
|
||||
padding_1: usize,
|
||||
groups: usize,
|
||||
}
|
||||
|
||||
#[cube(launch, address_type = "dynamic")]
|
||||
fn conv_transpose2d_direct_kernel<E: Numeric>(
|
||||
input: &Tensor<E>,
|
||||
weight: &Tensor<E>,
|
||||
bias: &Option<Tensor<E>>,
|
||||
output: &mut LinearView<E, ReadWrite>,
|
||||
out_shape: Sequence<FastDivmod<usize>>,
|
||||
args: ConvArgs,
|
||||
#[define(E)] _dtype: StorageType,
|
||||
) {
|
||||
if ABSOLUTE_POS >= output.shape() {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let in_c_per_group = weight.shape(0) / args.groups;
|
||||
let out_c_per_group = weight.shape(1);
|
||||
let kernel_h = weight.shape(2);
|
||||
let kernel_w = weight.shape(3);
|
||||
|
||||
let (_, pos) = decompose_linear(ABSOLUTE_POS, &out_shape);
|
||||
let [batch, oc_out, out_y, out_x] = *pos else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let k = oc_out / out_c_per_group;
|
||||
let group = k % args.groups;
|
||||
let out_c = oc_out - out_c_per_group * group;
|
||||
|
||||
let in_c_start = group * in_c_per_group;
|
||||
let in_c_end = in_c_start + in_c_per_group;
|
||||
|
||||
let stride_0_i = args.conv_stride_0 as i32;
|
||||
let stride_1_i = args.conv_stride_1 as i32;
|
||||
|
||||
let kms_h = (kernel_h * args.dilation_0) as i32 - stride_0_i;
|
||||
let kms_w = (kernel_w * args.dilation_1) as i32 - stride_1_i;
|
||||
|
||||
let y_start = ((out_y + args.padding_0) as i32 - kms_h) / stride_0_i;
|
||||
let x_start = ((out_x + args.padding_1) as i32 - kms_w) / stride_1_i;
|
||||
|
||||
let y_end = clamp(kms_h + y_start + 1, 0, input.shape(2) as i32) as usize;
|
||||
let x_end = clamp(kms_w + x_start + 1, 0, input.shape(3) as i32) as usize;
|
||||
let y_start = clamp_min(y_start, 0) as usize;
|
||||
let x_start = clamp_min(x_start, 0) as usize;
|
||||
|
||||
let idx_input_batch = batch * input.stride(0);
|
||||
let idx_weight_oc = out_c * weight.stride(1);
|
||||
|
||||
let bias: Option<E> = bias.map(|bias| bias[oc_out]);
|
||||
let mut sum = bias.unwrap_or_default();
|
||||
|
||||
let numerator_h_base = out_y + args.padding_0;
|
||||
let numerator_w_base = out_x + args.padding_1;
|
||||
|
||||
for in_c in in_c_start..in_c_end {
|
||||
let idx_input_ic = in_c * input.stride(1);
|
||||
let idx_weight_ic = in_c * weight.stride(0);
|
||||
|
||||
for in_y in y_start..y_end {
|
||||
let numerator_tmp = in_y * args.conv_stride_0;
|
||||
let numerator_h = numerator_h_base - numerator_tmp;
|
||||
|
||||
if numerator_h_base >= numerator_tmp && numerator_h.is_multiple_of(args.dilation_0) {
|
||||
let kernel_y = numerator_h / args.dilation_0;
|
||||
let idx_input_y = in_y * input.stride(2);
|
||||
let idx_weight_ky = kernel_y * weight.stride(2);
|
||||
|
||||
for in_x in x_start..x_end {
|
||||
let numerator_tmp = in_x * args.conv_stride_1;
|
||||
let numerator_w = numerator_w_base - numerator_tmp;
|
||||
|
||||
if numerator_w_base >= numerator_tmp
|
||||
&& numerator_w.is_multiple_of(args.dilation_1)
|
||||
{
|
||||
let kernel_x = numerator_w / args.dilation_1;
|
||||
let idx_input_x = in_x * input.stride(3);
|
||||
let idx_weight_kx = kernel_x * weight.stride(3);
|
||||
|
||||
let index_input =
|
||||
idx_input_batch + idx_input_ic + idx_input_y + idx_input_x;
|
||||
let index_weight =
|
||||
idx_weight_ic + idx_weight_oc + idx_weight_ky + idx_weight_kx;
|
||||
|
||||
let value = input[index_input];
|
||||
let weight = weight[index_weight];
|
||||
|
||||
sum += value * weight;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output[ABSOLUTE_POS] = sum;
|
||||
}
|
||||
|
||||
/// Perform a 2D convolution transposition using the direct algorithm.
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `weight` - The weights (filter) applied to each kernel
|
||||
/// * `bias` - The bias added to each channel
|
||||
/// * `options` - The options to use for the convolution
|
||||
///
|
||||
pub fn conv_transpose2d_direct<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
weight: CubeTensor<R>,
|
||||
bias: Option<CubeTensor<R>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
let [batch_size, _, in_height, in_width] = input.meta.shape().dims();
|
||||
let [_, out_channels, kernel_0, kernel_1] = weight.meta.shape().dims();
|
||||
|
||||
let out_0 = (in_height - 1) * options.stride[0]
|
||||
+ options.dilation[0] * (kernel_0 - 1)
|
||||
+ options.padding_out[0]
|
||||
- 2 * options.padding[0]
|
||||
+ 1;
|
||||
let out_1 = (in_width - 1) * options.stride[1]
|
||||
+ options.dilation[1] * (kernel_1 - 1)
|
||||
+ options.padding_out[1]
|
||||
- 2 * options.padding[1]
|
||||
+ 1;
|
||||
|
||||
let shape_out = Shape::new([batch_size, out_channels * options.groups, out_0, out_1]);
|
||||
|
||||
let output = empty_device_dtype(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
shape_out.clone(),
|
||||
input.dtype,
|
||||
);
|
||||
|
||||
let num_elems = output.meta.num_elements();
|
||||
let cube_dim = CubeDim::new(&input.client, num_elems);
|
||||
let cube_count = calculate_cube_count_elemwise(&input.client, num_elems, cube_dim);
|
||||
|
||||
conv_transpose2d_direct_kernel::launch(
|
||||
&input.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(input, weight, bias, output),
|
||||
input.as_tensor_arg(1),
|
||||
weight.as_tensor_arg(1),
|
||||
bias.as_ref().map(|bias| bias.as_tensor_arg(1)).into(),
|
||||
linear_view(&output, 1),
|
||||
shape_divmod(&output),
|
||||
ConvArgsLaunch::new(
|
||||
ScalarArg::new(options.stride[0]),
|
||||
ScalarArg::new(options.stride[1]),
|
||||
ScalarArg::new(options.dilation[0]),
|
||||
ScalarArg::new(options.dilation[1]),
|
||||
ScalarArg::new(options.padding[0]),
|
||||
ScalarArg::new(options.padding[1]),
|
||||
ScalarArg::new(options.groups),
|
||||
),
|
||||
input.dtype.into(),
|
||||
)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
use burn_backend::ops::ConvTransposeOptions;
|
||||
use cubecl::tune::{LocalTuner, Tunable, TunableSet, local_tuner};
|
||||
|
||||
use crate::{
|
||||
CubeAutotuneKey, CubeRuntime, CubeTuneId,
|
||||
kernel::conv::{ConvTranspose2dAutotuneKey, conv_transpose2d_col2im, conv_transpose2d_direct},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
|
||||
/// Executes autotune on conv2d operations
|
||||
pub fn conv_transpose2d_autotune<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
weights: CubeTensor<R>,
|
||||
bias: Option<CubeTensor<R>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> CubeTensor<R> {
|
||||
let client = input.client.clone();
|
||||
|
||||
static TUNER: LocalTuner<CubeAutotuneKey, CubeTuneId> = local_tuner!();
|
||||
|
||||
let tune_set = TUNER.init(|| {
|
||||
TunableSet::new(create_key::<R>, create_transpose2d_input::<R>)
|
||||
.with(Tunable::new(
|
||||
"conv_transpose2d_direct",
|
||||
conv_transpose2d_direct::<R>,
|
||||
))
|
||||
.with(Tunable::new(
|
||||
"conv_transpose2d_col2im",
|
||||
conv_transpose2d_col2im::<R>,
|
||||
))
|
||||
});
|
||||
|
||||
TUNER.execute(
|
||||
&CubeTuneId::new(&input.client, &input.device),
|
||||
&client,
|
||||
tune_set,
|
||||
(input, weights, bias, options),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn create_transpose2d_input<R: CubeRuntime>(
|
||||
_key: &CubeAutotuneKey,
|
||||
input: &CubeTensor<R>,
|
||||
weights: &CubeTensor<R>,
|
||||
bias: &Option<CubeTensor<R>>,
|
||||
options: &ConvTransposeOptions<2>,
|
||||
) -> (
|
||||
CubeTensor<R>,
|
||||
CubeTensor<R>,
|
||||
Option<CubeTensor<R>>,
|
||||
ConvTransposeOptions<2>,
|
||||
) {
|
||||
(
|
||||
input.clone(),
|
||||
weights.clone(),
|
||||
bias.clone(),
|
||||
options.clone(),
|
||||
)
|
||||
}
|
||||
|
||||
fn create_key<R: CubeRuntime>(
|
||||
input: &CubeTensor<R>,
|
||||
weights: &CubeTensor<R>,
|
||||
bias: &Option<CubeTensor<R>>,
|
||||
options: &ConvTransposeOptions<2>,
|
||||
) -> CubeAutotuneKey {
|
||||
let [batch_size, in_channels, height, width] = input.meta.shape().dims();
|
||||
let [out_channels, _, kernel_h, kernel_w] = weights.meta.shape().dims();
|
||||
let ConvTransposeOptions {
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
padding_out,
|
||||
} = options.clone();
|
||||
CubeAutotuneKey::ConvTranspose(ConvTranspose2dAutotuneKey::new(
|
||||
[kernel_h, kernel_w],
|
||||
stride,
|
||||
padding,
|
||||
padding_out,
|
||||
dilation,
|
||||
groups,
|
||||
in_channels,
|
||||
out_channels,
|
||||
height,
|
||||
width,
|
||||
batch_size,
|
||||
bias.is_some(),
|
||||
input.dtype,
|
||||
))
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
use cubecl::{
|
||||
calculate_cube_count_elemwise,
|
||||
prelude::*,
|
||||
std::{FastDivmod, tensor::layout::linear::LinearView},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{address_type, decompose_linear, linear_view, shape_divmod},
|
||||
ops::numeric::empty_device_dtype,
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::{Shape, ops::ConvTransposeOptions};
|
||||
|
||||
#[derive(CubeLaunch, CubeType)]
|
||||
struct ConvArgs {
|
||||
conv_stride_0: usize,
|
||||
conv_stride_1: usize,
|
||||
conv_stride_2: usize,
|
||||
dilation_0: usize,
|
||||
dilation_1: usize,
|
||||
dilation_2: usize,
|
||||
padding_0: usize,
|
||||
padding_1: usize,
|
||||
padding_2: usize,
|
||||
groups: usize,
|
||||
}
|
||||
|
||||
#[cube(launch, address_type = "dynamic")]
|
||||
fn conv_transpose3d_kernel<E: Numeric>(
|
||||
input: &Tensor<E>,
|
||||
weight: &Tensor<E>,
|
||||
bias: &Option<Tensor<E>>,
|
||||
output: &mut LinearView<E, ReadWrite>,
|
||||
out_shape: Sequence<FastDivmod<usize>>,
|
||||
args: ConvArgs,
|
||||
#[define(E)] _dtype: StorageType,
|
||||
) {
|
||||
let in_channels = weight.shape(0);
|
||||
let out_c_per_group = weight.shape(1);
|
||||
let kernel_size_0 = weight.shape(2);
|
||||
let kernel_size_1 = weight.shape(3);
|
||||
let kernel_size_2 = weight.shape(4);
|
||||
|
||||
let stride_0_i = args.conv_stride_0 as i32;
|
||||
let stride_1_i = args.conv_stride_1 as i32;
|
||||
let stride_2_i = args.conv_stride_2 as i32;
|
||||
|
||||
let (_, pos) = decompose_linear(ABSOLUTE_POS, &out_shape);
|
||||
let [batch, out_c_out, out_z, out_y, out_x] = *pos else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let groups = args.groups;
|
||||
let in_c_per_group = in_channels / groups;
|
||||
|
||||
let k = out_c_out / out_c_per_group;
|
||||
let group = k % groups;
|
||||
let out_channel = out_c_out - out_c_per_group * group;
|
||||
|
||||
let in_c_start = group * in_c_per_group;
|
||||
let in_c_end = in_c_start + in_c_per_group;
|
||||
|
||||
let kernel_d = (kernel_size_0 * args.dilation_0 - args.conv_stride_0) as i32;
|
||||
let kernel_h = (kernel_size_1 * args.dilation_1 - args.conv_stride_1) as i32;
|
||||
let kernel_w = (kernel_size_2 * args.dilation_2 - args.conv_stride_2) as i32;
|
||||
|
||||
let z_start = ((out_z + args.padding_0) as i32 - kernel_d) / stride_0_i;
|
||||
let y_start = ((out_y + args.padding_1) as i32 - kernel_h) / stride_1_i;
|
||||
let x_start = ((out_x + args.padding_2) as i32 - kernel_w) / stride_2_i;
|
||||
|
||||
let z_end = clamp(kernel_d + z_start + 1, 0, input.shape(2) as i32) as usize;
|
||||
let y_end = clamp(kernel_h + y_start + 1, 0, input.shape(3) as i32) as usize;
|
||||
let x_end = clamp(kernel_w + x_start + 1, 0, input.shape(4) as i32) as usize;
|
||||
|
||||
let z_start = clamp_min(z_start, 0) as usize;
|
||||
let y_start = clamp_min(y_start, 0) as usize;
|
||||
let x_start = clamp_min(x_start, 0) as usize;
|
||||
|
||||
let index_input_batch = batch * input.stride(0);
|
||||
let index_weight_out_c = out_channel * weight.stride(1);
|
||||
|
||||
let bias: Option<E> = bias.map(|bias| bias[out_c_out]);
|
||||
let mut sum = bias.unwrap_or_default();
|
||||
|
||||
let numerator_d_base = out_z + args.padding_0;
|
||||
let numerator_h_base = out_y + args.padding_1;
|
||||
let numerator_w_base = out_x + args.padding_2;
|
||||
|
||||
for in_c in in_c_start..in_c_end {
|
||||
let index_input_in_c = in_c * input.stride(1);
|
||||
let index_weight_in_c = in_c * weight.stride(0);
|
||||
|
||||
for in_z in z_start..z_end {
|
||||
let numerator_tmp = in_z * args.conv_stride_0;
|
||||
let numerator_d = numerator_d_base - numerator_tmp;
|
||||
|
||||
if numerator_d_base >= numerator_tmp && numerator_d.is_multiple_of(args.dilation_0) {
|
||||
let kernel_z = numerator_d / args.dilation_0;
|
||||
let index_input_z = in_z * input.stride(2);
|
||||
let index_weight_kz = kernel_z * weight.stride(2);
|
||||
|
||||
for in_y in y_start..y_end {
|
||||
let numerator_tmp = in_y * args.conv_stride_1;
|
||||
let numerator_h = numerator_h_base - numerator_tmp;
|
||||
|
||||
if numerator_h_base >= numerator_tmp
|
||||
&& numerator_h.is_multiple_of(args.dilation_1)
|
||||
{
|
||||
let kernel_y = numerator_h / args.dilation_1;
|
||||
let index_input_y = in_y * input.stride(3);
|
||||
let index_weight_ky = kernel_y * weight.stride(3);
|
||||
|
||||
for in_x in x_start..x_end {
|
||||
let numerator_tmp = in_x * args.conv_stride_2;
|
||||
let numerator_w = numerator_w_base - numerator_tmp;
|
||||
|
||||
if numerator_w_base >= numerator_tmp
|
||||
&& numerator_w.is_multiple_of(args.dilation_2)
|
||||
{
|
||||
let kernel_x = numerator_w / args.dilation_2;
|
||||
let index_input_x = in_x * input.stride(4);
|
||||
let index_weight_kx = kernel_x * weight.stride(4);
|
||||
|
||||
let index_input = index_input_batch
|
||||
+ index_input_in_c
|
||||
+ index_input_z
|
||||
+ index_input_y
|
||||
+ index_input_x;
|
||||
|
||||
let index_weight = index_weight_in_c
|
||||
+ index_weight_out_c
|
||||
+ index_weight_kz
|
||||
+ index_weight_ky
|
||||
+ index_weight_kx;
|
||||
|
||||
let value = input[index_input];
|
||||
let weight = weight[index_weight];
|
||||
|
||||
sum += value * weight;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output[ABSOLUTE_POS] = sum;
|
||||
}
|
||||
|
||||
pub(crate) fn conv_transpose3d<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
weight: CubeTensor<R>,
|
||||
bias: Option<CubeTensor<R>>,
|
||||
options: ConvTransposeOptions<3>,
|
||||
) -> Result<CubeTensor<R>, LaunchError> {
|
||||
let [batch_size, _, in_depth, in_height, in_width] = input.meta.shape().dims();
|
||||
let [_, out_channels, kernel_0, kernel_1, kernel_2] = weight.meta.shape().dims();
|
||||
|
||||
let out_0 = (in_depth - 1) * options.stride[0]
|
||||
+ options.dilation[0] * (kernel_0 - 1)
|
||||
+ options.padding_out[0]
|
||||
- 2 * options.padding[0]
|
||||
+ 1;
|
||||
let out_1 = (in_height - 1) * options.stride[1]
|
||||
+ options.dilation[1] * (kernel_1 - 1)
|
||||
+ options.padding_out[1]
|
||||
- 2 * options.padding[1]
|
||||
+ 1;
|
||||
let out_2 = (in_width - 1) * options.stride[2]
|
||||
+ options.dilation[2] * (kernel_2 - 1)
|
||||
+ options.padding_out[2]
|
||||
- 2 * options.padding[2]
|
||||
+ 1;
|
||||
|
||||
let shape_out = Shape::new([
|
||||
batch_size,
|
||||
out_channels * options.groups,
|
||||
out_0,
|
||||
out_1,
|
||||
out_2,
|
||||
]);
|
||||
|
||||
let output = empty_device_dtype(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
shape_out.clone(),
|
||||
input.dtype,
|
||||
);
|
||||
|
||||
let num_elems = output.meta.num_elements();
|
||||
let cube_dim = CubeDim::new(&input.client, num_elems);
|
||||
let cube_count = calculate_cube_count_elemwise(&input.client, num_elems, cube_dim);
|
||||
|
||||
conv_transpose3d_kernel::launch(
|
||||
&input.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(input, weight, bias, output),
|
||||
input.as_tensor_arg(1),
|
||||
weight.as_tensor_arg(1),
|
||||
bias.as_ref().map(|bias| bias.as_tensor_arg(1)).into(),
|
||||
linear_view(&output, 1),
|
||||
shape_divmod(&output),
|
||||
ConvArgsLaunch::new(
|
||||
ScalarArg::new(options.stride[0]),
|
||||
ScalarArg::new(options.stride[1]),
|
||||
ScalarArg::new(options.stride[2]),
|
||||
ScalarArg::new(options.dilation[0]),
|
||||
ScalarArg::new(options.dilation[1]),
|
||||
ScalarArg::new(options.dilation[2]),
|
||||
ScalarArg::new(options.padding[0]),
|
||||
ScalarArg::new(options.padding[1]),
|
||||
ScalarArg::new(options.padding[2]),
|
||||
ScalarArg::new(options.groups),
|
||||
),
|
||||
input.dtype.into(),
|
||||
)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
@@ -0,0 +1,314 @@
|
||||
use cubecl::{
|
||||
calculate_cube_count_elemwise,
|
||||
prelude::*,
|
||||
std::{FastDivmod, FastDivmodArgs},
|
||||
};
|
||||
use cubek::convolution::components::ConvSetupError;
|
||||
|
||||
use burn_backend::{
|
||||
Shape,
|
||||
ops::{DeformConvOptions, conv::calculate_conv_output_size},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::{
|
||||
AddOp, into_contiguous_aligned, launch_binop,
|
||||
matmul::{MatmulStrategy, matmul},
|
||||
utils::address_type,
|
||||
},
|
||||
ops::{numeric::zeros_client, reshape, swap_dims},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
|
||||
#[derive(CubeLaunch, CubeType)]
|
||||
struct DeformConv2dArgs {
|
||||
conv_stride_h: usize,
|
||||
conv_stride_w: usize,
|
||||
dilation_h: usize,
|
||||
dilation_w: usize,
|
||||
padding_h: InputScalar,
|
||||
padding_w: InputScalar,
|
||||
offset_groups: usize,
|
||||
|
||||
kernel_height: usize,
|
||||
kernel_width: usize,
|
||||
out_h: usize,
|
||||
out_w: usize,
|
||||
}
|
||||
|
||||
#[cube(launch, address_type = "dynamic")]
|
||||
fn deform_im2col_kernel<F: Float>(
|
||||
input: &Tensor<F>,
|
||||
offset: &Tensor<F>,
|
||||
mask: &Option<Tensor<F>>,
|
||||
columns: &mut Tensor<F>,
|
||||
pos_shape: Sequence<FastDivmod<usize>>,
|
||||
args: &DeformConv2dArgs,
|
||||
#[comptime] kernel_h_unroll: Option<usize>,
|
||||
#[comptime] kernel_w_unroll: Option<usize>,
|
||||
#[define(F)] _dtype: StorageType,
|
||||
) {
|
||||
// position shape: [in_channels, batch_size, out_h, out_w]
|
||||
// columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]
|
||||
|
||||
let kernel_height = kernel_h_unroll.unwrap_or(args.kernel_height);
|
||||
let unroll_h = kernel_h_unroll.is_some();
|
||||
let kernel_width = kernel_w_unroll.unwrap_or(args.kernel_width);
|
||||
let unroll_w = kernel_w_unroll.is_some();
|
||||
|
||||
let out_h = args.out_h;
|
||||
let out_w = args.out_w;
|
||||
let in_channels = input.shape(1);
|
||||
let height = input.shape(2);
|
||||
let width = input.shape(3);
|
||||
let col_stride_0 = columns.stride(0);
|
||||
|
||||
let (rem, out_x) = pos_shape[3].div_mod(ABSOLUTE_POS);
|
||||
let (rem, out_y) = pos_shape[2].div_mod(rem);
|
||||
let (in_channel, batch) = pos_shape[1].div_mod(rem);
|
||||
|
||||
if in_channel >= in_channels {
|
||||
terminate!()
|
||||
}
|
||||
|
||||
let out_k_base = in_channel * kernel_height * kernel_width;
|
||||
let out_n = batch * out_h * out_w + out_y * out_w + out_x;
|
||||
|
||||
let channels_per_offset_group = in_channels / args.offset_groups;
|
||||
let group_index = in_channel / channels_per_offset_group;
|
||||
|
||||
let mut col_base_idx = out_k_base * columns.stride(0) + out_n * columns.stride(1);
|
||||
|
||||
let input_base_idx = batch * input.stride(0) + in_channel * input.stride(1);
|
||||
|
||||
let offset_base_idx = batch * offset.stride(0)
|
||||
+ group_index * kernel_height * kernel_width * 2 * offset.stride(1);
|
||||
|
||||
let mask_base_idx = mask.as_ref().map(|mask| {
|
||||
batch * mask.stride(0) + group_index * kernel_height * kernel_width * mask.stride(1)
|
||||
});
|
||||
|
||||
#[unroll(unroll_h)]
|
||||
for kernel_y in 0..kernel_height {
|
||||
#[unroll(unroll_w)]
|
||||
for kernel_x in 0..kernel_width {
|
||||
let mask_index = kernel_y * kernel_width + kernel_x;
|
||||
let offset_index = mask_index * 2;
|
||||
|
||||
let offset_y = offset[offset_base_idx
|
||||
+ offset_index * offset.stride(1)
|
||||
+ out_y * offset.stride(2)
|
||||
+ out_x * offset.stride(3)];
|
||||
let offset_x = offset[offset_base_idx
|
||||
+ (offset_index + 1) * offset.stride(1)
|
||||
+ out_y * offset.stride(2)
|
||||
+ out_x * offset.stride(3)];
|
||||
let y = F::cast_from(out_y * args.conv_stride_h + kernel_y * args.dilation_h)
|
||||
- args.padding_h.get::<F>()
|
||||
+ offset_y;
|
||||
let x = F::cast_from(out_x * args.conv_stride_w + kernel_x * args.dilation_w)
|
||||
- args.padding_w.get::<F>()
|
||||
+ offset_x;
|
||||
|
||||
let interpolated = bilinear_interpolate(input, height, width, y, x, input_base_idx);
|
||||
let value = match mask.zip::<usize>(mask_base_idx) {
|
||||
Some((mask, base_idx)) => {
|
||||
let mask_value = mask[base_idx
|
||||
+ mask_index * mask.stride(1)
|
||||
+ out_y * mask.stride(2)
|
||||
+ out_x * mask.stride(3)];
|
||||
mask_value * interpolated
|
||||
}
|
||||
None => interpolated,
|
||||
};
|
||||
|
||||
columns[col_base_idx] = value;
|
||||
col_base_idx += col_stride_0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub(crate) fn bilinear_interpolate<F: Float>(
|
||||
input: &Tensor<F>,
|
||||
height: usize,
|
||||
width: usize,
|
||||
y: F,
|
||||
x: F,
|
||||
offset: usize,
|
||||
) -> F {
|
||||
// To simplify code
|
||||
let y = f32::cast_from(y);
|
||||
let x = f32::cast_from(x);
|
||||
let stride_y = input.stride(2);
|
||||
let stride_x = input.stride(3);
|
||||
|
||||
let mut result = F::new(0.0);
|
||||
if y > -1.0 && height as f32 > y && x > -1.0 && width as f32 > x {
|
||||
let y_low = y.floor();
|
||||
let x_low = x.floor();
|
||||
let y_high = (y_low + 1.) as usize;
|
||||
let x_high = (x_low + 1.) as usize;
|
||||
|
||||
let zero = F::new(0.0);
|
||||
let v1: F = if y_low >= 0. && x_low >= 0. {
|
||||
input[offset + y_low as usize * stride_y + x_low as usize * stride_x]
|
||||
} else {
|
||||
zero
|
||||
};
|
||||
let v2: F = if y_low >= 0. && x_high < width {
|
||||
input[offset + y_low as usize * stride_y + x_high * stride_x]
|
||||
} else {
|
||||
zero
|
||||
};
|
||||
let v3: F = if y_high < height && x_low >= 0. {
|
||||
input[offset + y_high * stride_y + x_low as usize * stride_x]
|
||||
} else {
|
||||
zero
|
||||
};
|
||||
let v4: F = if y_high < height && x_high < width {
|
||||
input[offset + y_high * stride_y + x_high * stride_x]
|
||||
} else {
|
||||
zero
|
||||
};
|
||||
|
||||
let l_y = y - y_low;
|
||||
let l_x = x - x_low;
|
||||
let h_y = 1.0 - l_y;
|
||||
let h_x = 1.0 - l_x;
|
||||
|
||||
let w1 = F::cast_from(h_y * h_x);
|
||||
let w2 = F::cast_from(h_y * l_x);
|
||||
let w3 = F::cast_from(l_y * h_x);
|
||||
let w4 = F::cast_from(l_y * l_x);
|
||||
|
||||
result = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub(crate) fn deform_im2col<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
offset: CubeTensor<R>,
|
||||
mask: Option<CubeTensor<R>>,
|
||||
options: DeformConvOptions<2>,
|
||||
out_dims: (usize, usize),
|
||||
kernel_dims: (usize, usize),
|
||||
) -> Result<CubeTensor<R>, LaunchError> {
|
||||
let client = input.client.clone();
|
||||
let device = input.device.clone();
|
||||
let dtype = input.dtype;
|
||||
|
||||
let [batch_size, in_channels, _, _] = input.meta.shape().dims();
|
||||
let (out_height, out_width) = out_dims;
|
||||
let (kernel_height, kernel_width) = kernel_dims;
|
||||
|
||||
let shape_out = Shape::new([
|
||||
in_channels * kernel_height * kernel_width,
|
||||
batch_size * out_height * out_width,
|
||||
]);
|
||||
|
||||
let pos_shape = [in_channels, batch_size, out_height, out_width]
|
||||
.into_iter()
|
||||
.map(|s| FastDivmodArgs::new(&client, s))
|
||||
.collect();
|
||||
|
||||
let output = zeros_client(client.clone(), device.clone(), shape_out.clone(), dtype);
|
||||
|
||||
let num_kernels = in_channels * batch_size * out_height * out_width;
|
||||
let cube_dim = CubeDim::new(&input.client, num_kernels);
|
||||
let cube_count = calculate_cube_count_elemwise(&input.client, num_kernels, cube_dim);
|
||||
|
||||
deform_im2col_kernel::launch(
|
||||
&input.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(input, offset, mask, output),
|
||||
input.as_tensor_arg(1),
|
||||
offset.as_tensor_arg(1),
|
||||
mask.as_ref().map(|mask| mask.as_tensor_arg(1)).into(),
|
||||
output.as_handle_ref().as_tensor_arg(1),
|
||||
pos_shape,
|
||||
DeformConv2dArgsLaunch::new(
|
||||
ScalarArg::new(options.stride[0]),
|
||||
ScalarArg::new(options.stride[1]),
|
||||
ScalarArg::new(options.dilation[0]),
|
||||
ScalarArg::new(options.dilation[1]),
|
||||
{
|
||||
let val = options.padding[0] as f32;
|
||||
InputScalar::new(val, dtype)
|
||||
},
|
||||
{
|
||||
let val = options.padding[1] as f32;
|
||||
InputScalar::new(val, dtype)
|
||||
},
|
||||
ScalarArg::new(options.offset_groups),
|
||||
ScalarArg::new(kernel_height),
|
||||
ScalarArg::new(kernel_width),
|
||||
ScalarArg::new(out_height),
|
||||
ScalarArg::new(out_width),
|
||||
),
|
||||
Some(kernel_height),
|
||||
Some(kernel_width),
|
||||
dtype.into(),
|
||||
)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub(crate) fn deform_conv2d<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
offset: CubeTensor<R>,
|
||||
weight: CubeTensor<R>,
|
||||
mask: Option<CubeTensor<R>>,
|
||||
bias: Option<CubeTensor<R>>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
let input = into_contiguous_aligned(input);
|
||||
let offset = into_contiguous_aligned(offset);
|
||||
let weight = into_contiguous_aligned(weight);
|
||||
let mask = mask.map(|it| into_contiguous_aligned(it));
|
||||
let bias = bias.map(|it| into_contiguous_aligned(it));
|
||||
|
||||
let [batch_size, _, in_height, in_width] = input.meta.shape().dims();
|
||||
let [out_channels, _, kernel_h, kernel_w] = weight.meta.shape().dims();
|
||||
let groups = options.weight_groups;
|
||||
|
||||
let out_h = calculate_conv_output_size(
|
||||
kernel_h,
|
||||
options.stride[0],
|
||||
options.padding[0],
|
||||
options.dilation[0],
|
||||
in_height,
|
||||
);
|
||||
let out_w = calculate_conv_output_size(
|
||||
kernel_w,
|
||||
options.stride[1],
|
||||
options.padding[1],
|
||||
options.dilation[1],
|
||||
in_width,
|
||||
);
|
||||
let out_dims = (out_h, out_w);
|
||||
|
||||
let columns = deform_im2col(input, offset, mask, options, out_dims, (kernel_h, kernel_w))?;
|
||||
|
||||
let [col_size_0, col_size_1] = columns.meta.shape().dims();
|
||||
let col_size_0 = col_size_0 / groups;
|
||||
let out_c_per_group = out_channels / groups;
|
||||
|
||||
let dtype = weight.dtype;
|
||||
let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_size_0]));
|
||||
let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1]));
|
||||
let out = matmul(weight, columns, None, MatmulStrategy::default(), dtype)?;
|
||||
|
||||
let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w]));
|
||||
let out = swap_dims(out, 0, 1);
|
||||
|
||||
if let Some(bias) = bias {
|
||||
let bias = reshape(bias, Shape::new([1, out_channels, 1, 1]));
|
||||
Ok(launch_binop::<R, AddOp>(out, bias))
|
||||
} else {
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,720 @@
|
||||
use super::{bilinear_interpolate, deform_im2col, index};
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::{
|
||||
cast, into_contiguous_aligned,
|
||||
matmul::{MatmulStrategy, matmul},
|
||||
reduce::reduce_dim,
|
||||
slice_assign,
|
||||
utils::{address_type, decompose_linear, linear_view},
|
||||
},
|
||||
ops::{
|
||||
numeric::{empty_device_dtype, zeros_client},
|
||||
reshape, swap_dims,
|
||||
},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::{DType, Shape, TensorMetadata, ops::DeformConvOptions};
|
||||
use cubecl::{
|
||||
CubeDim, CubeLaunch, calculate_cube_count_elemwise, cube,
|
||||
features::TypeUsage,
|
||||
ir::FloatKind,
|
||||
prelude::*,
|
||||
std::{FastDivmod, FastDivmodArgs, tensor::layout::linear::LinearView},
|
||||
};
|
||||
use cubek::{
|
||||
convolution::components::ConvSetupError,
|
||||
reduce::components::instructions::ReduceOperationConfig,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions.
|
||||
#[allow(
|
||||
clippy::single_range_in_vec_init,
|
||||
clippy::type_complexity,
|
||||
clippy::too_many_arguments
|
||||
)]
|
||||
pub(crate) fn deform_conv2d_backward<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
offset: CubeTensor<R>,
|
||||
weight: CubeTensor<R>,
|
||||
mask: Option<CubeTensor<R>>,
|
||||
bias: Option<CubeTensor<R>>,
|
||||
out_grad: CubeTensor<R>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> Result<
|
||||
(
|
||||
CubeTensor<R>,
|
||||
CubeTensor<R>,
|
||||
CubeTensor<R>,
|
||||
Option<CubeTensor<R>>,
|
||||
Option<CubeTensor<R>>,
|
||||
),
|
||||
ConvSetupError,
|
||||
> {
|
||||
let [_, _, out_h, out_w] = out_grad.meta.shape().dims();
|
||||
let [_, _, kernel_h, kernel_w] = weight.meta.shape().dims();
|
||||
|
||||
let gradient_bias = bias.map(|bias| {
|
||||
let grad = reduce_dim(
|
||||
out_grad.clone(),
|
||||
None,
|
||||
0,
|
||||
Default::default(),
|
||||
ReduceOperationConfig::Sum,
|
||||
)
|
||||
.unwrap();
|
||||
let grad = reduce_dim(
|
||||
grad,
|
||||
None,
|
||||
2,
|
||||
Default::default(),
|
||||
ReduceOperationConfig::Sum,
|
||||
)
|
||||
.unwrap();
|
||||
let grad = reduce_dim(
|
||||
grad,
|
||||
None,
|
||||
3,
|
||||
Default::default(),
|
||||
ReduceOperationConfig::Sum,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
reshape(grad, bias.meta.shape)
|
||||
});
|
||||
|
||||
let input = into_contiguous_aligned(input);
|
||||
let offset = into_contiguous_aligned(offset);
|
||||
let weight = into_contiguous_aligned(weight);
|
||||
let mask = mask.map(|it| into_contiguous_aligned(it));
|
||||
|
||||
let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs(
|
||||
input.clone(),
|
||||
weight.clone(),
|
||||
offset.clone(),
|
||||
mask.clone(),
|
||||
out_grad.clone(),
|
||||
&options,
|
||||
(kernel_h, kernel_w),
|
||||
)?;
|
||||
|
||||
let weight_grad = compute_weight_grad(
|
||||
input,
|
||||
offset,
|
||||
mask,
|
||||
out_grad,
|
||||
options,
|
||||
(kernel_h, kernel_w),
|
||||
(out_h, out_w),
|
||||
)?;
|
||||
|
||||
Ok((
|
||||
input_gradient,
|
||||
offset_gradient,
|
||||
weight_grad,
|
||||
mask_gradient,
|
||||
gradient_bias,
|
||||
))
|
||||
}
|
||||
|
||||
fn compute_weight_grad<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
offset: CubeTensor<R>,
|
||||
mask: Option<CubeTensor<R>>,
|
||||
out_grad: CubeTensor<R>,
|
||||
options: DeformConvOptions<2>,
|
||||
kernel_dims: (usize, usize),
|
||||
out_dims: (usize, usize),
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
let [_, in_channels, _, _] = input.meta.shape().dims();
|
||||
let [_, out_channels, _, _] = out_grad.meta.shape().dims();
|
||||
let (kernel_h, kernel_w) = kernel_dims;
|
||||
let groups = options.weight_groups;
|
||||
let dtype = input.dtype;
|
||||
|
||||
let in_c_per_group = in_channels / groups;
|
||||
let out_c_per_group = out_channels / groups;
|
||||
|
||||
let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims)?;
|
||||
let [col_size_0, col_size_1] = columns.meta.shape().dims();
|
||||
let col_size_0 = col_size_0 / groups;
|
||||
|
||||
let out_grad = swap_dims(out_grad, 0, 1);
|
||||
let out_grad = reshape(out_grad, Shape::new([groups, out_c_per_group, col_size_1]));
|
||||
|
||||
let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1]));
|
||||
let columns = swap_dims(columns, 1, 2);
|
||||
|
||||
let grad_weight = matmul(out_grad, columns, None, MatmulStrategy::default(), dtype)?;
|
||||
|
||||
Ok(reshape(
|
||||
grad_weight,
|
||||
Shape::new([out_channels, in_c_per_group, kernel_h, kernel_w]),
|
||||
))
|
||||
}
|
||||
|
||||
type InputGradients<R> = (CubeTensor<R>, CubeTensor<R>, Option<CubeTensor<R>>);
|
||||
|
||||
fn backward_gradient_inputs<R: CubeRuntime>(
|
||||
image: CubeTensor<R>,
|
||||
weight: CubeTensor<R>,
|
||||
offset: CubeTensor<R>,
|
||||
mask: Option<CubeTensor<R>>,
|
||||
out_grad: CubeTensor<R>,
|
||||
options: &DeformConvOptions<2>,
|
||||
kernel_dims: (usize, usize),
|
||||
) -> Result<InputGradients<R>, ConvSetupError> {
|
||||
let client = out_grad.client.clone();
|
||||
let device = out_grad.device.clone();
|
||||
|
||||
let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.meta.shape().dims();
|
||||
let [batch_size, _, out_h, out_w] = out_grad.meta.shape().dims();
|
||||
|
||||
let groups = options.weight_groups;
|
||||
let out_c_per_group = out_channels / groups;
|
||||
|
||||
let col_shape_0 = in_c_per_group * kernel_h * kernel_w;
|
||||
let col_shape_1 = batch_size * out_h * out_w;
|
||||
let col_shape = Shape::new([groups, col_shape_0, col_shape_1]);
|
||||
let mut columns = empty_device_dtype(client, device, col_shape, weight.dtype);
|
||||
|
||||
let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0]));
|
||||
|
||||
let out_grad = swap_dims(out_grad, 0, 1);
|
||||
let out_grad_shape = Shape::new([groups, out_c_per_group, col_shape_1]);
|
||||
let out_grad = reshape(out_grad, out_grad_shape);
|
||||
|
||||
for group in 0..groups {
|
||||
let dtype = weight.dtype;
|
||||
let weight = swap_dims(index(weight.clone(), group), 0, 1);
|
||||
let out_grad = index(out_grad.clone(), group);
|
||||
let values = matmul(weight, out_grad, None, MatmulStrategy::default(), dtype)?;
|
||||
let values = reshape(values, Shape::new([1, col_shape_0, col_shape_1]));
|
||||
columns = slice_assign(
|
||||
columns,
|
||||
&[
|
||||
burn_backend::Slice::from(group..group + 1),
|
||||
burn_backend::Slice::from(0..col_shape_0),
|
||||
burn_backend::Slice::from(0..col_shape_1),
|
||||
],
|
||||
values,
|
||||
);
|
||||
}
|
||||
|
||||
let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1]));
|
||||
|
||||
let input_shape = image.shape();
|
||||
let (offset_gradient, mask_gradient) = compute_offset_and_mask_gradient(
|
||||
columns.clone(),
|
||||
image,
|
||||
offset.clone(),
|
||||
mask.clone(),
|
||||
options,
|
||||
kernel_dims,
|
||||
)?;
|
||||
|
||||
let input_gradient =
|
||||
compute_input_grad(columns, offset, mask, options, kernel_dims, input_shape)?;
|
||||
|
||||
Ok((input_gradient, offset_gradient, mask_gradient))
|
||||
}
|
||||
|
||||
fn compute_offset_and_mask_gradient<R: CubeRuntime>(
|
||||
columns: CubeTensor<R>,
|
||||
image: CubeTensor<R>,
|
||||
offset: CubeTensor<R>,
|
||||
mask: Option<CubeTensor<R>>,
|
||||
options: &DeformConvOptions<2>,
|
||||
kernel_dims: (usize, usize),
|
||||
) -> Result<(CubeTensor<R>, Option<CubeTensor<R>>), ConvSetupError> {
|
||||
let client = offset.client.clone();
|
||||
let device = offset.device.clone();
|
||||
let (kernel_h, kernel_w) = kernel_dims;
|
||||
|
||||
let [batches, _, out_h, out_w] = offset.meta.shape().dims();
|
||||
let offset_groups = options.offset_groups;
|
||||
|
||||
let pos_shape = [batches, offset_groups, kernel_h, kernel_w, 2, out_h, out_w];
|
||||
let pos_shape = pos_shape
|
||||
.into_iter()
|
||||
.map(|s| FastDivmodArgs::new(&client, s))
|
||||
.collect();
|
||||
|
||||
let grad_offset =
|
||||
empty_device_dtype(client.clone(), device.clone(), offset.shape(), offset.dtype);
|
||||
let grad_mask = mask
|
||||
.as_ref()
|
||||
.map(|mask| empty_device_dtype(client.clone(), device.clone(), mask.shape(), mask.dtype));
|
||||
|
||||
let num_elements_offset = offset.meta.num_elements();
|
||||
let cube_dim = CubeDim::new(&image.client, num_elements_offset);
|
||||
let cube_count = calculate_cube_count_elemwise(&image.client, num_elements_offset, cube_dim);
|
||||
|
||||
let dtype: StorageType = image.dtype.into();
|
||||
unsafe {
|
||||
deform_col2img_coord_kernel::launch_unchecked(
|
||||
&image.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(image, offset, mask, grad_offset, grad_mask),
|
||||
image.as_tensor_arg(1),
|
||||
offset.as_tensor_arg(1),
|
||||
mask.as_ref().map(|mask| mask.as_tensor_arg(1)).into(),
|
||||
columns.as_tensor_arg(1),
|
||||
linear_view(&grad_offset, 1),
|
||||
grad_mask
|
||||
.as_ref()
|
||||
.map(|grad_mask| grad_mask.as_tensor_arg(1))
|
||||
.into(),
|
||||
pos_shape,
|
||||
DeformConv2dCol2ImgCoordArgsLaunch::new(
|
||||
ScalarArg::new(options.stride[0]),
|
||||
ScalarArg::new(options.stride[1]),
|
||||
ScalarArg::new(options.dilation[0]),
|
||||
ScalarArg::new(options.dilation[1]),
|
||||
InputScalar::new(options.padding[0] as f32, dtype.elem_type()),
|
||||
InputScalar::new(options.padding[1] as f32, dtype.elem_type()),
|
||||
ScalarArg::new(offset_groups),
|
||||
ScalarArg::new(kernel_h),
|
||||
ScalarArg::new(kernel_w),
|
||||
),
|
||||
dtype,
|
||||
)
|
||||
}?;
|
||||
|
||||
Ok((grad_offset, grad_mask))
|
||||
}
|
||||
|
||||
#[derive(CubeLaunch, CubeType)]
|
||||
struct DeformConv2dCol2ImgCoordArgs {
|
||||
stride_h: usize,
|
||||
stride_w: usize,
|
||||
dilation_h: usize,
|
||||
dilation_w: usize,
|
||||
pad_h: InputScalar,
|
||||
pad_w: InputScalar,
|
||||
offset_groups: usize,
|
||||
kernel_height: usize,
|
||||
kernel_width: usize,
|
||||
}
|
||||
|
||||
#[allow(clippy::collapsible_if)]
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn deform_col2img_coord_kernel<F: Float>(
|
||||
image: &Tensor<F>,
|
||||
offset: &Tensor<F>,
|
||||
mask: &Option<Tensor<F>>,
|
||||
columns: &Tensor<F>,
|
||||
grad_offset: &mut LinearView<F, ReadWrite>,
|
||||
grad_mask: &mut Option<Tensor<F>>,
|
||||
pos_shape: Sequence<FastDivmod<usize>>,
|
||||
args: &DeformConv2dCol2ImgCoordArgs,
|
||||
#[define(F)] _dtype: StorageType,
|
||||
) {
|
||||
// Position format: [batch, [offset_groups, kernel_h, kernel_w, 2], out_h, out_w]
|
||||
// Columns format: [[in_channel, kernel_h, kernel_w], [batch, out_h, out_w]]
|
||||
// Alternatively : [batch, offset_channels, out_h, out_w]
|
||||
|
||||
if ABSOLUTE_POS >= grad_offset.shape() {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let out_h = offset.shape(2);
|
||||
let out_w = offset.shape(3);
|
||||
let in_channels = image.shape(1);
|
||||
let height = image.shape(2);
|
||||
let width = image.shape(3);
|
||||
let kernel_w = args.kernel_width;
|
||||
let kernel_h = args.kernel_height;
|
||||
|
||||
let mut grad_offset_val = F::new(0.0);
|
||||
let mut grad_mask_val = F::new(0.0);
|
||||
|
||||
let (_, pos) = decompose_linear(ABSOLUTE_POS, &pos_shape);
|
||||
let [batch, offset_group, kernel_y, kernel_x, dir, out_y, out_x] = *pos else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let channels_per_offset_group = in_channels / args.offset_groups;
|
||||
|
||||
let col_n = batch * out_h * out_w + out_y * out_w + out_x;
|
||||
|
||||
let col_base_idx =
|
||||
offset_group * channels_per_offset_group * kernel_h * kernel_w * columns.stride(0)
|
||||
+ col_n * columns.stride(1);
|
||||
let mut image_base_idx =
|
||||
batch * image.stride(0) + offset_group * channels_per_offset_group * image.stride(1);
|
||||
|
||||
let offset_pos_1 =
|
||||
offset_group * kernel_h * kernel_w * 2 + kernel_y * kernel_w * 2 + kernel_x * 2;
|
||||
let offset_base_idx = batch * offset.stride(0)
|
||||
+ offset_pos_1 * offset.stride(1)
|
||||
+ out_y * offset.stride(2)
|
||||
+ out_x * offset.stride(3);
|
||||
|
||||
let offset_y_idx = offset_base_idx;
|
||||
let offset_x_idx = offset_base_idx + offset.stride(1);
|
||||
|
||||
let offset_y = offset[offset_y_idx];
|
||||
let offset_x = offset[offset_x_idx];
|
||||
|
||||
let mask_pos_1 = offset_group * kernel_h * kernel_w + kernel_y * kernel_w + kernel_x;
|
||||
let mask_value = match &mask {
|
||||
Some(mask) => {
|
||||
let mask_idx = batch * mask.stride(0)
|
||||
+ mask_pos_1 * mask.stride(1)
|
||||
+ out_y * mask.stride(2)
|
||||
+ out_x * mask.stride(3);
|
||||
mask[mask_idx]
|
||||
}
|
||||
None => F::new(1.0),
|
||||
};
|
||||
|
||||
let is_y_direction = dir == 0;
|
||||
|
||||
for col_c in 0..channels_per_offset_group {
|
||||
let col_pos = col_base_idx + col_c * kernel_h * kernel_w * columns.stride(0);
|
||||
|
||||
let y = F::cast_from(out_y * args.stride_h + kernel_y * args.dilation_h)
|
||||
- args.pad_h.get::<F>()
|
||||
+ offset_y;
|
||||
let x = F::cast_from(out_x * args.stride_w + kernel_x * args.dilation_w)
|
||||
- args.pad_w.get::<F>()
|
||||
+ offset_x;
|
||||
|
||||
let weight =
|
||||
get_coordinate_weight(image, image_base_idx, height, width, y, x, is_y_direction);
|
||||
let columns_value = columns[col_pos];
|
||||
|
||||
grad_offset_val += mask_value * weight * columns_value;
|
||||
|
||||
if grad_mask.is_some() && is_y_direction {
|
||||
grad_mask_val +=
|
||||
columns_value * bilinear_interpolate(image, height, width, y, x, image_base_idx);
|
||||
}
|
||||
|
||||
image_base_idx += image.stride(1);
|
||||
}
|
||||
|
||||
grad_offset[ABSOLUTE_POS] = grad_offset_val;
|
||||
|
||||
if let Some(grad_mask) = grad_mask {
|
||||
if is_y_direction {
|
||||
let idx = batch * grad_mask.stride(0)
|
||||
+ mask_pos_1 * grad_mask.stride(1)
|
||||
+ out_y * grad_mask.stride(2)
|
||||
+ out_x * grad_mask.stride(3);
|
||||
|
||||
grad_mask[idx] = grad_mask_val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn get_coordinate_weight<F: Float>(
|
||||
input: &Tensor<F>,
|
||||
offset: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
y: F,
|
||||
x: F,
|
||||
is_y_direction: bool,
|
||||
) -> F {
|
||||
let stride_y = input.stride(2);
|
||||
let stride_x = input.stride(3);
|
||||
|
||||
let y = f32::cast_from(y);
|
||||
let x = f32::cast_from(x);
|
||||
|
||||
let y_low = f32::floor(y);
|
||||
let x_low = f32::floor(x);
|
||||
let y_high = y_low + 1.;
|
||||
let x_high = x_low + 1.;
|
||||
|
||||
let valid_y_low = y_low >= 0. && y_low < height as f32;
|
||||
let valid_y_high = y_high >= 0. && y_high < height as f32;
|
||||
let valid_x_low = x_low >= 0. && x_low < width as f32;
|
||||
let valid_x_high = x_high >= 0. && x_high < width as f32;
|
||||
|
||||
let bottom_left = if valid_y_low && valid_x_low {
|
||||
input[offset + y_low as usize * stride_y + x_low as usize * stride_x]
|
||||
} else {
|
||||
F::new(0.0)
|
||||
};
|
||||
let bottom_right = if valid_y_low && valid_x_high {
|
||||
input[offset + y_low as usize * stride_y + x_high as usize * stride_x]
|
||||
} else {
|
||||
F::new(0.0)
|
||||
};
|
||||
let top_left = if valid_y_high && valid_x_low {
|
||||
input[offset + y_high as usize * stride_y + x_low as usize * stride_x]
|
||||
} else {
|
||||
F::new(0.0)
|
||||
};
|
||||
let top_right = if valid_y_high && valid_x_high {
|
||||
input[offset + y_high as usize * stride_y + x_high as usize * stride_x]
|
||||
} else {
|
||||
F::new(0.0)
|
||||
};
|
||||
|
||||
if is_y_direction {
|
||||
let delta_x = F::cast_from(x - x_low);
|
||||
delta_x * (top_right - bottom_right) + (F::new(1.0) - delta_x) * (top_left - bottom_left)
|
||||
} else {
|
||||
let delta_y = F::cast_from(y - y_low);
|
||||
delta_y * (top_right - top_left) + (F::new(1.0) - delta_y) * (bottom_right - bottom_left)
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_input_grad<R: CubeRuntime>(
|
||||
columns: CubeTensor<R>,
|
||||
offset: CubeTensor<R>,
|
||||
mask: Option<CubeTensor<R>>,
|
||||
options: &DeformConvOptions<2>,
|
||||
kernel_dims: (usize, usize),
|
||||
input_shape: Shape,
|
||||
) -> Result<CubeTensor<R>, LaunchError> {
|
||||
let client = offset.client.clone();
|
||||
let device = offset.device.clone();
|
||||
|
||||
let supports_fadd = client
|
||||
.properties()
|
||||
.type_usage(StorageType::Atomic(FloatKind::F32.into()))
|
||||
.contains(TypeUsage::AtomicAdd);
|
||||
let supports_same_type = client
|
||||
.properties()
|
||||
.type_usage(StorageType::Atomic(columns.dtype.into()))
|
||||
.contains(TypeUsage::AtomicAdd);
|
||||
|
||||
let [batches, in_channels, height, width] = input_shape.dims();
|
||||
let [_, _, out_h, out_w] = offset.meta.shape().dims();
|
||||
let (kernel_h, kernel_w) = kernel_dims;
|
||||
|
||||
let pos_shape = [in_channels, kernel_h, kernel_w, batches, out_h, out_w];
|
||||
let pos_shape = pos_shape
|
||||
.into_iter()
|
||||
.map(|s| FastDivmodArgs::new(&client, s))
|
||||
.collect();
|
||||
|
||||
let shape = Shape::new([batches, in_channels, height, width]);
|
||||
let grad_in = match supports_fadd && supports_same_type {
|
||||
// Use type as is to save a cast
|
||||
true => zeros_client(client.clone(), device.clone(), shape, columns.dtype),
|
||||
// Force `f32` to enable bitcasting as `u32`, or use intrinsic when supported
|
||||
false => zeros_client(client.clone(), device.clone(), shape, DType::F32),
|
||||
};
|
||||
let grad_arg = grad_in.as_tensor_arg(1);
|
||||
|
||||
let num_elements = columns.meta.num_elements();
|
||||
let cube_dim = CubeDim::new(&offset.client, num_elements);
|
||||
let cube_count = calculate_cube_count_elemwise(&offset.client, num_elements, cube_dim);
|
||||
|
||||
let launch = match supports_fadd {
|
||||
true => deform_col2img_kernel::launch_unchecked::<IntrinsicFloatAtomicAddFamily, R>,
|
||||
false => deform_col2img_kernel::launch_unchecked::<CASFloatAtomicAdd, R>,
|
||||
};
|
||||
let dtype = offset.dtype;
|
||||
let dtypes: [StorageType; 2] = match supports_same_type {
|
||||
true => [dtype.into(), dtype.into()],
|
||||
false => [dtype.into(), DType::F32.into()],
|
||||
};
|
||||
|
||||
unsafe {
|
||||
launch(
|
||||
&offset.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(offset, mask, columns, grad_in),
|
||||
offset.as_tensor_arg(1),
|
||||
mask.as_ref().map(|mask| mask.as_tensor_arg(1)).into(),
|
||||
linear_view(&columns, 1),
|
||||
grad_arg,
|
||||
pos_shape,
|
||||
DeformConv2dCol2ImgArgsLaunch::new(
|
||||
ScalarArg::new(options.stride[0]),
|
||||
ScalarArg::new(options.stride[1]),
|
||||
ScalarArg::new(options.dilation[0]),
|
||||
ScalarArg::new(options.dilation[1]),
|
||||
InputScalar::new(options.padding[0] as f32, dtypes[0].elem_type()),
|
||||
InputScalar::new(options.padding[1] as f32, dtypes[0].elem_type()),
|
||||
ScalarArg::new(options.offset_groups),
|
||||
ScalarArg::new(kernel_h),
|
||||
ScalarArg::new(kernel_w),
|
||||
),
|
||||
dtypes,
|
||||
)
|
||||
}?;
|
||||
|
||||
Ok(if !supports_same_type || !supports_fadd {
|
||||
cast(grad_in, dtype)
|
||||
} else {
|
||||
grad_in
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(CubeLaunch, CubeType)]
|
||||
struct DeformConv2dCol2ImgArgs {
|
||||
stride_h: usize,
|
||||
stride_w: usize,
|
||||
dilation_h: usize,
|
||||
dilation_w: usize,
|
||||
pad_h: InputScalar,
|
||||
pad_w: InputScalar,
|
||||
offset_groups: usize,
|
||||
kernel_height: usize,
|
||||
kernel_width: usize,
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn deform_col2img_kernel<F: Float, FP: Float, FAdd: FloatAtomicAddFamily>(
|
||||
offset: &Tensor<F>,
|
||||
mask: &Option<Tensor<F>>,
|
||||
columns: &LinearView<F>,
|
||||
grad_input: &mut Tensor<Atomic<ProxyType<FAdd, FP>>>,
|
||||
pos_shape: Sequence<FastDivmod<usize>>,
|
||||
args: &DeformConv2dCol2ImgArgs,
|
||||
#[define(F, FP)] _dtype: [StorageType; 2],
|
||||
) {
|
||||
// Position format: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]
|
||||
if ABSOLUTE_POS >= columns.shape() {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let n_in_channels = grad_input.shape(1);
|
||||
let height = grad_input.shape(2);
|
||||
let width = grad_input.shape(3);
|
||||
let kernel_h = args.kernel_height;
|
||||
let kernel_w = args.kernel_width;
|
||||
let n_offset_groups = args.offset_groups;
|
||||
|
||||
let (_, pos) = decompose_linear(ABSOLUTE_POS, &pos_shape);
|
||||
let [in_channel, kernel_y, kernel_x, batch, out_y, out_x] = *pos else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let channels_per_offset_group = n_in_channels / n_offset_groups;
|
||||
let offset_group = in_channel / channels_per_offset_group;
|
||||
|
||||
let offset_pos_1 =
|
||||
offset_group * kernel_h * kernel_w * 2 + kernel_y * kernel_w * 2 + kernel_x * 2;
|
||||
let offset_base_idx = batch * offset.stride(0)
|
||||
+ offset_pos_1 * offset.stride(1)
|
||||
+ out_y * offset.stride(2)
|
||||
+ out_x * offset.stride(3);
|
||||
|
||||
let offset_y_idx = offset_base_idx;
|
||||
let offset_x_idx = offset_base_idx + offset.stride(1);
|
||||
|
||||
let offset_y = offset[offset_y_idx];
|
||||
let offset_x = offset[offset_x_idx];
|
||||
|
||||
let mask_value = match mask {
|
||||
Some(mask) => {
|
||||
let mask_pos_1 = offset_group * kernel_h * kernel_w + kernel_y * kernel_w + kernel_x;
|
||||
mask[batch * mask.stride(0)
|
||||
+ mask_pos_1 * mask.stride(1)
|
||||
+ out_y * mask.stride(2)
|
||||
+ out_x * mask.stride(3)]
|
||||
}
|
||||
None => F::new(1.0),
|
||||
};
|
||||
|
||||
let y = F::cast_from(out_y * args.stride_h + kernel_y * args.dilation_h)
|
||||
- args.pad_h.get::<F>()
|
||||
+ offset_y;
|
||||
let x = F::cast_from(out_x * args.stride_w + kernel_x * args.dilation_w)
|
||||
- args.pad_w.get::<F>()
|
||||
+ offset_x;
|
||||
|
||||
for dy in -1..=1i32 {
|
||||
#[unroll]
|
||||
for dx in -1..=1i32 {
|
||||
let yp = y.floor() + F::cast_from(dy);
|
||||
let xp = x.floor() + F::cast_from(dx);
|
||||
|
||||
if yp >= F::new(0.0)
|
||||
&& yp < F::cast_from(height)
|
||||
&& xp >= F::new(0.0)
|
||||
&& xp < F::cast_from(width)
|
||||
&& F::abs(y - yp) < F::new(1.0)
|
||||
&& F::abs(x - xp) < F::new(1.0)
|
||||
{
|
||||
let gradient_pos = batch * grad_input.stride(0)
|
||||
+ in_channel * grad_input.stride(1)
|
||||
+ usize::cast_from(yp) * grad_input.stride(2)
|
||||
+ usize::cast_from(xp) * grad_input.stride(3);
|
||||
|
||||
let weight = (F::new(1.0) - F::abs(y - yp)) * (F::new(1.0) - F::abs(x - xp));
|
||||
|
||||
let value = mask_value * F::cast_from(weight) * columns[ABSOLUTE_POS];
|
||||
|
||||
FAdd::Op::<FP>::float_atomic_add::<F>(&mut grad_input[gradient_pos], value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ProxyType<FADF, FP> = <<FADF as FloatAtomicAddFamily>::Op<FP> as FloatAtomicAdd>::ProxyType;
|
||||
|
||||
#[cube]
|
||||
trait FloatAtomicAddFamily: Send + Sync + 'static {
|
||||
type Op<ProxyType: Float>: FloatAtomicAdd;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
trait FloatAtomicAdd: Send + Sync + 'static {
|
||||
type ProxyType: Numeric;
|
||||
|
||||
fn float_atomic_add<F: Float>(ptr: &mut Atomic<Self::ProxyType>, value: F);
|
||||
}
|
||||
|
||||
#[derive(CubeType)]
|
||||
struct IntrinsicFloatAtomicAdd<F: Float> {
|
||||
#[cube(comptime)]
|
||||
_ty: PhantomData<F>,
|
||||
}
|
||||
|
||||
#[derive(CubeType)]
|
||||
struct CASFloatAtomicAdd;
|
||||
|
||||
struct IntrinsicFloatAtomicAddFamily;
|
||||
|
||||
impl FloatAtomicAddFamily for IntrinsicFloatAtomicAddFamily {
|
||||
type Op<ProxyType: Float> = IntrinsicFloatAtomicAdd<ProxyType>;
|
||||
}
|
||||
|
||||
impl FloatAtomicAddFamily for CASFloatAtomicAdd {
|
||||
type Op<ProxyType: Float> = Self;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<FAdd: Float> FloatAtomicAdd for IntrinsicFloatAtomicAdd<FAdd> {
|
||||
type ProxyType = FAdd;
|
||||
|
||||
fn float_atomic_add<F: Float>(ptr: &mut Atomic<FAdd>, value: F) {
|
||||
let value = FAdd::cast_from(value);
|
||||
ptr.fetch_add(value);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl FloatAtomicAdd for CASFloatAtomicAdd {
|
||||
type ProxyType = u32;
|
||||
|
||||
fn float_atomic_add<F: Float>(ptr: &mut Atomic<Self::ProxyType>, value: F) {
|
||||
let value = f32::cast_from(value);
|
||||
if value != 0.0 {
|
||||
let mut v = ptr.load();
|
||||
loop {
|
||||
let prev = v;
|
||||
let v_float = f32::from_bits(v);
|
||||
let new = (v_float + value).to_bits();
|
||||
v = ptr.compare_exchange_weak(v, new);
|
||||
if prev == v {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,320 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::{
|
||||
into_contiguous_aligned,
|
||||
utils::{address_type, linear_view},
|
||||
},
|
||||
ops::max_line_size,
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use crate::{kernel::utils::decompose_linear, ops::numeric::empty_device_dtype};
|
||||
use burn_backend::{
|
||||
TensorMetadata,
|
||||
ops::{ConvOptions, conv::calculate_conv_output_sizes},
|
||||
};
|
||||
use cubecl::std::{FastDivmod, FastDivmodArgs};
|
||||
use cubecl::{
|
||||
calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView,
|
||||
tensor_line_size_parallel,
|
||||
};
|
||||
use cubek::convolution::components::ConvSetupError;
|
||||
|
||||
#[derive(CubeLaunch, CubeType, Clone)]
|
||||
pub(crate) struct ConvParam {
|
||||
pub stride: u32,
|
||||
pub dilation: u32,
|
||||
pub padding: i32,
|
||||
}
|
||||
|
||||
#[derive(CubeLaunch, CubeType)]
|
||||
struct Conv2dArgs {
|
||||
conv_params: Sequence<ConvParam>,
|
||||
channels_per_group: u32,
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn direct_conv2d_kernel<E: Numeric>(
|
||||
input: &Tensor<Line<E>>,
|
||||
weight: &Tensor<Line<E>>,
|
||||
bias: Option<Tensor<Line<E>>>,
|
||||
output: &mut LinearView<Line<E>, ReadWrite>,
|
||||
args: Conv2dArgs,
|
||||
shape_out: Sequence<FastDivmod<u32>>,
|
||||
shape_out_c: FastDivmod<u32>,
|
||||
#[comptime] has_padding: bool,
|
||||
#[define(E)] _dtype: StorageType,
|
||||
) {
|
||||
if !output.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let n_spatial = comptime![shape_out.len()];
|
||||
|
||||
let line_size_out = output.line_size();
|
||||
let pos = ABSOLUTE_POS * line_size_out;
|
||||
|
||||
let in_c_per_group = weight.shape(weight.rank() - 1) as u32;
|
||||
|
||||
let (rem, out_c) = shape_out_c.div_mod(pos as u32);
|
||||
let (b, spatial_pos) = decompose_linear(rem, &shape_out);
|
||||
|
||||
let g = out_c / args.channels_per_group;
|
||||
let ic_start = in_c_per_group * g;
|
||||
|
||||
let bias: Option<Line<E>> = bias.map(|bias| bias[out_c as usize / line_size_out]);
|
||||
let mut sum = bias.unwrap_or_else(|| Line::empty(line_size_out).fill(E::from_int(0)));
|
||||
|
||||
let in_offs = b as usize * input.stride(0) + ic_start as usize;
|
||||
|
||||
let stride_oc = weight.stride(0);
|
||||
|
||||
let mut in_shape = Sequence::new();
|
||||
let mut in_strides = Sequence::new();
|
||||
let mut kernel_shape = Sequence::new();
|
||||
let mut kernel_strides = Sequence::new();
|
||||
|
||||
#[unroll]
|
||||
for i in 0..n_spatial {
|
||||
in_shape.push(input.shape(i + 1) as u32);
|
||||
in_strides.push(input.stride(i + 1));
|
||||
kernel_shape.push(weight.shape(i + 1) as u32);
|
||||
kernel_strides.push(weight.stride(i + 1));
|
||||
}
|
||||
|
||||
let weight_offs = out_c as usize * stride_oc;
|
||||
|
||||
let loop_params = LoopParams {
|
||||
out_pos: spatial_pos,
|
||||
in_shape,
|
||||
in_strides,
|
||||
kernel_shape,
|
||||
kernel_strides,
|
||||
conv_params: args.conv_params,
|
||||
in_c_per_group,
|
||||
stride_oc,
|
||||
};
|
||||
|
||||
kernel_loop(
|
||||
input,
|
||||
weight,
|
||||
&mut sum,
|
||||
in_offs,
|
||||
true,
|
||||
weight_offs,
|
||||
&loop_params,
|
||||
0usize,
|
||||
has_padding,
|
||||
);
|
||||
|
||||
output[ABSOLUTE_POS] = sum;
|
||||
}
|
||||
|
||||
#[derive(CubeType, Clone)]
|
||||
struct LoopParams {
|
||||
out_pos: Sequence<u32>,
|
||||
in_shape: Sequence<u32>,
|
||||
in_strides: Sequence<usize>,
|
||||
kernel_shape: Sequence<u32>,
|
||||
kernel_strides: Sequence<usize>,
|
||||
conv_params: Sequence<ConvParam>,
|
||||
|
||||
in_c_per_group: u32,
|
||||
stride_oc: usize,
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn kernel_loop<E: Numeric>(
|
||||
input: &Tensor<Line<E>>,
|
||||
weight: &Tensor<Line<E>>,
|
||||
sum: &mut Line<E>,
|
||||
in_offs: usize,
|
||||
in_bounds: bool,
|
||||
weight_offs: usize,
|
||||
params: &LoopParams,
|
||||
#[comptime] kernel_dim: usize,
|
||||
#[comptime] has_padding: bool,
|
||||
) {
|
||||
if comptime![kernel_dim < params.kernel_shape.len()] {
|
||||
let out_idx = *params.out_pos.index(kernel_dim);
|
||||
let conv = params.conv_params.index(kernel_dim);
|
||||
let shape = *params.in_shape.index(kernel_dim);
|
||||
let stride = *params.in_strides.index(kernel_dim);
|
||||
let k_stride = *params.kernel_strides.index(kernel_dim);
|
||||
|
||||
for pos in 0..*params.kernel_shape.index(kernel_dim) {
|
||||
let in_pos = (out_idx * conv.stride + pos * conv.dilation) as i32 - conv.padding;
|
||||
let in_offs = in_offs + in_pos as usize * stride;
|
||||
let weight_offs = weight_offs + pos as usize * k_stride;
|
||||
let mut in_bounds = in_bounds;
|
||||
|
||||
if has_padding {
|
||||
in_bounds &= in_pos >= 0 && (in_pos as u32) < shape;
|
||||
}
|
||||
|
||||
kernel_loop(
|
||||
input,
|
||||
weight,
|
||||
sum,
|
||||
in_offs,
|
||||
in_bounds,
|
||||
weight_offs,
|
||||
params,
|
||||
comptime![kernel_dim + 1],
|
||||
has_padding,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
kernel_loop_inner(
|
||||
input,
|
||||
weight,
|
||||
sum,
|
||||
in_offs,
|
||||
in_bounds,
|
||||
weight_offs,
|
||||
params.in_c_per_group,
|
||||
params.stride_oc,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn kernel_loop_inner<E: Numeric>(
|
||||
input: &Tensor<Line<E>>,
|
||||
weight: &Tensor<Line<E>>,
|
||||
sum: &mut Line<E>,
|
||||
in_offs: usize,
|
||||
in_bounds: bool,
|
||||
weight_offs: usize,
|
||||
in_c_per_group: u32,
|
||||
stride_oc: usize,
|
||||
) {
|
||||
let line_size_in = input.line_size();
|
||||
let line_size_out = sum.size();
|
||||
|
||||
if in_bounds {
|
||||
for in_c in range_stepped(0, in_c_per_group, line_size_in as u32) {
|
||||
let in_pos = in_offs + in_c as usize;
|
||||
let mut weight_pos = weight_offs + in_c as usize;
|
||||
|
||||
let val = input[in_pos / line_size_in];
|
||||
|
||||
#[unroll]
|
||||
for v in 0..line_size_out {
|
||||
let weight = weight[weight_pos / line_size_in];
|
||||
let val = val * weight;
|
||||
|
||||
#[unroll]
|
||||
for i in 0..line_size_in {
|
||||
sum[v] += val[i];
|
||||
}
|
||||
weight_pos += stride_oc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a 2D convolution using the direct convolution algorithm.
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `weight` - The weights (filter) applied to each kernel
|
||||
/// * `bias` - The bias added to each channel
|
||||
/// * `options` - The options to use for the convolution
|
||||
///
|
||||
pub fn conv_direct<R: CubeRuntime, const N: usize>(
|
||||
mut input: CubeTensor<R>,
|
||||
mut weight: CubeTensor<R>,
|
||||
bias: Option<CubeTensor<R>>,
|
||||
options: ConvOptions<N>,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
let client = input.client.clone();
|
||||
let out_dtype = input.dtype;
|
||||
let rank = input.meta.shape().num_dims();
|
||||
let dim_c = rank - 1;
|
||||
|
||||
// We only care about the channels here, everything else can be permuted
|
||||
if input.meta.strides()[dim_c] != 1 {
|
||||
input = into_contiguous_aligned(input);
|
||||
}
|
||||
if weight.meta.strides()[dim_c] != 1 {
|
||||
weight = into_contiguous_aligned(weight);
|
||||
}
|
||||
|
||||
let batch_size = input.meta.shape()[0];
|
||||
let in_shape = &input.meta.shape()[1..dim_c];
|
||||
let out_channels = weight.meta.shape()[0];
|
||||
let kernel_shape = &weight.meta.shape()[1..dim_c];
|
||||
|
||||
let channels_per_group = out_channels / options.groups;
|
||||
|
||||
let out_size = calculate_conv_output_sizes(
|
||||
kernel_shape,
|
||||
&options.stride,
|
||||
&options.padding,
|
||||
&options.dilation,
|
||||
in_shape,
|
||||
);
|
||||
|
||||
let mut shape_out = vec![batch_size];
|
||||
shape_out.extend(out_size.iter().copied());
|
||||
shape_out.push(out_channels);
|
||||
|
||||
let output = empty_device_dtype(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
shape_out.into(),
|
||||
out_dtype,
|
||||
);
|
||||
|
||||
// Need custom line size calculation here to account for the groups division. Need to vectorize
|
||||
// over `channels_per_group` instead.
|
||||
let mut grouped_out_shape = output.shape();
|
||||
grouped_out_shape[dim_c] = channels_per_group;
|
||||
let line_size_out = tensor_line_size_parallel(
|
||||
input.client.io_optimized_line_sizes(input.dtype.size()),
|
||||
&grouped_out_shape,
|
||||
output.meta.strides(),
|
||||
dim_c,
|
||||
);
|
||||
// Use channels_per_group instead of in_channels to avoid issues here
|
||||
let line_size_in = max_line_size(&weight);
|
||||
|
||||
let shape_out = output.meta.shape()[1..dim_c]
|
||||
.iter()
|
||||
.map(|s| FastDivmodArgs::<u32>::new(&client, *s as u32))
|
||||
.collect();
|
||||
let shape_out_c = FastDivmodArgs::<u32>::new(&client, out_channels as u32);
|
||||
|
||||
let mut conv_params = SequenceArg::new();
|
||||
|
||||
for i in 0..kernel_shape.len() {
|
||||
conv_params.push(ConvParamLaunch::new(
|
||||
ScalarArg::new(options.stride[i] as u32),
|
||||
ScalarArg::new(options.dilation[i] as u32),
|
||||
ScalarArg::new(options.padding[i] as i32),
|
||||
));
|
||||
}
|
||||
|
||||
let working_units = output.meta.num_elements() / line_size_out;
|
||||
let cube_dim = CubeDim::new(&input.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
direct_conv2d_kernel::launch_unchecked(
|
||||
&input.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(input, weight, bias, output),
|
||||
input.as_tensor_arg(line_size_in),
|
||||
weight.as_tensor_arg(line_size_in),
|
||||
bias.as_ref().map(|b| b.as_tensor_arg(line_size_out)).into(),
|
||||
linear_view(&output, line_size_out),
|
||||
Conv2dArgsLaunch::new(conv_params, ScalarArg::new(channels_per_group as u32)),
|
||||
shape_out,
|
||||
shape_out_c,
|
||||
options.padding.iter().any(|it| *it != 0),
|
||||
out_dtype.into(),
|
||||
)
|
||||
}?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
|
||||
use burn_backend::ops::{ConvOptions, conv::calculate_conv_output_sizes};
|
||||
use cubek::{
|
||||
convolution::{
|
||||
AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy,
|
||||
components::ConvSetupError, forward,
|
||||
},
|
||||
matmul::{
|
||||
definition::{MatmulElems, MatmulGlobalElems},
|
||||
launch::MatmulInputHandleRef,
|
||||
},
|
||||
};
|
||||
|
||||
/// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul
|
||||
/// components. Uses [`CmmaLargeMAlgorithm`] for the stage size
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `weight` - The weights (filter) applied to each kernel
|
||||
/// * `bias` - The bias added to each channel
|
||||
/// * `options` - The options to use for the convolution
|
||||
pub fn conv_gemm_simple_sync<R: CubeRuntime, const N: usize>(
|
||||
input: CubeTensor<R>,
|
||||
weight: CubeTensor<R>,
|
||||
bias: Option<CubeTensor<R>>,
|
||||
options: ConvOptions<N>,
|
||||
tile_kind: AcceleratedTileKind,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
let read_strategy = match tile_kind {
|
||||
AcceleratedTileKind::Cmma => ReadingStrategy::Cyclic,
|
||||
AcceleratedTileKind::Mma => ReadingStrategy::Strided,
|
||||
};
|
||||
launch_convolution_forward::<R, N>(
|
||||
&Strategy::Simple {
|
||||
read_strategy,
|
||||
tile_kind,
|
||||
},
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
options,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn conv_gemm_simple_async<R: CubeRuntime, const N: usize>(
|
||||
input: CubeTensor<R>,
|
||||
weight: CubeTensor<R>,
|
||||
bias: Option<CubeTensor<R>>,
|
||||
options: ConvOptions<N>,
|
||||
tile_kind: AcceleratedTileKind,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
let read_strategy = match tile_kind {
|
||||
AcceleratedTileKind::Cmma => ReadingStrategy::AsyncCyclic,
|
||||
AcceleratedTileKind::Mma => ReadingStrategy::AsyncStrided,
|
||||
};
|
||||
launch_convolution_forward::<R, N>(
|
||||
&Strategy::Simple {
|
||||
read_strategy,
|
||||
tile_kind,
|
||||
},
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
options,
|
||||
)
|
||||
}
|
||||
|
||||
/// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul
|
||||
/// components. Uses [`CmmaLargeMAlgorithm`] for the stage size
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `weight` - The weights (filter) applied to each kernel
|
||||
/// * `bias` - The bias added to each channel
|
||||
/// * `options` - The options to use for the convolution
|
||||
pub fn conv_gemm_simple_tma<R: CubeRuntime, const N: usize>(
|
||||
input: CubeTensor<R>,
|
||||
weight: CubeTensor<R>,
|
||||
bias: Option<CubeTensor<R>>,
|
||||
options: ConvOptions<N>,
|
||||
tile_kind: AcceleratedTileKind,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
launch_convolution_forward::<R, N>(
|
||||
&Strategy::Simple {
|
||||
read_strategy: ReadingStrategy::Tma,
|
||||
tile_kind,
|
||||
},
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
options,
|
||||
)
|
||||
}
|
||||
|
||||
/// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul
|
||||
/// components, using the specified algorithm.
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `weight` - The weights (filter) applied to each kernel
|
||||
/// * `bias` - The bias added to each channel
|
||||
/// * `options` - The options to use for the convolution
|
||||
pub fn launch_convolution_forward<R: CubeRuntime, const N: usize>(
|
||||
strategy: &Strategy,
|
||||
input: CubeTensor<R>,
|
||||
weight: CubeTensor<R>,
|
||||
bias: Option<CubeTensor<R>>,
|
||||
options: ConvOptions<N>,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
if options.groups != 1 {
|
||||
return Err(ConvSetupError::Groups(options.groups));
|
||||
}
|
||||
|
||||
let out_dtype = input.dtype;
|
||||
let rank = input.meta.shape().num_dims();
|
||||
let batch_size = input.meta.shape()[0];
|
||||
let dim_c = rank - 1;
|
||||
let shape = &input.meta.shape()[1..dim_c];
|
||||
|
||||
let out_channels = weight.meta.shape()[0];
|
||||
let weight_shape = &weight.meta.shape()[1..dim_c];
|
||||
|
||||
let mut out_shape = calculate_conv_output_sizes(
|
||||
weight_shape,
|
||||
&options.stride,
|
||||
&options.padding,
|
||||
&options.dilation,
|
||||
shape,
|
||||
);
|
||||
|
||||
out_shape.insert(0, batch_size);
|
||||
out_shape.push(out_channels);
|
||||
|
||||
let out = empty_device_dtype(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
out_shape.into(),
|
||||
out_dtype,
|
||||
);
|
||||
|
||||
let bias = bias
|
||||
.as_ref()
|
||||
.map(|bias| MatmulInputHandleRef::Normal(bias.as_handle_ref(), bias.dtype.into()));
|
||||
|
||||
let client = input.client.clone();
|
||||
let dtypes = MatmulElems::from_globals(&MatmulGlobalElems {
|
||||
lhs: input.dtype.into(),
|
||||
rhs: weight.dtype.into(),
|
||||
out: out_dtype.into(),
|
||||
});
|
||||
let input = MatmulInputHandleRef::new(input.as_handle_ref(), input.dtype.into());
|
||||
let weight = MatmulInputHandleRef::new(weight.as_handle_ref(), weight.dtype.into());
|
||||
|
||||
forward::launch_ref::<R, N>(
|
||||
strategy,
|
||||
&client,
|
||||
&input,
|
||||
&weight,
|
||||
&bias,
|
||||
&out.as_handle_ref(),
|
||||
ConvolutionArgs {
|
||||
stride: options.stride,
|
||||
padding: options.padding,
|
||||
dilation: options.dilation,
|
||||
},
|
||||
dtypes,
|
||||
)?;
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
pub mod launch;
|
||||
pub use launch::*;
|
||||
@@ -0,0 +1,7 @@
|
||||
pub mod implicit_gemm;
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
pub mod tune;
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
pub(crate) use tune::*;
|
||||
@@ -0,0 +1,174 @@
|
||||
use burn_backend::ops::ConvOptions;
|
||||
use cubecl::{
|
||||
ir::StorageType,
|
||||
tune::{LocalTuner, Tunable, TunableSet, anchor, local_tuner},
|
||||
};
|
||||
use cubek::convolution::AcceleratedTileKind;
|
||||
|
||||
use crate::{
|
||||
CubeAutotuneKey, CubeRuntime, CubeTuneId,
|
||||
kernel::conv::{ConvAutotuneKey, conv_direct, conv_im2col_1x1, forward::implicit_gemm::*},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
|
||||
/// Executes autotune on convolution operations
|
||||
pub fn conv_autotune<R: CubeRuntime, const N: usize>(
|
||||
input: CubeTensor<R>,
|
||||
weight: CubeTensor<R>,
|
||||
bias: Option<CubeTensor<R>>,
|
||||
options: ConvOptions<N>,
|
||||
) -> CubeTensor<R> {
|
||||
let client = input.client.clone();
|
||||
|
||||
static TUNER: LocalTuner<CubeAutotuneKey, CubeTuneId> = local_tuner!();
|
||||
|
||||
let tunables = TUNER.init(|| {
|
||||
TunableSet::new(create_key::<R, N>, create_conv_input::<R, N>)
|
||||
.with(Tunable::new("conv_direct", conv_direct::<R, N>))
|
||||
.with(Tunable::new("conv_im2col_1x1", conv_im2col_1x1::<R, N>))
|
||||
.with(Tunable::new(
|
||||
"simple_sync_cmma",
|
||||
|input, weight, bias, options| {
|
||||
conv_gemm_simple_sync(input, weight, bias, options, AcceleratedTileKind::Cmma)
|
||||
},
|
||||
))
|
||||
.with(Tunable::new(
|
||||
"simple_sync_mma",
|
||||
|input, weight, bias, options| {
|
||||
conv_gemm_simple_sync(input, weight, bias, options, AcceleratedTileKind::Mma)
|
||||
},
|
||||
))
|
||||
.with(Tunable::new(
|
||||
"simple_async_cmma",
|
||||
|input, weight, bias, options| {
|
||||
conv_gemm_simple_async(input, weight, bias, options, AcceleratedTileKind::Cmma)
|
||||
},
|
||||
))
|
||||
.with(Tunable::new(
|
||||
"simple_async_mma",
|
||||
|input, weight, bias, options| {
|
||||
conv_gemm_simple_async(input, weight, bias, options, AcceleratedTileKind::Mma)
|
||||
},
|
||||
))
|
||||
.with(Tunable::new(
|
||||
"simple_tma_cmma",
|
||||
|input, weight, bias, options| {
|
||||
conv_gemm_simple_tma(input, weight, bias, options, AcceleratedTileKind::Cmma)
|
||||
},
|
||||
))
|
||||
.with(Tunable::new(
|
||||
"simple_tma_mma",
|
||||
|input, weight, bias, options| {
|
||||
conv_gemm_simple_tma(input, weight, bias, options, AcceleratedTileKind::Mma)
|
||||
},
|
||||
))
|
||||
});
|
||||
|
||||
TUNER.execute(
|
||||
&CubeTuneId::new(&input.client, &input.device),
|
||||
&client,
|
||||
tunables,
|
||||
(input, weight, bias, options),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn create_conv_input<R: CubeRuntime, const N: usize>(
|
||||
_key: &CubeAutotuneKey,
|
||||
input: &CubeTensor<R>,
|
||||
weights: &CubeTensor<R>,
|
||||
bias: &Option<CubeTensor<R>>,
|
||||
options: &ConvOptions<N>,
|
||||
) -> (
|
||||
CubeTensor<R>,
|
||||
CubeTensor<R>,
|
||||
Option<CubeTensor<R>>,
|
||||
ConvOptions<N>,
|
||||
) {
|
||||
(
|
||||
input.clone(),
|
||||
weights.clone(),
|
||||
bias.clone(),
|
||||
options.clone(),
|
||||
)
|
||||
}
|
||||
|
||||
fn create_key<R: CubeRuntime, const N: usize>(
|
||||
input: &CubeTensor<R>,
|
||||
weights: &CubeTensor<R>,
|
||||
bias: &Option<CubeTensor<R>>,
|
||||
options: &ConvOptions<N>,
|
||||
) -> CubeAutotuneKey {
|
||||
let dtype = input.dtype;
|
||||
let rank = input.meta.shape().num_dims();
|
||||
let dim_c = rank - 1;
|
||||
|
||||
let batch_size = input.meta.shape()[0];
|
||||
let in_channels = input.meta.shape()[dim_c];
|
||||
let out_channels = weights.meta.shape()[0];
|
||||
|
||||
let kernel_size = weights.meta.shape()[1..dim_c].to_vec();
|
||||
let in_shape = input.meta.shape()[1..dim_c]
|
||||
.iter()
|
||||
.map(|shape| anchor(*shape, None, None, None))
|
||||
.collect();
|
||||
|
||||
let ConvOptions {
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
} = options.clone();
|
||||
|
||||
let lhs_stride_align = if input.meta.strides()[dim_c] == 1 {
|
||||
stride_align(input.meta.strides(), input.dtype.into())
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let lhs_shape_align = pow2_factor(in_channels).min(lhs_stride_align);
|
||||
let rhs_stride_align = if weights.meta.strides()[dim_c] == 1 {
|
||||
stride_align(weights.meta.strides(), weights.dtype.into())
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let rhs_shape_align = pow2_factor(in_channels).min(rhs_stride_align);
|
||||
|
||||
CubeAutotuneKey::Conv(ConvAutotuneKey::new(
|
||||
kernel_size,
|
||||
stride.to_vec(),
|
||||
padding.to_vec(),
|
||||
dilation.to_vec(),
|
||||
groups,
|
||||
in_channels,
|
||||
out_channels,
|
||||
in_shape,
|
||||
batch_size,
|
||||
bias.is_some(),
|
||||
dtype,
|
||||
lhs_shape_align,
|
||||
lhs_stride_align,
|
||||
rhs_shape_align,
|
||||
rhs_stride_align,
|
||||
))
|
||||
}
|
||||
|
||||
/// Maximum factor relevant for strides. Currently set to 2^10 because that's 128-byte swizzle's
|
||||
/// repeat number, so it's the largest align that can have performance impacts.
|
||||
const MAX_STRIDE_FACTOR: u32 = 10;
|
||||
|
||||
/// Defines the non-contiguous stride alignment in terms of powers of two
|
||||
fn stride_align(strides: &[usize], elem: StorageType) -> u8 {
|
||||
let max = MAX_STRIDE_FACTOR;
|
||||
let dim_c = strides.len() - 1;
|
||||
let factor = strides[..dim_c]
|
||||
.iter()
|
||||
.map(|it| (*it * elem.size_bits()) / 8)
|
||||
.map(|it| it.trailing_zeros())
|
||||
.min()
|
||||
.unwrap_or(max);
|
||||
factor.min(max) as u8
|
||||
}
|
||||
|
||||
/// Defines the potential vectorization.
|
||||
fn pow2_factor(axis: usize) -> u8 {
|
||||
axis.trailing_zeros().min(4) as u8
|
||||
}
|
||||
@@ -0,0 +1,187 @@
|
||||
use burn_backend::{
|
||||
DType,
|
||||
ops::{ConvOptions, conv::calculate_conv_output_sizes},
|
||||
};
|
||||
use burn_std::{Metadata, Shape};
|
||||
use core::iter;
|
||||
use cubecl::{
|
||||
prelude::*,
|
||||
std::tensor::{TensorHandle, into_contiguous_pitched_ref},
|
||||
};
|
||||
use cubek::convolution::components::ConvSetupError;
|
||||
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::{
|
||||
AddOp, into_contiguous_aligned, launch_binop,
|
||||
matmul::{MatmulStrategy, matmul},
|
||||
utils::split_dim,
|
||||
},
|
||||
ops::{reshape, swap_dims},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
|
||||
#[cfg(not(test))]
|
||||
pub(crate) fn batches_per_run(
|
||||
batch_size: usize,
|
||||
out_shape: usize,
|
||||
plane_size: usize,
|
||||
) -> Result<usize, ConvSetupError> {
|
||||
use cubek::matmul::definition::MatmulAvailabilityError;
|
||||
|
||||
let cube_count_per_batch = out_shape.div_ceil(plane_size);
|
||||
let max_cube_count = u16::MAX as usize;
|
||||
let max_simultaneous = Ord::min(max_cube_count / cube_count_per_batch, batch_size);
|
||||
if max_simultaneous == 0 {
|
||||
return Err(MatmulAvailabilityError::CubeCountTooBig(CubeCount::Static(
|
||||
cube_count_per_batch as u32,
|
||||
1,
|
||||
1,
|
||||
))
|
||||
.into());
|
||||
}
|
||||
Ok((0..=max_simultaneous)
|
||||
.rev()
|
||||
.find(|per_run| batch_size.is_multiple_of(*per_run))
|
||||
.expect("Logically not possible"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(unused)]
|
||||
pub(crate) fn batches_per_run(
|
||||
batch_size: usize,
|
||||
out_shape: usize,
|
||||
plane_size: usize,
|
||||
) -> Result<usize, ConvSetupError> {
|
||||
Ok(1)
|
||||
}
|
||||
|
||||
pub fn conv_im2col_1x1<R: CubeRuntime, const N: usize>(
|
||||
input: CubeTensor<R>,
|
||||
mut weight: CubeTensor<R>,
|
||||
bias: Option<CubeTensor<R>>,
|
||||
options: ConvOptions<N>,
|
||||
) -> Result<CubeTensor<R>, ConvSetupError> {
|
||||
if options.groups != 1 {
|
||||
return Err(ConvSetupError::Groups(options.groups));
|
||||
}
|
||||
|
||||
let rank = input.meta.num_dims();
|
||||
let dim_c = rank - 1;
|
||||
|
||||
let batch_size = input.meta.shape()[0];
|
||||
let in_channels = input.meta.shape()[dim_c];
|
||||
let in_shape = &input.meta.shape()[1..dim_c];
|
||||
let out_channels = weight.meta.shape()[0];
|
||||
let kernel_shape = &weight.meta.shape()[1..dim_c];
|
||||
|
||||
if kernel_shape.iter().any(|s| *s != 1) {
|
||||
return Err(ConvSetupError::Unknown);
|
||||
}
|
||||
|
||||
let out_shape = calculate_conv_output_sizes(
|
||||
kernel_shape,
|
||||
&options.stride,
|
||||
&options.padding,
|
||||
&options.dilation,
|
||||
in_shape,
|
||||
);
|
||||
|
||||
let mut split_m = vec![batch_size];
|
||||
split_m.extend(out_shape.iter().copied());
|
||||
|
||||
if kernel_shape.iter().any(|it| *it != 1) || in_shape != out_shape {
|
||||
return Err(ConvSetupError::Unknown);
|
||||
}
|
||||
|
||||
let input = reshape_input(input); // [(NHW), C] : [M, K]
|
||||
let dtype = input.dtype;
|
||||
|
||||
// Efficient permutation that takes the stride required for TMA into account
|
||||
let weight = if weight.meta.strides()[dim_c] != 1 {
|
||||
// Remove kernel dims so padded dim is channels
|
||||
*weight.meta = Metadata::new(
|
||||
[out_channels, in_channels], // [N, K]
|
||||
[weight.meta.strides()[0], weight.meta.strides()[dim_c]],
|
||||
);
|
||||
// Pitched contiguous to skip running another kernel for TMA
|
||||
into_contiguous_aligned(weight)
|
||||
} else {
|
||||
// Already compatible, skip initial reshape
|
||||
*weight.meta = Metadata::new([out_channels, in_channels], [weight.meta.strides()[0], 1]);
|
||||
weight
|
||||
};
|
||||
|
||||
// Permute to N-major, while keeping memory layout K-major. K-major for both sides is the most
|
||||
// efficient for matmul, and allows skipping a contiguous kernel
|
||||
let weight = swap_dims(weight, 0, 1); // [K, N]
|
||||
|
||||
let out = matmul(input, weight, None, MatmulStrategy::default(), dtype)?; // [M, N]
|
||||
|
||||
// Skip reshape to avoid potential `into_contiguous`. We're only splitting dims so it's safe.
|
||||
let mut out = split_dim(out, 0, &split_m); // [N, H, W, C]
|
||||
|
||||
if let Some(bias) = bias {
|
||||
let mut bias_shape = iter::repeat_n(1, rank - 1).collect::<Vec<_>>();
|
||||
bias_shape.push(out_channels);
|
||||
let bias = reshape(bias, bias_shape.into());
|
||||
out = launch_binop::<R, AddOp>(out, bias);
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Reshapes NHWC input to [(N, H, W), C]
|
||||
fn reshape_input<R: CubeRuntime>(mut input: CubeTensor<R>) -> CubeTensor<R> {
|
||||
let rank = input.meta.num_dims();
|
||||
let dim_c = rank - 1;
|
||||
let dtype = input.dtype;
|
||||
|
||||
let batch_size = input.meta.shape()[0];
|
||||
let in_c: usize = input.meta.shape()[dim_c];
|
||||
let in_shape: Shape = input.meta.shape()[1..dim_c].into();
|
||||
|
||||
if !is_spatial_contiguous(input.meta.shape(), input.meta.strides()) {
|
||||
let contiguous =
|
||||
into_contiguous_pitched_ref(&input.client, &input.as_handle_ref(), dtype.into())
|
||||
.expect("Kernel to never fail");
|
||||
input = from_handle(&input.client, &input.device, contiguous, dtype);
|
||||
}
|
||||
*input.meta = Metadata::new(
|
||||
[batch_size * in_shape.num_elements(), in_c], // [M, K]
|
||||
[input.meta.strides()[dim_c - 1], input.meta.strides()[dim_c]],
|
||||
);
|
||||
input
|
||||
}
|
||||
|
||||
fn is_spatial_contiguous(shape: &[usize], strides: &[usize]) -> bool {
|
||||
let rank = shape.len();
|
||||
|
||||
let mut ordered = strides.to_vec();
|
||||
ordered.sort();
|
||||
if ordered != strides {
|
||||
return false;
|
||||
}
|
||||
|
||||
for i in (1..rank - 2).rev() {
|
||||
if strides[i + 1] * shape[i + 1] != strides[i] {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn from_handle<R: CubeRuntime>(
|
||||
client: &ComputeClient<R>,
|
||||
device: &R::Device,
|
||||
handle: TensorHandle<R>,
|
||||
dtype: DType,
|
||||
) -> CubeTensor<R> {
|
||||
CubeTensor::new(
|
||||
client.clone(),
|
||||
handle.handle,
|
||||
*handle.metadata,
|
||||
device.clone(),
|
||||
dtype,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
mod backward_data;
|
||||
mod backward_weight;
|
||||
mod base;
|
||||
mod conv_transpose2d;
|
||||
mod conv_transpose3d;
|
||||
mod deform_conv2d;
|
||||
mod deform_conv_transpose2d;
|
||||
mod direct;
|
||||
mod forward;
|
||||
mod im2col;
|
||||
|
||||
mod tune_key;
|
||||
|
||||
pub(crate) use backward_data::*;
|
||||
pub(crate) use conv_transpose2d::*;
|
||||
pub(crate) use conv_transpose3d::*;
|
||||
pub(crate) use deform_conv_transpose2d::*;
|
||||
pub(crate) use deform_conv2d::*;
|
||||
pub(crate) use direct::*;
|
||||
pub(crate) use im2col::*;
|
||||
|
||||
pub use base::*;
|
||||
pub use conv_transpose2d::{ConvTranspose2dStrategy, conv_transpose2d};
|
||||
|
||||
pub(crate) use tune_key::*;
|
||||
@@ -0,0 +1,50 @@
|
||||
use burn_backend::DType;
|
||||
use cubecl::AutotuneKey;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
|
||||
/// Autotune key representative of matmul versions
|
||||
pub struct ConvAutotuneKey {
|
||||
pub kernel_size: Vec<usize>,
|
||||
pub stride: Vec<usize>,
|
||||
pub padding: Vec<usize>,
|
||||
pub dilation: Vec<usize>,
|
||||
pub groups: usize,
|
||||
#[autotune(anchor)]
|
||||
pub in_channels: usize,
|
||||
#[autotune(anchor)]
|
||||
pub out_channels: usize,
|
||||
pub shape: Vec<usize>,
|
||||
#[autotune(anchor)]
|
||||
pub batch_size: usize,
|
||||
pub has_bias: bool,
|
||||
pub dtype: DType,
|
||||
|
||||
pub lhs_shape_align: u8,
|
||||
pub lhs_stride_align: u8,
|
||||
pub rhs_shape_align: u8,
|
||||
pub rhs_stride_align: u8,
|
||||
}
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
|
||||
/// Autotune key representative of matmul versions
|
||||
pub struct ConvTranspose2dAutotuneKey {
|
||||
pub kernel_size: [usize; 2],
|
||||
pub stride: [usize; 2],
|
||||
pub padding: [usize; 2],
|
||||
pub padding_out: [usize; 2],
|
||||
pub dilation: [usize; 2],
|
||||
pub groups: usize,
|
||||
#[autotune(anchor)]
|
||||
pub in_channels: usize,
|
||||
#[autotune(anchor)]
|
||||
pub out_channels: usize,
|
||||
#[autotune(anchor)]
|
||||
pub height: usize,
|
||||
#[autotune(anchor)]
|
||||
pub width: usize,
|
||||
#[autotune(anchor)]
|
||||
pub batch_size: usize,
|
||||
pub has_bias: bool,
|
||||
pub dtype: DType,
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{address_type, broadcast_shape, linear_view, linear_view_ref},
|
||||
ops::numeric::empty_device_dtype,
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use cubecl::std::tensor::layout::linear::LinearView;
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*};
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn cross_kernel<E: Float>(
|
||||
lhs: &LinearView<Line<E>>,
|
||||
rhs: &LinearView<Line<E>>,
|
||||
output: &mut LinearView<Line<E>, ReadWrite>,
|
||||
#[define(E)] _dtype: StorageType,
|
||||
) {
|
||||
// Each thread processes one 3-element vector
|
||||
let vector_idx = ABSOLUTE_POS;
|
||||
let base_pos = vector_idx * 3;
|
||||
|
||||
if !output.is_in_bounds(base_pos) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
// Extract vectors
|
||||
let a0 = lhs[base_pos];
|
||||
let a1 = lhs[base_pos + 1];
|
||||
let a2 = lhs[base_pos + 2];
|
||||
let b0 = rhs[base_pos];
|
||||
let b1 = rhs[base_pos + 1];
|
||||
let b2 = rhs[base_pos + 2];
|
||||
|
||||
// Compute cross product: a × b
|
||||
let x = a1 * b2 - a2 * b1;
|
||||
let y = a2 * b0 - a0 * b2;
|
||||
let z = a0 * b1 - a1 * b0;
|
||||
|
||||
// Store result
|
||||
output[base_pos] = x;
|
||||
output[base_pos + 1] = y;
|
||||
output[base_pos + 2] = z;
|
||||
}
|
||||
|
||||
pub(crate) fn cross<R: CubeRuntime>(
|
||||
lhs: CubeTensor<R>,
|
||||
rhs: CubeTensor<R>,
|
||||
dim: usize,
|
||||
) -> CubeTensor<R> {
|
||||
let ndims = lhs.meta.num_dims();
|
||||
|
||||
// Validate that the cross dimension has size 3
|
||||
if lhs.meta.shape()[dim] != 3 || rhs.meta.shape()[dim] != 3 {
|
||||
panic!(
|
||||
"Cross product requires dimension {} to have size 3, but got {} and {}",
|
||||
dim,
|
||||
lhs.meta.shape()[dim],
|
||||
rhs.meta.shape()[dim]
|
||||
);
|
||||
}
|
||||
|
||||
// For now, only support cross on the last dimension
|
||||
if dim != ndims - 1 {
|
||||
unimplemented!(
|
||||
"Cross product on non-last dimension not yet implemented for CubeCL backend"
|
||||
);
|
||||
}
|
||||
|
||||
let output_shape = broadcast_shape(&[&lhs, &rhs]);
|
||||
|
||||
// Since the cross dimension is forced to be size 3, line size would be restricted to 1 anyway
|
||||
let line_size = 1;
|
||||
|
||||
let output = empty_device_dtype(
|
||||
lhs.client.clone(),
|
||||
lhs.device.clone(),
|
||||
output_shape.clone(),
|
||||
lhs.dtype,
|
||||
);
|
||||
|
||||
// Number of vectors to process
|
||||
let num_vectors = output_shape.num_elements() / 3;
|
||||
|
||||
let cube_dim = CubeDim::new(&lhs.client, num_vectors);
|
||||
let cube_count = calculate_cube_count_elemwise(&lhs.client, num_vectors, cube_dim);
|
||||
|
||||
unsafe {
|
||||
cross_kernel::launch_unchecked(
|
||||
&lhs.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(lhs, rhs, output),
|
||||
linear_view_ref(&lhs, &output, line_size),
|
||||
linear_view_ref(&rhs, &output, line_size),
|
||||
linear_view(&output, line_size),
|
||||
lhs.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
};
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
use cubecl::prelude::*;
|
||||
|
||||
use crate::{CubeRuntime, tensor::CubeTensor};
|
||||
use burn_backend::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
|
||||
|
||||
use super::bilinear::grid_sample_bilinear_launch;
|
||||
|
||||
/// Grid sample operation supporting bilinear interpolation
|
||||
pub fn grid_sample<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
grid: CubeTensor<R>,
|
||||
options: GridSampleOptions,
|
||||
) -> CubeTensor<R> {
|
||||
match options.mode {
|
||||
InterpolateMode::Bilinear => grid_sample_bilinear_launch(input, grid, options),
|
||||
_ => panic!(
|
||||
"Unsupported grid_sample interpolation mode: {:?}",
|
||||
options.mode
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compile-time padding mode for kernel specialization
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum PaddingMode {
|
||||
/// Fill with zeros for out-of-bounds coordinates.
|
||||
Zeros,
|
||||
/// Clamp coordinates to the border (use nearest edge value).
|
||||
Border,
|
||||
/// Reflect coordinates at the boundary.
|
||||
Reflection,
|
||||
}
|
||||
|
||||
impl From<GridSamplePaddingMode> for PaddingMode {
|
||||
fn from(mode: GridSamplePaddingMode) -> Self {
|
||||
match mode {
|
||||
GridSamplePaddingMode::Zeros => PaddingMode::Zeros,
|
||||
GridSamplePaddingMode::Border => PaddingMode::Border,
|
||||
GridSamplePaddingMode::Reflection => PaddingMode::Reflection,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch value based on padding mode (dispatch to appropriate handler)
|
||||
#[cube]
|
||||
pub(crate) fn fetch_value<F: Float>(
|
||||
input: &Tensor<F>,
|
||||
base: usize,
|
||||
stride_h: usize,
|
||||
stride_w: usize,
|
||||
y: i32,
|
||||
x: i32,
|
||||
h: i32,
|
||||
w: i32,
|
||||
#[comptime] padding_mode: PaddingMode,
|
||||
) -> F {
|
||||
match padding_mode {
|
||||
PaddingMode::Zeros => fetch_with_zeros(input, base, stride_h, stride_w, y, x, h, w),
|
||||
PaddingMode::Border => fetch_with_border(input, base, stride_h, stride_w, y, x, h, w),
|
||||
PaddingMode::Reflection => {
|
||||
fetch_with_reflection(input, base, stride_h, stride_w, y, x, h, w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch value with zeros padding (return 0 for out-of-bounds).
|
||||
#[cube]
|
||||
pub(crate) fn fetch_with_zeros<F: Float>(
|
||||
input: &Tensor<F>,
|
||||
base: usize,
|
||||
stride_h: usize,
|
||||
stride_w: usize,
|
||||
y: i32,
|
||||
x: i32,
|
||||
h: i32,
|
||||
w: i32,
|
||||
) -> F {
|
||||
let in_bounds = x >= 0 && x < w && y >= 0 && y < h;
|
||||
let x_clamped = clamp(x, 0, w - 1) as usize;
|
||||
let y_clamped = clamp(y, 0, h - 1) as usize;
|
||||
let idx = base + y_clamped * stride_h + x_clamped * stride_w;
|
||||
select(in_bounds, input[idx], F::new(0.0))
|
||||
}
|
||||
|
||||
/// Fetch value with border padding (clamp to edge).
|
||||
#[cube]
|
||||
pub(crate) fn fetch_with_border<F: Float>(
|
||||
input: &Tensor<F>,
|
||||
base: usize,
|
||||
stride_h: usize,
|
||||
stride_w: usize,
|
||||
y: i32,
|
||||
x: i32,
|
||||
h: i32,
|
||||
w: i32,
|
||||
) -> F {
|
||||
let x_clamped = clamp(x, 0, w - 1) as usize;
|
||||
let y_clamped = clamp(y, 0, h - 1) as usize;
|
||||
let idx = base + y_clamped * stride_h + x_clamped * stride_w;
|
||||
input[idx]
|
||||
}
|
||||
|
||||
/// Fetch value with reflection padding.
|
||||
/// Assumes float reflection was applied to center, so indices are at most 2 steps out of bounds.
|
||||
#[cube]
|
||||
pub(crate) fn fetch_with_reflection<F: Float>(
|
||||
input: &Tensor<F>,
|
||||
base: usize,
|
||||
stride_h: usize,
|
||||
stride_w: usize,
|
||||
y: i32,
|
||||
x: i32,
|
||||
h: i32,
|
||||
w: i32,
|
||||
) -> F {
|
||||
let x_reflected = reflect_coord_bounded(x, w);
|
||||
let y_reflected = reflect_coord_bounded(y, h);
|
||||
let idx = base + y_reflected * stride_h + x_reflected * stride_w;
|
||||
input[idx]
|
||||
}
|
||||
|
||||
/// Reflect an integer index that may be out of bounds.
|
||||
/// After float reflection, indices can be up to 2 steps out for bicubic (1 step for bilinear).
|
||||
#[cube]
|
||||
fn reflect_coord_bounded(idx: i32, size: i32) -> usize {
|
||||
let max_idx = size - 1;
|
||||
let neg_reflected = -idx - 1;
|
||||
let pos_reflected = 2 * max_idx + 1 - idx;
|
||||
let result = select(
|
||||
idx < 0,
|
||||
neg_reflected,
|
||||
select(idx > max_idx, pos_reflected, idx),
|
||||
);
|
||||
clamp(result, 0, max_idx) as usize
|
||||
}
|
||||
|
||||
/// Reflect a float coordinate into the valid sampling range.
|
||||
#[cube]
|
||||
pub(crate) fn reflect_coord<F: Float>(coord: F, size: u32, #[comptime] align_corners: bool) -> F {
|
||||
let size_f = F::cast_from(size);
|
||||
if align_corners {
|
||||
reflect_float_impl::<F>(coord, F::new(0.0), size_f - F::new(1.0))
|
||||
} else {
|
||||
reflect_float_impl::<F>(coord, F::new(-0.5), size_f - F::new(0.5))
|
||||
}
|
||||
}
|
||||
|
||||
/// Reflect a float coordinate into [min_val, max_val] using a triangle wave pattern.
|
||||
#[cube]
|
||||
fn reflect_float_impl<F: Float>(coord: F, min_val: F, max_val: F) -> F {
|
||||
let span = max_val - min_val;
|
||||
|
||||
let is_valid = span > F::new(0.0);
|
||||
let safe_span = select(is_valid, span, F::new(1.0));
|
||||
|
||||
// Triangle wave formula: span - |((x mod 2*span) - span)| + min_val
|
||||
let period = safe_span * F::new(2.0);
|
||||
let x = (coord - min_val).abs();
|
||||
let x_mod = x - (x / period).floor() * period;
|
||||
let reflected = safe_span - (x_mod - safe_span).abs() + min_val;
|
||||
|
||||
select(is_valid, reflected, min_val)
|
||||
}
|
||||
@@ -0,0 +1,177 @@
|
||||
use cubecl::std::{FastDivmod, FastDivmodArgs};
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*};
|
||||
|
||||
use crate::{
|
||||
CubeRuntime, kernel::utils::address_type, ops::numeric::empty_device_dtype, tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::{Shape, ops::GridSampleOptions};
|
||||
|
||||
use super::base::{PaddingMode, fetch_value, reflect_coord};
|
||||
|
||||
/// Grid sample with bilinear interpolation.
|
||||
///
|
||||
/// Each thread processes all channels for one spatial output position:
|
||||
/// 1. Reading (x, y) coordinates from the grid tensor (once per spatial position)
|
||||
/// 2. Converting normalized [-1, 1] coords to pixel coordinates (once)
|
||||
/// 3. For each channel: fetch 4 corner values, interpolate, and write output
|
||||
#[cube(launch, address_type = "dynamic")]
|
||||
fn grid_sample_bilinear_kernel<F: Float>(
|
||||
input: &Tensor<F>, // [N, C, H_in, W_in]
|
||||
grid: &Tensor<F>, // [N, H_out, W_out, 2]
|
||||
output: &mut Tensor<F>, // [N, C, H_out, W_out]
|
||||
shape_spatial: Sequence<FastDivmod<usize>>, // [N, H_out, W_out] for thread decomposition
|
||||
#[comptime] align_corners: bool,
|
||||
#[comptime] pad_mode: PaddingMode,
|
||||
#[define(F)] _dtype: StorageType,
|
||||
) {
|
||||
// Thread index maps to spatial position (n, h_out, w_out) only
|
||||
let spatial_idx = ABSOLUTE_POS;
|
||||
let num_spatial = output.shape(0) * output.shape(2) * output.shape(3);
|
||||
if spatial_idx >= num_spatial {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
// Decompose spatial index into (n, h_out, w_out)
|
||||
let (rem, w_out) = shape_spatial[2].div_mod(spatial_idx);
|
||||
let (n, h_out) = shape_spatial[1].div_mod(rem);
|
||||
|
||||
let channels = input.shape(1) as u32;
|
||||
let h_in = input.shape(2) as u32;
|
||||
let w_in = input.shape(3) as u32;
|
||||
|
||||
// Read grid coordinates once per spatial position
|
||||
let grid_offset = n * grid.stride(0) + h_out * grid.stride(1) + w_out * grid.stride(2);
|
||||
let gx = grid[grid_offset]; // x coordinate in [-1, 1]
|
||||
let gy = grid[grid_offset + 1]; // y coordinate in [-1, 1]
|
||||
|
||||
// Convert normalized coordinates to pixel coordinates
|
||||
let (px, py) = if align_corners {
|
||||
let px = (gx + F::new(1.0)) * F::cast_from((w_in - 1) as f32) / F::new(2.0);
|
||||
let py = (gy + F::new(1.0)) * F::cast_from((h_in - 1) as f32) / F::new(2.0);
|
||||
(px, py)
|
||||
} else {
|
||||
let px = (gx + F::new(1.0)) * F::cast_from(w_in as f32) / F::new(2.0) - F::new(0.5);
|
||||
let py = (gy + F::new(1.0)) * F::cast_from(h_in as f32) / F::new(2.0) - F::new(0.5);
|
||||
(px, py)
|
||||
};
|
||||
|
||||
// For reflection padding, reflect the coordinate into the valid sampling range.
|
||||
// This ensures integer indices are at most 1 step out of bounds.
|
||||
let (px, py) = if comptime!(pad_mode == PaddingMode::Reflection) {
|
||||
let px = reflect_coord::<F>(px, w_in, align_corners);
|
||||
let py = reflect_coord::<F>(py, h_in, align_corners);
|
||||
(px, py)
|
||||
} else {
|
||||
(px, py)
|
||||
};
|
||||
|
||||
// Compute floor and ceil indices
|
||||
let x0_f = px.floor();
|
||||
let y0_f = py.floor();
|
||||
let x1_f = x0_f + F::new(1.0);
|
||||
let y1_f = y0_f + F::new(1.0);
|
||||
|
||||
// Compute interpolation weights
|
||||
let wx = px - x0_f;
|
||||
let wy = py - y0_f;
|
||||
let wx_ = F::new(1.0) - wx;
|
||||
let wy_ = F::new(1.0) - wy;
|
||||
|
||||
// Convert to integers for indexing
|
||||
let x0 = i32::cast_from(x0_f);
|
||||
let y0 = i32::cast_from(y0_f);
|
||||
let x1 = i32::cast_from(x1_f);
|
||||
let y1 = i32::cast_from(y1_f);
|
||||
|
||||
let w_in = w_in as i32;
|
||||
let h_in = h_in as i32;
|
||||
|
||||
// Pre-compute strides
|
||||
let stride_n = input.stride(0);
|
||||
let stride_c = input.stride(1);
|
||||
let stride_h = input.stride(2);
|
||||
let stride_w = input.stride(3);
|
||||
let out_stride_n = output.stride(0);
|
||||
let out_stride_c = output.stride(1);
|
||||
let out_stride_h = output.stride(2);
|
||||
let out_stride_w = output.stride(3);
|
||||
|
||||
// Base offsets for this spatial position
|
||||
let in_base_n = n * stride_n;
|
||||
let out_base_spatial = n * out_stride_n + h_out * out_stride_h + w_out * out_stride_w;
|
||||
|
||||
// Loop over all channels - grid coords and weights are reused
|
||||
for c in 0..channels {
|
||||
let in_base = in_base_n + c as usize * stride_c;
|
||||
|
||||
let v00 = fetch_value(
|
||||
input, in_base, stride_h, stride_w, y0, x0, h_in, w_in, pad_mode,
|
||||
);
|
||||
let v01 = fetch_value(
|
||||
input, in_base, stride_h, stride_w, y1, x0, h_in, w_in, pad_mode,
|
||||
);
|
||||
let v10 = fetch_value(
|
||||
input, in_base, stride_h, stride_w, y0, x1, h_in, w_in, pad_mode,
|
||||
);
|
||||
let v11 = fetch_value(
|
||||
input, in_base, stride_h, stride_w, y1, x1, h_in, w_in, pad_mode,
|
||||
);
|
||||
|
||||
// Bilinear interpolation
|
||||
let result = wx_ * wy_ * v00 + wx_ * wy * v01 + wx * wy_ * v10 + wx * wy * v11;
|
||||
|
||||
let out_idx = out_base_spatial + c as usize * out_stride_c;
|
||||
output[out_idx] = result;
|
||||
}
|
||||
}
|
||||
|
||||
/// Launch the grid sample bilinear kernel
|
||||
pub(crate) fn grid_sample_bilinear_launch<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
grid: CubeTensor<R>,
|
||||
options: GridSampleOptions,
|
||||
) -> CubeTensor<R> {
|
||||
let [batch_size, channels, _h_in, _w_in] = input.meta.shape().dims();
|
||||
let [_n, h_out, w_out, two] = grid.meta.shape().dims();
|
||||
assert_eq!(two, 2, "Grid last dimension must be 2");
|
||||
|
||||
// Create output tensor [N, C, H_out, W_out]
|
||||
let output_shape = Shape::new([batch_size, channels, h_out, w_out]);
|
||||
let output = empty_device_dtype(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
output_shape,
|
||||
input.dtype,
|
||||
);
|
||||
|
||||
// Spatial threading: one thread per (n, h_out, w_out)
|
||||
let spatial_shape = Shape::new([batch_size, h_out, w_out]);
|
||||
let num_spatial = spatial_shape.num_elements();
|
||||
|
||||
let mut shape_spatial = SequenceArg::new();
|
||||
for dim in spatial_shape.iter() {
|
||||
shape_spatial.push(FastDivmodArgs::new(&input.client, *dim));
|
||||
}
|
||||
|
||||
let cube_dim = CubeDim::new(&input.client, num_spatial);
|
||||
let cube_count = calculate_cube_count_elemwise(&input.client, num_spatial, cube_dim);
|
||||
|
||||
let padding_mode: PaddingMode = options.padding_mode.into();
|
||||
|
||||
grid_sample_bilinear_kernel::launch(
|
||||
&input.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(input, grid, output),
|
||||
input.as_tensor_arg(1),
|
||||
grid.as_tensor_arg(1),
|
||||
output.as_tensor_arg(1),
|
||||
shape_spatial,
|
||||
options.align_corners,
|
||||
padding_mode,
|
||||
input.dtype.into(),
|
||||
)
|
||||
.expect("Grid sample kernel failed");
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
mod base;
|
||||
mod bilinear;
|
||||
|
||||
pub use base::*;
|
||||
@@ -0,0 +1,99 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{address_type, linear_view, shape_divmod},
|
||||
ops::numeric::empty_device_dtype,
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::{DType, TensorMetadata};
|
||||
use cubecl::{
|
||||
calculate_cube_count_elemwise,
|
||||
prelude::*,
|
||||
std::{FastDivmod, tensor::layout::linear::LinearView},
|
||||
};
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn flip_kernel<E: Numeric, Bool: Int>(
|
||||
input: &Tensor<E>,
|
||||
output: &mut LinearView<E, ReadWrite>,
|
||||
in_shape: Sequence<FastDivmod<usize>>,
|
||||
indices: Sequence<InputScalar>,
|
||||
#[define(E, Bool)] _dtypes: [StorageType; 2],
|
||||
) {
|
||||
if !output.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let rank = in_shape.len().comptime();
|
||||
|
||||
let mut offset = ABSOLUTE_POS;
|
||||
let mut offset_input = 0;
|
||||
|
||||
#[unroll]
|
||||
for i in 0..rank {
|
||||
let dim = rank - i - 1;
|
||||
let shape = input.shape(dim);
|
||||
|
||||
let (rem, offset_local) = in_shape[dim].div_mod(offset);
|
||||
offset = rem;
|
||||
|
||||
let flip = indices.index(dim).get::<Bool>() == Bool::from_int(1);
|
||||
let offset_local = select(flip, shape - offset_local - 1, offset_local);
|
||||
|
||||
offset_input += offset_local * input.stride(dim);
|
||||
}
|
||||
|
||||
output[ABSOLUTE_POS] = input[offset_input];
|
||||
}
|
||||
|
||||
pub(crate) fn flip<R: CubeRuntime>(
|
||||
tensor: CubeTensor<R>,
|
||||
indices: &[usize],
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
let output = empty_device_dtype(
|
||||
tensor.client.clone(),
|
||||
tensor.device.clone(),
|
||||
tensor.shape(),
|
||||
tensor.dtype,
|
||||
);
|
||||
flip_on_output(tensor, output, indices, dtype_bool)
|
||||
}
|
||||
|
||||
pub(crate) fn flip_on_output<R: CubeRuntime>(
|
||||
tensor: CubeTensor<R>,
|
||||
output: CubeTensor<R>,
|
||||
indices: &[usize],
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
let dtype_input = tensor.dtype;
|
||||
let ndims = tensor.meta.num_dims();
|
||||
let mut indices_sequence = SequenceArg::<'_, R, InputScalar>::new();
|
||||
|
||||
for i in 0..ndims {
|
||||
indices_sequence.push({
|
||||
let val = indices.contains(&i) as u8;
|
||||
InputScalar::new(val, dtype_bool)
|
||||
});
|
||||
}
|
||||
|
||||
let num_elements = output.meta.num_elements();
|
||||
let cube_dim = CubeDim::new(&tensor.client, num_elements);
|
||||
let cube_count = calculate_cube_count_elemwise(&tensor.client, num_elements, cube_dim);
|
||||
|
||||
unsafe {
|
||||
flip_kernel::launch_unchecked(
|
||||
&tensor.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor, output),
|
||||
tensor.as_tensor_arg(1),
|
||||
linear_view(&output, 1),
|
||||
shape_divmod(&tensor),
|
||||
indices_sequence,
|
||||
[dtype_input.into(), dtype_bool.into()],
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{address_type, broadcast_strides, linear_view, shape_divmod},
|
||||
ops::numeric::empty_device_dtype,
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::TensorMetadata;
|
||||
use cubecl::frontend::{ABSOLUTE_POS, Numeric, Tensor};
|
||||
use cubecl::std::{FastDivmod, tensor::index_offset_contiguous_fastdivmod};
|
||||
use cubecl::{CubeDim, std::tensor::layout::linear::LinearView};
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*};
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn gather_kernel<T: Numeric, I: Numeric>(
|
||||
input: &Tensor<Line<T>>,
|
||||
indices: &LinearView<Line<I>>,
|
||||
output: &mut LinearView<Line<T>, ReadWrite>,
|
||||
in_strides: Sequence<usize>, // zeroed out for broadcast dims and `dim`
|
||||
out_shape: Sequence<FastDivmod<usize>>,
|
||||
dim: usize,
|
||||
#[define(T, I)] _dtypes: [StorageType; 2],
|
||||
) {
|
||||
if !indices.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let mut offset = index_offset_contiguous_fastdivmod(
|
||||
ABSOLUTE_POS,
|
||||
&out_shape,
|
||||
&in_strides,
|
||||
input.line_size(),
|
||||
);
|
||||
|
||||
offset += usize::cast_from(indices[ABSOLUTE_POS]) * input.stride(dim);
|
||||
|
||||
output[ABSOLUTE_POS] = input[offset];
|
||||
}
|
||||
|
||||
pub(crate) fn gather<R: CubeRuntime>(
|
||||
dim: usize,
|
||||
tensor: CubeTensor<R>,
|
||||
indices: CubeTensor<R>,
|
||||
) -> CubeTensor<R> {
|
||||
let shape_output = indices.shape();
|
||||
let total_elem = shape_output.num_elements();
|
||||
let output = empty_device_dtype(
|
||||
tensor.client.clone(),
|
||||
tensor.device.clone(),
|
||||
shape_output,
|
||||
tensor.dtype,
|
||||
);
|
||||
|
||||
let cube_dim = CubeDim::new(&tensor.client, total_elem);
|
||||
let cube_count = calculate_cube_count_elemwise(&tensor.client, total_elem, cube_dim);
|
||||
let mut in_strides = broadcast_strides(&output, &tensor);
|
||||
in_strides.values[dim] = ScalarArg::new(0); // Zero `dim` to exclude it from the indexing
|
||||
|
||||
unsafe {
|
||||
gather_kernel::launch_unchecked(
|
||||
&tensor.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor, indices, output),
|
||||
tensor.as_tensor_arg(1),
|
||||
linear_view(&indices, 1),
|
||||
linear_view(&output, 1),
|
||||
in_strides,
|
||||
shape_divmod(&output),
|
||||
ScalarArg::new(dim),
|
||||
[tensor.dtype.into(), indices.dtype.into()],
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
mod flip;
|
||||
mod gather;
|
||||
mod repeat_dim;
|
||||
mod scatter;
|
||||
mod select;
|
||||
mod select_assign;
|
||||
mod slice;
|
||||
mod slice_assign;
|
||||
|
||||
pub(crate) use flip::*;
|
||||
pub(crate) use repeat_dim::*;
|
||||
pub(crate) use select::*;
|
||||
pub(crate) use select_assign::*;
|
||||
pub use slice::*;
|
||||
pub(crate) use slice_assign::*;
|
||||
|
||||
pub(crate) use gather::*;
|
||||
pub(crate) use scatter::*;
|
||||
@@ -0,0 +1,93 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{address_type, shape_divmod},
|
||||
ops::numeric::empty_device_dtype,
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use cubecl::{
|
||||
calculate_cube_count_elemwise,
|
||||
prelude::*,
|
||||
std::{FastDivmod, FastDivmodArgs},
|
||||
};
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn repeat_dim_kernel<E: Numeric>(
|
||||
input: &Tensor<E>,
|
||||
output: &mut Tensor<E>,
|
||||
out_shape: Sequence<FastDivmod<usize>>,
|
||||
in_shape: FastDivmod<usize>,
|
||||
#[comptime] dim: usize,
|
||||
#[define(E)] _dtype: StorageType,
|
||||
) {
|
||||
if ABSOLUTE_POS >= output.len() {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let rank = out_shape.len().comptime();
|
||||
|
||||
let mut pos = ABSOLUTE_POS;
|
||||
let mut offset_input = 0;
|
||||
let mut offset_output = 0;
|
||||
|
||||
#[unroll]
|
||||
for i in 0..rank {
|
||||
let i = rank - i - 1;
|
||||
|
||||
let (rem, mut local_pos) = out_shape[i].div_mod(pos);
|
||||
pos = rem;
|
||||
|
||||
offset_output += local_pos * output.stride(i);
|
||||
|
||||
if i == dim {
|
||||
local_pos = in_shape.modulo(local_pos);
|
||||
}
|
||||
|
||||
offset_input += local_pos * input.stride(i);
|
||||
}
|
||||
|
||||
output[offset_output] = input[offset_input];
|
||||
}
|
||||
|
||||
pub(crate) fn repeat_dim<R: CubeRuntime>(
|
||||
mut input: CubeTensor<R>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> CubeTensor<R> {
|
||||
if input.meta.shape()[dim] == 1 {
|
||||
input.meta.strides[dim] = 0;
|
||||
input.meta.shape = input.meta.shape.repeat(dim, times).unwrap();
|
||||
return input;
|
||||
}
|
||||
|
||||
let shape = input.meta.shape.clone().repeat(dim, times).unwrap();
|
||||
|
||||
// Create output handle
|
||||
let output = empty_device_dtype(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
shape,
|
||||
input.dtype,
|
||||
);
|
||||
|
||||
let working_units = output.meta.num_elements();
|
||||
let cube_dim = CubeDim::new(&input.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
repeat_dim_kernel::launch_unchecked(
|
||||
&input.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(input, output),
|
||||
input.as_tensor_arg(1),
|
||||
output.as_tensor_arg(1),
|
||||
shape_divmod(&output),
|
||||
FastDivmodArgs::new(&input.client, input.meta.shape()[dim]),
|
||||
dim,
|
||||
output.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
};
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::{
|
||||
AddOp, BinaryOp, BinaryOpFamily, OrOp,
|
||||
utils::{address_type, shape_divmod},
|
||||
},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use cubecl::{CubeDim, calculate_cube_count_elemwise};
|
||||
use cubecl::{prelude::*, std::FastDivmod};
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn scatter_kernel<T: Numeric, I: Int, Op: BinaryOpFamily>(
|
||||
input: &mut Tensor<T>,
|
||||
indices: &Tensor<I>,
|
||||
value: &Tensor<T>,
|
||||
in_shape: Sequence<FastDivmod<usize>>,
|
||||
#[comptime] dim: usize,
|
||||
#[define(T, I)] _dtypes: [StorageType; 2],
|
||||
) {
|
||||
let rank = in_shape.len().comptime();
|
||||
let stride_input = input.stride(dim);
|
||||
let stride_value = value.stride(dim);
|
||||
let stride_indices = indices.stride(dim);
|
||||
let shape_value = value.shape(dim);
|
||||
|
||||
let mut offset = ABSOLUTE_POS;
|
||||
let mut offset_input = 0;
|
||||
let mut offset_indices = 0;
|
||||
let mut offset_value = 0;
|
||||
let mut num_elems = 1;
|
||||
|
||||
#[unroll]
|
||||
for i in 0..rank {
|
||||
let i = rank - i - 1;
|
||||
if i != dim {
|
||||
let shape_input_loop = input.shape(i);
|
||||
|
||||
let (rem, local_pos) = in_shape[i].div_mod(offset);
|
||||
offset = rem;
|
||||
|
||||
offset_input += local_pos * input.stride(i);
|
||||
offset_indices += local_pos * indices.stride(i);
|
||||
offset_value += local_pos * value.stride(i);
|
||||
|
||||
num_elems *= shape_input_loop;
|
||||
}
|
||||
}
|
||||
|
||||
let should_stop = ABSOLUTE_POS >= num_elems;
|
||||
if should_stop {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
for i in 0..shape_value {
|
||||
let value_idx = (stride_value * i) + offset_value;
|
||||
let index_idx = (stride_indices * i) + offset_indices;
|
||||
|
||||
let value = value[value_idx];
|
||||
let index = usize::cast_from(indices[index_idx]);
|
||||
|
||||
let input_idx = (stride_input * index) + offset_input;
|
||||
|
||||
let value =
|
||||
Op::BinaryOp::<T>::execute(Line::cast_from(input[input_idx]), Line::cast_from(value));
|
||||
input[input_idx] = value[0];
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn scatter<R: CubeRuntime>(
|
||||
dim: usize,
|
||||
tensor: CubeTensor<R>,
|
||||
indices: CubeTensor<R>,
|
||||
value: CubeTensor<R>,
|
||||
is_bool: bool,
|
||||
) -> CubeTensor<R> {
|
||||
let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() {
|
||||
true => tensor,
|
||||
false => tensor.copy(),
|
||||
};
|
||||
|
||||
let num_elems = tensor.meta.num_elements() / tensor.meta.shape()[dim];
|
||||
|
||||
let working_units = num_elems;
|
||||
let cube_dim = CubeDim::new(&indices.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&indices.client, working_units, cube_dim);
|
||||
|
||||
let launch = match is_bool {
|
||||
true => scatter_kernel::launch_unchecked::<OrOp, R>,
|
||||
false => scatter_kernel::launch_unchecked::<AddOp, R>,
|
||||
};
|
||||
|
||||
unsafe {
|
||||
launch(
|
||||
&indices.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor, indices, value),
|
||||
tensor.as_tensor_arg(1),
|
||||
indices.as_tensor_arg(1),
|
||||
value.as_tensor_arg(1),
|
||||
shape_divmod(&tensor),
|
||||
dim,
|
||||
[tensor.dtype.into(), indices.dtype.into()],
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
}
|
||||
tensor
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
use crate::{CubeRuntime, kernel::utils::address_type, tensor::CubeTensor};
|
||||
use crate::{
|
||||
kernel::utils::{linear_view, shape_divmod},
|
||||
ops::numeric::empty_device_dtype,
|
||||
};
|
||||
use burn_backend::TensorMetadata;
|
||||
use cubecl::{CubeDim, calculate_cube_count_elemwise, std::tensor::layout::linear::LinearView};
|
||||
use cubecl::{prelude::*, std::FastDivmod};
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn select_kernel<T: Numeric, I: Numeric>(
|
||||
input: &Tensor<T>,
|
||||
indices: &LinearView<I>,
|
||||
output: &mut LinearView<T, ReadWrite>,
|
||||
out_shape: Sequence<FastDivmod<usize>>,
|
||||
dim: usize,
|
||||
#[define(T, I)] _dtypes: [StorageType; 2],
|
||||
) {
|
||||
if ABSOLUTE_POS >= output.shape() {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let rank = out_shape.len().comptime();
|
||||
|
||||
let mut offset = ABSOLUTE_POS;
|
||||
let mut offset_input = 0;
|
||||
|
||||
#[unroll]
|
||||
for i in 0..rank {
|
||||
let i = rank - i - 1;
|
||||
let (rem, offset_local) = out_shape[i].div_mod(offset);
|
||||
offset = rem;
|
||||
|
||||
let offset_local = cubecl::prelude::select(
|
||||
i == dim,
|
||||
usize::cast_from(indices[offset_local]),
|
||||
offset_local,
|
||||
);
|
||||
|
||||
offset_input += offset_local * input.stride(i);
|
||||
}
|
||||
|
||||
output[ABSOLUTE_POS] = input[offset_input];
|
||||
}
|
||||
|
||||
pub(crate) fn select<R: CubeRuntime>(
|
||||
tensor: CubeTensor<R>,
|
||||
dim: usize,
|
||||
indices: CubeTensor<R>,
|
||||
) -> CubeTensor<R> {
|
||||
let mut shape_output = tensor.shape();
|
||||
shape_output[dim] = indices.meta.shape()[0];
|
||||
let total_elem = shape_output.num_elements();
|
||||
|
||||
let output = empty_device_dtype(
|
||||
tensor.client.clone(),
|
||||
tensor.device.clone(),
|
||||
shape_output,
|
||||
tensor.dtype,
|
||||
);
|
||||
|
||||
let working_units = total_elem;
|
||||
let cube_dim = CubeDim::new(&indices.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&indices.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
select_kernel::launch_unchecked(
|
||||
&tensor.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor, indices, output),
|
||||
tensor.as_tensor_arg(1),
|
||||
linear_view(&indices, 1),
|
||||
linear_view(&output, 1),
|
||||
shape_divmod(&output),
|
||||
ScalarArg::new(dim),
|
||||
[tensor.dtype.into(), indices.dtype.into()],
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
};
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
use crate::kernel::{
|
||||
AddOp, BinaryOp, BinaryOpFamily, OrOp,
|
||||
utils::{address_type, linear_view, shape_divmod},
|
||||
};
|
||||
use crate::{CubeRuntime, tensor::CubeTensor};
|
||||
use cubecl::{CubeDim, calculate_cube_count_elemwise, std::tensor::layout::linear::LinearView};
|
||||
use cubecl::{prelude::*, std::FastDivmod};
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn select_assign_kernel<F: Numeric, I: Numeric, Op: BinaryOpFamily>(
|
||||
tensor: &mut Tensor<F>,
|
||||
indices: &LinearView<I>,
|
||||
value: &Tensor<F>,
|
||||
value_shape: Sequence<FastDivmod<usize>>,
|
||||
num_elems: usize,
|
||||
#[comptime] dim: usize,
|
||||
#[define(F, I)] _dtypes: [StorageType; 2],
|
||||
) {
|
||||
if ABSOLUTE_POS >= num_elems {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let rank = value_shape.len().comptime();
|
||||
|
||||
let mut offset = ABSOLUTE_POS;
|
||||
let mut offset_tensor = 0;
|
||||
let mut offset_value = 0;
|
||||
|
||||
// Calculate offsets and num_elems
|
||||
#[unroll]
|
||||
for i in 0..rank {
|
||||
let i = rank - i - 1;
|
||||
if i != dim {
|
||||
let (rem, local_pos) = value_shape[i].div_mod(offset);
|
||||
offset = rem;
|
||||
|
||||
offset_tensor += local_pos * tensor.stride(i);
|
||||
offset_value += local_pos * value.stride(i);
|
||||
}
|
||||
}
|
||||
|
||||
let strides_tensor_dim = tensor.stride(dim);
|
||||
let strides_value_dim = value.stride(dim);
|
||||
|
||||
// Main operation
|
||||
for i in 0..value.shape(dim) {
|
||||
let index_tensor = usize::cast_from(indices[i]) * strides_tensor_dim + offset_tensor;
|
||||
let index_value = i * strides_value_dim + offset_value;
|
||||
|
||||
let value = Op::BinaryOp::<F>::execute(
|
||||
Line::cast_from(tensor[index_tensor]),
|
||||
Line::cast_from(value[index_value]),
|
||||
);
|
||||
tensor[index_tensor] = F::cast_from(value);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn select_assign<R: CubeRuntime>(
|
||||
tensor: CubeTensor<R>,
|
||||
dim: usize,
|
||||
indices: CubeTensor<R>,
|
||||
value: CubeTensor<R>,
|
||||
is_bool: bool,
|
||||
) -> CubeTensor<R> {
|
||||
let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() {
|
||||
true => tensor,
|
||||
false => tensor.copy(),
|
||||
};
|
||||
|
||||
let num_elems = tensor.meta.num_elements() / tensor.meta.shape()[dim];
|
||||
let working_units = num_elems;
|
||||
let cube_dim = CubeDim::new(&indices.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&indices.client, working_units, cube_dim);
|
||||
|
||||
let launch = match is_bool {
|
||||
true => select_assign_kernel::launch_unchecked::<OrOp, R>,
|
||||
false => select_assign_kernel::launch_unchecked::<AddOp, R>,
|
||||
};
|
||||
|
||||
unsafe {
|
||||
launch(
|
||||
&tensor.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor, indices, value),
|
||||
tensor.as_tensor_arg(1),
|
||||
linear_view(&indices, 1),
|
||||
value.as_tensor_arg(1),
|
||||
shape_divmod(&value),
|
||||
ScalarArg::new(num_elems),
|
||||
dim,
|
||||
[tensor.dtype.into(), indices.dtype.into()],
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
};
|
||||
|
||||
tensor
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{address_type, linear_view, shape_divmod},
|
||||
ops::numeric::empty_device_dtype,
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::{Slice, TensorMetadata};
|
||||
use burn_std::{Metadata, SliceOps};
|
||||
use cubecl::{
|
||||
calculate_cube_count_elemwise, intrinsic,
|
||||
prelude::*,
|
||||
std::{FastDivmod, tensor::layout::linear::LinearView},
|
||||
};
|
||||
use std::ops::Range;
|
||||
|
||||
/// Slice a jit tensor with a set of ranges
|
||||
pub fn slice<R: CubeRuntime>(tensor: CubeTensor<R>, indices: &[Range<usize>]) -> CubeTensor<R> {
|
||||
let mut dims = tensor.shape();
|
||||
let mut offset_start = 0u64;
|
||||
let mut offset_end = 0u64;
|
||||
|
||||
for i in 0..indices.len() {
|
||||
offset_start += (tensor.meta.strides()[i] * indices[i].start) as u64;
|
||||
offset_end += (tensor.meta.strides()[i] * (dims[i] - indices[i].end)) as u64;
|
||||
dims[i] = indices[i].end - indices[i].start;
|
||||
}
|
||||
|
||||
let offset_start = offset_start * tensor.dtype.size() as u64;
|
||||
let offset_end = offset_end * tensor.dtype.size() as u64;
|
||||
|
||||
let memory_offset_alignment = tensor.client.properties().memory.alignment;
|
||||
|
||||
if offset_start.is_multiple_of(memory_offset_alignment)
|
||||
&& offset_end.is_multiple_of(memory_offset_alignment)
|
||||
{
|
||||
CubeTensor::new(
|
||||
tensor.client,
|
||||
tensor
|
||||
.handle
|
||||
.offset_start(offset_start)
|
||||
.offset_end(offset_end),
|
||||
Metadata::new(dims, tensor.meta.strides),
|
||||
tensor.device,
|
||||
tensor.dtype,
|
||||
)
|
||||
} else {
|
||||
let output = empty_device_dtype(
|
||||
tensor.client.clone(),
|
||||
tensor.device.clone(),
|
||||
dims,
|
||||
tensor.dtype,
|
||||
);
|
||||
slice_on_output(tensor, output, indices)
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn slice_kernel<E: Numeric>(
|
||||
input: &Tensor<E>,
|
||||
output: &mut LinearView<E, ReadWrite>,
|
||||
out_shape: Sequence<FastDivmod<usize>>,
|
||||
indices: Sequence<usize>,
|
||||
#[define(E)] _dtype: StorageType,
|
||||
) {
|
||||
if !output.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let rank = comptime![out_shape.len()];
|
||||
let mut offset_output = ABSOLUTE_POS;
|
||||
let mut offset_input = 0;
|
||||
|
||||
#[unroll]
|
||||
for i in 0..rank {
|
||||
// Iterate in reverse to use divmod
|
||||
let dim = rank - i - 1;
|
||||
|
||||
let range_start = indices[dim];
|
||||
let (rem, offset_local) = out_shape[dim].div_mod(offset_output);
|
||||
offset_output = rem;
|
||||
|
||||
let offset_local = offset_local + range_start;
|
||||
|
||||
offset_input += offset_local * input.stride(dim);
|
||||
}
|
||||
|
||||
output[ABSOLUTE_POS] = input[offset_input];
|
||||
}
|
||||
|
||||
pub(crate) fn slice_on_output<R: CubeRuntime>(
|
||||
tensor: CubeTensor<R>,
|
||||
output: CubeTensor<R>,
|
||||
indices: &[Range<usize>],
|
||||
) -> CubeTensor<R> {
|
||||
let ndims = tensor.meta.num_dims();
|
||||
let mut indices_sequence = SequenceArg::<R, usize>::new();
|
||||
|
||||
for i in 0..ndims {
|
||||
let start = indices.get(i).map(|index| index.start).unwrap_or(0);
|
||||
indices_sequence.push(ScalarArg::new(start));
|
||||
}
|
||||
|
||||
let working_units = output.meta.num_elements();
|
||||
let cube_dim = CubeDim::new(&tensor.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
slice_kernel::launch_unchecked(
|
||||
&tensor.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor, output),
|
||||
tensor.as_tensor_arg(1),
|
||||
linear_view(&output, 1),
|
||||
shape_divmod(&output),
|
||||
indices_sequence,
|
||||
tensor.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
};
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Kernel for slicing with steps
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn slice_with_steps_kernel<E: Numeric>(
|
||||
input: &Tensor<E>,
|
||||
output: &mut LinearView<E, ReadWrite>,
|
||||
out_shape: Sequence<FastDivmod<usize>>,
|
||||
starts: Sequence<usize>,
|
||||
ends: Sequence<usize>,
|
||||
steps: Sequence<i32>,
|
||||
#[define(E)] _dtype: StorageType,
|
||||
) {
|
||||
if !output.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let rank = comptime![out_shape.len()];
|
||||
let mut output_offset = ABSOLUTE_POS;
|
||||
let mut input_offset = 0;
|
||||
|
||||
// Calculate the input offset based on output position and slice info
|
||||
#[unroll]
|
||||
for i in 0..rank {
|
||||
// Iterate in reverse to use divmod
|
||||
let dim = rank - i - 1;
|
||||
let start = starts[dim];
|
||||
let end = ends[dim];
|
||||
let step = steps[dim];
|
||||
|
||||
let (rem, output_idx) = out_shape[dim].div_mod(output_offset);
|
||||
output_offset = rem;
|
||||
|
||||
let input_idx = if step > 0 {
|
||||
// Forward stepping
|
||||
start + output_idx * (step as usize)
|
||||
} else {
|
||||
// Backward stepping - start from end-1
|
||||
let abs_step = (-step) as usize;
|
||||
let end_minus_1 = end - 1;
|
||||
end_minus_1 - output_idx * abs_step
|
||||
};
|
||||
|
||||
input_offset += input_idx * input.stride(dim);
|
||||
}
|
||||
|
||||
output[ABSOLUTE_POS] = input[input_offset];
|
||||
}
|
||||
|
||||
/// Slice a tensor with steps
|
||||
pub fn slice_with_steps<R: CubeRuntime>(tensor: CubeTensor<R>, slices: &[Slice]) -> CubeTensor<R> {
|
||||
// Check if all steps are 1 - if so, use the optimized regular slice
|
||||
let all_steps_one = slices.iter().all(|info| info.step == 1);
|
||||
|
||||
if all_steps_one {
|
||||
// Convert Slice to Range for step=1
|
||||
let simple_ranges: Vec<Range<usize>> = slices
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, slice)| slice.to_range(tensor.meta.shape()[i]))
|
||||
.collect();
|
||||
return slice(tensor, &simple_ranges);
|
||||
}
|
||||
|
||||
// Calculate output shape
|
||||
let shape_output = tensor.shape().slice(slices).unwrap();
|
||||
|
||||
// Create output tensor
|
||||
let output = empty_device_dtype(
|
||||
tensor.client.clone(),
|
||||
tensor.device.clone(),
|
||||
shape_output.clone(),
|
||||
tensor.dtype,
|
||||
);
|
||||
|
||||
// Prepare three separate sequences for kernel
|
||||
let mut starts = SequenceArg::<R, usize>::new();
|
||||
let mut ends = SequenceArg::<R, usize>::new();
|
||||
let mut steps = SequenceArg::<R, i32>::new();
|
||||
|
||||
for (dim, slice) in slices.iter().enumerate() {
|
||||
let range = slice.to_range(tensor.meta.shape()[dim]);
|
||||
starts.push(ScalarArg::new(range.start));
|
||||
ends.push(ScalarArg::new(range.end));
|
||||
steps.push(ScalarArg::new(slice.step as i32));
|
||||
}
|
||||
|
||||
// Pad with default values if needed to match tensor dimensions
|
||||
for dim in slices.len()..tensor.meta.num_dims() {
|
||||
starts.push(ScalarArg::new(0));
|
||||
ends.push(ScalarArg::new(tensor.meta.shape()[dim]));
|
||||
steps.push(ScalarArg::new(1));
|
||||
}
|
||||
|
||||
// Launch kernel
|
||||
let working_units = shape_output.num_elements();
|
||||
let cube_dim = CubeDim::new(&tensor.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
slice_with_steps_kernel::launch_unchecked(
|
||||
&tensor.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor, output),
|
||||
tensor.as_tensor_arg(1),
|
||||
linear_view(&output, 1),
|
||||
shape_divmod(&output),
|
||||
starts,
|
||||
ends,
|
||||
steps,
|
||||
tensor.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// This is annoying and we need to find a way to do this automatically at some point
|
||||
#[allow(unused)]
|
||||
#[cube]
|
||||
fn unwrap(value: u32) -> comptime_type!(u32) {
|
||||
intrinsic!(|_| value.constant().unwrap().as_u32())
|
||||
}
|
||||
@@ -0,0 +1,258 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{address_type, linear_view, shape_divmod},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use cubecl::{
|
||||
calculate_cube_count_elemwise, intrinsic,
|
||||
prelude::*,
|
||||
std::{FastDivmod, FastDivmodArgs, tensor::layout::linear::LinearView},
|
||||
};
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn slice_assign_kernel<E: Numeric>(
|
||||
input: &mut Tensor<Line<E>>,
|
||||
value: &LinearView<Line<E>>,
|
||||
slice_shape: Sequence<FastDivmod<usize>>,
|
||||
slice_offsets: Sequence<usize>,
|
||||
#[define(E)] _dtype: StorageType,
|
||||
) {
|
||||
if !value.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!()
|
||||
}
|
||||
|
||||
let rank = comptime!(slice_shape.len());
|
||||
|
||||
let line_size = input.line_size();
|
||||
let mut offset_remainder = ABSOLUTE_POS * line_size;
|
||||
let mut offset_input = 0;
|
||||
|
||||
#[allow(clippy::explicit_counter_loop)]
|
||||
#[unroll]
|
||||
for i in 0..rank {
|
||||
let dim = rank - i - 1;
|
||||
let (rem, offset_local) = slice_shape[dim].div_mod(offset_remainder);
|
||||
|
||||
let range_start = slice_offsets[dim];
|
||||
let offset_local_input = offset_local + range_start;
|
||||
|
||||
offset_input += offset_local_input * input.stride(dim);
|
||||
offset_remainder = rem;
|
||||
}
|
||||
|
||||
// Value tensor is accessed linearly since it's a LinearView
|
||||
input[offset_input / line_size] = value[ABSOLUTE_POS];
|
||||
}
|
||||
|
||||
/// Kernel for slice assign with steps
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn slice_assign_with_steps_kernel<E: Numeric>(
|
||||
input: &mut Tensor<E>,
|
||||
value: &LinearView<E>,
|
||||
value_shape: Sequence<FastDivmod<usize>>,
|
||||
starts: Sequence<usize>,
|
||||
ends: Sequence<usize>,
|
||||
steps: Sequence<i32>,
|
||||
#[define(E)] _dtype: StorageType,
|
||||
) {
|
||||
if !value.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let rank = comptime![value_shape.len()];
|
||||
let mut value_offset = ABSOLUTE_POS;
|
||||
let mut input_offset = 0;
|
||||
|
||||
// Calculate the input offset based on value position and slice info
|
||||
#[unroll]
|
||||
for i in 0..rank {
|
||||
// Iterate in reverse to use divmod
|
||||
let dim = rank - i - 1;
|
||||
let start = starts[dim];
|
||||
let end = ends[dim];
|
||||
let step = steps[dim];
|
||||
|
||||
let (rem, value_idx) = value_shape[dim].div_mod(value_offset);
|
||||
value_offset = rem;
|
||||
|
||||
let input_idx = if step > 0 {
|
||||
// Forward stepping
|
||||
start + value_idx * (step as usize)
|
||||
} else if step < 0 {
|
||||
// Backward stepping - start from end-1
|
||||
// For negative steps, we iterate backwards through the selected indices
|
||||
let abs_step = (-step) as usize;
|
||||
let end_minus_1 = end - 1;
|
||||
end_minus_1 - value_idx * abs_step
|
||||
} else {
|
||||
// step == 0, shouldn't happen
|
||||
value_idx
|
||||
};
|
||||
|
||||
input_offset += input_idx * input.stride(dim);
|
||||
}
|
||||
|
||||
input[input_offset] = value[ABSOLUTE_POS];
|
||||
}
|
||||
|
||||
pub(crate) fn slice_assign<R: CubeRuntime>(
|
||||
tensor: CubeTensor<R>,
|
||||
indices: &[burn_backend::Slice],
|
||||
value: CubeTensor<R>,
|
||||
) -> CubeTensor<R> {
|
||||
// Check if any slice has non-unit step
|
||||
let has_non_unit_step = indices.iter().any(|s| s.step != 1 && s.step != 0);
|
||||
|
||||
if has_non_unit_step {
|
||||
// Use slice_assign_with_steps
|
||||
return slice_assign_with_steps(tensor, indices, value);
|
||||
}
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() {
|
||||
true => tensor,
|
||||
false => tensor.copy(),
|
||||
};
|
||||
let ndims = tensor.meta.num_dims();
|
||||
|
||||
let line_size = if tensor.meta.strides()[ndims - 1] == 1 && value.meta.strides()[ndims - 1] == 1
|
||||
{
|
||||
let last = indices
|
||||
.get(ndims - 1)
|
||||
.cloned()
|
||||
.unwrap_or(burn_backend::Slice {
|
||||
start: 0,
|
||||
end: Some(tensor.meta.shape()[ndims - 1] as isize),
|
||||
step: 1,
|
||||
});
|
||||
let end = last.end.unwrap_or(tensor.meta.shape()[ndims - 1] as isize);
|
||||
let shape = (end - last.start) as usize;
|
||||
let offset = last.start as usize;
|
||||
client
|
||||
.io_optimized_line_sizes(tensor.dtype.size())
|
||||
.filter(|&it| {
|
||||
shape.is_multiple_of(it)
|
||||
&& strides_compatible(tensor.meta.strides(), it)
|
||||
&& strides_compatible(value.meta.strides(), it)
|
||||
&& offset.is_multiple_of(it)
|
||||
})
|
||||
.max()
|
||||
.unwrap_or(1)
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
||||
let mut shape = SequenceArg::<R, FastDivmod<usize>>::new();
|
||||
let mut offsets = SequenceArg::<R, usize>::new();
|
||||
|
||||
for i in 0..ndims {
|
||||
let slice = indices.get(i).cloned().unwrap_or(burn_backend::Slice {
|
||||
start: 0,
|
||||
end: Some(tensor.meta.shape()[i] as isize),
|
||||
step: 1,
|
||||
});
|
||||
let start = slice.start as usize;
|
||||
let end = slice.end.unwrap_or(tensor.meta.shape()[i] as isize);
|
||||
let length = (end - slice.start) as usize;
|
||||
|
||||
shape.push(FastDivmodArgs::<usize>::new(&client, length));
|
||||
offsets.push(ScalarArg::new(start));
|
||||
}
|
||||
|
||||
let working_units = value.meta.num_elements() / line_size;
|
||||
let cube_dim = CubeDim::new(&tensor.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
slice_assign_kernel::launch_unchecked(
|
||||
&tensor.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor, value),
|
||||
tensor.as_tensor_arg(line_size),
|
||||
linear_view(&value, line_size),
|
||||
shape,
|
||||
offsets,
|
||||
tensor.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
|
||||
/// Slice assign with steps support
|
||||
///
|
||||
/// This function handles slice assignment with arbitrary step values, including negative steps.
|
||||
/// It follows NumPy/PyTorch semantics where values[i] is assigned to selected_indices[i].
|
||||
///
|
||||
/// For example, with s![0..6;-1] which selects indices [5,4,3,2,1,0]:
|
||||
/// - values[0] goes to index 5
|
||||
/// - values[1] goes to index 4
|
||||
/// - etc.
|
||||
pub(crate) fn slice_assign_with_steps<R: CubeRuntime>(
|
||||
tensor: CubeTensor<R>,
|
||||
slices: &[burn_backend::Slice],
|
||||
value: CubeTensor<R>,
|
||||
) -> CubeTensor<R> {
|
||||
let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() {
|
||||
true => tensor,
|
||||
false => tensor.copy(),
|
||||
};
|
||||
|
||||
// Prepare sequences for kernel
|
||||
let mut starts = SequenceArg::<R, usize>::new();
|
||||
let mut ends = SequenceArg::<R, usize>::new();
|
||||
let mut steps = SequenceArg::<R, i32>::new();
|
||||
|
||||
for (dim, slice) in slices.iter().enumerate() {
|
||||
let range = slice.to_range(tensor.meta.shape()[dim]);
|
||||
starts.push(ScalarArg::new(range.start));
|
||||
ends.push(ScalarArg::new(range.end));
|
||||
steps.push(ScalarArg::new(slice.step as i32));
|
||||
}
|
||||
|
||||
// Pad with default values if needed to match tensor dimensions
|
||||
for dim in slices.len()..tensor.meta.num_dims() {
|
||||
starts.push(ScalarArg::new(0));
|
||||
ends.push(ScalarArg::new(tensor.meta.shape()[dim]));
|
||||
steps.push(ScalarArg::new(1));
|
||||
}
|
||||
|
||||
// Launch kernel
|
||||
let working_units = value.meta.num_elements();
|
||||
let cube_dim = CubeDim::new(&tensor.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
slice_assign_with_steps_kernel::launch_unchecked(
|
||||
&tensor.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor, value),
|
||||
tensor.as_tensor_arg(1),
|
||||
linear_view(&value, 1),
|
||||
shape_divmod(&value),
|
||||
starts,
|
||||
ends,
|
||||
steps,
|
||||
tensor.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
|
||||
fn strides_compatible(strides: &[usize], vec: usize) -> bool {
|
||||
strides
|
||||
.iter()
|
||||
.all(|stride| *stride % vec == 0 || *stride == 1)
|
||||
}
|
||||
|
||||
/// Helper function for unwrap
|
||||
#[allow(unused)]
|
||||
#[cube]
|
||||
fn unwrap(value: u32) -> comptime_type!(u32) {
|
||||
intrinsic!(|_| value.constant().unwrap().as_u32())
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::into_contiguous,
|
||||
ops::{numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::{
|
||||
Shape, TensorMetadata,
|
||||
ops::{InterpolateMode, InterpolateOptions},
|
||||
};
|
||||
|
||||
use super::{
|
||||
bicubic::interpolate_bicubic_launch, bilinear::interpolate_bilinear_launch,
|
||||
nearest::interpolate_nearest_launch, nearest_backward::interpolate_nearest_backward_launch,
|
||||
};
|
||||
|
||||
/// Interpolate operation
|
||||
///
|
||||
/// Supports nearest, bilinear and bicubic modes
|
||||
pub fn interpolate<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
output_size: [usize; 2],
|
||||
options: InterpolateOptions,
|
||||
) -> CubeTensor<R> {
|
||||
let [batch_size, channels, _, _] = input.meta.shape().dims();
|
||||
let [out_height, out_width] = output_size;
|
||||
|
||||
let input = into_contiguous(permute_nchw_to_nhwc(input));
|
||||
|
||||
let shape_out = Shape::new([batch_size, out_height, out_width, channels]);
|
||||
let output = empty_device_dtype(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
shape_out,
|
||||
input.dtype,
|
||||
);
|
||||
|
||||
let align_corners = options.align_corners;
|
||||
let output = match options.mode {
|
||||
InterpolateMode::Nearest => interpolate_nearest_launch(input, output),
|
||||
InterpolateMode::Bilinear => interpolate_bilinear_launch(input, output, align_corners),
|
||||
InterpolateMode::Bicubic => interpolate_bicubic_launch(input, output, align_corners),
|
||||
};
|
||||
|
||||
permute_nhwc_to_nchw(output)
|
||||
}
|
||||
|
||||
/// Backward interpolate operation
|
||||
///
|
||||
/// Note: only nearest mode is supported
|
||||
pub fn interpolate_backward<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
out_grad: CubeTensor<R>,
|
||||
_output_size: [usize; 2],
|
||||
options: InterpolateOptions,
|
||||
) -> CubeTensor<R> {
|
||||
let input = permute_nchw_to_nhwc(input);
|
||||
let out_grad = permute_nchw_to_nhwc(out_grad);
|
||||
|
||||
let output_shape = input.shape();
|
||||
let output = empty_device_dtype(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
output_shape,
|
||||
input.dtype,
|
||||
);
|
||||
|
||||
let output = match options.mode {
|
||||
InterpolateMode::Nearest => interpolate_nearest_backward_launch(out_grad, output),
|
||||
InterpolateMode::Bilinear => {
|
||||
panic!("bilinear interpolation backward is not supported by JIT backend")
|
||||
}
|
||||
InterpolateMode::Bicubic => {
|
||||
panic!("bicubic interpolation backward is not supported by JIT backend")
|
||||
}
|
||||
};
|
||||
|
||||
permute_nhwc_to_nchw(output)
|
||||
}
|
||||
@@ -0,0 +1,194 @@
|
||||
use cubecl::std::{
|
||||
FastDivmod,
|
||||
tensor::layout::{linear::LinearLayout, *},
|
||||
};
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*};
|
||||
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{address_type, linear_layout, shape_divmod},
|
||||
ops::max_line_size,
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
|
||||
#[cube(launch, address_type = "dynamic")]
|
||||
fn interpolate_bicubic_kernel<F: Float>(
|
||||
input: &Tensor<Line<F>>,
|
||||
output: &mut Tensor<Line<F>>,
|
||||
shape_out: Sequence<FastDivmod<usize>>,
|
||||
out_layout: LinearLayout,
|
||||
#[comptime] align_corners: bool,
|
||||
#[define(F)] _dtype: StorageType,
|
||||
) {
|
||||
if ABSOLUTE_POS >= output.len() {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let line_size = input.line_size();
|
||||
let out_idx = out_layout.to_source_pos(ABSOLUTE_POS);
|
||||
|
||||
let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * line_size);
|
||||
let (rem, x) = shape_out[2].div_mod(rem);
|
||||
let (b, y) = shape_out[1].div_mod(rem);
|
||||
|
||||
let input_height = input.shape(1) - 1;
|
||||
let input_height_f = input_height as f32;
|
||||
|
||||
let frac = if align_corners {
|
||||
let output_height = clamp_min(output.shape(1) - 1, 1) as f32;
|
||||
(y * input_height) as f32 / output_height
|
||||
} else {
|
||||
let in_size = (input_height + 1) as f32;
|
||||
let out_size = output.shape(1) as f32;
|
||||
(y as f32 + 0.5) * (in_size / out_size) - 0.5
|
||||
};
|
||||
let y_in_f = frac.floor();
|
||||
let yw = Line::empty(line_size).fill(F::cast_from(frac - y_in_f));
|
||||
|
||||
// Clamp indices in float space to handle negative coordinates from half_pixel
|
||||
let y0 = clamp(y_in_f - 1.0, 0.0, input_height_f) as usize;
|
||||
let y1 = clamp(y_in_f, 0.0, input_height_f) as usize;
|
||||
let y2 = clamp(y_in_f + 1.0, 0.0, input_height_f) as usize;
|
||||
let y3 = clamp(y_in_f + 2.0, 0.0, input_height_f) as usize;
|
||||
|
||||
let input_width = input.shape(2) - 1;
|
||||
let input_width_f = input_width as f32;
|
||||
|
||||
let frac = if align_corners {
|
||||
let output_width = clamp_min(output.shape(2) - 1, 1) as f32;
|
||||
(x * input_width) as f32 / output_width
|
||||
} else {
|
||||
let in_size = (input_width + 1) as f32;
|
||||
let out_size = output.shape(2) as f32;
|
||||
(x as f32 + 0.5) * (in_size / out_size) - 0.5
|
||||
};
|
||||
let x_in_f = frac.floor();
|
||||
let xw = Line::empty(line_size).fill(F::cast_from(frac - x_in_f));
|
||||
|
||||
// Clamp indices in float space to handle negative coordinates from half_pixel
|
||||
let x0 = clamp(x_in_f - 1.0, 0.0, input_width_f) as usize;
|
||||
let x1 = clamp(x_in_f, 0.0, input_width_f) as usize;
|
||||
let x2 = clamp(x_in_f + 1.0, 0.0, input_width_f) as usize;
|
||||
let x3 = clamp(x_in_f + 2.0, 0.0, input_width_f) as usize;
|
||||
|
||||
let index_base = b * input.stride(0) + c * input.stride(3);
|
||||
let in_stride_y = input.stride(1);
|
||||
let in_stride_x = input.stride(2);
|
||||
|
||||
let y0_stride = y0 * in_stride_y;
|
||||
let y1_stride = y1 * in_stride_y;
|
||||
let y2_stride = y2 * in_stride_y;
|
||||
let y3_stride = y3 * in_stride_y;
|
||||
let x0_stride = x0 * in_stride_x;
|
||||
let x1_stride = x1 * in_stride_x;
|
||||
let x2_stride = x2 * in_stride_x;
|
||||
let x3_stride = x3 * in_stride_x;
|
||||
|
||||
let inp_0 = input[(index_base + y0_stride + x0_stride) / line_size];
|
||||
let inp_1 = input[(index_base + y0_stride + x1_stride) / line_size];
|
||||
let inp_2 = input[(index_base + y0_stride + x2_stride) / line_size];
|
||||
let inp_3 = input[(index_base + y0_stride + x3_stride) / line_size];
|
||||
|
||||
let coefficients0 = cubic_interp_1d::<F>(inp_0, inp_1, inp_2, inp_3, xw);
|
||||
|
||||
let inp_0 = input[(index_base + y1_stride + x0_stride) / line_size];
|
||||
let inp_1 = input[(index_base + y1_stride + x1_stride) / line_size];
|
||||
let inp_2 = input[(index_base + y1_stride + x2_stride) / line_size];
|
||||
let inp_3 = input[(index_base + y1_stride + x3_stride) / line_size];
|
||||
|
||||
let coefficients1 = cubic_interp_1d::<F>(inp_0, inp_1, inp_2, inp_3, xw);
|
||||
|
||||
let inp_0 = input[(index_base + y2_stride + x0_stride) / line_size];
|
||||
let inp_1 = input[(index_base + y2_stride + x1_stride) / line_size];
|
||||
let inp_2 = input[(index_base + y2_stride + x2_stride) / line_size];
|
||||
let inp_3 = input[(index_base + y2_stride + x3_stride) / line_size];
|
||||
|
||||
let coefficients2 = cubic_interp_1d::<F>(inp_0, inp_1, inp_2, inp_3, xw);
|
||||
|
||||
let inp_0 = input[(index_base + y3_stride + x0_stride) / line_size];
|
||||
let inp_1 = input[(index_base + y3_stride + x1_stride) / line_size];
|
||||
let inp_2 = input[(index_base + y3_stride + x2_stride) / line_size];
|
||||
let inp_3 = input[(index_base + y3_stride + x3_stride) / line_size];
|
||||
|
||||
let coefficients3 = cubic_interp_1d::<F>(inp_0, inp_1, inp_2, inp_3, xw);
|
||||
|
||||
let val = cubic_interp_1d::<F>(
|
||||
coefficients0,
|
||||
coefficients1,
|
||||
coefficients2,
|
||||
coefficients3,
|
||||
yw,
|
||||
);
|
||||
|
||||
output[out_idx] = val;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn cubic_interp_1d<F: Float>(
|
||||
x0: Line<F>,
|
||||
x1: Line<F>,
|
||||
x2: Line<F>,
|
||||
x3: Line<F>,
|
||||
t: Line<F>,
|
||||
) -> Line<F> {
|
||||
let a = lined(&x0, -0.75);
|
||||
|
||||
let coeffs0 = cubic_convolution_2::<F>(t + lined(&x0, 1.0), a);
|
||||
let coeffs1 = cubic_convolution_1::<F>(t, a);
|
||||
let coeffs2 = cubic_convolution_1::<F>(lined(&x0, 1.0) - t, a);
|
||||
let coeffs3 = cubic_convolution_2::<F>(lined(&x0, 2.0) - t, a);
|
||||
|
||||
x0 * coeffs0 + x1 * coeffs1 + x2 * coeffs2 + x3 * coeffs3
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn cubic_convolution_1<F: Float>(x: Line<F>, a: Line<F>) -> Line<F> {
|
||||
let conv = (a + lined(&x, 2.0)) * x;
|
||||
let tmp = a + lined(&x, 3.0);
|
||||
(conv - tmp) * x * x + lined(&x, 1.0)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn cubic_convolution_2<F: Float>(x: Line<F>, a: Line<F>) -> Line<F> {
|
||||
let conv = a * x;
|
||||
let conv = (conv - lined(&x, 5.0) * a) * x;
|
||||
let tmp = lined(&x, 8.0) * a;
|
||||
let conv = (conv + tmp) * x;
|
||||
|
||||
conv - lined(&x, 4.0) * a
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn lined<F: Float>(x: &Line<F>, #[comptime] v: f32) -> Line<F> {
|
||||
Line::empty(x.size()).fill(F::new(v))
|
||||
}
|
||||
|
||||
pub(crate) fn interpolate_bicubic_launch<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
output: CubeTensor<R>,
|
||||
align_corners: bool,
|
||||
) -> CubeTensor<R> {
|
||||
let line_size = max_line_size(&input);
|
||||
let out_shape = shape_divmod(&output);
|
||||
let out_layout = linear_layout(&output, line_size);
|
||||
|
||||
let working_units = output.meta.num_elements() / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&input.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
|
||||
|
||||
interpolate_bicubic_kernel::launch(
|
||||
&input.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(input, output),
|
||||
input.as_tensor_arg(line_size),
|
||||
output.as_tensor_arg(line_size),
|
||||
out_shape,
|
||||
out_layout,
|
||||
align_corners,
|
||||
output.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,149 @@
|
||||
use cubecl::std::{
|
||||
FastDivmod,
|
||||
tensor::layout::{linear::LinearLayout, *},
|
||||
};
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*};
|
||||
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{address_type, linear_layout, shape_divmod},
|
||||
ops::max_line_size,
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
|
||||
#[cube(launch, address_type = "dynamic")]
|
||||
fn interpolate_bilinear_kernel<F: Float>(
|
||||
input: &Tensor<Line<F>>,
|
||||
output: &mut Tensor<Line<F>>,
|
||||
shape_out: Sequence<FastDivmod<usize>>,
|
||||
out_layout: LinearLayout,
|
||||
#[comptime] align_corners: bool,
|
||||
#[define(F)] _dtype: StorageType,
|
||||
) {
|
||||
if ABSOLUTE_POS >= output.len() {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let line_size = input.line_size();
|
||||
let out_idx = out_layout.to_source_pos(ABSOLUTE_POS);
|
||||
|
||||
let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * line_size);
|
||||
let (rem, x) = shape_out[2].div_mod(rem);
|
||||
let (b, y) = shape_out[1].div_mod(rem);
|
||||
|
||||
let frac = if align_corners {
|
||||
let numerator = (input.shape(1) - 1) as f32;
|
||||
let denominator = clamp_min(output.shape(1) - 1, 1) as f32;
|
||||
y as f32 * (numerator / denominator)
|
||||
} else {
|
||||
let in_size = input.shape(1) as f32;
|
||||
let out_size = output.shape(1) as f32;
|
||||
clamp(
|
||||
(y as f32 + 0.5) * (in_size / out_size) - 0.5,
|
||||
0.0,
|
||||
in_size - 1.0,
|
||||
)
|
||||
};
|
||||
|
||||
let v0 = frac.floor();
|
||||
let v1 = frac.ceil();
|
||||
let yw = F::cast_from(frac - v0);
|
||||
let yw_ = Line::empty(line_size).fill(F::new(1.0) - yw);
|
||||
let yw = Line::empty(line_size).fill(yw);
|
||||
let y0_ok = v0 >= 0.0;
|
||||
let y0 = v0 as usize;
|
||||
let y1 = v1 as usize;
|
||||
|
||||
let frac = if align_corners {
|
||||
let numerator = (input.shape(2) - 1) as f32;
|
||||
let denominator = clamp_min(output.shape(2) - 1, 1) as f32;
|
||||
x as f32 * (numerator / denominator)
|
||||
} else {
|
||||
let in_size = input.shape(2) as f32;
|
||||
let out_size = output.shape(2) as f32;
|
||||
clamp(
|
||||
(x as f32 + 0.5) * (in_size / out_size) - 0.5,
|
||||
0.0,
|
||||
in_size - 1.0,
|
||||
)
|
||||
};
|
||||
let v0 = frac.floor();
|
||||
let v1 = frac.ceil();
|
||||
let xw = F::cast_from(frac - v0);
|
||||
let xw_ = Line::empty(line_size).fill(F::new(1.0) - xw);
|
||||
let xw = Line::empty(line_size).fill(xw);
|
||||
let x0_ok = v0 >= 0.0;
|
||||
let x0 = v0 as usize;
|
||||
let x1 = v1 as usize;
|
||||
|
||||
let index_base = b * input.stride(0) + c * input.stride(3);
|
||||
|
||||
let in_stride_y = input.stride(1);
|
||||
let in_stride_x = input.stride(2);
|
||||
|
||||
let y0_stride = y0 * in_stride_y;
|
||||
let y1_stride = y1 * in_stride_y;
|
||||
let x0_stride = x0 * in_stride_x;
|
||||
let x1_stride = x1 * in_stride_x;
|
||||
|
||||
let height = input.shape(1);
|
||||
let width = input.shape(2);
|
||||
|
||||
let y1_ok = y1 < height;
|
||||
let x1_ok = x1 < width;
|
||||
|
||||
let zero = Line::empty(line_size).fill(F::new(0.0));
|
||||
|
||||
let p_a = select(
|
||||
x0_ok && y0_ok,
|
||||
input[(index_base + y0_stride + x0_stride) / line_size] * xw_ * yw_,
|
||||
zero,
|
||||
);
|
||||
let p_b = select(
|
||||
x1_ok && y0_ok,
|
||||
input[(index_base + y0_stride + x1_stride) / line_size] * xw * yw_,
|
||||
zero,
|
||||
);
|
||||
let p_c = select(
|
||||
x0_ok && y1_ok,
|
||||
input[(index_base + y1_stride + x0_stride) / line_size] * xw_ * yw,
|
||||
zero,
|
||||
);
|
||||
let p_d = select(
|
||||
x1_ok && y1_ok,
|
||||
input[(index_base + y1_stride + x1_stride) / line_size] * xw * yw,
|
||||
zero,
|
||||
);
|
||||
|
||||
output[out_idx] = p_a + p_b + p_c + p_d;
|
||||
}
|
||||
|
||||
pub(crate) fn interpolate_bilinear_launch<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
output: CubeTensor<R>,
|
||||
align_corners: bool,
|
||||
) -> CubeTensor<R> {
|
||||
let line_size = max_line_size(&input);
|
||||
let out_shape = shape_divmod(&output);
|
||||
let out_layout = linear_layout(&output, line_size);
|
||||
|
||||
let working_units = output.meta.num_elements() / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&input.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
|
||||
|
||||
interpolate_bilinear_kernel::launch(
|
||||
&input.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(input, output),
|
||||
input.as_tensor_arg(line_size),
|
||||
output.as_tensor_arg(line_size),
|
||||
out_shape,
|
||||
out_layout,
|
||||
align_corners,
|
||||
output.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
mod base;
|
||||
mod bicubic;
|
||||
mod bilinear;
|
||||
mod nearest;
|
||||
mod nearest_backward;
|
||||
|
||||
pub use base::*;
|
||||
@@ -0,0 +1,80 @@
|
||||
use cubecl::std::{
|
||||
FastDivmod,
|
||||
tensor::layout::{linear::LinearLayout, *},
|
||||
};
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*};
|
||||
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{address_type, linear_layout, shape_divmod},
|
||||
ops::max_line_size,
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn interpolate_nearest_kernel<F: Float>(
|
||||
input: &Tensor<Line<F>>,
|
||||
output: &mut Tensor<Line<F>>,
|
||||
shape_out: Sequence<FastDivmod<usize>>,
|
||||
out_layout: LinearLayout,
|
||||
#[define(F)] _dtype: StorageType,
|
||||
) {
|
||||
if ABSOLUTE_POS >= output.len() {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let line_size = input.line_size();
|
||||
let out_idx = out_layout.to_source_pos(ABSOLUTE_POS);
|
||||
|
||||
let out_pos = ABSOLUTE_POS * line_size;
|
||||
|
||||
let (h_in, w_in) = (input.shape(1) as f32, input.shape(2) as f32);
|
||||
let (h_out, w_out) = (output.shape(1) as f32, output.shape(2) as f32);
|
||||
|
||||
let (rem, c) = shape_out[3].div_mod(out_pos);
|
||||
let (rem, x) = shape_out[2].div_mod(rem);
|
||||
let (b, y) = shape_out[1].div_mod(rem);
|
||||
|
||||
let y = y as f32 * (h_in / h_out);
|
||||
let x = x as f32 * (w_in / w_out);
|
||||
|
||||
let in_idx = b * input.stride(0)
|
||||
+ y as usize * input.stride(1)
|
||||
+ x as usize * input.stride(2)
|
||||
+ c * input.stride(3);
|
||||
|
||||
output[out_idx] = input[in_idx / line_size];
|
||||
}
|
||||
|
||||
pub(crate) fn interpolate_nearest_launch<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
output: CubeTensor<R>,
|
||||
) -> CubeTensor<R> {
|
||||
let client = input.client.clone();
|
||||
|
||||
let line_size = max_line_size(&input);
|
||||
|
||||
let working_units = output.meta.num_elements() / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&input.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
|
||||
|
||||
let shape_out = shape_divmod(&output);
|
||||
let out_layout = linear_layout(&output, line_size);
|
||||
|
||||
unsafe {
|
||||
interpolate_nearest_kernel::launch_unchecked(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(input, output),
|
||||
input.as_tensor_arg(line_size),
|
||||
output.as_tensor_arg(line_size),
|
||||
shape_out,
|
||||
out_layout,
|
||||
output.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
};
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
use cubecl::std::{
|
||||
FastDivmod,
|
||||
tensor::layout::{linear::LinearLayout, *},
|
||||
};
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*};
|
||||
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{address_type, linear_layout, shape_divmod},
|
||||
ops::max_line_size,
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn interpolate_nearest_backward_kernel<F: Float>(
|
||||
grad: &Tensor<Line<F>>,
|
||||
output: &mut Tensor<Line<F>>,
|
||||
shape_out: Sequence<FastDivmod<usize>>,
|
||||
out_layout: LinearLayout,
|
||||
#[define(F)] _dtype: StorageType,
|
||||
) {
|
||||
if ABSOLUTE_POS >= output.len() {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let line_size = grad.line_size();
|
||||
let out_idx = out_layout.to_source_pos(ABSOLUTE_POS);
|
||||
|
||||
let out_h = output.shape(1);
|
||||
let out_w = output.shape(2);
|
||||
let grad_h = grad.shape(1);
|
||||
let grad_w = grad.shape(2);
|
||||
|
||||
let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * line_size);
|
||||
let (rem, out_x) = shape_out[2].div_mod(rem);
|
||||
let (b, out_y) = shape_out[1].div_mod(rem);
|
||||
|
||||
let grad_y_start = start_index::<F>(out_y, grad_h, out_h);
|
||||
let grad_y_end = end_index::<F>(out_y, grad_h, out_h);
|
||||
let grad_x_start = start_index::<F>(out_x, grad_w, out_w);
|
||||
let grad_x_end = end_index::<F>(out_x, grad_w, out_w);
|
||||
|
||||
let index_grad_base = b * grad.stride(0) + c * grad.stride(3);
|
||||
|
||||
let mut sum = Line::empty(line_size).fill(F::new(0.0));
|
||||
|
||||
for grad_y in grad_y_start..grad_y_end {
|
||||
for grad_x in grad_x_start..grad_x_end {
|
||||
let index_grad = index_grad_base + grad_y * grad.stride(1) + grad_x * grad.stride(2);
|
||||
|
||||
sum += grad[index_grad];
|
||||
}
|
||||
}
|
||||
|
||||
output[out_idx] = sum;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn start_index<F: Float>(input_index: usize, output_size: usize, input_size: usize) -> usize {
|
||||
let numerator = F::cast_from(input_index * output_size);
|
||||
let div = (numerator / F::cast_from(input_size)).ceil();
|
||||
|
||||
usize::cast_from(div)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn end_index<F: Float>(input_index: usize, output_size: usize, input_size: usize) -> usize {
|
||||
let numerator = F::cast_from((input_index + 1) * output_size);
|
||||
let div = (numerator / F::cast_from(input_size)).ceil();
|
||||
let index = usize::cast_from(div);
|
||||
|
||||
clamp_max(index, output_size)
|
||||
}
|
||||
|
||||
pub(crate) fn interpolate_nearest_backward_launch<R: CubeRuntime>(
|
||||
out_grad: CubeTensor<R>,
|
||||
output: CubeTensor<R>,
|
||||
) -> CubeTensor<R> {
|
||||
let line_size = max_line_size(&out_grad);
|
||||
let out_shape = shape_divmod(&output);
|
||||
let out_layout = linear_layout(&output, line_size);
|
||||
|
||||
let working_units = output.meta.num_elements() / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&out_grad.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&out_grad.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
interpolate_nearest_backward_kernel::launch_unchecked(
|
||||
&out_grad.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(out_grad, output),
|
||||
out_grad.as_tensor_arg(line_size),
|
||||
output.as_tensor_arg(line_size),
|
||||
out_shape,
|
||||
out_layout,
|
||||
output.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
};
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
use burn_backend::DType;
|
||||
use cubecl::prelude::InputScalar;
|
||||
|
||||
use super::{MaskFillStrategy, mask_where::MaskWhereStrategy};
|
||||
use crate::{CubeRuntime, tensor::CubeTensor};
|
||||
|
||||
/// Execute the mask fill kernel.
|
||||
pub(crate) fn mask_fill_auto<R: CubeRuntime>(
|
||||
tensor: CubeTensor<R>,
|
||||
mask: CubeTensor<R>,
|
||||
value: InputScalar,
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
let strategy = if tensor.can_mut() && tensor.is_nonoverlapping() {
|
||||
MaskFillStrategy::Inplace
|
||||
} else {
|
||||
MaskFillStrategy::Readonly
|
||||
};
|
||||
|
||||
super::mask_fill(tensor, mask, value, strategy, dtype_bool)
|
||||
}
|
||||
|
||||
/// Execute the mask where kernel.
|
||||
pub(crate) fn mask_where_auto<R: CubeRuntime>(
|
||||
tensor: CubeTensor<R>,
|
||||
mask: CubeTensor<R>,
|
||||
value: CubeTensor<R>,
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
let strategy = if tensor.can_mut_broadcast(&value) {
|
||||
MaskWhereStrategy::InplaceLhs
|
||||
} else if value.can_mut_broadcast(&tensor) {
|
||||
MaskWhereStrategy::InplaceRhs
|
||||
} else {
|
||||
MaskWhereStrategy::Readonly
|
||||
};
|
||||
|
||||
super::mask_where(tensor, mask, value, strategy, dtype_bool)
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
use burn_backend::{DType, TensorMetadata};
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
|
||||
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{address_type, linear_view, linear_view_alias, linear_view_ref},
|
||||
ops::{max_line_size_many, numeric::empty_device_dtype},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn mask_fill_kernel<T: Numeric, B: Int>(
|
||||
input: &LinearView<Line<T>>,
|
||||
mask: &LinearView<Line<B>>,
|
||||
output: &mut LinearView<Line<T>, ReadWrite>,
|
||||
value: InputScalar,
|
||||
#[define(T, B)] _dtypes: [StorageType; 2],
|
||||
) {
|
||||
if !output.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let mask = Line::cast_from(mask[ABSOLUTE_POS]);
|
||||
let input = input[ABSOLUTE_POS];
|
||||
let value = Line::new(value.get::<T>());
|
||||
|
||||
output[ABSOLUTE_POS] = select_many(mask, value, input);
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
/// Define how to run the mask fill kernel.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// All assertions should be done before choosing the strategy.
|
||||
pub enum MaskFillStrategy {
|
||||
/// Don't mutate any input.
|
||||
Readonly,
|
||||
/// Reuse the input tensor inplace.
|
||||
Inplace,
|
||||
}
|
||||
|
||||
/// Execute the mask fill kernel with the given strategy.
|
||||
pub fn mask_fill<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
mask: CubeTensor<R>,
|
||||
value: InputScalar,
|
||||
strategy: MaskFillStrategy,
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
let ndims = input.meta.num_dims();
|
||||
let output = match strategy {
|
||||
MaskFillStrategy::Readonly => empty_device_dtype(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
input.shape(),
|
||||
input.dtype,
|
||||
),
|
||||
MaskFillStrategy::Inplace => input.clone(),
|
||||
};
|
||||
|
||||
let line_size = max_line_size_many(&[&input, &mask], ndims - 1);
|
||||
let working_units = input.meta.num_elements() / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&input.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
|
||||
|
||||
let out_arg = match strategy {
|
||||
MaskFillStrategy::Readonly => linear_view(&output, line_size),
|
||||
MaskFillStrategy::Inplace => linear_view_alias(&output, line_size, 0),
|
||||
};
|
||||
|
||||
unsafe {
|
||||
mask_fill_kernel::launch_unchecked(
|
||||
&input.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(input, mask, output),
|
||||
linear_view(&input, line_size),
|
||||
linear_view_ref(&mask, &input, line_size),
|
||||
out_arg,
|
||||
value,
|
||||
[output.dtype.into(), dtype_bool.into()],
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
use burn_backend::DType;
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
|
||||
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{
|
||||
address_type, broadcast_shape, linear_view, linear_view_alias, linear_view_ref,
|
||||
},
|
||||
ops::{max_line_size_many, numeric::empty_device_dtype},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
|
||||
#[cube(launch, address_type = "dynamic")]
|
||||
fn mask_where_kernel<T: Numeric, B: Int>(
|
||||
input: &LinearView<Line<T>>,
|
||||
value: &LinearView<Line<T>>,
|
||||
mask: &LinearView<Line<B>>,
|
||||
output: &mut LinearView<Line<T>, ReadWrite>,
|
||||
#[define(T, B)] _dtypes: [StorageType; 2],
|
||||
) {
|
||||
let pos = ABSOLUTE_POS;
|
||||
if !output.is_in_bounds(pos) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
output[pos] = select_many(Line::cast_from(mask[pos]), value[pos], input[pos]);
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
/// Define how to run the mask where kernel.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// All assertions should be done before choosing the strategy.
|
||||
pub enum MaskWhereStrategy {
|
||||
/// Don't mutate any input.
|
||||
Readonly,
|
||||
/// Reuse the lhs tensor inplace.
|
||||
InplaceLhs,
|
||||
/// Reuse the rhs tensor inplace.
|
||||
InplaceRhs,
|
||||
}
|
||||
|
||||
/// Execute the mask where kernel with the given strategy.
|
||||
pub fn mask_where<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
mask: CubeTensor<R>,
|
||||
value: CubeTensor<R>,
|
||||
strategy: MaskWhereStrategy,
|
||||
dtype_bool: DType,
|
||||
) -> CubeTensor<R> {
|
||||
let line_size = max_line_size_many(&[&input, &mask, &value], input.meta.num_dims() - 1);
|
||||
|
||||
let working_units = input.meta.num_elements() / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&input.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
|
||||
|
||||
let out_shape = broadcast_shape(&[&input, &mask, &value]);
|
||||
|
||||
let output = match strategy {
|
||||
MaskWhereStrategy::Readonly => empty_device_dtype(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
out_shape,
|
||||
input.dtype,
|
||||
),
|
||||
MaskWhereStrategy::InplaceLhs => input.clone(),
|
||||
MaskWhereStrategy::InplaceRhs => value.clone(),
|
||||
};
|
||||
|
||||
let out = match strategy {
|
||||
MaskWhereStrategy::Readonly => linear_view(&output, line_size),
|
||||
MaskWhereStrategy::InplaceLhs => linear_view_alias(&output, line_size, 0),
|
||||
MaskWhereStrategy::InplaceRhs => linear_view_alias(&output, line_size, 1),
|
||||
};
|
||||
|
||||
mask_where_kernel::launch(
|
||||
&input.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(input, value, mask, output),
|
||||
linear_view_ref(&input, &output, line_size),
|
||||
linear_view_ref(&value, &output, line_size),
|
||||
linear_view_ref(&mask, &output, line_size),
|
||||
out,
|
||||
[output.dtype.into(), dtype_bool.into()],
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
mod base;
|
||||
mod mask_fill;
|
||||
mod mask_where;
|
||||
|
||||
pub(crate) use base::*;
|
||||
|
||||
pub use mask_fill::*;
|
||||
pub use mask_where::*;
|
||||
@@ -0,0 +1,155 @@
|
||||
use super::init_matmul_output;
|
||||
use crate::{CubeRuntime, kernel::quantization::dequantize, tensor::CubeTensor};
|
||||
use burn_backend::{DType, QTensorPrimitive};
|
||||
use burn_std::QuantLevel;
|
||||
use cubek::matmul::{
|
||||
definition::{MatmulElems, MatmulGlobalElems, MatmulSetupError},
|
||||
launch::{MatmulInputHandleRef, Strategy},
|
||||
};
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
use super::matmul_autotune;
|
||||
|
||||
/// The strategy to be used when launching a matmul kernel.
|
||||
pub enum MatmulStrategy {
|
||||
#[cfg(feature = "autotune")]
|
||||
/// Using autotune to choose the best kernel based on runtime information.
|
||||
Autotune,
|
||||
/// Cube implementation of matmul.
|
||||
Cube,
|
||||
}
|
||||
|
||||
impl Default for MatmulStrategy {
|
||||
fn default() -> Self {
|
||||
// if autotune is enabled, default to autotune
|
||||
#[cfg(feature = "autotune")]
|
||||
return MatmulStrategy::Autotune;
|
||||
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
MatmulStrategy::Cube
|
||||
}
|
||||
}
|
||||
|
||||
/// Launch a matmul kernel using the given strategy.
|
||||
pub fn matmul<R: CubeRuntime>(
|
||||
lhs: CubeTensor<R>,
|
||||
rhs: CubeTensor<R>,
|
||||
out: Option<CubeTensor<R>>,
|
||||
strategy: MatmulStrategy,
|
||||
out_dtype: DType,
|
||||
) -> Result<CubeTensor<R>, MatmulSetupError> {
|
||||
match strategy {
|
||||
MatmulStrategy::Cube => {
|
||||
let out = out.unwrap_or_else(|| init_matmul_output(&lhs, &rhs, out_dtype));
|
||||
launch_matmul(&Default::default(), lhs, rhs, out.clone())?;
|
||||
Ok(out)
|
||||
}
|
||||
#[cfg(feature = "autotune")]
|
||||
MatmulStrategy::Autotune => Ok(matmul_autotune(lhs, rhs, out, out_dtype)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn launch_matmul_naive<R: CubeRuntime>(
|
||||
strategy: &Strategy,
|
||||
mut lhs: CubeTensor<R>,
|
||||
mut rhs: CubeTensor<R>,
|
||||
out: CubeTensor<R>,
|
||||
) -> Result<(), MatmulSetupError> {
|
||||
// Naive has very specific layout requirements for block scaled tensors, so we need to manually
|
||||
// dequantize if it fails to launch normally. This is because naive is assumed to always work.
|
||||
if lhs.qparams.is_some() || rhs.qparams.is_some() {
|
||||
match launch_matmul(strategy, lhs.clone(), rhs.clone(), out.clone()) {
|
||||
Err(_) => {
|
||||
if lhs.qparams.is_some() {
|
||||
lhs = dequantize(lhs, out.dtype);
|
||||
}
|
||||
if rhs.qparams.is_some() {
|
||||
rhs = dequantize(rhs, out.dtype);
|
||||
}
|
||||
launch_matmul(strategy, lhs, rhs, out)
|
||||
}
|
||||
Ok(_) => Ok(()),
|
||||
}
|
||||
} else {
|
||||
launch_matmul(strategy, lhs, rhs, out)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn launch_matmul<R: CubeRuntime>(
|
||||
strategy: &Strategy,
|
||||
lhs: CubeTensor<R>,
|
||||
mut rhs: CubeTensor<R>,
|
||||
out: CubeTensor<R>,
|
||||
) -> Result<(), MatmulSetupError> {
|
||||
let client = &lhs.client;
|
||||
|
||||
let lhs_quant_handles = lhs.quantized_handles();
|
||||
let out_dtype: DType = out.dtype;
|
||||
|
||||
let (lhs_dtype, lhs_handle) = match &lhs_quant_handles {
|
||||
None => (
|
||||
lhs.dtype,
|
||||
MatmulInputHandleRef::new(lhs.as_handle_ref(), lhs.dtype.into()),
|
||||
),
|
||||
Some((data, scale)) => (
|
||||
out_dtype,
|
||||
MatmulInputHandleRef::quantized(
|
||||
data.as_handle_ref(),
|
||||
scale.as_handle_ref(),
|
||||
lhs.meta.shape(),
|
||||
lhs.scheme(),
|
||||
data.dtype.into(),
|
||||
scale.dtype.into(),
|
||||
),
|
||||
),
|
||||
};
|
||||
|
||||
let rhs_quant_handles = rhs.quantized_handles();
|
||||
|
||||
let (rhs_dtype, rhs_handle) = match &rhs_quant_handles {
|
||||
None => (
|
||||
lhs.dtype,
|
||||
MatmulInputHandleRef::new(rhs.as_handle_ref(), lhs.dtype.into()),
|
||||
),
|
||||
Some((data, scale)) => {
|
||||
// Extremely hacky fix to ensure naive can run in every case
|
||||
if matches!(strategy, Strategy::Naive)
|
||||
&& matches!(rhs.scheme().level, QuantLevel::Block(_))
|
||||
{
|
||||
rhs = dequantize(rhs.clone(), lhs.dtype);
|
||||
(
|
||||
lhs.dtype,
|
||||
MatmulInputHandleRef::new(rhs.as_handle_ref(), rhs.dtype.into()),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
out_dtype,
|
||||
MatmulInputHandleRef::quantized(
|
||||
data.as_handle_ref(),
|
||||
scale.as_handle_ref(),
|
||||
rhs.meta.shape(),
|
||||
rhs.scheme(),
|
||||
data.dtype.into(),
|
||||
scale.dtype.into(),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let mut dtypes = MatmulElems::from_globals(&MatmulGlobalElems {
|
||||
lhs: lhs_dtype.into(),
|
||||
rhs: rhs_dtype.into(),
|
||||
out: out_dtype.into(),
|
||||
});
|
||||
cubek::matmul::launch::launch_ref(
|
||||
strategy,
|
||||
client,
|
||||
&lhs_handle,
|
||||
&rhs_handle,
|
||||
&out.as_handle_ref(),
|
||||
&mut dtypes,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
mod base;
|
||||
mod tune;
|
||||
|
||||
/// Contains utilities for matmul operation
|
||||
pub mod utils;
|
||||
|
||||
pub use base::*;
|
||||
#[cfg(feature = "autotune")]
|
||||
pub use tune::*;
|
||||
pub use utils::*;
|
||||
@@ -0,0 +1,409 @@
|
||||
use crate::{
|
||||
CubeRuntime, CubeTuneId,
|
||||
kernel::matmul::{launch_matmul, launch_matmul_naive, utils::init_matmul_output},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::DType;
|
||||
use cubecl::tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner};
|
||||
use cubek::matmul::{
|
||||
definition::MatmulKind,
|
||||
launch::{MatmulAutotuneKey, MatmulGlobalScale, Strategy, should_tune_double_buffering},
|
||||
routines::{
|
||||
BlueprintStrategy, TileSizeSelection, double_buffering::DoubleBufferingArgs,
|
||||
double_unit::DoubleUnitSelectionArgs, ordered_double_buffering::OrderedSelectionArgs,
|
||||
simple::SimpleArgs, simple_unit::SimpleUnitSelectionArgs,
|
||||
},
|
||||
};
|
||||
|
||||
fn matmul_input_gen<R: CubeRuntime>(
|
||||
_key: &MatmulAutotuneKey,
|
||||
lhs: &CubeTensor<R>,
|
||||
rhs: &CubeTensor<R>,
|
||||
out: &CubeTensor<R>,
|
||||
) -> (CubeTensor<R>, CubeTensor<R>, CubeTensor<R>) {
|
||||
(lhs.clone(), rhs.clone(), out.copy())
|
||||
}
|
||||
|
||||
/// Executes autotune on matmul operations
|
||||
pub fn matmul_autotune<R: CubeRuntime>(
|
||||
lhs: CubeTensor<R>,
|
||||
rhs: CubeTensor<R>,
|
||||
out: Option<CubeTensor<R>>,
|
||||
out_dtype: DType,
|
||||
) -> CubeTensor<R> {
|
||||
let output = out.unwrap_or_else(|| init_matmul_output(&lhs, &rhs, out_dtype));
|
||||
|
||||
let client = lhs.client.clone();
|
||||
|
||||
static TUNER: LocalTuner<MatmulAutotuneKey, CubeTuneId> = local_tuner!();
|
||||
|
||||
let tunables = TUNER.init(|| {
|
||||
const PRIORITY_MAX: i8 = 3;
|
||||
const PRIORITY_HIGH: i8 = 2;
|
||||
const PRIORITY_MEDIUM: i8 = 1;
|
||||
const PRIORITY_MIN: i8 = 0;
|
||||
const PRIORITY_NEVER: i8 = -1;
|
||||
|
||||
let cmma = TuneGroup::<MatmulAutotuneKey>::new("cmma", |key| {
|
||||
if matches!(
|
||||
key.analysis.kind,
|
||||
MatmulKind::General
|
||||
// Those variants are just because the unit alternatives aren't very good yet.
|
||||
| MatmulKind::VecMat | MatmulKind::MatVec
|
||||
) {
|
||||
PRIORITY_HIGH
|
||||
} else {
|
||||
PRIORITY_MEDIUM
|
||||
}
|
||||
});
|
||||
|
||||
let mma = TuneGroup::<MatmulAutotuneKey>::new("mma", |key| {
|
||||
if matches!(
|
||||
key.analysis.kind,
|
||||
// General is usually bad, but I think shapes like 16x8196 would be classed as
|
||||
// general and are very good with MMA
|
||||
// Should highly degenerated matrices that aren't VecMat have their own class?
|
||||
MatmulKind::General | MatmulKind::VecMat | MatmulKind::MatVec
|
||||
) {
|
||||
PRIORITY_HIGH
|
||||
} else {
|
||||
PRIORITY_MEDIUM
|
||||
}
|
||||
});
|
||||
|
||||
let unit = TuneGroup::<MatmulAutotuneKey>::new("unit", |key| {
|
||||
if !matches!(key.analysis.kind, MatmulKind::General)
|
||||
|| matches!(key.analysis.scale_global, MatmulGlobalScale::Small)
|
||||
{
|
||||
PRIORITY_HIGH
|
||||
} else {
|
||||
PRIORITY_MIN
|
||||
}
|
||||
});
|
||||
|
||||
let tma = TuneGroup::<MatmulAutotuneKey>::new("tma", |key| {
|
||||
// For large matmul, we set the max priority to TMA kernels, higher than any other
|
||||
// matmuls, since they are the best kernels no matter what.
|
||||
//
|
||||
// But only when all axis are large.
|
||||
let max_axis = usize::max(key.definition.m, key.definition.n);
|
||||
let max_axis = usize::max(key.definition.k, max_axis);
|
||||
|
||||
let min_axis = usize::min(key.definition.m, key.definition.n);
|
||||
let min_axis = usize::min(key.definition.k, min_axis);
|
||||
|
||||
let skewed_factor = max_axis / min_axis;
|
||||
|
||||
let priority_max = if matches!(key.analysis.kind, MatmulKind::General)
|
||||
&& matches!(key.analysis.scale_global, MatmulGlobalScale::Large)
|
||||
&& skewed_factor < 4
|
||||
{
|
||||
PRIORITY_MAX
|
||||
} else {
|
||||
PRIORITY_HIGH
|
||||
};
|
||||
|
||||
if key.definition.lhs_stride_factor >= 4 && key.definition.rhs_stride_factor >= 4 {
|
||||
priority_max
|
||||
} else {
|
||||
PRIORITY_NEVER
|
||||
}
|
||||
});
|
||||
|
||||
fn double_buffering_priority(key: &MatmulAutotuneKey, max: i8, min: i8) -> i8 {
|
||||
if should_tune_double_buffering(false, key) {
|
||||
max
|
||||
} else {
|
||||
min
|
||||
}
|
||||
}
|
||||
|
||||
let mut set = TunableSet::new(create_key::<R>, matmul_input_gen::<R>);
|
||||
|
||||
// First entry should always work, since it is considered the fallback.
|
||||
set = set.with(
|
||||
Tunable::new("matmul_naive", |lhs, rhs, out| {
|
||||
launch_matmul_naive::<R>(&Strategy::Naive, lhs, rhs, out)
|
||||
.map_err(|err| std::format!("{err:?}"))
|
||||
})
|
||||
.group(&unit, |key| {
|
||||
if matches!(key.analysis.scale_global, MatmulGlobalScale::Small)
|
||||
|| matches!(key.analysis.kind, MatmulKind::InnerProduct)
|
||||
{
|
||||
PRIORITY_MAX
|
||||
} else {
|
||||
PRIORITY_MIN
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// Unit VecMat
|
||||
for (strategy, double_buf) in [
|
||||
(
|
||||
Strategy::SimpleVecMat(BlueprintStrategy::Inferred(().into())),
|
||||
false,
|
||||
),
|
||||
(
|
||||
Strategy::DoubleVecMat(BlueprintStrategy::Inferred(().into())),
|
||||
true,
|
||||
),
|
||||
] {
|
||||
set = set.with(
|
||||
Tunable::new(strategy.to_string(), move |lhs, rhs, out| {
|
||||
launch_matmul::<R>(&strategy, lhs, rhs, out)
|
||||
.map_err(|err| std::format!("{err:?}"))
|
||||
})
|
||||
.group(&unit, move |key| match double_buf {
|
||||
false => PRIORITY_MAX,
|
||||
true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH),
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
// Unit matmuls
|
||||
for tile_size in [
|
||||
TileSizeSelection::MaxTileSize,
|
||||
TileSizeSelection::MinTileSize,
|
||||
] {
|
||||
for (strategy, double_buf) in [
|
||||
(
|
||||
Strategy::SimpleUnit(BlueprintStrategy::Inferred(SimpleUnitSelectionArgs {
|
||||
tile_size,
|
||||
})),
|
||||
false,
|
||||
),
|
||||
(
|
||||
Strategy::DoubleUnit(BlueprintStrategy::Inferred(DoubleUnitSelectionArgs {
|
||||
tile_size,
|
||||
})),
|
||||
true,
|
||||
),
|
||||
] {
|
||||
set = set.with(
|
||||
Tunable::new(strategy.to_string(), move |lhs, rhs, out| {
|
||||
launch_matmul::<R>(&strategy, lhs, rhs, out)
|
||||
.map_err(|err| format!("{err:?}"))
|
||||
})
|
||||
.group(&unit, move |key| match double_buf {
|
||||
false => PRIORITY_MAX,
|
||||
true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH),
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Accelerated matmuls
|
||||
for (strategy, double_buf, group_extra, tile_group) in [
|
||||
(
|
||||
Strategy::SimpleCyclicCmma(BlueprintStrategy::Inferred(SimpleArgs {
|
||||
multi_rows: false,
|
||||
})),
|
||||
false,
|
||||
None,
|
||||
&cmma,
|
||||
),
|
||||
(
|
||||
Strategy::SimpleCyclicMma(BlueprintStrategy::Inferred(SimpleArgs {
|
||||
multi_rows: false,
|
||||
})),
|
||||
false,
|
||||
None,
|
||||
&mma,
|
||||
),
|
||||
(
|
||||
Strategy::SimpleCyclicCmma(BlueprintStrategy::Inferred(SimpleArgs {
|
||||
multi_rows: true,
|
||||
})),
|
||||
false,
|
||||
None,
|
||||
&cmma,
|
||||
),
|
||||
(
|
||||
Strategy::SimpleCyclicMma(BlueprintStrategy::Inferred(SimpleArgs {
|
||||
multi_rows: true,
|
||||
})),
|
||||
false,
|
||||
None,
|
||||
&mma,
|
||||
),
|
||||
(
|
||||
Strategy::OrderedDoubleCmma(BlueprintStrategy::Inferred(OrderedSelectionArgs {
|
||||
partition_k: Some(2),
|
||||
row_count: Some(4),
|
||||
rows_per_plane: Some(2),
|
||||
})),
|
||||
true,
|
||||
None,
|
||||
&cmma,
|
||||
),
|
||||
(
|
||||
Strategy::OrderedDoubleMma(BlueprintStrategy::Inferred(OrderedSelectionArgs {
|
||||
partition_k: Some(2),
|
||||
row_count: Some(4),
|
||||
rows_per_plane: Some(2),
|
||||
})),
|
||||
true,
|
||||
None,
|
||||
&mma,
|
||||
),
|
||||
(
|
||||
Strategy::OrderedDoubleCmma(BlueprintStrategy::Inferred(OrderedSelectionArgs {
|
||||
partition_k: Some(2),
|
||||
row_count: Some(8),
|
||||
rows_per_plane: Some(2),
|
||||
})),
|
||||
true,
|
||||
None,
|
||||
&cmma,
|
||||
),
|
||||
(
|
||||
Strategy::OrderedDoubleMma(BlueprintStrategy::Inferred(OrderedSelectionArgs {
|
||||
partition_k: Some(2),
|
||||
row_count: Some(8),
|
||||
rows_per_plane: Some(2),
|
||||
})),
|
||||
true,
|
||||
None,
|
||||
&mma,
|
||||
),
|
||||
(
|
||||
Strategy::DoubleCyclicCmma(BlueprintStrategy::Inferred(DoubleBufferingArgs {
|
||||
specialized: false,
|
||||
})),
|
||||
true,
|
||||
None,
|
||||
&cmma,
|
||||
),
|
||||
(
|
||||
Strategy::DoubleCyclicMma(BlueprintStrategy::Inferred(DoubleBufferingArgs {
|
||||
specialized: false,
|
||||
})),
|
||||
true,
|
||||
None,
|
||||
&mma,
|
||||
),
|
||||
(
|
||||
Strategy::DoubleCyclicCmma(BlueprintStrategy::Inferred(DoubleBufferingArgs {
|
||||
specialized: true,
|
||||
})),
|
||||
true,
|
||||
None,
|
||||
&cmma,
|
||||
),
|
||||
(
|
||||
Strategy::DoubleCyclicMma(BlueprintStrategy::Inferred(DoubleBufferingArgs {
|
||||
specialized: true,
|
||||
})),
|
||||
true,
|
||||
None,
|
||||
&mma,
|
||||
),
|
||||
(
|
||||
Strategy::SpecializedCyclicCmma(BlueprintStrategy::Inferred(().into())),
|
||||
true,
|
||||
None,
|
||||
&cmma,
|
||||
),
|
||||
(
|
||||
Strategy::SpecializedCyclicMma(BlueprintStrategy::Inferred(().into())),
|
||||
true,
|
||||
None,
|
||||
&mma,
|
||||
),
|
||||
(
|
||||
Strategy::SimpleTmaCmma(BlueprintStrategy::Inferred(SimpleArgs {
|
||||
multi_rows: false,
|
||||
})),
|
||||
false,
|
||||
Some(&tma),
|
||||
&cmma,
|
||||
),
|
||||
(
|
||||
Strategy::SimpleTmaMma(BlueprintStrategy::Inferred(SimpleArgs {
|
||||
multi_rows: false,
|
||||
})),
|
||||
false,
|
||||
Some(&tma),
|
||||
&mma,
|
||||
),
|
||||
(
|
||||
Strategy::SimpleTmaCmma(BlueprintStrategy::Inferred(SimpleArgs {
|
||||
multi_rows: true,
|
||||
})),
|
||||
false,
|
||||
Some(&tma),
|
||||
&cmma,
|
||||
),
|
||||
(
|
||||
Strategy::SimpleTmaMma(BlueprintStrategy::Inferred(SimpleArgs {
|
||||
multi_rows: true,
|
||||
})),
|
||||
false,
|
||||
Some(&tma),
|
||||
&mma,
|
||||
),
|
||||
(
|
||||
Strategy::SpecializedTmaCmma(BlueprintStrategy::Inferred(().into())),
|
||||
true,
|
||||
Some(&tma),
|
||||
&cmma,
|
||||
),
|
||||
(
|
||||
Strategy::SpecializedTmaMma(BlueprintStrategy::Inferred(().into())),
|
||||
true,
|
||||
Some(&tma),
|
||||
&mma,
|
||||
),
|
||||
] {
|
||||
let priority_within_group = |key: &MatmulAutotuneKey, double_buf: bool| match double_buf
|
||||
{
|
||||
false => PRIORITY_MAX,
|
||||
true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH),
|
||||
};
|
||||
let mut tunable = Tunable::new(strategy.to_string(), move |lhs, rhs, out| {
|
||||
launch_matmul::<R>(&strategy, lhs, rhs, out).map_err(|err| format!("{err:?}"))
|
||||
});
|
||||
|
||||
// tile group
|
||||
tunable = tunable.group(tile_group, move |key| {
|
||||
priority_within_group(key, double_buf)
|
||||
});
|
||||
|
||||
// extra group
|
||||
if let Some(group) = group_extra {
|
||||
tunable = tunable.group(group, move |key| priority_within_group(key, double_buf));
|
||||
}
|
||||
set = set.with(tunable);
|
||||
}
|
||||
|
||||
set
|
||||
});
|
||||
|
||||
TUNER.execute(
|
||||
&CubeTuneId::new(&lhs.client, &lhs.device),
|
||||
&client,
|
||||
tunables,
|
||||
(lhs, rhs, output.clone()),
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn create_key<R: CubeRuntime>(
|
||||
lhs: &CubeTensor<R>,
|
||||
rhs: &CubeTensor<R>,
|
||||
out: &CubeTensor<R>,
|
||||
) -> MatmulAutotuneKey {
|
||||
MatmulAutotuneKey::generate(
|
||||
&lhs.client,
|
||||
lhs.meta.shape(),
|
||||
rhs.meta.shape(),
|
||||
lhs.meta.strides(),
|
||||
rhs.meta.strides(),
|
||||
lhs.dtype.into(),
|
||||
rhs.dtype.into(),
|
||||
out.dtype.into(),
|
||||
lhs.try_scheme(),
|
||||
rhs.try_scheme(),
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
#[cfg(feature = "autotune")]
|
||||
mod base;
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
pub use base::matmul_autotune;
|
||||
@@ -0,0 +1,16 @@
|
||||
use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
|
||||
use burn_backend::{DType, calculate_matmul_output};
|
||||
|
||||
/// Creates an empty output tensor with matmul output shape
|
||||
pub fn init_matmul_output<R: CubeRuntime>(
|
||||
lhs: &CubeTensor<R>,
|
||||
rhs: &CubeTensor<R>,
|
||||
dtype: DType,
|
||||
) -> CubeTensor<R> {
|
||||
empty_device_dtype(
|
||||
lhs.client.clone(),
|
||||
lhs.device.clone(),
|
||||
calculate_matmul_output(lhs.meta.shape(), rhs.meta.shape()).unwrap(),
|
||||
dtype,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
mod binary;
|
||||
mod binary_float;
|
||||
mod binary_int;
|
||||
mod cast;
|
||||
mod clamp;
|
||||
mod comparison;
|
||||
mod contiguous;
|
||||
mod cross;
|
||||
mod index;
|
||||
mod mask;
|
||||
mod unary_float;
|
||||
mod unary_int;
|
||||
mod unary_numeric;
|
||||
|
||||
pub(crate) use binary::*;
|
||||
pub(crate) use binary_float::*;
|
||||
pub(crate) use binary_int::*;
|
||||
pub use cast::*;
|
||||
pub use contiguous::*;
|
||||
pub(crate) use cross::*;
|
||||
pub use mask::*;
|
||||
pub(crate) use unary_float::*;
|
||||
pub(crate) use unary_int::*;
|
||||
pub(crate) use unary_numeric::*;
|
||||
|
||||
pub use crate::cubecl::prelude::KernelMetadata;
|
||||
|
||||
/// Attention kernels
|
||||
pub mod attention;
|
||||
/// Convolution kernels
|
||||
pub mod conv;
|
||||
/// Grid sampling kernels
|
||||
pub mod grid_sample;
|
||||
/// Interpolation kernels
|
||||
pub mod interpolate;
|
||||
/// Matmul kernels
|
||||
pub mod matmul;
|
||||
/// Pooling kernels
|
||||
pub mod pool;
|
||||
/// Pseudo-random number generator kernels
|
||||
pub mod prng;
|
||||
/// Quantization operations
|
||||
pub mod quantization;
|
||||
/// Reduction algorithms
|
||||
pub mod reduce;
|
||||
|
||||
pub(crate) use clamp::*;
|
||||
pub(crate) use comparison::*;
|
||||
pub use index::*;
|
||||
|
||||
pub(crate) mod utils;
|
||||
@@ -0,0 +1,117 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::{
|
||||
into_contiguous_aligned,
|
||||
pool::pool2d::{Position, view4d},
|
||||
utils::{address_type, decompose_linear, shape_divmod},
|
||||
},
|
||||
ops::{max_line_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::Shape;
|
||||
use cubecl::{
|
||||
calculate_cube_count_elemwise,
|
||||
prelude::*,
|
||||
std::{FastDivmod, tensor::View},
|
||||
};
|
||||
|
||||
#[cube(launch, address_type = "dynamic")]
|
||||
fn adaptive_avg_pool2d_direct<E: Numeric>(
|
||||
input: &Tensor<Line<E>>,
|
||||
output: &mut View<Line<E>, Position, ReadWrite>,
|
||||
out_shape: Sequence<FastDivmod<usize>>,
|
||||
working_units: usize,
|
||||
#[define(E)] _dtype: StorageType,
|
||||
) {
|
||||
if ABSOLUTE_POS >= working_units {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let (_, pos) = decompose_linear(ABSOLUTE_POS * output.line_size(), &out_shape);
|
||||
let [b, oh, ow, c] = *pos else { unreachable!() };
|
||||
|
||||
let (_, out_h, out_w, _) = output.shape();
|
||||
let (in_stride_h, in_stride_w) = (input.stride(1), input.stride(2));
|
||||
let (in_h, in_w) = (input.shape(1), input.shape(2));
|
||||
|
||||
let ih_start = start_index(oh, out_h, in_h);
|
||||
let ih_end = end_index(oh, out_h, in_h);
|
||||
|
||||
let iw_start = start_index(ow, out_w, in_w);
|
||||
let iw_end = end_index(ow, out_w, in_w);
|
||||
|
||||
let mut sum = Line::empty(input.line_size()).fill(E::from_int(0));
|
||||
|
||||
let index_input_base = b * input.stride(0) + c * input.stride(3);
|
||||
|
||||
for ih in ih_start..ih_end {
|
||||
let index_input_2 = ih * in_stride_h;
|
||||
|
||||
for iw in iw_start..iw_end {
|
||||
let index_input_3 = iw * in_stride_w;
|
||||
|
||||
let index_input = index_input_base + index_input_2 + index_input_3;
|
||||
sum += input[index_input / input.line_size()];
|
||||
}
|
||||
}
|
||||
|
||||
let num_ih = ih_end - ih_start;
|
||||
let num_iw = iw_end - iw_start;
|
||||
|
||||
output[(b, oh, ow, c)] = sum / Line::cast_from(num_ih * num_iw);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
|
||||
(output_size_index * input_size) / output_size
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
|
||||
let index = (output_size_index + 1) * input_size;
|
||||
let index = index.div_ceil(output_size);
|
||||
|
||||
if input_size < index {
|
||||
input_size
|
||||
} else {
|
||||
index
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn adaptive_avg_pool2d<R: CubeRuntime>(
|
||||
input: CubeTensor<R>,
|
||||
output_size: [usize; 2],
|
||||
) -> CubeTensor<R> {
|
||||
let [batch_size, channels, _, _] = input.meta.shape().dims();
|
||||
|
||||
let input = into_contiguous_aligned(permute_nchw_to_nhwc(input));
|
||||
let line_size = max_line_size(&input);
|
||||
|
||||
let output_shape = Shape::new([batch_size, output_size[0], output_size[1], channels]);
|
||||
let num_elems: usize = output_shape.num_elements();
|
||||
let output = empty_device_dtype(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
output_shape,
|
||||
input.dtype,
|
||||
);
|
||||
|
||||
let working_units = num_elems / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&input.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);
|
||||
|
||||
adaptive_avg_pool2d_direct::launch(
|
||||
&input.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(input, output),
|
||||
input.as_tensor_arg(line_size),
|
||||
view4d(&output, line_size),
|
||||
shape_divmod(&output),
|
||||
ScalarArg::new(working_units),
|
||||
output.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
permute_nhwc_to_nchw(output)
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::{
|
||||
into_contiguous_aligned,
|
||||
pool::pool2d::{Position, view4d},
|
||||
utils::{address_type, decompose_linear, shape_divmod},
|
||||
},
|
||||
ops::{max_line_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::Shape;
|
||||
use cubecl::{
|
||||
calculate_cube_count_elemwise,
|
||||
prelude::*,
|
||||
std::{FastDivmod, tensor::View},
|
||||
};
|
||||
|
||||
#[cube(launch, address_type = "dynamic")]
|
||||
fn adaptive_avg_pool2d_backward_direct<E: Numeric>(
|
||||
grad: &Tensor<Line<E>>,
|
||||
output: &mut View<Line<E>, Position, ReadWrite>,
|
||||
out_shape: Sequence<FastDivmod<usize>>,
|
||||
working_units: usize,
|
||||
#[define(E)] _dtype: StorageType,
|
||||
) {
|
||||
if ABSOLUTE_POS >= working_units {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let (_, out_h, out_w, _) = output.shape();
|
||||
let (grad_stride_h, grad_stride_w) = (grad.stride(1), grad.stride(2));
|
||||
let (grad_h, grad_w) = (grad.shape(1), grad.shape(2));
|
||||
|
||||
let (_, pos) = decompose_linear(ABSOLUTE_POS * output.line_size(), &out_shape);
|
||||
let [b, ih, iw, c] = *pos else { unreachable!() };
|
||||
|
||||
let oh_start = start_index(ih, out_h, grad_h);
|
||||
let oh_end = end_index(ih, out_h, grad_h);
|
||||
|
||||
let ow_start = start_index(iw, out_w, grad_w);
|
||||
let ow_end = end_index(iw, out_w, grad_w);
|
||||
|
||||
let mut grad_acc = Line::empty(grad.line_size()).fill(E::from_int(0));
|
||||
|
||||
let index_base = b * grad.stride(0) + (c * grad.stride(3));
|
||||
|
||||
for oh in oh_start..oh_end {
|
||||
let ih_start = start_index(oh, grad_h, out_h);
|
||||
let ih_end = end_index(oh, grad_h, out_h);
|
||||
|
||||
if ih >= ih_start && ih < ih_end {
|
||||
for ow in ow_start..ow_end {
|
||||
let iw_start = start_index(ow, grad_w, out_w);
|
||||
let iw_end = end_index(ow, grad_w, out_w);
|
||||
|
||||
if iw >= iw_start && iw < iw_end {
|
||||
let num_ih = ih_end - ih_start;
|
||||
let num_iw = iw_end - iw_start;
|
||||
|
||||
let index = index_base + (oh * grad_stride_h) + (ow * grad_stride_w);
|
||||
grad_acc += grad[index / grad.line_size()] / Line::cast_from(num_iw * num_ih);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output[(b, ih, iw, c)] = grad_acc;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
|
||||
(output_size_index * input_size) / output_size
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
|
||||
let index = (output_size_index + 1) * input_size;
|
||||
let index = index.div_ceil(output_size);
|
||||
|
||||
if input_size < index {
|
||||
input_size
|
||||
} else {
|
||||
index
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn adaptive_avg_pool2d_backward<R: CubeRuntime>(
|
||||
x: CubeTensor<R>,
|
||||
out_grad: CubeTensor<R>,
|
||||
) -> CubeTensor<R> {
|
||||
let [batches, channels, height, width] = x.meta.shape().dims();
|
||||
|
||||
let out_grad = into_contiguous_aligned(permute_nchw_to_nhwc(out_grad));
|
||||
let line_size = max_line_size(&out_grad);
|
||||
|
||||
let out_shape = Shape::new([batches, height, width, channels]);
|
||||
let output = empty_device_dtype(x.client.clone(), x.device.clone(), out_shape, x.dtype);
|
||||
|
||||
let num_elems = output.meta.num_elements();
|
||||
|
||||
let working_units = num_elems / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&x.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);
|
||||
|
||||
adaptive_avg_pool2d_backward_direct::launch(
|
||||
&x.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(out_grad, output),
|
||||
out_grad.as_tensor_arg(line_size),
|
||||
view4d(&output, line_size),
|
||||
shape_divmod(&output),
|
||||
ScalarArg::new(working_units),
|
||||
output.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
permute_nhwc_to_nchw(output)
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
use super::pool2d::{
|
||||
Pool2dDirectArgsLaunch, Pool2dDirectStrategy, Pool2dDirectStrategyFamily, pool2d_direct,
|
||||
};
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::{
|
||||
into_contiguous_aligned,
|
||||
pool::pool2d::{Position, view4d},
|
||||
utils::{address_type, shape_divmod},
|
||||
},
|
||||
ops::{max_line_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::{Shape, ops::conv::calculate_pool_output_size};
|
||||
use cubecl::{CubeDim, calculate_cube_count_elemwise, prelude::ScalarArg};
|
||||
use cubecl::{prelude::*, std::tensor::View};
|
||||
|
||||
struct AvgPoolStrategy;
|
||||
|
||||
impl Pool2dDirectStrategyFamily for AvgPoolStrategy {
|
||||
type Indices = ();
|
||||
type Config = AvgPoolStrategyConfig;
|
||||
type Pool2d<N: Numeric> = Self;
|
||||
}
|
||||
|
||||
#[derive(CubeType, Debug, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
pub struct AvgPoolStrategyConfig {
|
||||
count_include_pad: bool,
|
||||
/// Total padded height (input_height + 2 * padding_0)
|
||||
padded_h: u32,
|
||||
/// Total padded width (input_width + 2 * padding_1)
|
||||
padded_w: u32,
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> Pool2dDirectStrategy<N> for AvgPoolStrategy {
|
||||
type Accumulator = (Line<N>, u32);
|
||||
type Config = AvgPoolStrategyConfig;
|
||||
type Indices = ();
|
||||
|
||||
fn initialize(
|
||||
#[comptime] _config: &Self::Config,
|
||||
#[comptime] line_size: LineSize,
|
||||
) -> Self::Accumulator {
|
||||
let sum = Line::empty(line_size).fill(N::from_int(0));
|
||||
// Count will be set dynamically: either by accumulate (count_include_pad=false)
|
||||
// or by set_padded_count (count_include_pad=true)
|
||||
let count = 0u32;
|
||||
|
||||
(sum, count)
|
||||
}
|
||||
|
||||
fn accumulate(
|
||||
#[comptime] config: &Self::Config,
|
||||
accumulator: &mut Self::Accumulator,
|
||||
_index: usize,
|
||||
result: Line<N>,
|
||||
) {
|
||||
let (sum, count) = accumulator;
|
||||
|
||||
// Only count valid positions when count_include_pad=false
|
||||
if comptime![!config.count_include_pad] {
|
||||
*count += 1;
|
||||
}
|
||||
|
||||
*sum += result;
|
||||
}
|
||||
|
||||
fn count_position(
|
||||
#[comptime] config: &Self::Config,
|
||||
accumulator: &mut Self::Accumulator,
|
||||
ih: u32,
|
||||
iw: u32,
|
||||
) {
|
||||
// When count_include_pad=true, count positions within padded bounds
|
||||
// (excludes ceil_mode extensions beyond the padded input)
|
||||
if comptime![config.count_include_pad] && ih < config.padded_h && iw < config.padded_w {
|
||||
let (_sum, count) = accumulator;
|
||||
*count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
fn store(
|
||||
#[comptime] _config: &Self::Config,
|
||||
position: Position,
|
||||
output: &mut View<Line<N>, Position, ReadWrite>,
|
||||
_output_indices: &mut (),
|
||||
accumulator: Self::Accumulator,
|
||||
) {
|
||||
let (sum, count) = accumulator;
|
||||
output[position] = sum / Line::cast_from(count);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn avg_pool2d<R: CubeRuntime>(
|
||||
x: CubeTensor<R>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
ceil_mode: bool,
|
||||
) -> CubeTensor<R> {
|
||||
let [batch_size, channels, in_h, in_w] = x.meta.shape().dims();
|
||||
let dilation = 1;
|
||||
|
||||
let size_0 = calculate_pool_output_size(
|
||||
kernel_size[0],
|
||||
stride[0],
|
||||
padding[0],
|
||||
dilation,
|
||||
in_h,
|
||||
ceil_mode,
|
||||
);
|
||||
let size_1 = calculate_pool_output_size(
|
||||
kernel_size[1],
|
||||
stride[1],
|
||||
padding[1],
|
||||
dilation,
|
||||
in_w,
|
||||
ceil_mode,
|
||||
);
|
||||
|
||||
// Padded dimensions (for count_include_pad with ceil_mode)
|
||||
let padded_0 = in_h + 2 * padding[0];
|
||||
let padded_1 = in_w + 2 * padding[1];
|
||||
|
||||
let x = into_contiguous_aligned(permute_nchw_to_nhwc(x));
|
||||
let line_size = max_line_size(&x);
|
||||
|
||||
let shape_out = Shape::new([batch_size, size_0, size_1, channels]);
|
||||
let output = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, x.dtype);
|
||||
|
||||
let working_units = output.meta.num_elements() / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&x.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);
|
||||
|
||||
pool2d_direct::launch::<AvgPoolStrategy, R>(
|
||||
&x.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(x, output),
|
||||
x.as_tensor_arg(line_size),
|
||||
view4d(&output, line_size),
|
||||
(),
|
||||
shape_divmod(&output),
|
||||
ScalarArg::new(working_units),
|
||||
Pool2dDirectArgsLaunch::new(
|
||||
ScalarArg::new(stride[0] as u32),
|
||||
ScalarArg::new(stride[1] as u32),
|
||||
ScalarArg::new(dilation as u32),
|
||||
ScalarArg::new(dilation as u32),
|
||||
ScalarArg::new(padding[0] as u32),
|
||||
ScalarArg::new(padding[1] as u32),
|
||||
),
|
||||
(kernel_size[0] as u32, kernel_size[1] as u32),
|
||||
AvgPoolStrategyConfig {
|
||||
count_include_pad,
|
||||
padded_h: padded_0 as u32,
|
||||
padded_w: padded_1 as u32,
|
||||
},
|
||||
output.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
permute_nhwc_to_nchw(output)
|
||||
}
|
||||
@@ -0,0 +1,183 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::{
|
||||
pool::pool2d::{Position, view4d},
|
||||
utils::{address_type, decompose_linear, shape_divmod},
|
||||
},
|
||||
ops::{max_line_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::Shape;
|
||||
use cubecl::{
|
||||
calculate_cube_count_elemwise,
|
||||
prelude::*,
|
||||
std::{FastDivmod, tensor::View},
|
||||
};
|
||||
|
||||
#[derive(CubeLaunch, CubeType)]
|
||||
pub(crate) struct PoolBackwardArgs {
|
||||
pub stride_0: i32,
|
||||
pub stride_1: i32,
|
||||
pub dilation_0: i32,
|
||||
pub dilation_1: i32,
|
||||
pub padding_0: i32,
|
||||
pub padding_1: i32,
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn avg_pool2d_backward_kernel<E: Numeric>(
|
||||
grad: &Tensor<Line<E>>,
|
||||
output: &mut View<Line<E>, Position, ReadWrite>,
|
||||
out_shape: Sequence<FastDivmod<usize>>,
|
||||
working_units: usize,
|
||||
args: &PoolBackwardArgs,
|
||||
#[comptime] kernel_size_0: i32,
|
||||
#[comptime] kernel_size_1: i32,
|
||||
#[comptime] count_include_pad: bool,
|
||||
#[define(E)] _dtype: StorageType,
|
||||
) {
|
||||
if ABSOLUTE_POS >= working_units {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let line_size = grad.line_size();
|
||||
|
||||
let (_, pos) = decompose_linear(ABSOLUTE_POS * output.line_size(), &out_shape);
|
||||
let [batch, ih, iw, channel] = *pos else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let mut grad_acc = Line::empty(grad.line_size()).fill(E::from_int(0));
|
||||
|
||||
let (oh_start, oh_end, ow_start, ow_end) = loop_ranges(
|
||||
ih as i32,
|
||||
iw as i32,
|
||||
grad.shape(1) as u32,
|
||||
grad.shape(2) as u32,
|
||||
args,
|
||||
kernel_size_0,
|
||||
kernel_size_1,
|
||||
);
|
||||
|
||||
let padding_0 = args.padding_0 as u32;
|
||||
let padding_1 = args.padding_1 as u32;
|
||||
let stride_0 = args.stride_0 as u32;
|
||||
let stride_1 = args.stride_1 as u32;
|
||||
let kernel_size_0 = comptime![kernel_size_0 as u32];
|
||||
let kernel_size_1 = comptime![kernel_size_1 as u32];
|
||||
|
||||
let index_base = batch * grad.stride(0) + channel * grad.stride(3);
|
||||
let border_bottom = output.shape().1 as u32 + padding_0;
|
||||
let border_right = output.shape().2 as u32 + padding_1;
|
||||
let begin_h = ih as u32 + padding_0;
|
||||
let begin_w = iw as u32 + padding_1;
|
||||
|
||||
for oh in oh_start..oh_end {
|
||||
let ih_start = oh * stride_0;
|
||||
let ih_end = clamp_max(ih_start + kernel_size_0, border_bottom);
|
||||
let ih_start = clamp_min(ih_start, padding_0);
|
||||
|
||||
if begin_h >= ih_start && (ih as u32) < ih_end {
|
||||
for ow in ow_start..ow_end {
|
||||
let index =
|
||||
index_base + oh as usize * grad.stride(1) + ow as usize * grad.stride(2);
|
||||
|
||||
let iw_start = ow * stride_1;
|
||||
let iw_end = clamp_max(iw_start + kernel_size_1, border_right);
|
||||
let iw_start = clamp_min(iw_start, padding_1);
|
||||
|
||||
if begin_w >= iw_start && (iw as u32) < iw_end {
|
||||
if count_include_pad {
|
||||
grad_acc += grad[index / line_size]
|
||||
/ Line::cast_from(kernel_size_0 * kernel_size_1);
|
||||
} else {
|
||||
let ih_diff = ih_end - ih_start;
|
||||
let iw_diff = iw_end - iw_start;
|
||||
let count = Line::cast_from(ih_diff * iw_diff);
|
||||
grad_acc += grad[index / line_size] / count;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output[(batch, ih, iw, channel)] = grad_acc;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn loop_ranges(
|
||||
ih: i32,
|
||||
iw: i32,
|
||||
grad_h: u32,
|
||||
grad_w: u32,
|
||||
args: &PoolBackwardArgs,
|
||||
#[comptime] kernel_size_0: i32,
|
||||
#[comptime] kernel_size_1: i32,
|
||||
) -> (u32, u32, u32, u32) {
|
||||
let kms_0 = args.dilation_0 * kernel_size_0 - args.stride_0;
|
||||
let kms_1 = args.dilation_1 * kernel_size_1 - args.stride_1;
|
||||
|
||||
let oh_start = clamp_min((ih + args.padding_0 - kms_0) / args.stride_0, 0) as u32;
|
||||
let ow_start = clamp_min((iw + args.padding_1 - kms_1) / args.stride_1, 0) as u32;
|
||||
let oh_end = clamp_max(clamp_min(kms_0, 0) as u32 + oh_start, grad_h - 1) + 1;
|
||||
let ow_end = clamp_max(clamp_min(kms_1, 0) as u32 + ow_start, grad_w - 1) + 1;
|
||||
|
||||
(oh_start, oh_end, ow_start, ow_end)
|
||||
}
|
||||
|
||||
pub(crate) fn avg_pool2d_backward<R: CubeRuntime>(
|
||||
x: CubeTensor<R>,
|
||||
grad: CubeTensor<R>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
_ceil_mode: bool,
|
||||
) -> CubeTensor<R> {
|
||||
let [batches, channels, height, width] = x.meta.shape().dims();
|
||||
|
||||
let grad = permute_nchw_to_nhwc(grad);
|
||||
|
||||
let line_size = if x.meta.strides()[3] == grad.meta.strides()[3] {
|
||||
max_line_size(&x)
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
||||
let dilation = 1;
|
||||
|
||||
let out_shape = Shape::new([batches, height, width, channels]);
|
||||
let output = empty_device_dtype(x.client.clone(), x.device.clone(), out_shape, x.dtype);
|
||||
|
||||
let working_units = output.meta.num_elements() / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&x.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
avg_pool2d_backward_kernel::launch_unchecked(
|
||||
&grad.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(grad, output),
|
||||
grad.as_tensor_arg(line_size),
|
||||
view4d(&output, line_size),
|
||||
shape_divmod(&output),
|
||||
ScalarArg::new(working_units),
|
||||
PoolBackwardArgsLaunch::new(
|
||||
ScalarArg::new(stride[0] as i32),
|
||||
ScalarArg::new(stride[1] as i32),
|
||||
ScalarArg::new(dilation),
|
||||
ScalarArg::new(dilation),
|
||||
ScalarArg::new(padding[0] as i32),
|
||||
ScalarArg::new(padding[1] as i32),
|
||||
),
|
||||
kernel_size[0] as i32,
|
||||
kernel_size[1] as i32,
|
||||
count_include_pad,
|
||||
output.dtype.into(),
|
||||
)
|
||||
}
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
permute_nhwc_to_nchw(output)
|
||||
}
|
||||
@@ -0,0 +1,255 @@
|
||||
use super::pool2d::{
|
||||
Pool2dDirectArgsLaunch, Pool2dDirectStrategy, Pool2dDirectStrategyFamily, pool2d_direct,
|
||||
};
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::{
|
||||
into_contiguous_aligned,
|
||||
pool::pool2d::{Position, view4d},
|
||||
utils::{address_type, shape_divmod},
|
||||
},
|
||||
ops::{max_line_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::{DType, Shape, ops::conv::calculate_pool_output_size};
|
||||
use cubecl::{CubeDim, calculate_cube_count_elemwise, prelude::*, std::tensor::View};
|
||||
|
||||
struct MaxPoolStrategy;
|
||||
struct MaxPoolWithIndicesStrategy;
|
||||
|
||||
impl Pool2dDirectStrategyFamily for MaxPoolStrategy {
|
||||
type Indices = ();
|
||||
type Config = ();
|
||||
type Pool2d<N: Numeric> = Self;
|
||||
}
|
||||
|
||||
impl Pool2dDirectStrategyFamily for MaxPoolWithIndicesStrategy {
|
||||
type Indices = View<Line<i32>, Position, ReadWrite>;
|
||||
type Config = ();
|
||||
type Pool2d<N: Numeric> = Self;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> Pool2dDirectStrategy<N> for MaxPoolStrategy {
|
||||
type Accumulator = Line<N>;
|
||||
type Config = ();
|
||||
type Indices = ();
|
||||
|
||||
fn initialize(
|
||||
#[comptime] _config: &Self::Config,
|
||||
#[comptime] line_size: LineSize,
|
||||
) -> Self::Accumulator {
|
||||
Line::empty(line_size).fill(N::min_value())
|
||||
}
|
||||
|
||||
fn accumulate(
|
||||
#[comptime] _config: &Self::Config,
|
||||
accumulator: &mut Self::Accumulator,
|
||||
_index: LineSize,
|
||||
result: Line<N>,
|
||||
) {
|
||||
*accumulator = max(*accumulator, result);
|
||||
}
|
||||
|
||||
fn count_position(
|
||||
#[comptime] _config: &Self::Config,
|
||||
_accumulator: &mut Self::Accumulator,
|
||||
_ih: u32,
|
||||
_iw: u32,
|
||||
) {
|
||||
}
|
||||
|
||||
fn store(
|
||||
#[comptime] _config: &Self::Config,
|
||||
position: Position,
|
||||
output: &mut View<Line<N>, Position, ReadWrite>,
|
||||
_output_indices: &mut (),
|
||||
accumulator: Self::Accumulator,
|
||||
) {
|
||||
output[position] = accumulator;
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> Pool2dDirectStrategy<N> for MaxPoolWithIndicesStrategy {
|
||||
type Accumulator = (Line<N>, Line<i32>);
|
||||
type Config = ();
|
||||
type Indices = View<Line<i32>, Position, ReadWrite>;
|
||||
|
||||
fn initialize(
|
||||
#[comptime] _config: &Self::Config,
|
||||
#[comptime] line_size: LineSize,
|
||||
) -> Self::Accumulator {
|
||||
let val = Line::empty(line_size).fill(N::min_value());
|
||||
let idx = Line::empty(line_size).fill(0i32);
|
||||
(val, idx)
|
||||
}
|
||||
|
||||
fn accumulate(
|
||||
#[comptime] _config: &Self::Config,
|
||||
accumulator: &mut Self::Accumulator,
|
||||
index: usize,
|
||||
result: Line<N>,
|
||||
) {
|
||||
let indices = Line::cast_from(index);
|
||||
accumulator.1 = select_many(result.greater_than(accumulator.0), indices, accumulator.1);
|
||||
accumulator.0 = max(result, accumulator.0);
|
||||
}
|
||||
|
||||
fn count_position(
|
||||
#[comptime] _config: &Self::Config,
|
||||
_accumulator: &mut Self::Accumulator,
|
||||
_ih: u32,
|
||||
_iw: u32,
|
||||
) {
|
||||
}
|
||||
|
||||
fn store(
|
||||
#[comptime] _config: &Self::Config,
|
||||
position: Position,
|
||||
output: &mut View<Line<N>, Position, ReadWrite>,
|
||||
output_indices: &mut View<Line<i32>, Position, ReadWrite>,
|
||||
accumulator: Self::Accumulator,
|
||||
) {
|
||||
output[position] = accumulator.0;
|
||||
output_indices[position] = accumulator.1;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn max_pool2d<R: CubeRuntime>(
|
||||
x: CubeTensor<R>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
) -> CubeTensor<R> {
|
||||
let [batch_size, channels, height, width] = x.meta.shape().dims();
|
||||
|
||||
let size_0 = calculate_pool_output_size(
|
||||
kernel_size[0],
|
||||
stride[0],
|
||||
padding[0],
|
||||
dilation[0],
|
||||
height,
|
||||
ceil_mode,
|
||||
);
|
||||
let size_1 = calculate_pool_output_size(
|
||||
kernel_size[1],
|
||||
stride[1],
|
||||
padding[1],
|
||||
dilation[1],
|
||||
width,
|
||||
ceil_mode,
|
||||
);
|
||||
|
||||
let x = into_contiguous_aligned(permute_nchw_to_nhwc(x));
|
||||
|
||||
let line_size = max_line_size(&x);
|
||||
|
||||
let shape_out = Shape::new([batch_size, size_0, size_1, channels]);
|
||||
let output = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, x.dtype);
|
||||
|
||||
let working_units = output.meta.num_elements() / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&x.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);
|
||||
|
||||
pool2d_direct::launch::<MaxPoolStrategy, R>(
|
||||
&x.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(x, output),
|
||||
x.as_tensor_arg(line_size),
|
||||
view4d(&output, line_size),
|
||||
(),
|
||||
shape_divmod(&output),
|
||||
ScalarArg::new(working_units),
|
||||
Pool2dDirectArgsLaunch::new(
|
||||
ScalarArg::new(stride[0] as u32),
|
||||
ScalarArg::new(stride[1] as u32),
|
||||
ScalarArg::new(dilation[0] as u32),
|
||||
ScalarArg::new(dilation[1] as u32),
|
||||
ScalarArg::new(padding[0] as u32),
|
||||
ScalarArg::new(padding[1] as u32),
|
||||
),
|
||||
(kernel_size[0] as u32, kernel_size[1] as u32),
|
||||
(),
|
||||
output.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
permute_nhwc_to_nchw(output)
|
||||
}
|
||||
|
||||
pub(crate) fn max_pool2d_with_indices<R: CubeRuntime>(
|
||||
x: CubeTensor<R>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
dtype_indices: DType,
|
||||
) -> (CubeTensor<R>, CubeTensor<R>) {
|
||||
let [batch_size, channels, size_0, size_1] = x.meta.shape().dims();
|
||||
|
||||
let size_0 = calculate_pool_output_size(
|
||||
kernel_size[0],
|
||||
stride[0],
|
||||
padding[0],
|
||||
dilation[0],
|
||||
size_0,
|
||||
ceil_mode,
|
||||
);
|
||||
let size_1 = calculate_pool_output_size(
|
||||
kernel_size[1],
|
||||
stride[1],
|
||||
padding[1],
|
||||
dilation[1],
|
||||
size_1,
|
||||
ceil_mode,
|
||||
);
|
||||
|
||||
let x = into_contiguous_aligned(permute_nchw_to_nhwc(x));
|
||||
let line_size = max_line_size(&x);
|
||||
|
||||
let shape_out = Shape::new([batch_size, size_0, size_1, channels]);
|
||||
let output = empty_device_dtype(
|
||||
x.client.clone(),
|
||||
x.device.clone(),
|
||||
shape_out.clone(),
|
||||
x.dtype,
|
||||
);
|
||||
let indices = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, dtype_indices);
|
||||
|
||||
let working_units = output.meta.num_elements() / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&x.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);
|
||||
|
||||
pool2d_direct::launch::<MaxPoolWithIndicesStrategy, R>(
|
||||
&x.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(x, output, indices),
|
||||
x.as_tensor_arg(line_size),
|
||||
view4d(&output, line_size),
|
||||
view4d(&indices, line_size),
|
||||
shape_divmod(&output),
|
||||
ScalarArg::new(working_units),
|
||||
Pool2dDirectArgsLaunch::new(
|
||||
ScalarArg::new(stride[0] as u32),
|
||||
ScalarArg::new(stride[1] as u32),
|
||||
ScalarArg::new(dilation[0] as u32),
|
||||
ScalarArg::new(dilation[1] as u32),
|
||||
ScalarArg::new(padding[0] as u32),
|
||||
ScalarArg::new(padding[1] as u32),
|
||||
),
|
||||
(kernel_size[0] as u32, kernel_size[1] as u32),
|
||||
(),
|
||||
output.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
let output = permute_nhwc_to_nchw(output);
|
||||
let indices = permute_nhwc_to_nchw(indices);
|
||||
(output, indices)
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::{
|
||||
into_contiguous_aligned,
|
||||
utils::{address_type, decompose_linear, shape_divmod},
|
||||
},
|
||||
ops::{max_line_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::Shape;
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::FastDivmod};
|
||||
|
||||
use super::{PoolBackwardArgs, PoolBackwardArgsLaunch};
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn max_pool2d_with_indices_backward_kernel<E: Numeric, I: Int>(
|
||||
grad: &Tensor<Line<E>>,
|
||||
indices: &Tensor<Line<I>>,
|
||||
output: &mut Tensor<Line<E>>,
|
||||
out_shape: Sequence<FastDivmod<usize>>,
|
||||
working_units: usize,
|
||||
args: &PoolBackwardArgs,
|
||||
#[comptime] kernel_size_0: i32,
|
||||
#[comptime] kernel_size_1: i32,
|
||||
#[define(E, I)] _dtypes: [StorageType; 2],
|
||||
) {
|
||||
if ABSOLUTE_POS >= working_units {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let (_, pos) = decompose_linear(ABSOLUTE_POS * output.line_size(), &out_shape);
|
||||
let [batch, ih, iw, channel] = *pos else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let line_size = grad.line_size();
|
||||
|
||||
let index_current = ih * output.shape(2) + iw;
|
||||
|
||||
let (oh_start, oh_end, ow_start, ow_end) = loop_ranges(
|
||||
ih as i32,
|
||||
iw as i32,
|
||||
grad.shape(1) as u32,
|
||||
grad.shape(2) as u32,
|
||||
args,
|
||||
kernel_size_0,
|
||||
kernel_size_1,
|
||||
);
|
||||
|
||||
let mut grad_acc = Line::empty(grad.line_size()).fill(E::from_int(0));
|
||||
|
||||
let grad_idx_base = batch * grad.stride(0) + channel * grad.stride(3);
|
||||
let ind_idx_base = batch * indices.stride(0) + channel * indices.stride(3);
|
||||
|
||||
for oh in oh_start..oh_end {
|
||||
for ow in ow_start..ow_end {
|
||||
let grad_index =
|
||||
grad_idx_base + oh as usize * grad.stride(1) + ow as usize * grad.stride(2);
|
||||
let indices_index =
|
||||
ind_idx_base + oh as usize * indices.stride(1) + ow as usize * indices.stride(2);
|
||||
let index_max = Line::<u32>::cast_from(indices[indices_index / line_size]);
|
||||
|
||||
grad_acc += select_many(
|
||||
index_max.equal(Line::cast_from(index_current)),
|
||||
grad[grad_index / line_size],
|
||||
Line::new(E::from_int(0)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let index_output = batch * output.stride(0)
|
||||
+ ih * output.stride(1)
|
||||
+ iw * output.stride(2)
|
||||
+ channel * output.stride(3);
|
||||
|
||||
output[index_output / output.line_size()] = grad_acc;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn loop_ranges(
|
||||
ih: i32,
|
||||
iw: i32,
|
||||
grad_h: u32,
|
||||
grad_w: u32,
|
||||
args: &PoolBackwardArgs,
|
||||
#[comptime] kernel_size_0: i32,
|
||||
#[comptime] kernel_size_1: i32,
|
||||
) -> (u32, u32, u32, u32) {
|
||||
let kms_0 = args.dilation_0 * kernel_size_0 - args.stride_0;
|
||||
let kms_1 = args.dilation_1 * kernel_size_1 - args.stride_1;
|
||||
|
||||
let oh_start = clamp_min((ih + args.padding_0 - kms_0) / args.stride_0, 0) as u32;
|
||||
let ow_start = clamp_min((iw + args.padding_1 - kms_1) / args.stride_1, 0) as u32;
|
||||
let oh_end = clamp_max(clamp_min(kms_0, 0) as u32 + oh_start, grad_h - 1) + 1;
|
||||
let ow_end = clamp_max(clamp_min(kms_1, 0) as u32 + ow_start, grad_w - 1) + 1;
|
||||
|
||||
(oh_start, oh_end, ow_start, ow_end)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn max_pool2d_with_indices_backward<R: CubeRuntime>(
|
||||
x: CubeTensor<R>,
|
||||
grad: CubeTensor<R>,
|
||||
indices: CubeTensor<R>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
_ceil_mode: bool,
|
||||
) -> CubeTensor<R> {
|
||||
let [batches, channels, height, width] = x.meta.shape().dims();
|
||||
|
||||
let grad = into_contiguous_aligned(permute_nchw_to_nhwc(grad));
|
||||
let indices = into_contiguous_aligned(permute_nchw_to_nhwc(indices));
|
||||
|
||||
let line_size = if grad.meta.strides()[3] == indices.meta.strides()[3] {
|
||||
max_line_size(&grad)
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
||||
let out_shape = Shape::new([batches, height, width, channels]);
|
||||
let output = empty_device_dtype(x.client.clone(), x.device.clone(), out_shape, x.dtype);
|
||||
|
||||
let working_units = output.meta.num_elements() / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&x.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
max_pool2d_with_indices_backward_kernel::launch_unchecked(
|
||||
&x.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(grad, indices, output),
|
||||
grad.as_tensor_arg(line_size),
|
||||
indices.as_tensor_arg(line_size),
|
||||
output.as_tensor_arg(line_size),
|
||||
shape_divmod(&output),
|
||||
ScalarArg::new(working_units),
|
||||
PoolBackwardArgsLaunch::new(
|
||||
ScalarArg::new(stride[0] as i32),
|
||||
ScalarArg::new(stride[1] as i32),
|
||||
ScalarArg::new(dilation[0] as i32),
|
||||
ScalarArg::new(dilation[1] as i32),
|
||||
ScalarArg::new(padding[0] as i32),
|
||||
ScalarArg::new(padding[1] as i32),
|
||||
),
|
||||
kernel_size[0] as i32,
|
||||
kernel_size[1] as i32,
|
||||
[x.dtype.into(), indices.dtype.into()],
|
||||
)
|
||||
.expect("Kernel to never fail")
|
||||
};
|
||||
|
||||
permute_nhwc_to_nchw(output)
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
mod adaptive_avg_pool2d;
|
||||
mod adaptive_avg_pool2d_backward;
|
||||
mod avg_pool2d;
|
||||
mod avg_pool2d_backward;
|
||||
mod max_pool2d;
|
||||
mod max_pool2d_backward;
|
||||
|
||||
pub(super) mod pool2d;
|
||||
|
||||
pub(crate) use adaptive_avg_pool2d::*;
|
||||
pub(crate) use adaptive_avg_pool2d_backward::*;
|
||||
pub(crate) use avg_pool2d::*;
|
||||
pub(crate) use avg_pool2d_backward::*;
|
||||
pub(crate) use max_pool2d::*;
|
||||
pub(crate) use max_pool2d_backward::*;
|
||||
@@ -0,0 +1,155 @@
|
||||
use core::hash::Hash;
|
||||
use cubecl::{
|
||||
prelude::*,
|
||||
std::{
|
||||
FastDivmod,
|
||||
tensor::{
|
||||
View,
|
||||
launch::ViewArg,
|
||||
layout::fixed_dim::{FixedDimLayout, FixedDimLayoutLaunch},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{CubeRuntime, kernel::utils::decompose_linear, tensor::CubeTensor};
|
||||
|
||||
pub trait Pool2dDirectStrategyFamily: Send + Sync + 'static {
|
||||
type Indices: LaunchArg;
|
||||
type Config: CubeType + Clone + Send + Sync + core::fmt::Debug + Hash + core::cmp::Eq;
|
||||
type Pool2d<N: Numeric>: Pool2dDirectStrategy<N, Config = Self::Config, Indices = Self::Indices>;
|
||||
}
|
||||
|
||||
pub(super) type Position = (usize, usize, usize, usize);
|
||||
|
||||
#[cube]
|
||||
pub(crate) trait Pool2dDirectStrategy<N: Numeric>: Send + Sync + 'static {
|
||||
type Accumulator: CubeType;
|
||||
type Config: CubeType + Clone + Send + Sync + core::fmt::Debug + Hash + core::cmp::Eq;
|
||||
|
||||
type Indices: LaunchArg;
|
||||
|
||||
fn initialize(
|
||||
#[comptime] config: &Self::Config,
|
||||
#[comptime] line_size: LineSize,
|
||||
) -> Self::Accumulator;
|
||||
|
||||
fn accumulate(
|
||||
#[comptime] config: &Self::Config,
|
||||
accumulator: &mut Self::Accumulator,
|
||||
index: usize,
|
||||
result: Line<N>,
|
||||
);
|
||||
|
||||
/// Count a position within the kernel window (for avg_pool count_include_pad).
|
||||
/// Called for each position in the kernel window with the current ih/iw coordinates.
|
||||
/// Only avg_pool uses this; max_pool implements as no-op.
|
||||
fn count_position(
|
||||
#[comptime] config: &Self::Config,
|
||||
accumulator: &mut Self::Accumulator,
|
||||
ih: u32,
|
||||
iw: u32,
|
||||
);
|
||||
|
||||
fn store(
|
||||
#[comptime] config: &Self::Config,
|
||||
position: Position,
|
||||
output: &mut View<Line<N>, Position, ReadWrite>,
|
||||
output_indices: &mut Self::Indices,
|
||||
accumulator: Self::Accumulator,
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(CubeLaunch, CubeType)]
|
||||
pub struct Pool2dDirectArgs {
|
||||
pub strides_0: u32,
|
||||
pub strides_1: u32,
|
||||
pub dilation_0: u32,
|
||||
pub dilation_1: u32,
|
||||
pub padding_0: u32,
|
||||
pub padding_1: u32,
|
||||
}
|
||||
|
||||
#[cube(launch, address_type = "dynamic")]
|
||||
pub fn pool2d_direct<E: Numeric, S: Pool2dDirectStrategyFamily>(
|
||||
input: &Tensor<Line<E>>,
|
||||
output: &mut View<Line<E>, Position, ReadWrite>,
|
||||
indices: &mut S::Indices,
|
||||
out_shape: Sequence<FastDivmod<usize>>,
|
||||
working_units: usize,
|
||||
args: &Pool2dDirectArgs,
|
||||
#[comptime] kernel_size: (u32, u32),
|
||||
#[comptime] config: &S::Config,
|
||||
#[define(E)] _dtype: StorageType,
|
||||
) {
|
||||
if ABSOLUTE_POS >= working_units {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let (_, pos) = decompose_linear(ABSOLUTE_POS * output.line_size(), &out_shape);
|
||||
let [b, oh, ow, c] = *pos else { unreachable!() };
|
||||
|
||||
let (in_stride_h, in_stride_w) = (input.stride(1), input.stride(2));
|
||||
let (in_h, in_w) = (input.shape(1) as u32, input.shape(2) as u32);
|
||||
|
||||
let mut accumulator = S::Pool2d::<E>::initialize(config, input.line_size());
|
||||
|
||||
let in_b_off = b * input.stride(0);
|
||||
let in_c_off = c * input.stride(3);
|
||||
|
||||
let border_bottom = in_h + args.padding_0;
|
||||
let border_right = in_w + args.padding_1;
|
||||
|
||||
for kh in 0..kernel_size.0 {
|
||||
let ih = oh as u32 * args.strides_0 + kh * args.dilation_0;
|
||||
let within_padding_h = ih >= args.padding_0 && ih < border_bottom;
|
||||
|
||||
for kw in 0..kernel_size.1 {
|
||||
let iw = ow as u32 * args.strides_1 + kw * args.dilation_1;
|
||||
let within_padding_w = iw >= args.padding_1 && iw < border_right;
|
||||
|
||||
// Let strategy handle position counting (only used by avg_pool)
|
||||
S::Pool2d::<E>::count_position(config, &mut accumulator, ih, iw);
|
||||
|
||||
// Only accumulate values from valid input positions
|
||||
if within_padding_h && within_padding_w {
|
||||
let ih_pad = ih - args.padding_0;
|
||||
let iw_pad = iw - args.padding_1;
|
||||
|
||||
let in_h_off = ih_pad as usize * in_stride_h;
|
||||
let in_w_off = iw_pad as usize * in_stride_w;
|
||||
|
||||
let index_input = in_b_off + in_c_off + in_h_off + in_w_off;
|
||||
|
||||
S::Pool2d::<E>::accumulate(
|
||||
config,
|
||||
&mut accumulator,
|
||||
ih_pad as usize * in_w as usize + iw_pad as usize,
|
||||
input[index_input / input.line_size()],
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
S::Pool2d::<E>::store(config, (b, oh, ow, c), output, indices, accumulator);
|
||||
}
|
||||
|
||||
pub(super) fn view4d<R: CubeRuntime>(
|
||||
tensor: &CubeTensor<R>,
|
||||
line_size: LineSize,
|
||||
) -> ViewArg<'_, Position, R> {
|
||||
let shape = tensor.meta.shape();
|
||||
let shape = (
|
||||
ScalarArg::new(shape[0]),
|
||||
ScalarArg::new(shape[1]),
|
||||
ScalarArg::new(shape[2]),
|
||||
ScalarArg::new(shape[3]),
|
||||
);
|
||||
let handle = tensor.as_handle_ref();
|
||||
let len = handle.shape.iter().product::<usize>();
|
||||
let layout =
|
||||
FixedDimLayoutLaunch::<Position, R>::from_shape_handle_unchecked(&handle, shape, line_size);
|
||||
let buffer = unsafe {
|
||||
ArrayArg::from_raw_parts_and_size(handle.handle, len, line_size, handle.elem_size)
|
||||
};
|
||||
ViewArg::new::<FixedDimLayout<Position>>(buffer, layout)
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
|
||||
use burn_backend::{DType, Shape};
|
||||
|
||||
/// Pseudo-random generator with bernoulli distribution
|
||||
pub fn random_bernoulli<R: CubeRuntime>(
|
||||
shape: Shape,
|
||||
device: &R::Device,
|
||||
probability: f32,
|
||||
dtype: DType,
|
||||
) -> CubeTensor<R> {
|
||||
let client = R::client(device);
|
||||
let output = empty_device_dtype(client.clone(), device.clone(), shape, dtype);
|
||||
|
||||
cubek::random::random_bernoulli(&client, probability, output.as_handle_ref(), dtype.into())
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
mod bernoulli;
|
||||
mod normal;
|
||||
mod uniform;
|
||||
|
||||
pub use bernoulli::*;
|
||||
pub use normal::*;
|
||||
pub use uniform::*;
|
||||
@@ -0,0 +1,20 @@
|
||||
use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
|
||||
use burn_backend::{DType, Shape};
|
||||
|
||||
/// Pseudo-random generator with uniform distribution
|
||||
pub fn random_normal<R: CubeRuntime>(
|
||||
shape: Shape,
|
||||
device: &R::Device,
|
||||
mean: f32,
|
||||
std: f32,
|
||||
dtype: DType,
|
||||
) -> CubeTensor<R> {
|
||||
let client = R::client(device);
|
||||
let output = empty_device_dtype(client.clone(), device.clone(), shape, dtype);
|
||||
let output_handle = output.as_handle_ref();
|
||||
|
||||
cubek::random::random_normal(&client, mean, std, output_handle, dtype.into())
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
|
||||
use burn_backend::{DType, Shape, TensorMetadata};
|
||||
|
||||
/// Pseudo-random generator with uniform distribution
|
||||
pub fn random_uniform<R: CubeRuntime>(
|
||||
shape: Shape,
|
||||
device: &R::Device,
|
||||
lower_bound: f32,
|
||||
upper_bound: f32,
|
||||
dtype: DType,
|
||||
) -> CubeTensor<R> {
|
||||
let client = R::client(device);
|
||||
let output = empty_device_dtype(client.clone(), device.clone(), shape, dtype);
|
||||
let output_handle = output.as_handle_ref();
|
||||
|
||||
cubek::random::random_uniform(
|
||||
&client,
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
output_handle,
|
||||
dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Pseudo-random generator for uniform distribution, based on
|
||||
/// another tensor.
|
||||
pub fn random_like_uniform<R: CubeRuntime>(
|
||||
tensor: &CubeTensor<R>,
|
||||
lower_bound: f32,
|
||||
upper_bound: f32,
|
||||
dtype: DType,
|
||||
) -> CubeTensor<R> {
|
||||
random_uniform(
|
||||
tensor.shape(),
|
||||
&tensor.device,
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
dtype,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
use crate::tensor::CubeTensor;
|
||||
use crate::{CubeRuntime, ops::numeric::empty_device_dtype};
|
||||
use burn_backend::{DType, TensorMetadata};
|
||||
|
||||
/// Convert the tensor back to a higher precision data type.
|
||||
pub fn dequantize<R>(tensor: CubeTensor<R>, dtype: DType) -> CubeTensor<R>
|
||||
where
|
||||
R: CubeRuntime,
|
||||
{
|
||||
let scheme = match tensor.dtype {
|
||||
DType::QFloat(scheme) => scheme,
|
||||
_ => return tensor,
|
||||
};
|
||||
|
||||
let output = empty_device_dtype(
|
||||
tensor.client.clone(),
|
||||
tensor.device.clone(),
|
||||
tensor.shape(),
|
||||
dtype,
|
||||
);
|
||||
let (values, params) = tensor.quantized_handles().unwrap();
|
||||
|
||||
cubek::quantization::dequantize::launch_ref(
|
||||
&values.client,
|
||||
&values.as_handle_ref(),
|
||||
&output.as_handle_ref(),
|
||||
¶ms.as_handle_ref(),
|
||||
&scheme,
|
||||
dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
mod dequantize;
|
||||
mod quantize;
|
||||
|
||||
pub use dequantize::*;
|
||||
pub use quantize::*;
|
||||
@@ -0,0 +1,29 @@
|
||||
use crate::CubeRuntime;
|
||||
use crate::{ops::empty_qtensor_optimized, tensor::CubeTensor};
|
||||
use burn_backend::{TensorMetadata, quantization::QuantScheme};
|
||||
|
||||
/// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
|
||||
pub fn quantize<R>(
|
||||
tensor: CubeTensor<R>,
|
||||
scheme: &QuantScheme,
|
||||
scale: CubeTensor<R>,
|
||||
) -> CubeTensor<R>
|
||||
where
|
||||
R: CubeRuntime,
|
||||
{
|
||||
let output = empty_qtensor_optimized(tensor.shape(), *scheme, &tensor.device);
|
||||
let (out_values, out_params) = output.clone().quantized_handles().unwrap();
|
||||
|
||||
cubek::quantization::quantize::launch_ref(
|
||||
&tensor.client,
|
||||
&tensor.as_handle_ref(),
|
||||
&out_values.as_handle_ref(),
|
||||
&scale.as_handle_ref(),
|
||||
&out_params.as_handle_ref(),
|
||||
scheme,
|
||||
tensor.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,242 @@
|
||||
#[cfg(feature = "autotune")]
|
||||
use super::{autotune_reduce, autotune_sum};
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
ops::numeric::{empty_device_contiguous_dtype, zeros_client},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::{DType, TensorMetadata};
|
||||
use burn_std::Metadata;
|
||||
use cubecl::{AutotuneKey, client::ComputeClient, features::TypeUsage, ir::StorageType};
|
||||
use cubek::reduce::{
|
||||
ReduceDtypes, ReduceError, ReduceStrategy,
|
||||
components::instructions::ReduceOperationConfig,
|
||||
launch::{LineSizeStrategy, RoutineStrategy},
|
||||
routines::{BlueprintStrategy, unit::UnitStrategy},
|
||||
shared_sum,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
|
||||
/// Autotune key representative of sum versions
|
||||
pub struct SumAutotuneKey {
|
||||
/// The type of the tensor
|
||||
dtype: burn_backend::DType,
|
||||
/// The anchored length of the tensor
|
||||
#[autotune(anchor)]
|
||||
length: usize,
|
||||
}
|
||||
|
||||
/// Check if the client supports atomic add for the given element type.
|
||||
fn supports_atomic_add<R: CubeRuntime>(client: &ComputeClient<R>, dtype: DType) -> bool {
|
||||
client
|
||||
.properties()
|
||||
.type_usage(StorageType::Atomic(dtype.into()))
|
||||
.contains(TypeUsage::AtomicAdd)
|
||||
}
|
||||
|
||||
/// [Sum](sum) with fallback when `client` doesn't support atomic add for the type `E`.
|
||||
pub fn sum_fallback<R: CubeRuntime>(
|
||||
tensor: CubeTensor<R>,
|
||||
mut strategy: SumStrategy,
|
||||
) -> Result<CubeTensor<R>, ReduceError> {
|
||||
// Early check before creating output and fallback
|
||||
if matches!(strategy, SumStrategy::OneShot(_))
|
||||
&& !supports_atomic_add(&tensor.client, tensor.dtype)
|
||||
{
|
||||
strategy = SumStrategy::Chained(Default::default());
|
||||
}
|
||||
sum(tensor, strategy)
|
||||
}
|
||||
|
||||
/// Specialize reduce function to compute the sum of all elements of the `input` tensor and return
|
||||
/// the value into a single-element tensor of shape `1 x 1 x 1 x ...` with the same rank as `input`.
|
||||
///
|
||||
/// This is expected to be faster for larger tensors than calling [reduce] with the `Sum` instruction.
|
||||
///
|
||||
/// Return an error if the `client` doesn't support atomic add for the type `E`.
|
||||
pub fn sum<Run: CubeRuntime>(
|
||||
tensor: CubeTensor<Run>,
|
||||
strategy: SumStrategy,
|
||||
) -> Result<CubeTensor<Run>, ReduceError> {
|
||||
let client = tensor.client.clone();
|
||||
let device = tensor.device.clone();
|
||||
|
||||
match strategy {
|
||||
SumStrategy::OneShot(cube_count) => {
|
||||
let output = zeros_client(client.clone(), device, [1].into(), tensor.dtype);
|
||||
shared_sum::<Run>(
|
||||
&client,
|
||||
tensor.as_handle_ref(),
|
||||
output.as_handle_ref(),
|
||||
cube_count,
|
||||
tensor.dtype.into(),
|
||||
)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
SumStrategy::Chained(strategy) => {
|
||||
reduce::<Run>(tensor, None, strategy, ReduceOperationConfig::Sum)
|
||||
}
|
||||
#[cfg(feature = "autotune")]
|
||||
SumStrategy::Autotune => Ok(autotune_sum::<Run>(&client, tensor)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Select a strategy to perform a sum.
|
||||
pub enum SumStrategy {
|
||||
/// Run a single kernel with many cubes working in parallel to sum all elements.
|
||||
/// The provided value is the number of elements summed per unit (up-to-rounding )
|
||||
OneShot(u32),
|
||||
/// Use multiple kernels
|
||||
Chained(KernelReduceStrategy),
|
||||
/// Use autotune to find the best cube count given the hardware and the input.
|
||||
#[cfg(feature = "autotune")]
|
||||
Autotune,
|
||||
}
|
||||
|
||||
impl Default for SumStrategy {
|
||||
fn default() -> Self {
|
||||
#[cfg(feature = "autotune")]
|
||||
return Self::Autotune;
|
||||
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
return Self::OneShot(4);
|
||||
}
|
||||
}
|
||||
|
||||
/// Reduce all elements of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).
|
||||
///
|
||||
/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.
|
||||
///
|
||||
/// If there is no error, the output is a tensor with decreasing strides
|
||||
/// where the shape of reduced dim is set to 1 but all shape are similar to the input.
|
||||
pub fn reduce<Run: CubeRuntime>(
|
||||
mut tensor: CubeTensor<Run>,
|
||||
output_dtype: Option<DType>,
|
||||
strategy: KernelReduceStrategy,
|
||||
config: ReduceOperationConfig,
|
||||
) -> Result<CubeTensor<Run>, cubek::reduce::ReduceError> {
|
||||
// In practice, it looks like starting by the axis with the smallest shape
|
||||
// and going in increasing order lead to the fastest calculation.
|
||||
let sorted_axis = argsort(tensor.meta.shape());
|
||||
for axis in sorted_axis {
|
||||
tensor = reduce_dim::<Run>(tensor, output_dtype, axis, strategy.clone(), config)?;
|
||||
}
|
||||
// reshape to scalar tensor
|
||||
*tensor.meta = Metadata::new([1], [1]);
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
fn argsort(shape: &[usize]) -> Vec<usize> {
|
||||
let mut indices = (0..shape.len()).collect::<Vec<_>>();
|
||||
indices.sort_by_key(|&i| &shape[i]);
|
||||
indices
|
||||
}
|
||||
|
||||
/// Reduce the given `axis` of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).
|
||||
///
|
||||
/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.
|
||||
/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid.
|
||||
///
|
||||
/// If there is no error, the output is a tensor with decreasing strides
|
||||
/// where the shape of reduced dim is set to 1 but all shape are similar to the input.
|
||||
pub fn reduce_dim<Run: CubeRuntime>(
|
||||
input: CubeTensor<Run>,
|
||||
output_dtype: Option<DType>,
|
||||
dim: usize,
|
||||
strategy: KernelReduceStrategy,
|
||||
config: ReduceOperationConfig,
|
||||
) -> Result<CubeTensor<Run>, cubek::reduce::ReduceError> {
|
||||
debug_assert!(
|
||||
!matches!(
|
||||
config,
|
||||
ReduceOperationConfig::ArgMax | ReduceOperationConfig::ArgMin
|
||||
) || output_dtype.is_some(),
|
||||
"The `output_dtype` has to be `Some` only when the `config` is `ArgMax` or `ArgMin`.
|
||||
"
|
||||
);
|
||||
|
||||
let dtypes = config.precision(input.dtype.into(), output_dtype.map(Into::into));
|
||||
let client = input.client.clone();
|
||||
let output = init_reduce_output::<Run>(&input, dim, &dtypes).ok_or(
|
||||
cubek::reduce::ReduceError::InvalidAxis {
|
||||
axis: dim,
|
||||
rank: input.meta.num_dims(),
|
||||
},
|
||||
)?;
|
||||
|
||||
let result = match strategy {
|
||||
KernelReduceStrategy::Unspecified => cubek::reduce::reduce::<Run>(
|
||||
&client,
|
||||
input.as_handle_ref(),
|
||||
output.as_handle_ref(),
|
||||
dim,
|
||||
ReduceStrategy {
|
||||
routine: RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
|
||||
line_size: LineSizeStrategy {
|
||||
parallel_output_vectorization: false,
|
||||
},
|
||||
},
|
||||
config,
|
||||
dtypes,
|
||||
),
|
||||
KernelReduceStrategy::Specific(strategy) => cubek::reduce::reduce::<Run>(
|
||||
&client,
|
||||
input.as_handle_ref(),
|
||||
output.as_handle_ref(),
|
||||
dim,
|
||||
strategy,
|
||||
config,
|
||||
dtypes,
|
||||
),
|
||||
#[cfg(feature = "autotune")]
|
||||
KernelReduceStrategy::Autotune => {
|
||||
autotune_reduce::<Run>(&client, input, output.clone(), dim, config, dtypes);
|
||||
Ok(())
|
||||
}
|
||||
};
|
||||
result.map(|_| output)
|
||||
}
|
||||
|
||||
/// Creates an empty output tensor with the proper shape and decreasing strides to reduce the given `axis` of `input`
|
||||
/// or return `None` if `axis` is out-of-bound.
|
||||
pub fn init_reduce_output<Run: CubeRuntime>(
|
||||
input: &CubeTensor<Run>,
|
||||
dim: usize,
|
||||
dtypes: &ReduceDtypes,
|
||||
) -> Option<CubeTensor<Run>> {
|
||||
(dim < input.meta.num_dims()).then(|| {
|
||||
let mut shape_out = input.shape();
|
||||
shape_out[dim] = 1;
|
||||
empty_device_contiguous_dtype(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
shape_out,
|
||||
dtypes.output.elem_type().into(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Select a strategy to perform a reduction.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum KernelReduceStrategy {
|
||||
/// Use a best-effort strategy based on the hardware capacity.
|
||||
/// This differs from Autotune as it doesn't try and compare many strategies to select the best.
|
||||
Unspecified,
|
||||
/// Fix the exact strategy for the reduction.
|
||||
Specific(cubek::reduce::launch::ReduceStrategy),
|
||||
/// Use autotune to find the best strategy given the hardware and the inputs.
|
||||
#[cfg(feature = "autotune")]
|
||||
Autotune,
|
||||
}
|
||||
|
||||
impl Default for KernelReduceStrategy {
|
||||
fn default() -> Self {
|
||||
#[cfg(feature = "autotune")]
|
||||
return Self::Autotune;
|
||||
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
return Self::Unspecified;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
mod base;
|
||||
#[cfg(feature = "autotune")]
|
||||
mod tune;
|
||||
|
||||
pub use base::*;
|
||||
#[cfg(feature = "autotune")]
|
||||
pub use tune::*;
|
||||
@@ -0,0 +1,286 @@
|
||||
#![allow(missing_docs)]
|
||||
|
||||
use super::SumAutotuneKey;
|
||||
use crate::{CubeAutotuneKey, CubeRuntime, CubeTuneId, tensor::CubeTensor};
|
||||
use cubecl::{
|
||||
client::ComputeClient,
|
||||
tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner},
|
||||
};
|
||||
use cubek::reduce::{
|
||||
ReduceDtypes, ReduceStrategy,
|
||||
components::instructions::ReduceOperationConfig,
|
||||
launch::{LineSizeStrategy, RoutineStrategy, tune_key::ReduceAutotuneKey},
|
||||
routines::{BlueprintStrategy, cube::CubeStrategy, plane::PlaneStrategy, unit::UnitStrategy},
|
||||
};
|
||||
|
||||
/// Executes autotune on reduce operations.
|
||||
pub fn autotune_reduce<R: CubeRuntime>(
|
||||
client: &ComputeClient<R>,
|
||||
input: CubeTensor<R>,
|
||||
output: CubeTensor<R>,
|
||||
axis: usize,
|
||||
config: ReduceOperationConfig,
|
||||
dtypes: ReduceDtypes,
|
||||
) {
|
||||
use reduce_ops::*;
|
||||
|
||||
static TUNER: LocalTuner<ReduceAutotuneKey, CubeTuneId> = local_tuner!("reduce-dim");
|
||||
|
||||
let tunables = TUNER.init(|| {
|
||||
const PRIORITY_MAX: i8 = 2;
|
||||
const PRIORITY_MIN: i8 = 1;
|
||||
const PRIORITY_SKIP: i8 = -1;
|
||||
|
||||
let mut set = TunableSet::new(create_key::<R>, reduce_input_gen::<R>);
|
||||
|
||||
let default_group =
|
||||
TuneGroup::<ReduceAutotuneKey>::new("default_reduce", |_key| PRIORITY_MAX);
|
||||
let vectorized_parallel_group =
|
||||
TuneGroup::<ReduceAutotuneKey>::new("vectorized_parallel_reduce", |key| {
|
||||
if key.axis_is_contiguous {
|
||||
PRIORITY_MAX
|
||||
} else {
|
||||
// We disable the tunable with the setting [line_size.parallel_output_vectorization]
|
||||
// when the reduce isn't parallel, since it would duplicate tunables.
|
||||
PRIORITY_SKIP
|
||||
}
|
||||
});
|
||||
|
||||
enum ReduceProps {
|
||||
GreatWithLowReduceCount,
|
||||
GreatWithHighReduceCount,
|
||||
Balanced,
|
||||
}
|
||||
|
||||
for (line_size, line_size_ident) in [
|
||||
(
|
||||
LineSizeStrategy {
|
||||
parallel_output_vectorization: true,
|
||||
},
|
||||
"_vectorized_parallel_reduce",
|
||||
),
|
||||
(
|
||||
LineSizeStrategy {
|
||||
parallel_output_vectorization: false,
|
||||
},
|
||||
"",
|
||||
),
|
||||
] {
|
||||
for (name, routine, props) in [
|
||||
(
|
||||
"unit",
|
||||
RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
|
||||
ReduceProps::GreatWithHighReduceCount,
|
||||
),
|
||||
(
|
||||
"plane",
|
||||
RoutineStrategy::Plane(BlueprintStrategy::Inferred(PlaneStrategy {
|
||||
independent: true,
|
||||
})),
|
||||
ReduceProps::Balanced,
|
||||
),
|
||||
(
|
||||
"cube",
|
||||
RoutineStrategy::Cube(BlueprintStrategy::Inferred(CubeStrategy {
|
||||
use_planes: true,
|
||||
})),
|
||||
ReduceProps::GreatWithLowReduceCount,
|
||||
),
|
||||
] {
|
||||
let name = format!("{name}{line_size_ident}");
|
||||
let mut tunable = Tunable::new(
|
||||
name,
|
||||
move |(input, output, axis, config, dtypes): (
|
||||
CubeTensor<R>,
|
||||
CubeTensor<R>,
|
||||
usize,
|
||||
ReduceOperationConfig,
|
||||
ReduceDtypes,
|
||||
)| {
|
||||
let strategy = ReduceStrategy {
|
||||
routine: routine.clone(),
|
||||
line_size,
|
||||
};
|
||||
cubek::reduce::reduce::<R>(
|
||||
&input.client,
|
||||
input.as_handle_ref(),
|
||||
output.as_handle_ref(),
|
||||
axis,
|
||||
strategy,
|
||||
config,
|
||||
dtypes,
|
||||
)
|
||||
.map_err(|e| format!("{e}"))
|
||||
},
|
||||
);
|
||||
if line_size.parallel_output_vectorization {
|
||||
tunable = tunable.group(&vectorized_parallel_group, |_| PRIORITY_MAX);
|
||||
}
|
||||
|
||||
tunable = tunable.group(&default_group, move |key| match props {
|
||||
ReduceProps::GreatWithLowReduceCount => {
|
||||
if key.vector_count < 128 {
|
||||
PRIORITY_MAX
|
||||
} else {
|
||||
// When you have a high level of vector to reduce, it is normally
|
||||
// better to use another routine.
|
||||
PRIORITY_MIN
|
||||
}
|
||||
}
|
||||
ReduceProps::GreatWithHighReduceCount => {
|
||||
if key.vector_count > 64 {
|
||||
PRIORITY_MAX
|
||||
} else {
|
||||
// Bellow 64 it is normally better to use another routine
|
||||
PRIORITY_MIN
|
||||
}
|
||||
}
|
||||
ReduceProps::Balanced => PRIORITY_MAX,
|
||||
});
|
||||
set = set.with(tunable);
|
||||
}
|
||||
}
|
||||
|
||||
set
|
||||
});
|
||||
|
||||
TUNER.execute(
|
||||
&CubeTuneId::new(&input.client, &input.device),
|
||||
client,
|
||||
tunables,
|
||||
(input, output, axis, config, dtypes),
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) fn create_key<Run: CubeRuntime>(
|
||||
input: &CubeTensor<Run>,
|
||||
output: &CubeTensor<Run>,
|
||||
axis: &usize,
|
||||
_config: &ReduceOperationConfig,
|
||||
dtypes: &ReduceDtypes,
|
||||
) -> ReduceAutotuneKey {
|
||||
let elem_input = input.dtype.into();
|
||||
let elem_output = output.dtype.into();
|
||||
let elem_acc = dtypes.accumulation.elem_type();
|
||||
|
||||
ReduceAutotuneKey::generate(
|
||||
elem_input,
|
||||
elem_output,
|
||||
elem_acc,
|
||||
input.meta.shape(),
|
||||
input.meta.strides()[*axis] == 1,
|
||||
*axis,
|
||||
)
|
||||
}
|
||||
|
||||
mod reduce_ops {
|
||||
#![allow(missing_docs)]
|
||||
|
||||
use cubek::reduce::ReduceDtypes;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub(crate) fn reduce_input_gen<Run: CubeRuntime>(
|
||||
_key: &ReduceAutotuneKey,
|
||||
input: &CubeTensor<Run>,
|
||||
output: &CubeTensor<Run>,
|
||||
dim: &usize,
|
||||
config: &ReduceOperationConfig,
|
||||
dtypes: &ReduceDtypes,
|
||||
) -> (
|
||||
CubeTensor<Run>,
|
||||
CubeTensor<Run>,
|
||||
usize,
|
||||
ReduceOperationConfig,
|
||||
ReduceDtypes,
|
||||
) {
|
||||
(input.clone(), output.copy(), *dim, *config, *dtypes)
|
||||
}
|
||||
}
|
||||
|
||||
/// Executes autotune on reduce operations.
|
||||
#[cfg(feature = "autotune")]
|
||||
pub fn autotune_sum<R: CubeRuntime>(
|
||||
client: &ComputeClient<R>,
|
||||
input: CubeTensor<R>,
|
||||
) -> CubeTensor<R> {
|
||||
use sum_ops::*;
|
||||
|
||||
static TUNER: LocalTuner<CubeAutotuneKey, CubeTuneId> = local_tuner!("autotune-sum");
|
||||
|
||||
let tunables = TUNER.init(|| {
|
||||
TunableSet::new(create_key_sum::<R>, sum_input_gen::<R>)
|
||||
.with(Tunable::new("sum_chained", sum_chained::<R>))
|
||||
.with(Tunable::new("sum_one_shot", sum_one_shot::<R, 1>))
|
||||
.with(Tunable::new("sum_one_shot", sum_one_shot::<R, 2>))
|
||||
.with(Tunable::new("sum_one_shot", sum_one_shot::<R, 4>))
|
||||
.with(Tunable::new("sum_one_shot", sum_one_shot::<R, 8>))
|
||||
.with(Tunable::new("sum_one_shot", sum_one_shot::<R, 16>))
|
||||
.with(Tunable::new("sum_one_shot", sum_one_shot::<R, 32>))
|
||||
.with(Tunable::new("sum_one_shot", sum_one_shot::<R, 64>))
|
||||
});
|
||||
|
||||
TUNER.execute(
|
||||
&CubeTuneId::new(&input.client, &input.device),
|
||||
client,
|
||||
tunables,
|
||||
input,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn create_key_sum<Run: CubeRuntime>(input: &CubeTensor<Run>) -> CubeAutotuneKey {
|
||||
CubeAutotuneKey::Sum(SumAutotuneKey::generate(input))
|
||||
}
|
||||
|
||||
impl SumAutotuneKey {
|
||||
#[allow(unused)]
|
||||
pub(crate) fn generate<Run: CubeRuntime>(input: &CubeTensor<Run>) -> Self {
|
||||
let dtype = input.dtype;
|
||||
let length = input.meta.num_elements();
|
||||
Self::new(dtype, length)
|
||||
}
|
||||
}
|
||||
mod sum_ops {
|
||||
#![allow(missing_docs)]
|
||||
use crate::ops::numeric::zeros_client;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub(crate) fn sum_input_gen<Run: CubeRuntime>(
|
||||
_key: &CubeAutotuneKey,
|
||||
input: &CubeTensor<Run>,
|
||||
) -> CubeTensor<Run> {
|
||||
input.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn sum_one_shot<Run: CubeRuntime, const C: u32>(
|
||||
input: CubeTensor<Run>,
|
||||
) -> Result<CubeTensor<Run>, String> {
|
||||
let client = input.client.clone();
|
||||
let device = input.device.clone();
|
||||
let output = zeros_client(client.clone(), device, [1].into(), input.dtype);
|
||||
|
||||
cubek::reduce::shared_sum::<Run>(
|
||||
&input.client,
|
||||
input.as_handle_ref(),
|
||||
output.as_handle_ref(),
|
||||
C,
|
||||
input.dtype.into(),
|
||||
)
|
||||
.map_err(|e| e.to_string())
|
||||
.map(|_| output)
|
||||
}
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
pub(crate) fn sum_chained<Run: CubeRuntime>(
|
||||
input: CubeTensor<Run>,
|
||||
) -> Result<CubeTensor<Run>, String> {
|
||||
crate::kernel::reduce::reduce::<Run>(
|
||||
input,
|
||||
None,
|
||||
crate::kernel::reduce::KernelReduceStrategy::Autotune,
|
||||
cubek::reduce::components::instructions::ReduceOperationConfig::Sum,
|
||||
)
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,191 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{address_type, linear_view, linear_view_alias},
|
||||
ops::{max_line_size, numeric::empty_device_dtype},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::TensorMetadata;
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
|
||||
|
||||
pub(crate) trait FloatUnaryOpFamily: 'static + Send + Sync {
|
||||
type Options: LaunchArg;
|
||||
type Unary<F: Float>: FloatUnaryOp<F, Options = Self::Options>;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub(crate) trait FloatUnaryOp<F: Float>: 'static + Send + Sync {
|
||||
type Options: LaunchArg;
|
||||
|
||||
fn execute(input: Line<F>, options: &Self::Options) -> Line<F>;
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
pub(crate) fn unary_float<F: Float, O: FloatUnaryOpFamily>(
|
||||
input: &LinearView<Line<F>>,
|
||||
output: &mut LinearView<Line<F>, ReadWrite>,
|
||||
options: &O::Options,
|
||||
#[define(F)] _dtype: StorageType,
|
||||
) {
|
||||
if !output.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
output[ABSOLUTE_POS] = O::Unary::<F>::execute(input[ABSOLUTE_POS], options);
|
||||
}
|
||||
|
||||
pub(crate) fn launch_unary_float<R, O, Args>(tensor: CubeTensor<R>, args: Args) -> CubeTensor<R>
|
||||
where
|
||||
// Magic fix for lifetime, the closure is supposed to capture everything required to create the
|
||||
// argument.
|
||||
for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>,
|
||||
R: CubeRuntime,
|
||||
O: FloatUnaryOpFamily,
|
||||
{
|
||||
let line_size = max_line_size(&tensor);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let num_elems = tensor.meta.num_elements();
|
||||
|
||||
let working_units = num_elems / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&tensor.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
if tensor.can_mut() && tensor.is_nonoverlapping() {
|
||||
unary_float::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor),
|
||||
linear_view(&tensor, line_size),
|
||||
linear_view_alias(&tensor, line_size, 0),
|
||||
args(&()),
|
||||
tensor.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
tensor
|
||||
} else {
|
||||
let output = empty_device_dtype(
|
||||
tensor.client.clone(),
|
||||
tensor.device.clone(),
|
||||
tensor.shape(),
|
||||
tensor.dtype,
|
||||
);
|
||||
|
||||
unary_float::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor, output),
|
||||
linear_view(&tensor, line_size),
|
||||
linear_view(&output, line_size),
|
||||
args(&()),
|
||||
tensor.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Use comptime enum to implement all unary operations that don't have any input argument in the
|
||||
/// kernel definition.
|
||||
pub(crate) mod unary_basic {
|
||||
use super::*;
|
||||
|
||||
pub(crate) fn launch<R, Args>(tensor: CubeTensor<R>, args: Args) -> CubeTensor<R>
|
||||
where
|
||||
R: CubeRuntime,
|
||||
for<'a> Args: FnOnce(&'a ()) -> BasicFloatUnaryKind,
|
||||
{
|
||||
launch_unary_float::<R, BasicFloatUnary, _>(tensor, |input| {
|
||||
BasicFloatUnaryOptionsLaunch::new(args(input))
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
pub enum BasicFloatUnaryKind {
|
||||
Exp,
|
||||
Log,
|
||||
Log1p,
|
||||
Sqrt,
|
||||
Abs,
|
||||
Sign,
|
||||
ArcCos,
|
||||
ArcCosh,
|
||||
ArcSin,
|
||||
ArcSinh,
|
||||
ArcTan,
|
||||
ArcTanh,
|
||||
Cos,
|
||||
Cosh,
|
||||
Sin,
|
||||
Sinh,
|
||||
Tan,
|
||||
Tanh,
|
||||
Round,
|
||||
Floor,
|
||||
Ceil,
|
||||
Trunc,
|
||||
Erf,
|
||||
Recip,
|
||||
}
|
||||
|
||||
#[derive(CubeLaunch, CubeType)]
|
||||
struct BasicFloatUnaryOptions {
|
||||
#[cube(comptime)]
|
||||
kind: BasicFloatUnaryKind,
|
||||
}
|
||||
struct BasicFloatUnary;
|
||||
|
||||
#[cube]
|
||||
impl<F: Float> FloatUnaryOp<F> for BasicFloatUnary {
|
||||
type Options = BasicFloatUnaryOptions;
|
||||
|
||||
fn execute(input: Line<F>, options: &Self::Options) -> Line<F> {
|
||||
match comptime![options.kind] {
|
||||
BasicFloatUnaryKind::Exp => Line::exp(input),
|
||||
BasicFloatUnaryKind::Log => Line::ln(input),
|
||||
BasicFloatUnaryKind::Log1p => Line::log1p(input),
|
||||
BasicFloatUnaryKind::Sqrt => Line::sqrt(input),
|
||||
BasicFloatUnaryKind::Abs => Line::abs(input),
|
||||
BasicFloatUnaryKind::Sign => {
|
||||
let zero = Line::new(F::new(0.0));
|
||||
let one = Line::new(F::new(1.0));
|
||||
let minus_one = Line::new(F::new(-1.0));
|
||||
|
||||
let is_positive = input.greater_than(zero);
|
||||
let is_negative = input.less_than(zero);
|
||||
let sign = select_many(is_negative, minus_one, zero);
|
||||
|
||||
select_many(is_positive, one, sign)
|
||||
}
|
||||
BasicFloatUnaryKind::Cos => Line::cos(input),
|
||||
BasicFloatUnaryKind::Sin => Line::sin(input),
|
||||
BasicFloatUnaryKind::Tan => Line::tan(input),
|
||||
BasicFloatUnaryKind::Cosh => Line::cosh(input),
|
||||
BasicFloatUnaryKind::Sinh => Line::sinh(input),
|
||||
BasicFloatUnaryKind::Tanh => Line::tanh(input),
|
||||
BasicFloatUnaryKind::Round => Line::round(input),
|
||||
BasicFloatUnaryKind::Floor => Line::floor(input),
|
||||
BasicFloatUnaryKind::Ceil => Line::ceil(input),
|
||||
BasicFloatUnaryKind::Trunc => Line::trunc(input),
|
||||
BasicFloatUnaryKind::Erf => Line::erf(input),
|
||||
BasicFloatUnaryKind::Recip => Line::recip(input),
|
||||
BasicFloatUnaryKind::ArcCos => Line::acos(input),
|
||||
BasicFloatUnaryKind::ArcCosh => Line::acosh(input),
|
||||
BasicFloatUnaryKind::ArcSin => Line::asin(input),
|
||||
BasicFloatUnaryKind::ArcSinh => Line::asinh(input),
|
||||
BasicFloatUnaryKind::ArcTan => Line::atan(input),
|
||||
BasicFloatUnaryKind::ArcTanh => Line::atanh(input),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FloatUnaryOpFamily for BasicFloatUnary {
|
||||
type Options = BasicFloatUnaryOptions;
|
||||
type Unary<F: Float> = Self;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{address_type, linear_view, linear_view_alias},
|
||||
ops::{max_line_size, numeric::empty_device_dtype},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::TensorMetadata;
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
|
||||
|
||||
pub(crate) trait IntUnaryOpFamily: 'static + Send + Sync {
|
||||
type Options: LaunchArg;
|
||||
type Unary<I: Int>: IntUnaryOp<I, Options = Self::Options>;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub(crate) trait IntUnaryOp<I: CubePrimitive>: 'static + Send + Sync {
|
||||
type Options: LaunchArg;
|
||||
|
||||
fn execute(input: Line<I>, options: &Self::Options) -> Line<I>;
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
pub(crate) fn unary_int<I: Int, O: IntUnaryOpFamily>(
|
||||
input: &LinearView<Line<I>>,
|
||||
output: &mut LinearView<Line<I>, ReadWrite>,
|
||||
options: &O::Options,
|
||||
#[define(I)] _dtype: StorageType,
|
||||
) {
|
||||
if !output.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
output[ABSOLUTE_POS] = O::Unary::<I>::execute(input[ABSOLUTE_POS], options);
|
||||
}
|
||||
|
||||
pub(crate) fn launch_unary_int<R, O, Args>(tensor: CubeTensor<R>, args: Args) -> CubeTensor<R>
|
||||
where
|
||||
for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>,
|
||||
R: CubeRuntime,
|
||||
O: IntUnaryOpFamily,
|
||||
{
|
||||
let line_size = max_line_size(&tensor);
|
||||
let client = tensor.client.clone();
|
||||
let num_elems = tensor.meta.num_elements();
|
||||
|
||||
let working_units = num_elems / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&tensor.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
if tensor.can_mut() && tensor.is_nonoverlapping() {
|
||||
unary_int::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor),
|
||||
linear_view(&tensor, line_size),
|
||||
linear_view_alias(&tensor, line_size, 0),
|
||||
args(&()),
|
||||
tensor.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
tensor
|
||||
} else {
|
||||
let output = empty_device_dtype(
|
||||
tensor.client.clone(),
|
||||
tensor.device.clone(),
|
||||
tensor.shape(),
|
||||
tensor.dtype,
|
||||
);
|
||||
|
||||
unary_int::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor, output),
|
||||
linear_view(&tensor, line_size),
|
||||
linear_view(&output, line_size),
|
||||
args(&()),
|
||||
tensor.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
output
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod unary_basic_int {
|
||||
|
||||
use super::*;
|
||||
|
||||
pub(crate) fn launch<R, Args>(tensor: CubeTensor<R>, args: Args) -> CubeTensor<R>
|
||||
where
|
||||
R: CubeRuntime,
|
||||
for<'a> Args: FnOnce(&'a ()) -> BasicIntUnaryKind,
|
||||
{
|
||||
launch_unary_int::<R, BasicIntUnary, _>(tensor, |input| {
|
||||
BasicIntUnaryOptionsLaunch::new(args(input))
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
pub enum BasicIntUnaryKind {
|
||||
BitwiseNot,
|
||||
Sign,
|
||||
}
|
||||
|
||||
#[derive(CubeLaunch, CubeType)]
|
||||
struct BasicIntUnaryOptions {
|
||||
#[cube(comptime)]
|
||||
kind: BasicIntUnaryKind,
|
||||
}
|
||||
struct BasicIntUnary;
|
||||
|
||||
#[cube]
|
||||
impl<I: Int> IntUnaryOp<I> for BasicIntUnary {
|
||||
type Options = BasicIntUnaryOptions;
|
||||
|
||||
fn execute(input: Line<I>, options: &Self::Options) -> Line<I> {
|
||||
match comptime![options.kind] {
|
||||
BasicIntUnaryKind::BitwiseNot => !input,
|
||||
BasicIntUnaryKind::Sign => {
|
||||
let zero = Line::new(I::new(0));
|
||||
let one = Line::new(I::new(1));
|
||||
let minus_one = Line::new(I::new(-1));
|
||||
|
||||
let is_positive = input.greater_than(zero);
|
||||
let is_negative = input.less_than(zero);
|
||||
let sign = select_many(is_negative, minus_one, zero);
|
||||
|
||||
select_many(is_positive, one, sign)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntUnaryOpFamily for BasicIntUnary {
|
||||
type Options = BasicIntUnaryOptions;
|
||||
type Unary<I: Int> = Self;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
use crate::{
|
||||
CubeRuntime,
|
||||
kernel::utils::{address_type, linear_view, linear_view_alias},
|
||||
ops::{max_line_size, numeric::empty_device_dtype},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_backend::TensorMetadata;
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};
|
||||
|
||||
pub(crate) trait NumericUnaryOpFamily: 'static + Send + Sync {
|
||||
type Options: LaunchArg;
|
||||
type Unary<N: Numeric>: NumericUnaryOp<N, Options = Self::Options>;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub(crate) trait NumericUnaryOp<N: CubePrimitive>: 'static + Send + Sync {
|
||||
type Options: LaunchArg;
|
||||
|
||||
fn execute(input: Line<N>, options: &Self::Options) -> Line<N>;
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
pub(crate) fn unary_numeric<N: Numeric, O: NumericUnaryOpFamily>(
|
||||
input: &LinearView<Line<N>>,
|
||||
output: &mut LinearView<Line<N>, ReadWrite>,
|
||||
options: &O::Options,
|
||||
#[define(N)] _dtype: StorageType,
|
||||
) {
|
||||
if !output.is_in_bounds(ABSOLUTE_POS) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
output[ABSOLUTE_POS] = O::Unary::<N>::execute(input[ABSOLUTE_POS], options);
|
||||
}
|
||||
|
||||
pub(crate) fn launch_unary_numeric<R, O, Args>(tensor: CubeTensor<R>, args: Args) -> CubeTensor<R>
|
||||
where
|
||||
// Magic fix for lifetime, the closure is supposed to capture everything required to create the
|
||||
// argument.
|
||||
for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>,
|
||||
R: CubeRuntime,
|
||||
O: NumericUnaryOpFamily,
|
||||
{
|
||||
let line_size = max_line_size(&tensor);
|
||||
let client = tensor.client.clone();
|
||||
let num_elems = tensor.meta.num_elements();
|
||||
|
||||
let working_units = num_elems / line_size as usize;
|
||||
let cube_dim = CubeDim::new(&tensor.client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
|
||||
|
||||
unsafe {
|
||||
if tensor.can_mut() && tensor.is_nonoverlapping() {
|
||||
unary_numeric::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor),
|
||||
linear_view(&tensor, line_size),
|
||||
linear_view_alias(&tensor, line_size, 0),
|
||||
args(&()),
|
||||
tensor.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
tensor
|
||||
} else {
|
||||
let output = empty_device_dtype(
|
||||
tensor.client.clone(),
|
||||
tensor.device.clone(),
|
||||
tensor.shape(),
|
||||
tensor.dtype,
|
||||
);
|
||||
|
||||
unary_numeric::launch_unchecked::<O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type!(tensor, output),
|
||||
linear_view(&tensor, line_size),
|
||||
linear_view(&output, line_size),
|
||||
args(&()),
|
||||
tensor.dtype.into(),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
use burn_backend::Shape;
|
||||
use cubecl::{
|
||||
ir::LineSize,
|
||||
prelude::*,
|
||||
std::{
|
||||
FastDivmod, FastDivmodArgs, FastDivmodInt,
|
||||
tensor::layout::linear::{LinearLayoutArgs, LinearViewLaunch},
|
||||
},
|
||||
};
|
||||
use cubecl::{prelude::SequenceArg, std::tensor::layout::linear::LinearLayout};
|
||||
|
||||
use crate::{CubeRuntime, tensor::CubeTensor};
|
||||
|
||||
pub fn shape_divmod<'a, R: CubeRuntime>(
|
||||
tensor: &CubeTensor<R>,
|
||||
) -> SequenceArg<'a, R, FastDivmod<usize>> {
|
||||
let mut arg = SequenceArg::new();
|
||||
for dim in tensor.meta.shape().iter() {
|
||||
arg.push(FastDivmodArgs::<usize>::new(&tensor.client, *dim));
|
||||
}
|
||||
arg
|
||||
}
|
||||
|
||||
pub fn linear_layout<'a, R: CubeRuntime>(
|
||||
tensor: &'a CubeTensor<R>,
|
||||
line_size: LineSize,
|
||||
) -> LinearLayoutArgs<'a, R> {
|
||||
LinearLayoutArgs::from_shape_strides(
|
||||
&tensor.client,
|
||||
tensor.meta.shape(),
|
||||
tensor.meta.strides(),
|
||||
line_size,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn linear_layout_ref<'a, R: CubeRuntime>(
|
||||
tensor: &'a CubeTensor<R>,
|
||||
reference: &'a CubeTensor<R>,
|
||||
line_size: LineSize,
|
||||
) -> LinearLayoutArgs<'a, R> {
|
||||
LinearLayoutArgs::from_shape_strides_with_reference(
|
||||
&tensor.client,
|
||||
tensor.meta.shape(),
|
||||
reference.meta.shape(),
|
||||
tensor.meta.strides(),
|
||||
line_size,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn linear_view<'a, R: CubeRuntime>(
|
||||
tensor: &'a CubeTensor<R>,
|
||||
line_size: LineSize,
|
||||
) -> LinearViewLaunch<'a, R> {
|
||||
let len = tensor.meta.num_elements();
|
||||
let layout = linear_layout(tensor, line_size);
|
||||
let buffer = unsafe {
|
||||
ArrayArg::from_raw_parts_and_size(&tensor.handle, len, line_size, tensor.elem_size())
|
||||
};
|
||||
LinearViewLaunch::new::<LinearLayout>(buffer, layout)
|
||||
}
|
||||
|
||||
pub fn linear_view_ref<'a, R: CubeRuntime>(
|
||||
tensor: &'a CubeTensor<R>,
|
||||
reference: &'a CubeTensor<R>,
|
||||
line_size: LineSize,
|
||||
) -> LinearViewLaunch<'a, R> {
|
||||
let len = tensor.meta.num_elements();
|
||||
let layout = linear_layout_ref(tensor, reference, line_size);
|
||||
let buffer = unsafe {
|
||||
ArrayArg::from_raw_parts_and_size(&tensor.handle, len, line_size, tensor.elem_size())
|
||||
};
|
||||
LinearViewLaunch::new::<LinearLayout>(buffer, layout)
|
||||
}
|
||||
|
||||
pub fn linear_view_alias<'a, R: CubeRuntime>(
|
||||
tensor: &'a CubeTensor<R>,
|
||||
line_size: LineSize,
|
||||
pos: usize,
|
||||
) -> LinearViewLaunch<'a, R> {
|
||||
let layout = linear_layout(tensor, line_size);
|
||||
let buffer = ArrayArg::Alias { input_pos: pos };
|
||||
LinearViewLaunch::new::<LinearLayout>(buffer, layout)
|
||||
}
|
||||
|
||||
pub fn split_dim<R: CubeRuntime>(
|
||||
mut tensor: CubeTensor<R>,
|
||||
dim: usize,
|
||||
shape: &[usize],
|
||||
) -> CubeTensor<R> {
|
||||
let mut stride = tensor.meta.strides()[dim];
|
||||
tensor.meta.remove(dim);
|
||||
|
||||
for size in shape.iter().rev() {
|
||||
tensor.meta.insert(dim, *size, stride);
|
||||
stride *= size;
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
|
||||
pub fn broadcast_shape<R: CubeRuntime>(tensors: &[&CubeTensor<R>]) -> Shape {
|
||||
let rank = tensors[0].meta.num_dims();
|
||||
debug_assert!(
|
||||
tensors.iter().all(|it| it.meta.num_dims() == rank),
|
||||
"Broadcast tensors must have the same rank"
|
||||
);
|
||||
|
||||
let dims = (0..rank).map(|dim| {
|
||||
let max = tensors.iter().map(|it| it.meta.shape()[dim]).max();
|
||||
let max = max.unwrap_or(1);
|
||||
debug_assert!(
|
||||
tensors
|
||||
.iter()
|
||||
.all(|it| it.meta.shape()[dim] == max || it.meta.shape()[dim] == 1),
|
||||
"Broadcast dims must be size 1"
|
||||
);
|
||||
max
|
||||
});
|
||||
|
||||
Shape::from(dims)
|
||||
}
|
||||
|
||||
pub fn broadcast_strides<'a, R: CubeRuntime>(
|
||||
reference: &CubeTensor<R>,
|
||||
tensor: &'a CubeTensor<R>,
|
||||
) -> SequenceArg<'a, R, usize> {
|
||||
if reference.meta.shape() != tensor.meta.shape() {
|
||||
tensor
|
||||
.meta
|
||||
.strides()
|
||||
.iter()
|
||||
.zip(
|
||||
tensor
|
||||
.meta
|
||||
.shape()
|
||||
.iter()
|
||||
.zip(reference.meta.shape().iter()),
|
||||
)
|
||||
.map(|(stride, (shape, ref_shape))| if *shape == *ref_shape { *stride } else { 0 })
|
||||
.map(ScalarArg::new)
|
||||
.collect()
|
||||
} else {
|
||||
tensor
|
||||
.meta
|
||||
.strides()
|
||||
.iter()
|
||||
.copied()
|
||||
.map(ScalarArg::new)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub(crate) fn decompose_linear<I: FastDivmodInt>(
|
||||
pos: I,
|
||||
shape: &Sequence<FastDivmod<I>>,
|
||||
) -> (I, Sequence<I>) {
|
||||
let rank = comptime![shape.len()];
|
||||
let mut offs = pos;
|
||||
let mut out = Sequence::new();
|
||||
|
||||
#[unroll]
|
||||
for i in 0..rank {
|
||||
let dim = comptime![rank - i - 1];
|
||||
let (rem, offs_local) = shape.index(dim).div_mod(offs);
|
||||
out.push(offs_local);
|
||||
offs = rem;
|
||||
}
|
||||
|
||||
(offs, out.rev())
|
||||
}
|
||||
|
||||
pub(crate) trait RequiredAddrType {
|
||||
fn required_address_type(&self) -> AddressType;
|
||||
}
|
||||
|
||||
impl<R: CubeRuntime> RequiredAddrType for CubeTensor<R> {
|
||||
fn required_address_type(&self) -> AddressType {
|
||||
self.required_address_type()
|
||||
}
|
||||
}
|
||||
impl<R: CubeRuntime> RequiredAddrType for Option<CubeTensor<R>> {
|
||||
fn required_address_type(&self) -> AddressType {
|
||||
self.as_ref()
|
||||
.map(|it| it.required_address_type())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! address_type {
|
||||
($($tensor: tt),*) => {
|
||||
[$($crate::kernel::utils::RequiredAddrType::required_address_type(&$tensor)),*]
|
||||
.into_iter()
|
||||
.max()
|
||||
.unwrap_or_default()
|
||||
};
|
||||
}
|
||||
pub(crate) use address_type;
|
||||
@@ -0,0 +1,50 @@
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
//! Burn JIT Backend
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
extern crate alloc;
|
||||
|
||||
/// Utilities for implementing JIT kernels
|
||||
pub mod ops;
|
||||
|
||||
/// Kernel module
|
||||
pub mod kernel;
|
||||
/// Tensor module.
|
||||
pub mod tensor;
|
||||
|
||||
/// Elements for JIT backend
|
||||
pub mod element;
|
||||
|
||||
use cubecl::{CubeTask, Runtime};
|
||||
pub use element::{BoolElement, CubeElement, FloatElement, IntElement};
|
||||
|
||||
mod backend;
|
||||
|
||||
pub use backend::*;
|
||||
|
||||
// Re-export cubecl.
|
||||
pub use cubecl;
|
||||
|
||||
mod tune_key;
|
||||
pub use tune_key::CubeAutotuneKey;
|
||||
|
||||
#[cfg(any(feature = "fusion", test))]
|
||||
/// Module for interacting with fusion
|
||||
pub mod fusion;
|
||||
|
||||
#[cfg(feature = "template")]
|
||||
/// Module for compiling custom non-jit kernels
|
||||
pub mod template;
|
||||
|
||||
/// Just-in-Time runtime extending the [cube runtime](Runtime).
|
||||
pub trait CubeRuntime: Runtime<Device = Self::CubeDevice, Server = Self::CubeServer> {
|
||||
/// The device that should also implement [burn_backend::backend::DeviceOps].
|
||||
type CubeDevice: burn_backend::DeviceOps;
|
||||
/// The cube server with the [CubeAutotuneKey].
|
||||
type CubeServer: cubecl::server::ComputeServer<Kernel = Box<dyn CubeTask<Self::Compiler>>>;
|
||||
}
|
||||
|
||||
pub use cubecl::CubeTuneId;
|
||||
@@ -0,0 +1,11 @@
|
||||
use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement};
|
||||
use burn_backend::ops::ActivationOps;
|
||||
|
||||
impl<R, F, I, BT> ActivationOps<Self> for CubeBackend<R, F, I, BT>
|
||||
where
|
||||
R: CubeRuntime,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
BT: BoolElement,
|
||||
{
|
||||
}
|
||||
@@ -0,0 +1,432 @@
|
||||
use crate::{CubeRuntime, kernel, ops::numeric::empty_device_dtype, tensor::CubeTensor};
|
||||
use burn_backend::{
|
||||
DType, ExecutionError, QTensorPrimitive, Shape, TensorData,
|
||||
quantization::{QuantLevel, QuantStore, params_shape},
|
||||
};
|
||||
use burn_backend::{TensorMetadata, ops::unfold::calculate_unfold_shape};
|
||||
use burn_std::{
|
||||
Metadata, strides,
|
||||
tensor::{ReshapeAction, contiguous_strides, reshape_action},
|
||||
};
|
||||
use cubecl::{ir::LineSize, server::CopyDescriptor};
|
||||
use cubecl::{quant::scheme::BlockSize, tensor_line_size_parallel};
|
||||
|
||||
pub(crate) fn from_data<R: CubeRuntime>(data: TensorData, device: &R::Device) -> CubeTensor<R> {
|
||||
let client = R::client(device);
|
||||
let alloc = client.create_tensor(data.bytes, &data.shape, data.dtype.size());
|
||||
let shape: Shape = (&data.shape).into();
|
||||
CubeTensor::new(
|
||||
client,
|
||||
alloc.handle,
|
||||
Metadata::new(shape, alloc.strides),
|
||||
device.clone(),
|
||||
data.dtype,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) async fn into_data<R: CubeRuntime>(
|
||||
tensor: CubeTensor<R>,
|
||||
) -> Result<TensorData, ExecutionError> {
|
||||
let tensor = kernel::into_contiguous_aligned(tensor);
|
||||
|
||||
let elem_size = tensor.elem_size();
|
||||
let shape = tensor.meta.shape();
|
||||
let strides = tensor.meta.strides();
|
||||
let binding = CopyDescriptor::new(tensor.handle.binding(), shape, strides, elem_size);
|
||||
let bytes = tensor
|
||||
.client
|
||||
.read_one_tensor_async(binding)
|
||||
.await
|
||||
.map_err(|err| ExecutionError::WithContext {
|
||||
reason: format!("{err}"),
|
||||
})?;
|
||||
|
||||
Ok(TensorData::from_bytes(
|
||||
bytes,
|
||||
tensor.meta.shape,
|
||||
tensor.dtype,
|
||||
))
|
||||
}
|
||||
|
||||
/// Read data from a `CubeTensor` synchronously
|
||||
#[allow(unused, reason = "useful for debugging kernels")]
|
||||
pub fn into_data_sync<R: CubeRuntime>(tensor: CubeTensor<R>) -> TensorData {
|
||||
burn_std::future::block_on(into_data(tensor)).unwrap()
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip(tensor, device))
|
||||
)]
|
||||
pub(crate) fn to_device<R: CubeRuntime>(
|
||||
tensor: CubeTensor<R>,
|
||||
device: &R::Device,
|
||||
) -> CubeTensor<R> {
|
||||
if &tensor.device == device {
|
||||
return tensor;
|
||||
}
|
||||
|
||||
let tensor = kernel::into_contiguous_aligned(tensor);
|
||||
let client = R::client(device);
|
||||
tensor.to_client(client, device.clone())
|
||||
}
|
||||
|
||||
pub(crate) fn empty<R: CubeRuntime>(
|
||||
shape: Shape,
|
||||
device: &R::Device,
|
||||
dtype: DType,
|
||||
) -> CubeTensor<R> {
|
||||
let client = R::client(device);
|
||||
let alloc = client.empty_tensor(&shape, dtype.size());
|
||||
|
||||
CubeTensor::new(
|
||||
client,
|
||||
alloc.handle,
|
||||
Metadata::new(shape, alloc.strides),
|
||||
device.clone(),
|
||||
dtype,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn swap_dims<R: CubeRuntime>(
|
||||
mut tensor: CubeTensor<R>,
|
||||
dim1: usize,
|
||||
dim2: usize,
|
||||
) -> CubeTensor<R> {
|
||||
tensor.meta.swap(dim1, dim2);
|
||||
|
||||
if let DType::QFloat(scheme) = tensor.dtype
|
||||
&& let QuantLevel::Block(block_size) = scheme.level
|
||||
{
|
||||
let rank = tensor.rank();
|
||||
let qparams = tensor.qparams.as_mut().unwrap();
|
||||
let mut block_size = block_size.to_dim_vec(rank);
|
||||
block_size.swap(dim1, dim2);
|
||||
|
||||
// Truncate unit dims from the start
|
||||
let block_size = BlockSize::new_trim(block_size);
|
||||
if block_size.len() > BlockSize::MAX_DIMS {
|
||||
panic!("Swapped block size would exceed max dims");
|
||||
}
|
||||
|
||||
qparams.scales.metadata.swap(dim1, dim2);
|
||||
|
||||
tensor.dtype = DType::QFloat(scheme.with_level(QuantLevel::Block(block_size)))
|
||||
}
|
||||
|
||||
if let DType::QFloat(scheme) = &mut tensor.dtype
|
||||
&& let QuantStore::PackedU32(packed_dim) | QuantStore::PackedNative(packed_dim) =
|
||||
&mut scheme.store
|
||||
{
|
||||
let rank = tensor.meta.num_dims();
|
||||
|
||||
if *packed_dim == rank - dim1 - 1 {
|
||||
*packed_dim = rank - dim2 - 1;
|
||||
} else if *packed_dim == rank - dim2 - 1 {
|
||||
*packed_dim = rank - dim1 - 1;
|
||||
}
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
|
||||
/// Permute a tensor's dimensions
|
||||
pub fn permute<R: CubeRuntime>(mut tensor: CubeTensor<R>, axes: &[usize]) -> CubeTensor<R> {
|
||||
tensor.meta.permute(axes).unwrap();
|
||||
|
||||
if let DType::QFloat(scheme) = tensor.dtype
|
||||
&& let QuantLevel::Block(block_size) = scheme.level
|
||||
{
|
||||
let rank = tensor.rank();
|
||||
let qparams = tensor.qparams.as_mut().unwrap();
|
||||
|
||||
let mut block_size = block_size.to_dim_vec(rank);
|
||||
block_size = axes.iter().map(|i| block_size[*i]).collect();
|
||||
|
||||
// Truncate unit dims from the start
|
||||
let block_size = block_size
|
||||
.into_iter()
|
||||
.skip_while(|it| *it == 1)
|
||||
.collect::<Vec<_>>();
|
||||
if block_size.len() > BlockSize::MAX_DIMS {
|
||||
panic!("Swapped block size would exceed max dims");
|
||||
}
|
||||
|
||||
qparams.scales.metadata.permute(axes).unwrap();
|
||||
|
||||
tensor.dtype = DType::QFloat(scheme.with_level(QuantLevel::block(&block_size)))
|
||||
}
|
||||
|
||||
if let DType::QFloat(scheme) = &mut tensor.dtype
|
||||
&& let QuantStore::PackedU32(packed_dim) = &mut scheme.store
|
||||
{
|
||||
let rank = tensor.meta.num_dims();
|
||||
let new_pos = axes
|
||||
.iter()
|
||||
.position(|axis| *axis == rank - *packed_dim - 1)
|
||||
.unwrap_or(0);
|
||||
*packed_dim = rank - new_pos - 1;
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
|
||||
/// Permute a tensor's dimensions from NCHW to NHWC, or the N-dimensional equivalent
|
||||
pub fn permute_nchw_to_nhwc<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
|
||||
let rank = tensor.meta.num_dims();
|
||||
let c_dim = 1;
|
||||
|
||||
let mut dims = vec![0];
|
||||
dims.extend(2..rank);
|
||||
dims.push(c_dim);
|
||||
|
||||
permute(tensor, &dims)
|
||||
}
|
||||
|
||||
/// Permute a shape's dimensions from NCHW to NHWC, or the N-dimensional equivalent
|
||||
pub fn permute_nchw_to_nhwc_shape(shape: Shape) -> Shape {
|
||||
let rank = shape.num_dims();
|
||||
let c_dim = 1;
|
||||
|
||||
let mut dims = vec![0];
|
||||
dims.extend(2..rank);
|
||||
dims.push(c_dim);
|
||||
|
||||
shape.permuted(&dims).expect("Shape permute should succeed")
|
||||
}
|
||||
|
||||
/// Permute a tensor's dimensions from NHWC to NCHW, or the N-dimensional equivalent
|
||||
pub fn permute_nhwc_to_nchw<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
|
||||
let rank = tensor.meta.num_dims();
|
||||
let c_dim = rank - 1;
|
||||
|
||||
let mut dims = vec![0];
|
||||
dims.push(c_dim);
|
||||
dims.extend(1..c_dim);
|
||||
|
||||
permute(tensor, &dims)
|
||||
}
|
||||
|
||||
/// Permute a shape's dimensions from NHWC to NCHW, or the N-dimensional equivalent
|
||||
pub fn permute_nhwc_to_nchw_shape(shape: Shape) -> Shape {
|
||||
let rank = shape.num_dims();
|
||||
let c_dim = rank - 1;
|
||||
|
||||
let mut dims = vec![0];
|
||||
dims.push(c_dim);
|
||||
dims.extend(1..c_dim);
|
||||
|
||||
shape.permuted(&dims).expect("Shape permute should succeed")
|
||||
}
|
||||
|
||||
pub(crate) fn expand<R: CubeRuntime>(tensor: CubeTensor<R>, target_shape: Shape) -> CubeTensor<R> {
|
||||
let ndims_in = tensor.meta.shape().num_dims();
|
||||
let ndims_out = target_shape.num_dims();
|
||||
|
||||
// Initialize new strides with zeros
|
||||
let mut new_strides = strides![0usize; ndims_out];
|
||||
|
||||
// Calculate the difference in dimensions
|
||||
let dim_diff = ndims_out.saturating_sub(ndims_in);
|
||||
|
||||
// Compare dimensions from the end, setting strides for matching dimensions or broadcasted ones
|
||||
let mut tensor_dim_iter = tensor.meta.shape().iter().rev();
|
||||
for i in (0..ndims_out).rev() {
|
||||
if i >= dim_diff {
|
||||
if let Some(&tensor_dim) = tensor_dim_iter.next() {
|
||||
if tensor_dim == target_shape[i] || tensor_dim == 1 {
|
||||
// Copy stride for non-broadcast dimensions or set to 0 for broadcast ones
|
||||
new_strides[i] = if tensor_dim == target_shape[i] {
|
||||
tensor.meta.strides()[i - dim_diff]
|
||||
} else {
|
||||
0
|
||||
};
|
||||
} else {
|
||||
// Error handling: Dimension mismatch for broadcasting
|
||||
panic!(
|
||||
"Dimension mismatch: cannot broadcast dimension {tensor_dim} of tensor to target shape"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// If the input tensor has fewer dimensions, treat missing dimensions as 1
|
||||
// and set stride to 0 (broadcasting)
|
||||
new_strides[i] = 0;
|
||||
}
|
||||
} else {
|
||||
// For extra dimensions in the target shape, set stride to 0 (broadcasting)
|
||||
new_strides[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Extra check to ensure block scales must be properly handled once they're added
|
||||
if tensor.qparams.is_some() {
|
||||
match tensor.scheme().level {
|
||||
QuantLevel::Tensor => {}
|
||||
QuantLevel::Block(_) => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
CubeTensor {
|
||||
client: tensor.client,
|
||||
device: tensor.device,
|
||||
meta: Box::new(Metadata::new(target_shape, new_strides)),
|
||||
handle: tensor.handle,
|
||||
dtype: tensor.dtype,
|
||||
qparams: tensor.qparams,
|
||||
}
|
||||
}
|
||||
|
||||
/// Reshape a jit tensor to a new shape
|
||||
pub fn reshape<R: CubeRuntime>(mut tensor: CubeTensor<R>, shape: Shape) -> CubeTensor<R> {
|
||||
let analysis = reshape_action(tensor.meta.shape(), tensor.meta.strides(), &shape);
|
||||
|
||||
match analysis {
|
||||
ReshapeAction::UpdateStrides { strides } => {
|
||||
*tensor.meta = Metadata::new(shape, strides);
|
||||
return tensor;
|
||||
}
|
||||
ReshapeAction::NoChange => return tensor,
|
||||
ReshapeAction::Recompute => (),
|
||||
}
|
||||
|
||||
let out = empty_device_dtype(
|
||||
tensor.client.clone(),
|
||||
tensor.device.clone(),
|
||||
shape,
|
||||
tensor.dtype,
|
||||
);
|
||||
|
||||
cubecl::std::tensor::copy_into(
|
||||
&tensor.client,
|
||||
&tensor.as_handle_ref(),
|
||||
&out.as_handle_ref(),
|
||||
tensor.dtype.into(),
|
||||
)
|
||||
.expect("Kernel should not fail");
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
/// Reshape a jit tensor to a new shape
|
||||
pub fn q_reshape<R: CubeRuntime>(mut tensor: CubeTensor<R>, shape: Shape) -> CubeTensor<R> {
|
||||
let scheme = *tensor.scheme();
|
||||
|
||||
let shape_values = {
|
||||
let rank = shape.num_dims();
|
||||
let mut shape = shape.clone();
|
||||
shape[rank - 1] = shape[rank - 1].div_ceil(scheme.num_quants());
|
||||
shape
|
||||
};
|
||||
let shape_scales = params_shape(&shape, scheme.level);
|
||||
let (values, scales) = tensor.quantized_handles().unwrap();
|
||||
|
||||
let analysis_values = reshape_action(values.meta.shape(), values.meta.strides(), &shape_values);
|
||||
let analysis_scales = reshape_action(scales.meta.shape(), scales.meta.strides(), &shape_scales);
|
||||
|
||||
match (analysis_values, analysis_scales) {
|
||||
(
|
||||
ReshapeAction::UpdateStrides { strides },
|
||||
ReshapeAction::UpdateStrides {
|
||||
strides: scales_strides,
|
||||
},
|
||||
) => {
|
||||
let qparams = tensor.qparams.as_mut().unwrap();
|
||||
|
||||
*tensor.meta = Metadata::new(shape, strides);
|
||||
qparams.scales.metadata = Metadata::new(shape_scales, scales_strides);
|
||||
}
|
||||
(ReshapeAction::UpdateStrides { strides }, ReshapeAction::NoChange) => {
|
||||
*tensor.meta = Metadata::new(shape, strides);
|
||||
}
|
||||
(
|
||||
ReshapeAction::NoChange,
|
||||
ReshapeAction::UpdateStrides {
|
||||
strides: scales_strides,
|
||||
},
|
||||
) => {
|
||||
let qparams = tensor.qparams.as_mut().unwrap();
|
||||
|
||||
qparams.scales.metadata = Metadata::new(shape_scales, scales_strides);
|
||||
}
|
||||
(ReshapeAction::NoChange, ReshapeAction::NoChange) => {}
|
||||
_ => {
|
||||
tensor = kernel::into_contiguous(tensor);
|
||||
*tensor.meta = Metadata::new(shape, contiguous_strides(&shape_values));
|
||||
|
||||
let qparams = tensor.qparams.as_mut().unwrap();
|
||||
|
||||
let strides = contiguous_strides(&shape_scales);
|
||||
qparams.scales.metadata = Metadata::new(shape_scales, strides);
|
||||
}
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
|
||||
pub(crate) fn max_line_size<R: CubeRuntime>(tensor: &CubeTensor<R>) -> LineSize {
|
||||
tensor_line_size_parallel(
|
||||
tensor.client.io_optimized_line_sizes(tensor.dtype.size()),
|
||||
tensor.meta.shape(),
|
||||
tensor.meta.strides(),
|
||||
tensor.meta.num_dims() - 1,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn max_line_size_many<R: CubeRuntime>(
|
||||
tensors: &[&CubeTensor<R>],
|
||||
axis: usize,
|
||||
) -> LineSize {
|
||||
let vec = tensors
|
||||
.iter()
|
||||
.map(|tensor| {
|
||||
tensor_line_size_parallel(
|
||||
tensor.client.io_optimized_line_sizes(tensor.dtype.size()),
|
||||
tensor.meta.shape(),
|
||||
tensor.meta.strides(),
|
||||
axis,
|
||||
)
|
||||
})
|
||||
.min();
|
||||
|
||||
vec.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Unfold windows along a dimension.
|
||||
///
|
||||
/// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
|
||||
/// where windows are advanced by `step` at each index.
|
||||
///
|
||||
/// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
|
||||
///
|
||||
/// The new view will have the unfolded dimension replaced by two dimensions;
|
||||
/// one in the position of the original dimension, with size equal to the number of windows,
|
||||
/// and one appended to the right-most position, with size equal to `size`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
|
||||
/// * `dim` - the dimension to unfold.
|
||||
/// * `size` - the size of each unfolded window.
|
||||
/// * `step` - the step between each window.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor view with the shape ``[pre=..., windows, post=..., size]``.
|
||||
pub fn unfold<R: CubeRuntime>(
|
||||
tensor: CubeTensor<R>,
|
||||
dim: usize,
|
||||
size: usize,
|
||||
step: usize,
|
||||
) -> CubeTensor<R> {
|
||||
let shape = calculate_unfold_shape(tensor.shape(), dim, size, step);
|
||||
|
||||
let d_stride = tensor.meta.strides()[dim];
|
||||
let mut strides = tensor.meta.strides.clone();
|
||||
strides[dim] = step * d_stride;
|
||||
strides.push(d_stride);
|
||||
|
||||
CubeTensor {
|
||||
meta: Box::new(Metadata::new(shape, strides)),
|
||||
..tensor
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user