feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,165 @@
use burn_fusion::stream::Context;
use burn_std::{DType, Strides, quantization::QParamTensor, strides};
use cubecl::{
CubeElement, Runtime,
client::ComputeClient,
ir::{AddressType, ElemType},
prelude::{TensorArg, TensorHandleRef},
};
use cubecl::{
ir::LineSize,
quant::scheme::{QuantParam, QuantScheme},
};
use std::marker::PhantomData;
/// Defines a fallback operation when fusion isn't possible.
pub trait FallbackOperation<R: Runtime>: Send + Sync {
/// Executes the fallback procedure.
fn run(&self, context: &mut Context<'_, CubeFusionHandle<R>>);
}
/// Runtime parameters for quantization. Can be used to construct a scales handle from the base
/// tensor handle.
pub type QParams = burn_std::quantization::QParams<QParamTensor>;
/// Handle to be used when fusing operations.
pub struct CubeFusionHandle<R: Runtime> {
/// Compute client for jit.
pub client: ComputeClient<R>,
/// The buffer where the data are stored.
pub handle: cubecl::server::Handle,
/// The device of the current tensor.
pub device: R::Device,
/// The element type of the tensor.
pub dtype: DType,
/// The strides of the tensor.
pub strides: Strides,
/// Quantization runtime parameters, if applicable
pub qparams: Option<QParams>,
}
impl<R: Runtime> core::fmt::Debug for CubeFusionHandle<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!(
"CubeFusionHandle {{ device: {:?}, runtime: {}}}",
self.device,
R::name(&self.client),
))
}
}
impl<R: Runtime> Clone for CubeFusionHandle<R> {
fn clone(&self) -> Self {
Self {
client: self.client.clone(),
handle: self.handle.clone(),
device: self.device.clone(),
strides: self.strides.clone(),
dtype: self.dtype,
qparams: self.qparams.clone(),
}
}
}
unsafe impl<R: Runtime> Send for CubeFusionHandle<R> {}
unsafe impl<R: Runtime> Sync for CubeFusionHandle<R> {}
impl<R: Runtime> CubeFusionHandle<R> {
/// Return the reference to a tensor handle.
pub fn as_handle_ref<'a>(&'a self, shape: &'a [usize]) -> TensorHandleRef<'a, R> {
TensorHandleRef {
handle: &self.handle,
strides: &self.strides,
shape,
runtime: PhantomData,
elem_size: self.dtype.size(),
}
}
pub fn required_address_type(&self) -> AddressType {
match self.dtype {
DType::QFloat(scheme) => {
let len = self.handle.size() as usize * 8 / scheme.size_bits_value();
AddressType::from_len(len)
}
_ => AddressType::from_len(self.handle.size() as usize / self.dtype.size()),
}
}
/// Return the reference to a tensor argument.
pub fn as_tensor_arg<'a>(
&'a self,
shape: &'a [usize],
line_size: LineSize,
) -> TensorArg<'a, R> {
let handle: TensorHandleRef<'a, R> = self.as_handle_ref(shape);
unsafe {
TensorArg::from_raw_parts_and_size(
handle.handle,
handle.strides,
handle.shape,
line_size,
self.dtype.size(),
)
}
}
/// Construct a separate tensor for the quantization scales, if present
pub fn params(&self, scheme: QuantScheme) -> Option<Self> {
let qparams = self.qparams.as_ref()?;
let mut handle = self.handle.clone();
handle.offset_start = Some(qparams.scales.offset_start as u64);
handle.offset_end = Some(qparams.scales.offset_end as u64);
Some(Self {
client: self.client.clone(),
handle,
device: self.device.clone(),
dtype: match scheme.param {
QuantParam::F32 => DType::F32,
QuantParam::F16 => DType::F16,
QuantParam::BF16 => DType::BF16,
QuantParam::UE8M0 | QuantParam::UE4M3 => unimplemented!("Not yet supported"),
},
strides: qparams.scales.metadata.strides().clone(),
qparams: None,
})
}
}
pub(crate) fn strides_dyn_rank(shape: &[usize]) -> Strides {
let mut strides = strides![0; shape.len()];
let mut current = 1;
shape.iter().enumerate().rev().for_each(|(index, val)| {
strides[index] = current;
current *= val;
});
strides
}
pub(crate) fn elem_dtype<E: CubeElement>() -> DType {
match E::cube_type().elem_type() {
ElemType::Float(kind) => match kind {
cubecl::ir::FloatKind::F64 => DType::F64,
cubecl::ir::FloatKind::F16 => DType::F16,
cubecl::ir::FloatKind::BF16 => DType::BF16,
cubecl::ir::FloatKind::F32 => DType::F32,
_ => todo!(),
},
ElemType::Int(kind) => match kind {
cubecl::ir::IntKind::I64 => DType::I64,
cubecl::ir::IntKind::I32 => DType::I32,
cubecl::ir::IntKind::I16 => DType::I16,
cubecl::ir::IntKind::I8 => DType::I8,
},
ElemType::UInt(kind) => match kind {
cubecl::ir::UIntKind::U64 => DType::U64,
cubecl::ir::UIntKind::U32 => DType::U32,
cubecl::ir::UIntKind::U16 => DType::U16,
cubecl::ir::UIntKind::U8 => DType::U8,
},
ElemType::Bool => DType::Bool,
}
}

View File

@@ -0,0 +1,6 @@
/// The element type ID to be used for dynamic element type while expanding a fused kernel.
pub(crate) const DYN_ELEM_ID: u8 = u8::MAX;
/// The element type ID to be used for the quantization store element type while expanding a fused kernel.
pub(crate) const Q_STORE_DYN_ELEM_ID: u8 = u8::MAX - 1;
/// The element type ID to be used for the quantization param element type while expanding a fused kernel.
pub(crate) const Q_PARAM_DYN_ELEM_ID: u8 = u8::MAX - 2;

View File

@@ -0,0 +1,792 @@
//! This module declares input-output primitives to read and write values during kernel expansion.
use super::{DYN_ELEM_ID, ir::*, tensor::GlobalTensor};
use burn_std::quantization::QuantScheme;
use cubecl::quant::scheme::QuantLevel;
use cubecl::{
intrinsic,
ir::{ExpandElement, Variable},
prelude::*,
std::{FastDivmod, tensor::View},
};
use cubek::quantization::layout::{BlockScaledLayout, PerTensorLayout, ScalesLayout};
use serde::{Deserialize, Serialize};
/// Define how a tensor might be transformed at runtime.
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
pub enum Transform {
/// A reshape operation has been registered on a tensor.
///
/// This enum entry contains a sequence of [arguments](FuseArg) that points to global scalars representing the
/// new shape for the current tensor.
Reshape(Vec<FuseArg>),
/// Two axes have been swapped on a tensor.
///
/// The enum entry contains those two axes.
SwapDims(usize, usize),
}
/// Reads the value from the [arg](FuseArg) and cast it to the generic cube primitive.
///
/// # Notes
///
/// The [global arguments](GlobalArgs) for both inputs and outputs as well as the
/// [local arguments](LocalArgs) need to be passed to this function.
///
/// This is because the [argument](FuseArg) might point to a global input, output or local variable
/// created during kernel expansion.
#[cube]
pub fn read<C: CubePrimitive>(
inputs: &GlobalArgs,
outputs: &GlobalArgs,
locals: &LocalArgs,
ref_pos: usize,
#[comptime] arg: FuseArg,
#[comptime] config: &FuseBlockConfig,
) -> Line<C> {
match arg {
FuseArg::Input(pos, _precision, layout) => {
let global = inputs.tensors.index(pos);
let line_size = global.tensor.line_size();
if comptime![!global.broadcasted && line_size != config.width] {
read_input_aligned(inputs, locals, pos, ref_pos, layout, config, None)
} else {
read_input(inputs, locals, pos, ref_pos, layout, config, None)
}
}
FuseArg::MultiBlockLocal(key, _) | FuseArg::MultiBlockGlobal(key, _) => {
Line::cast_from(outputs.variables.read(key))
}
FuseArg::Output(pos, _precision, layout) => {
read_output(inputs, outputs, locals, pos, ref_pos, layout, config)
}
FuseArg::BlockLocal { pos, ty } => match comptime![ty] {
FuseType::F64 => Line::cast_from(locals.l_f64.find(pos)),
FuseType::F32 | FuseType::Flex32 => Line::cast_from(locals.l_f32.find(pos)),
FuseType::F16 => Line::cast_from(locals.l_f16.find(pos)),
FuseType::BF16 => Line::cast_from(locals.l_bf16.find(pos)),
FuseType::U64 => Line::cast_from(locals.l_u64.find(pos)),
FuseType::U32 => Line::cast_from(locals.l_u32.find(pos)),
FuseType::U16 => Line::cast_from(locals.l_u16.find(pos)),
FuseType::U8 => Line::cast_from(locals.l_u8.find(pos)),
FuseType::I64 => Line::cast_from(locals.l_i64.find(pos)),
FuseType::I32 => Line::cast_from(locals.l_i32.find(pos)),
FuseType::I16 => Line::cast_from(locals.l_i16.find(pos)),
FuseType::I8 => Line::cast_from(locals.l_i8.find(pos)),
FuseType::Bool => Line::cast_from(locals.l_bool.find(pos)),
},
FuseArg::Scalar(..) => {
let scalar = read_scalar::<C>(inputs, arg);
Line::new(scalar)
}
FuseArg::ScalarShape(_) => {
let scalar = read_scalar_shape(inputs, arg);
Line::cast_from(scalar)
}
FuseArg::Literal(val, _precision) => Line::new(from_const_int::<C>(val)),
FuseArg::InputReshaped {
original,
shape,
broadcasted,
} => match comptime![original.as_ref().clone()] {
FuseArg::Input(pos, _precision, layout) => {
let global = inputs.tensors.index(pos);
let line_size = global.tensor.line_size();
if comptime![!broadcasted && line_size != config.width] {
read_input_aligned(
inputs,
locals,
pos,
ref_pos,
layout,
config,
comptime![Some(Transform::Reshape(shape))],
)
} else {
read_input(
inputs,
locals,
pos,
ref_pos,
layout,
config,
comptime![Some(Transform::Reshape(shape))],
)
}
}
_ => comptime![panic!("Only input can be reshaped")],
},
FuseArg::InputSwapDims {
original,
dims,
broadcasted,
} => match comptime![original.as_ref().clone()] {
FuseArg::Input(pos, _precision, layout) => {
let global = inputs.tensors.index(pos);
let line_size = global.tensor.line_size();
if comptime![!broadcasted && line_size != config.width] {
read_input_aligned(
inputs,
locals,
pos,
ref_pos,
layout,
config,
comptime![Some(Transform::SwapDims(dims.0, dims.1))],
)
} else {
read_input(
inputs,
locals,
pos,
ref_pos,
layout,
config,
comptime![Some(Transform::SwapDims(dims.0, dims.1))],
)
}
}
_ => comptime![panic!("Only input can be swapped dims")],
},
}
}
/// Computes the offset for the current global tensor with a quantized layout.
///
/// The offset can be used to fetch the correct data from the quantized tensor as if it was in a
/// linear contiguous format.
#[cube]
fn index_offset_with_quant_layout(
tensor: &GlobalTensor,
locals: &LocalArgs,
index: usize,
#[comptime] rank: usize,
#[comptime] scheme: QuantScheme,
) -> usize {
let (start, end) = (0, rank - 1);
let num_quants = scheme.num_quants();
let offset_ref = index * locals.ref_line_size;
let mut offset = 0;
#[unroll]
for i in start..end {
let ogwl = offset_ref / locals.ref_strides[i];
offset += ogwl % tensor.tensor.shape(i) * tensor.tensor.stride(i);
}
// Handle packed representation in last dim
let ogwl = offset_ref / locals.ref_strides[end];
let shape_last = tensor.tensor.shape(end).div_ceil(num_quants);
let stride_last = tensor.tensor.stride(end);
offset += (ogwl.div_ceil(num_quants)) % shape_last * stride_last;
offset / tensor.tensor.line_size()
}
/// Reads a global quantized tensor at the given position.
///
/// # Notes
///
/// The values returned in the [Line] are not dequantized.
#[cube]
pub fn read_quantized<C: CubePrimitive>(
inputs: &GlobalArgs,
locals: &LocalArgs,
ref_pos: usize,
#[comptime] arg: FuseArg,
#[comptime] config: &FuseBlockConfig,
#[comptime] scheme: QuantScheme,
) -> Line<C> {
match arg {
FuseArg::Input(pos, _precision, _layout) => {
let global = inputs.tensors.index(pos);
let offset =
index_offset_with_quant_layout(global, locals, ref_pos, config.rank, scheme);
let val = global.tensor[offset];
Line::cast_from(val)
}
_ => panic!("Not supported"),
}
}
/// Reads a global scalar.
#[cube]
pub fn read_scalar<C: CubePrimitive>(inputs: &GlobalArgs, #[comptime] arg: FuseArg) -> C {
match arg {
FuseArg::Scalar(pos, _precision) => {
let scalar = inputs.scalars.index(pos);
scalar.get::<C>()
}
_ => comptime![panic!("Not a scalar")],
}
}
/// Reads a global scalar that is used as a reshape position.
#[cube]
pub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: FuseArg) -> usize {
match arg {
FuseArg::ScalarShape(pos) => inputs.reshapes[pos],
_ => comptime![panic!("Not a scalar shape")],
}
}
/// Reads an input tensor.
#[cube]
pub fn read_input<C: CubePrimitive>(
inputs: &GlobalArgs,
locals: &LocalArgs,
#[comptime] pos: usize,
ref_pos: usize,
#[comptime] layout: LayoutInfo,
#[comptime] config: &FuseBlockConfig,
#[comptime] transform: Option<Transform>,
) -> Line<C> {
let tensor = inputs.tensors.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, locals, tensor, ref_pos, None, config, transform),
};
Line::cast_from(tensor.tensor[offset])
}
/// Returns a slice of data in the asked precision of the input tensor at the given position.
#[cube]
pub fn read_input_window<C: CubePrimitive>(
inputs: &GlobalArgs,
#[comptime] pos: usize,
start: usize,
end: usize,
) -> Slice<C> {
let tensor = inputs.tensors.index(pos);
let slice = tensor.tensor.slice(start, end);
slice.downcast()
}
/// Returns the input as a slice.
#[cube]
pub fn input_as_slice<C: CubePrimitive>(inputs: &GlobalArgs, #[comptime] pos: usize) -> Slice<C> {
let tensor = inputs.tensors.index(pos);
let slice = tensor.tensor.to_slice();
slice.downcast()
}
/// Returns the input tensor as a quantized scale view.
#[cube]
pub fn input_as_scales_view<C: CubePrimitive>(
inputs: &GlobalArgs,
#[comptime] pos: usize,
#[comptime] tensor_pos: usize,
#[comptime] level: QuantLevel,
#[comptime] config: &FuseBlockConfig,
) -> View<C, usize> {
set_polyfill_typed::<C, NumericExpand<DYN_ELEM_ID>>();
let tensor = inputs.tensors.index(tensor_pos);
let scales = inputs.tensors.index(pos);
let tensor_len = tensor.tensor.len();
let rank = config.rank;
let layout = match level {
QuantLevel::Tensor => ScalesLayout::new_PerTensor(PerTensorLayout::new(tensor_len)),
QuantLevel::Block(block_size) => {
let block_size = comptime![block_size.to_dim_vec(rank)];
let mut tensor_shape = Sequence::new();
let mut scales_strides = Sequence::new();
#[unroll]
for i in 0..rank {
tensor_shape.push(FastDivmod::new_Fallback(tensor.tensor.shape(i)));
scales_strides.push(scales.tensor.stride(i));
}
let line_size = scales.tensor.line_size();
let layout = BlockScaledLayout::new(
tensor_shape,
tensor_len,
scales_strides,
block_size,
line_size,
);
ScalesLayout::new_BlockScaled(layout)
}
};
View::new::<Slice<C>, usize>(&scales.tensor.to_slice().downcast(), layout)
}
/// Reads the input tensor aligned.
#[cube]
pub fn read_input_aligned<C: CubePrimitive>(
inputs: &GlobalArgs,
locals: &LocalArgs,
#[comptime] pos: usize,
ref_pos: usize,
#[comptime] layout: LayoutInfo,
#[comptime] config: &FuseBlockConfig,
#[comptime] transform: Option<Transform>,
) -> Line<C> {
let mut result: Line<C> = Line::<C>::empty(config.width);
let tensor = inputs.tensors.index(pos);
match transform.clone() {
Some(Transform::Reshape(shape)) => {
// Very brute force, not really efficient, but not easy to optimize and not a very
// frequent workflow.
let ref_pos = ref_pos * config.width;
#[unroll]
for i in 0..config.width {
let index = reshaped_index(
inputs,
locals,
ref_pos + i,
config.rank,
comptime![shape.clone()],
);
let index = reshaped_index_to_original_index(&tensor.tensor, index, config.rank);
result[i] = C::cast_from(tensor.tensor[index][0])
}
}
Some(Transform::SwapDims(dim1, dim2)) => {
let offset =
get_offset_aligned(inputs, locals, tensor, ref_pos, layout, config, transform);
let i = comptime![swap_dims_transform(config.rank - 1, (dim1, dim2))];
let stride = tensor.tensor.stride(i);
#[unroll]
for i in 0..config.width {
let index = offset + i * stride;
result[i] = C::cast_from(tensor.tensor[index][0])
}
}
None => {
let offset =
get_offset_aligned(inputs, locals, tensor, ref_pos, layout, config, transform);
let stride = tensor.tensor.stride(config.rank - 1);
#[unroll]
for i in 0..config.width {
let index = offset + i * stride;
result[i] = C::cast_from(tensor.tensor[index][0])
}
}
}
result
}
/// Computes the offset of the given [GlobalTensor] at on the reference position with a linear
/// layout.
#[cube]
pub fn get_offset_aligned(
inputs: &GlobalArgs,
locals: &LocalArgs,
tensor: &GlobalTensor,
ref_pos: usize,
#[comptime] layout: LayoutInfo,
#[comptime] config: &FuseBlockConfig,
#[comptime] transform: Option<Transform>,
) -> usize {
match layout {
LayoutInfo::SameAsRef | LayoutInfo::IsRef => {
(ref_pos * locals.ref_line_size) / tensor.tensor.line_size()
}
LayoutInfo::Unknown => get_offset(
inputs,
locals,
tensor,
ref_pos,
None,
config,
comptime!(transform.clone()),
),
}
}
/// Reads an output tensor.
#[cube]
pub fn read_output<C: CubePrimitive>(
inputs: &GlobalArgs,
outputs: &GlobalArgs,
locals: &LocalArgs,
#[comptime] pos: usize,
ref_pos: usize,
#[comptime] layout: LayoutInfo,
#[comptime] config: &FuseBlockConfig,
) -> Line<C> {
let tensor = outputs.tensors.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => get_offset(inputs, locals, tensor, ref_pos, None, config, None),
};
Line::cast_from(tensor.tensor[offset])
}
#[cube]
/// Write the given value at the [arg](Arg) position.
pub fn write<C: CubePrimitive>(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
locals: &mut LocalArgs,
ref_pos: usize,
value: Line<C>,
#[comptime] arg: FuseArg,
#[comptime] config: &FuseBlockConfig,
) {
match arg {
FuseArg::Output(pos, precision, layout) => {
let tensor = outputs.tensors.index(pos);
let offset = match layout {
LayoutInfo::SameAsRef => ref_pos,
LayoutInfo::IsRef => ref_pos,
LayoutInfo::Unknown => {
get_offset(inputs, locals, tensor, ref_pos, None, config, None)
}
};
let tensor = outputs.tensors.index_mut(pos);
set_polyfill::<NumericExpand<DYN_ELEM_ID>>(comptime![precision.into_type()]);
tensor.tensor[offset] = Line::cast_from(value);
}
FuseArg::BlockLocal { .. } => write_scalar::<C>(locals, value, arg),
FuseArg::MultiBlockLocal(key, _) | FuseArg::MultiBlockGlobal(key, _) => {
outputs.variables.write(key, Line::cast_from(value))
}
_ => comptime![panic!("Can't write into inputs and scalars")],
}
}
#[cube]
/// Write the given value at the [arg](Arg) position.
pub fn write_scalar<C: CubePrimitive>(
locals: &mut LocalArgs,
value: Line<C>,
#[comptime] arg: FuseArg,
) {
match arg {
FuseArg::BlockLocal { pos, ty } => match comptime![ty] {
FuseType::F64 => locals.l_f64.insert(pos, Line::cast_from(value)),
FuseType::F32 | FuseType::Flex32 => locals.l_f32.insert(pos, Line::cast_from(value)),
FuseType::F16 => locals.l_f16.insert(pos, Line::cast_from(value)),
FuseType::BF16 => locals.l_bf16.insert(pos, Line::cast_from(value)),
FuseType::U64 => locals.l_u64.insert(pos, Line::cast_from(value)),
FuseType::U32 => locals.l_u32.insert(pos, Line::cast_from(value)),
FuseType::U16 => locals.l_u16.insert(pos, Line::cast_from(value)),
FuseType::U8 => locals.l_u8.insert(pos, Line::cast_from(value)),
FuseType::I64 => locals.l_i64.insert(pos, Line::cast_from(value)),
FuseType::I32 => locals.l_i32.insert(pos, Line::cast_from(value)),
FuseType::I16 => locals.l_i16.insert(pos, Line::cast_from(value)),
FuseType::I8 => locals.l_i8.insert(pos, Line::cast_from(value)),
FuseType::Bool => locals.l_bool.insert(pos, Line::cast_from(value)),
},
_ => comptime![panic!("Can't write into something else than scalars")],
}
}
#[cube]
pub(crate) fn global_offset(
inputs: &GlobalArgs,
outputs: &GlobalArgs,
locals: &LocalArgs,
index: usize,
#[comptime] arg: FuseArg,
#[comptime] range: Option<(usize, usize)>,
#[comptime] config: &FuseBlockConfig,
) -> usize {
match arg {
FuseArg::Input(pos, _precision, _layout) => {
let tensor = inputs.tensors.index(pos);
get_offset(inputs, locals, tensor, index, range, config, None)
}
FuseArg::Output(pos, _precision, _layout) => {
let tensor = outputs.tensors.index(pos);
get_offset(inputs, locals, tensor, index, range, config, None)
}
_ => panic!("Only input and output tensors have global offset."),
}
}
#[cube]
fn get_offset(
inputs: &GlobalArgs,
locals: &LocalArgs,
tensor: &GlobalTensor,
ref_pos: usize,
#[comptime] range: Option<(usize, usize)>,
#[comptime] config: &FuseBlockConfig,
#[comptime] transform: Option<Transform>,
) -> usize {
index_offset_with_layout(
inputs,
tensor,
locals,
ref_pos,
range,
config.rank,
transform,
)
}
#[cube]
/// Gets the line size for a global tensor.
pub fn global_line_size(global: &GlobalArgs, #[comptime] pos: usize) -> comptime_type!(LineSize) {
let tensor = global.tensors.index(pos);
tensor.tensor.line_size()
}
#[cube]
/// Gets the rank for a global tensor.
pub fn global_rank(global: &GlobalArgs, #[comptime] pos: usize) -> usize {
let tensor = global.tensors.index(pos);
tensor.tensor.rank()
}
#[cube]
/// Gets the length for a global tensor.
pub fn global_len(global: &GlobalArgs, #[comptime] pos: usize) -> usize {
let tensor = global.tensors.index(pos);
tensor.tensor.len()
}
#[cube]
/// Gets the buffer length for a global tensor.
pub fn global_buffer_len(global: &GlobalArgs, #[comptime] pos: usize) -> usize {
let tensor = global.tensors.index(pos);
tensor.tensor.buffer_len()
}
#[cube]
/// Gets the reference tensor length.
pub fn ref_len(
inputs: &GlobalArgs,
outputs: &GlobalArgs,
locals: &LocalArgs,
#[comptime] config: &FuseBlockConfig,
) -> usize {
match config.ref_layout.clone() {
RefLayout::Concrete(arg) => match comptime![arg] {
FuseArg::Input(index, _, _) => global_len(inputs, index),
FuseArg::Output(index, _, _) => global_len(outputs, index),
_ => panic!("Invalid concrete ref layout."),
},
RefLayout::Virtual(..) => num_elements(locals, config),
}
}
#[cube]
/// Gets the reference buffer tensor length.
pub fn ref_buffer_len(
inputs: &GlobalArgs,
outputs: &GlobalArgs,
locals: &LocalArgs,
#[comptime] config: &FuseBlockConfig,
) -> usize {
match config.ref_layout.clone() {
RefLayout::Concrete(arg) => match comptime![arg] {
FuseArg::Input(index, _, _) => global_buffer_len(inputs, index),
FuseArg::Output(index, _, _) => global_buffer_len(outputs, index),
_ => panic!("Invalid concrete ref layout."),
},
RefLayout::Virtual(VirtualLayout::SwapDims(arg, ..)) => match arg {
FuseArg::Input(index, _, _) => global_buffer_len(inputs, index),
FuseArg::Output(index, _, _) => global_buffer_len(outputs, index),
_ => panic!("Invalid concrete ref layout."),
},
RefLayout::Virtual(VirtualLayout::Reshaped { .. }) => num_elements(locals, config),
RefLayout::Virtual(VirtualLayout::Shape(..)) => num_elements(locals, config),
RefLayout::Virtual(VirtualLayout::Runtime { .. }) => num_elements(locals, config),
}
}
#[cube]
/// Gets the reference number of elements.
pub fn num_elements(locals: &LocalArgs, #[comptime] config: &FuseBlockConfig) -> usize {
let mut length = 1;
for i in 0..config.rank {
length *= locals.ref_shape[i];
}
length
}
#[cube]
/// Gets the reference axis shape.
pub fn ref_shape(locals: &LocalArgs, axis: usize) -> usize {
locals.ref_shape[axis]
}
#[cube]
/// Gets the reference axis stride.
pub fn ref_stride(locals: &LocalArgs, axis: usize) -> usize {
locals.ref_strides[axis]
}
#[cube]
/// Gets the reference line size.
pub fn ref_line_size(locals: &LocalArgs) -> comptime_type!(LineSize) {
comptime![locals.ref_line_size]
}
#[cube]
/// Gets the given tensor axis shape.
pub fn global_shape(global: &GlobalArgs, axis: usize, #[comptime] pos: usize) -> usize {
let tensor = global.tensors.index(pos);
tensor.tensor.shape(axis)
}
#[cube]
/// Gets the given tensor axis stride.
pub fn global_stride(global: &GlobalArgs, dim: usize, #[comptime] pos: usize) -> usize {
let tensor = global.tensors.index(pos);
tensor.tensor.stride(dim)
}
#[cube]
fn index_offset_with_layout(
inputs: &GlobalArgs,
tensor: &GlobalTensor,
locals: &LocalArgs,
index: usize,
#[comptime] range: Option<(usize, usize)>,
#[comptime] rank: usize,
#[comptime] transform: Option<Transform>,
) -> usize {
match comptime![transform.clone()] {
Some(Transform::Reshape(shape)) => {
comptime![assert!(
range.is_none(),
"Can't get a range on a reshaped tensor."
)];
let index = index * locals.ref_line_size;
let index = reshaped_index(inputs, locals, index, rank, shape);
reshaped_index_to_original_index(&tensor.tensor, index, rank)
}
Some(Transform::SwapDims(dim1, dim2)) => {
let (start, end) = comptime! {match range {
Some(range) => range,
None => (0, rank),
}};
let offset_ref = index * locals.ref_line_size;
let mut offset = 0;
#[unroll]
for i in start..end {
let index = comptime![swap_dims_transform(i, (dim1, dim2))];
let ogwl = offset_ref / locals.ref_strides[i];
offset += ogwl % tensor.tensor.shape(index) * tensor.tensor.stride(index);
}
offset / tensor.tensor.line_size()
}
None => {
let (start, end) = comptime! {match range {
Some(range) => range,
None => (0, rank),
}};
let offset_ref = index * locals.ref_line_size;
let mut offset = 0;
#[unroll]
for i in start..end {
let ogwl = offset_ref / locals.ref_strides[i];
offset += ogwl % tensor.tensor.shape(i) * tensor.tensor.stride(i);
}
offset / tensor.tensor.line_size()
}
}
}
pub(crate) fn swap_dims_transform(i: usize, dims: (usize, usize)) -> usize {
if i == dims.0 {
dims.1
} else if i == dims.1 {
dims.0
} else {
i
}
}
#[cube]
#[allow(clippy::clone_on_copy)]
/// The index the input tensor would be at if it was contiguous.
fn reshaped_index(
inputs: &GlobalArgs,
locals: &LocalArgs,
index: usize,
#[comptime] rank: usize,
#[comptime] shape: Vec<FuseArg>,
) -> usize {
let mut offset = 0;
let mut stride_curr = 1;
#[unroll]
for r in 0..rank {
let i = reverse_index(rank, r).comptime();
let arg = shape[i].clone();
let shape_i = read_scalar_shape(inputs, arg);
let ogwl = index / locals.ref_strides[i];
offset += ogwl % shape_i * stride_curr;
stride_curr *= shape_i;
}
offset
}
#[allow(unreachable_code)]
#[cube]
#[allow(clippy::clone_on_copy)]
fn reshaped_index_to_original_index<C: CubePrimitive>(
original: &Tensor<Line<C>>,
index_reshaped: usize,
#[comptime] rank: usize,
) -> usize {
let mut remaining = index_reshaped;
let mut offset = 0;
#[unroll]
for r in 0..rank {
let i = reverse_index(rank, r);
let shape = original.shape(i);
let stride = original.stride(i);
let coordinate = remaining % shape;
remaining /= shape;
offset += coordinate * stride;
}
offset / original.line_size()
}
#[cube]
#[allow(unused_variables)]
pub(crate) fn reverse_index(
#[comptime] rank: usize,
#[comptime] iter: usize,
) -> comptime_type!(usize) {
rank - iter - 1
}
/// Generic way to construct any [`CubePrimitive`] from an int. Used for fusion.
#[allow(unused_variables)]
#[cube]
fn from_const_int<C: CubePrimitive>(#[comptime] value: usize) -> C {
intrinsic!(|scope| {
ExpandElement::Plain(Variable::constant(value.into(), C::as_type(scope))).into()
})
}
#[cube]
#[allow(clippy::extra_unused_type_parameters)]
fn set_polyfill_typed<C: CubePrimitive, Dyn: CubePrimitive>() {
intrinsic!(|scope| {
let elem_type = C::as_type(scope);
set_polyfill::expand::<Dyn>(scope, elem_type);
})
}

View File

@@ -0,0 +1,917 @@
use super::tensor::GlobalTensor;
use crate::engine::codegen::DYN_ELEM_ID;
use burn_std::{
DType, Shape, Strides, bf16, f16,
quantization::{QuantScheme, QuantStore, QuantValue},
strides,
};
use core::fmt::Display;
use cubecl::{
ir::{ElemType, FloatKind, IntKind, StorageType, UIntKind},
prelude::*,
};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
/// Argument to a [fuse operation](FuseOp).
pub enum FuseArg {
/// A readonly input tensor.
Input(usize, FuseType, LayoutInfo),
/// A readwrite output tensor.
Output(usize, FuseType, LayoutInfo),
/// A temporary local variable within a single [block](FuseBlockConfig).
BlockLocal {
/// The position of the current variable relative to all local variables within a single block.
pos: usize,
/// The type of the current variable.
ty: FuseType,
},
/// A variable shared between multiple [block](FuseBlockConfig) that must have a compatible
/// scope.
MultiBlockLocal(MultiBlockPos, FuseType),
/// A variable shared between multiple [blocks](FuseBlockConfig) within a global accessible
/// scope.
MultiBlockGlobal(MultiBlockPos, FuseType),
/// A global scalar.
Scalar(usize, FuseType),
/// A global scalar used in a reshape operation.
///
/// This is not a scalar defined by a user for computation, but a scalar defined as part of
/// a reshape operation.
ScalarShape(usize),
/// Only constant that can be encoded into an u32 can be used as literal.
Literal(usize, FuseType),
/// A readonly input tensor that is reshaped.
InputReshaped {
original: Box<FuseArg>,
shape: Vec<FuseArg>,
broadcasted: bool,
},
/// A readonly input tensor with swapped dimensions.
InputSwapDims {
original: Box<FuseArg>,
dims: (usize, usize),
broadcasted: bool,
},
}
/// Metadata of a variable shared between blocks.
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
pub struct MultiBlockPos {
/// The block position in all blocks included in a fused trace.
pub block_pos: usize,
/// The [FuseArg::BlockLocal] position in the block where the variable is first initialized.
pub block_local_pos: usize,
}
#[derive(
CubeType, Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord,
)]
/// Layout information.
pub enum LayoutInfo {
/// The layout if the same as the reference.
SameAsRef,
/// The reference layout.
IsRef,
/// The layout if unknown.
Unknown,
}
impl FuseArg {
pub fn precision(&self) -> FuseType {
*match self {
FuseArg::Input(_, p, _) => p,
FuseArg::BlockLocal { ty, .. } => ty,
FuseArg::MultiBlockLocal(_, p) => p,
FuseArg::MultiBlockGlobal(_, p) => p,
FuseArg::Output(_, p, _) => p,
FuseArg::Scalar(_, p) => p,
FuseArg::Literal(_, p) => p,
FuseArg::ScalarShape(_) => return FuseType::U32,
FuseArg::InputReshaped { original, .. } => return original.precision(),
FuseArg::InputSwapDims { original, .. } => return original.precision(),
}
}
}
impl CubeType for FuseArg {
type ExpandType = Self;
}
impl IntoMut for FuseArg {
fn into_mut(self, _context: &mut Scope) -> Self {
self
}
}
impl IntoRuntime for FuseArg {
fn __expand_runtime_method(self, _context: &mut Scope) -> Self::ExpandType {
self
}
}
impl CubeDebug for FuseArg {}
#[derive(CubeType, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
/// Operations that can be executed and fused automatically using a fuse-on-read and/or
/// fuse-on-write strategy.
pub enum FuseOp {
Add(BinaryFuseArgs),
Sub(BinaryFuseArgs),
Mul(BinaryFuseArgs),
Div(BinaryFuseArgs),
Powf(BinaryFuseArgs),
Abs(UnaryFuseArgs),
Exp(UnaryFuseArgs),
Log(UnaryFuseArgs),
Log1p(UnaryFuseArgs),
Cos(UnaryFuseArgs),
Sin(UnaryFuseArgs),
Tanh(UnaryFuseArgs),
Erf(UnaryFuseArgs),
Sqrt(UnaryFuseArgs),
Recip(UnaryFuseArgs),
Assign(UnaryFuseArgs),
Equal(BinaryFuseArgs),
Lower(BinaryFuseArgs),
Greater(BinaryFuseArgs),
LowerEqual(BinaryFuseArgs),
Rem(BinaryFuseArgs),
GreaterEqual(BinaryFuseArgs),
Clamp {
input: FuseArg,
min: FuseArg,
max: FuseArg,
out: FuseArg,
},
ConditionalAssign {
cond: FuseArg,
lhs: FuseArg,
rhs: FuseArg,
out: FuseArg,
},
Gather {
input: FuseArg,
indices: FuseArg,
output: FuseArg,
dim: usize,
},
Select {
input: FuseArg,
indices: FuseArg,
output: FuseArg,
dim: usize,
},
Dequantize {
values: FuseArg,
params: FuseArg,
output: FuseArg,
scheme: QuantSchemeFuse,
},
}
impl Display for FuseOp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FuseOp::Add(args) => write!(f, "{} = {} + {}", args.out, args.lhs, args.rhs),
FuseOp::Sub(args) => write!(f, "{} = {} - {}", args.out, args.lhs, args.rhs),
FuseOp::Mul(args) => write!(f, "{} = {} * {}", args.out, args.lhs, args.rhs),
FuseOp::Div(args) => write!(f, "{} = {} / {}", args.out, args.lhs, args.rhs),
FuseOp::Powf(args) => write!(f, "{} = powf({}, {})", args.out, args.lhs, args.rhs),
FuseOp::Abs(args) => write!(f, "{} = abs({})", args.out, args.input),
FuseOp::Exp(args) => write!(f, "{} = exp({})", args.out, args.input),
FuseOp::Log(args) => write!(f, "{} = log({})", args.out, args.input),
FuseOp::Log1p(args) => write!(f, "{} = log1p({})", args.out, args.input),
FuseOp::Cos(args) => write!(f, "{} = cos({})", args.out, args.input),
FuseOp::Sin(args) => write!(f, "{} = sin({})", args.out, args.input),
FuseOp::Tanh(args) => write!(f, "{} = tanh({})", args.out, args.input),
FuseOp::Erf(args) => write!(f, "{} = erf({})", args.out, args.input),
FuseOp::Sqrt(args) => write!(f, "{} = sqrt({})", args.out, args.input),
FuseOp::Recip(args) => write!(f, "{} = recip({})", args.out, args.input),
FuseOp::Assign(args) => write!(f, "{} = {}", args.out, args.input),
FuseOp::Equal(args) => write!(f, "{} = {} == {}", args.out, args.lhs, args.rhs),
FuseOp::Lower(args) => write!(f, "{} = {} < {}", args.out, args.lhs, args.rhs),
FuseOp::Greater(args) => write!(f, "{} = {} > {}", args.out, args.lhs, args.rhs),
FuseOp::LowerEqual(args) => write!(f, "{} = {} <= {}", args.out, args.lhs, args.rhs),
FuseOp::Rem(args) => write!(f, "{} = {} % {}", args.out, args.lhs, args.rhs),
FuseOp::GreaterEqual(args) => write!(f, "{} = {} >= {}", args.out, args.lhs, args.rhs),
FuseOp::Clamp {
input,
min,
max,
out,
} => write!(f, "{} = clamp({}, min={}, max={})", out, input, min, max),
FuseOp::ConditionalAssign {
cond,
lhs,
rhs,
out,
} => write!(
f,
"{} = select(cond={}, lhs={}, rhs={})",
out, cond, lhs, rhs
),
FuseOp::Gather {
input,
indices,
output,
dim,
} => write!(
f,
"{} = gather(input={}, indices={}, dim={})",
output, input, indices, dim
),
FuseOp::Select {
input,
indices,
output,
dim,
} => write!(
f,
"{} = select(input={}, indices={}, dim={})",
output, input, indices, dim
),
FuseOp::Dequantize {
values,
params,
output,
scheme: _,
} => write!(
f,
"{} = dequantize(values={}, params={})",
output, values, params
),
}
}
}
#[derive(
CubeType, CubeLaunch, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord,
)]
pub struct QuantSchemeFuse {
#[cube(comptime)]
pub(crate) scheme: QuantScheme,
}
impl FuseOp {
/// Element type used for the computation.
pub(crate) fn cmp_elem(&self) -> ElemType {
match self {
FuseOp::Add(op) => op.lhs.precision().into_elem(),
FuseOp::Sub(op) => op.lhs.precision().into_elem(),
FuseOp::Mul(op) => op.lhs.precision().into_elem(),
FuseOp::Div(op) => op.lhs.precision().into_elem(),
FuseOp::Powf(op) => op.lhs.precision().into_elem(),
FuseOp::Abs(op) => op.out.precision().into_elem(),
FuseOp::Exp(op) => op.out.precision().into_elem(),
FuseOp::Log(op) => op.out.precision().into_elem(),
FuseOp::Log1p(op) => op.out.precision().into_elem(),
FuseOp::Cos(op) => op.out.precision().into_elem(),
FuseOp::Sin(op) => op.out.precision().into_elem(),
FuseOp::Tanh(op) => op.out.precision().into_elem(),
FuseOp::Erf(op) => op.out.precision().into_elem(),
FuseOp::Recip(op) => op.out.precision().into_elem(),
FuseOp::Sqrt(op) => op.out.precision().into_elem(),
FuseOp::Assign(op) => op.out.precision().into_elem(),
FuseOp::Equal(op) => op.lhs.precision().into_elem(),
FuseOp::Lower(op) => op.lhs.precision().into_elem(),
FuseOp::Greater(op) => op.lhs.precision().into_elem(),
FuseOp::LowerEqual(op) => op.lhs.precision().into_elem(),
FuseOp::GreaterEqual(op) => op.lhs.precision().into_elem(),
FuseOp::ConditionalAssign { out, .. } => out.precision().into_elem(),
FuseOp::Gather { output, .. } => output.precision().into_elem(),
FuseOp::Select { output, .. } => output.precision().into_elem(),
FuseOp::Dequantize { output, .. } => output.precision().into_elem(),
FuseOp::Rem(op) => op.out.precision().into_elem(),
FuseOp::Clamp { out, .. } => out.precision().into_elem(),
}
}
/// Element type used for the computation.
pub(crate) fn cmp_type(&self) -> StorageType {
self.cmp_elem().into()
}
}
#[derive(CubeType, CubeLaunch, Default, Clone)]
/// Global arguments that are used for fusing [element wise operations](ElemTypewiseOp).
pub struct GlobalArgs {
/// Tensors that are stored in global memory.
pub tensors: Sequence<GlobalTensor>,
/// Scalars that are stored in global memory.
pub scalars: Sequence<InputScalar>,
/// To be used to perform reshape inside a fused kernel.
pub reshapes: Sequence<usize>,
/// When there are no metadata as a reference layout, we provide runtime shape/strides in this
/// sequence instead.
pub runtime_layouts: Sequence<usize>,
/// Variables shared between blocks.
pub variables: MultiBlockVariables,
}
impl<'a, R: Runtime> GlobalArgsLaunch<'a, R> {
pub fn required_address_type(&self) -> AddressType {
self.tensors
.values
.iter()
.map(|it| it.address_type)
.max()
.unwrap_or_default()
}
}
/// Variables shared between blocks.
#[derive(CubeType, Default, Clone)]
pub struct MultiBlockVariables {
variables: Registry<usize, Registry<usize, RuntimeCell<Line<NumericExpand<DYN_ELEM_ID>>>>>,
}
#[cube]
impl MultiBlockVariables {
/// Initializes the variable with the given key and line size.
///
/// # Notes
///
/// The type of [`NumericExpand<DYN_ELEM_ID>`] must be set before calling this function.
pub fn init(&mut self, #[comptime] key: MultiBlockPos, #[comptime] line_size: usize) {
let mut registers = Registry::<
usize,
Registry<usize, RuntimeCell<Line<NumericExpand<DYN_ELEM_ID>>>>,
>::find_or_default::<usize>(&mut self.variables, key.block_pos);
let cell = RuntimeCell::new(Line::empty(line_size));
registers.insert(key.block_local_pos, cell);
}
/// Read the variable using the provided key.
///
/// # Notes
///
/// The variable must be initialized.
pub fn read(&self, #[comptime] key: MultiBlockPos) -> Line<NumericExpand<DYN_ELEM_ID>> {
let registers = self.variables.find(key.block_pos);
let cell = registers.find(key.block_local_pos);
cell.read()
}
/// Write to the variable using the provided key and value.
///
/// # Notes
///
/// The variable must be initialized.
pub fn write(
&mut self,
#[comptime] key: MultiBlockPos,
value: Line<NumericExpand<DYN_ELEM_ID>>,
) {
let registers = self.variables.find(key.block_pos);
// Try find for local(visibility) registers.
let cell = registers.find(key.block_local_pos);
cell.store(value);
}
}
// Because we only create it DURING compilation, not as a real launch arg.
unsafe impl Send for MultiBlockVariables {}
unsafe impl Sync for MultiBlockVariables {}
impl LaunchArg for MultiBlockVariables {
type RuntimeArg<'a, R: Runtime> = ();
type CompilationArg = ();
fn compilation_arg<R: Runtime>(_runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
}
fn expand(
_arg: &Self::CompilationArg,
_builder: &mut KernelBuilder,
) -> <Self as CubeType>::ExpandType {
MultiBlockVariablesExpand {
variables: Default::default(),
}
}
}
impl<R: Runtime> Default for GlobalArgsLaunch<'_, R> {
fn default() -> Self {
Self {
tensors: Default::default(),
scalars: Default::default(),
reshapes: Default::default(),
variables: Default::default(),
runtime_layouts: Default::default(),
_phantom_runtime: std::marker::PhantomData,
_phantom_a: std::marker::PhantomData,
}
}
}
impl<R: Runtime> core::fmt::Debug for GlobalArgsLaunch<'_, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "({:?})", self.tensors.values)
}
}
impl<R: Runtime> GlobalArgsLaunch<'_, R> {
/// Get the shape of the given [argument](Arg).
///
/// # Panics
///
/// If the argument doesn't have an handle.
pub fn shape(&self, arg: &FuseArg) -> Shape {
match self.resolve_arg(arg) {
TensorArg::Handle { handle, .. } => handle.shape.into(),
TensorArg::Alias { .. } => panic!("Unsupported yet"),
}
}
/// Shape used by the reference tensor.
pub fn shape_ref(&self, ref_layout: &RefLayout, rank: usize) -> Shape {
match ref_layout {
RefLayout::Concrete(arg) => self.shape(arg),
RefLayout::Virtual(layout) => match layout {
VirtualLayout::SwapDims(original, dims) => {
let mut shape = self.shape(original);
shape.swap(dims.0, dims.1);
shape
}
VirtualLayout::Reshaped { reshape_pos, .. } => {
let start = *reshape_pos * rank;
let end = start + rank;
self.reshapes.values[start..end]
.iter()
.map(|s| s.elem)
.collect()
}
VirtualLayout::Shape(original, _) => self.shape(original),
VirtualLayout::Runtime { pos } => {
let start = (*pos * 2) * rank;
let end = start + rank;
self.runtime_layouts.values[start..end]
.iter()
.map(|s| s.elem)
.collect()
}
},
}
}
/// Get the strides of the given [argument](Arg).
///
/// # Panics
///
/// If the argument doesn't have an handle.
pub fn strides(&self, arg: &FuseArg) -> Strides {
match self.resolve_arg(arg) {
TensorArg::Handle { handle, .. } => handle.strides.into(),
TensorArg::Alias { .. } => panic!("Unsupported yet"),
}
}
pub fn strides_ref(&self, ref_layout: &RefLayout, rank: usize) -> Strides {
match ref_layout {
RefLayout::Concrete(arg) => self.strides(arg),
// When not concrete, we operate on the contiguous layout.
_ => {
let shape = self.shape_ref(ref_layout, rank);
let mut strides = strides![0; shape.len()];
let mut current = 1;
shape.iter().enumerate().rev().for_each(|(index, val)| {
strides[index] = current;
current *= val;
});
strides
}
}
}
/// Get the line size of the given [argument](Arg).
///
/// # Panics
///
/// If the argument doesn't have an handle.
pub fn line_size(&self, arg: &FuseArg) -> LineSize {
match self.resolve_arg(arg) {
TensorArg::Handle { line_size, .. } => *line_size,
TensorArg::Alias { .. } => panic!("Unsupported yet"),
}
}
/// Resolve the [argument](Arg) to a [tensor argument](TensorArg).
///
/// # Panics
///
/// If the argument isn't a global input or output tensor.
pub fn resolve_arg(&self, arg: &FuseArg) -> &TensorArg<'_, R> {
match arg {
FuseArg::Input(pos, _, _) => &self.tensors.values[*pos].tensor,
FuseArg::Output(pos, _, _) => &self.tensors.values[*pos].tensor,
other => panic!("Arg not found: {other:?}"),
}
}
}
#[derive(CubeType, Clone)]
/// Keep track of all local variables that are used as argument in fused
/// [element wise operations](ElemwiseOp).
pub struct LocalArgs {
pub l_f64: Registry<usize, Line<f64>>,
pub l_f32: Registry<usize, Line<f32>>,
pub l_f16: Registry<usize, Line<f16>>,
pub l_bf16: Registry<usize, Line<bf16>>,
pub l_i64: Registry<usize, Line<i64>>,
pub l_i32: Registry<usize, Line<i32>>,
pub l_i16: Registry<usize, Line<i16>>,
pub l_i8: Registry<usize, Line<i8>>,
pub l_u64: Registry<usize, Line<u64>>,
pub l_u32: Registry<usize, Line<u32>>,
pub l_u16: Registry<usize, Line<u16>>,
pub l_u8: Registry<usize, Line<u8>>,
pub l_bool: Registry<usize, Line<bool>>,
pub ref_shape: Slice<usize>,
pub ref_strides: Slice<usize>,
#[cube(comptime)]
pub ref_line_size: LineSize,
}
#[cube]
impl LocalArgs {
/// Creates a new [LocalArgs] container.
pub fn new(
ref_shape: Slice<usize>,
ref_strides: Slice<usize>,
#[comptime] ref_line_size: LineSize,
) -> LocalArgs {
LocalArgs {
l_f64: Registry::<usize, Line<f64>>::new(),
l_f32: Registry::<usize, Line<f32>>::new(),
l_f16: Registry::<usize, Line<f16>>::new(),
l_bf16: Registry::<usize, Line<bf16>>::new(),
l_i64: Registry::<usize, Line<i64>>::new(),
l_i32: Registry::<usize, Line<i32>>::new(),
l_i16: Registry::<usize, Line<i16>>::new(),
l_i8: Registry::<usize, Line<i8>>::new(),
l_u64: Registry::<usize, Line<u64>>::new(),
l_u32: Registry::<usize, Line<u32>>::new(),
l_u16: Registry::<usize, Line<u16>>::new(),
l_u8: Registry::<usize, Line<u8>>::new(),
l_bool: Registry::<usize, Line<bool>>::new(),
ref_shape,
ref_strides,
ref_line_size,
}
}
}
#[derive(CubeType, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
/// Unary [element wise operation](ElemwiseOp) arguments.
pub struct UnaryFuseArgs {
pub input: FuseArg,
pub out: FuseArg,
}
#[derive(CubeType, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
/// Binary [element wise operation](ElemwiseOp) arguments.
pub struct BinaryFuseArgs {
pub lhs: FuseArg,
pub rhs: FuseArg,
pub out: FuseArg,
}
#[derive(
CubeType, Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize,
)]
/// Precisions supported by [element wise operations](ElemwiseOp).
///
/// This is a custom type instead of [ElemType] so it can implement [CubeType]
/// and restricts the supported types for fusion.
pub enum FuseType {
F64,
F32,
Flex32,
F16,
BF16,
I64,
I32,
I16,
I8,
U64,
U32,
U16,
U8,
Bool,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
/// Configuration that encapsulates all comptime information necessary for element wise fusion.
pub struct FuseBlockConfig {
pub rank: usize,
pub ref_layout: RefLayout,
pub ops: Vec<FuseOp>,
pub width: LineSize,
}
impl FuseBlockConfig {
pub fn multi_block_variables(&self, registers: &mut Vec<(MultiBlockPos, StorageType)>) {
for op in self.ops.iter() {
op.multi_block_variables(registers);
}
}
}
impl FuseArg {
pub fn multi_block_variable(&self, registers: &mut Vec<(MultiBlockPos, StorageType)>) {
match self {
FuseArg::MultiBlockGlobal(arg, fuse_type)
// TODO: we need to init the multi-block local, but at some point we could avoid
// that for performance (easier for the underlying compiler).
| FuseArg::MultiBlockLocal(arg, fuse_type) => {
registers.push((arg.clone(), fuse_type.into_type()))
}
_ => {}
};
}
}
impl FuseOp {
pub fn multi_block_variables(&self, registers: &mut Vec<(MultiBlockPos, StorageType)>) {
match self {
FuseOp::Add(binary_fuse_args)
| FuseOp::Sub(binary_fuse_args)
| FuseOp::Mul(binary_fuse_args)
| FuseOp::Div(binary_fuse_args)
| FuseOp::Powf(binary_fuse_args)
| FuseOp::Equal(binary_fuse_args)
| FuseOp::Lower(binary_fuse_args)
| FuseOp::Greater(binary_fuse_args)
| FuseOp::LowerEqual(binary_fuse_args)
| FuseOp::Rem(binary_fuse_args)
| FuseOp::GreaterEqual(binary_fuse_args) => {
binary_fuse_args.lhs.multi_block_variable(registers);
binary_fuse_args.rhs.multi_block_variable(registers);
binary_fuse_args.out.multi_block_variable(registers);
}
FuseOp::Abs(unary_fuse_args)
| FuseOp::Exp(unary_fuse_args)
| FuseOp::Log(unary_fuse_args)
| FuseOp::Log1p(unary_fuse_args)
| FuseOp::Cos(unary_fuse_args)
| FuseOp::Sin(unary_fuse_args)
| FuseOp::Tanh(unary_fuse_args)
| FuseOp::Erf(unary_fuse_args)
| FuseOp::Sqrt(unary_fuse_args)
| FuseOp::Recip(unary_fuse_args)
| FuseOp::Assign(unary_fuse_args) => {
unary_fuse_args.input.multi_block_variable(registers);
unary_fuse_args.out.multi_block_variable(registers);
}
FuseOp::Clamp {
input,
min,
max,
out,
} => {
input.multi_block_variable(registers);
min.multi_block_variable(registers);
max.multi_block_variable(registers);
out.multi_block_variable(registers);
}
FuseOp::ConditionalAssign {
cond,
lhs,
rhs,
out,
} => {
cond.multi_block_variable(registers);
lhs.multi_block_variable(registers);
rhs.multi_block_variable(registers);
out.multi_block_variable(registers);
}
FuseOp::Gather {
input,
indices,
output,
dim: _,
} => {
input.multi_block_variable(registers);
indices.multi_block_variable(registers);
output.multi_block_variable(registers);
}
FuseOp::Select {
input,
indices,
output,
dim: _,
} => {
input.multi_block_variable(registers);
indices.multi_block_variable(registers);
output.multi_block_variable(registers);
}
FuseOp::Dequantize {
values,
params,
output,
scheme: _,
} => {
values.multi_block_variable(registers);
params.multi_block_variable(registers);
output.multi_block_variable(registers);
}
}
}
}
#[cube]
/// Initializes block variables, both globals and locals.
pub fn multi_block_variables_init(
#[comptime] block: &FuseBlockConfig,
variables: &mut MultiBlockVariables,
) {
let output = comptime! {
let mut output = Vec::<(MultiBlockPos, StorageType)>::new();
block.multi_block_variables(&mut output);
output
};
#[unroll]
for i in 0..comptime!(output.len()) {
let (key, dtype) = comptime!(output.get(i).unwrap().clone());
set_polyfill::<NumericExpand<DYN_ELEM_ID>>(dtype);
variables.init(key, block.width);
}
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
/// A reference layout determines how a fuse execution will access elements in tensors.
///
/// It can either follow the same layout as a concrete tensor, or follow a virtual layout.
pub enum RefLayout {
Concrete(FuseArg),
Virtual(VirtualLayout),
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
/// A virtual layout is always contiguous and retrieves its shape from either a reshaped tensor or a
/// tensor with swap dimensions.
pub enum VirtualLayout {
/// Virtual tensor with the provided shape id and contiguous strides.
Reshaped {
reshape_pos: usize,
line_size: LineSize,
},
/// Virtual tensor with the same shape as the given input, but with swap dims and contiguous
/// strides.
SwapDims(FuseArg, (usize, usize)),
/// Virtual tensor with the same shape as the given input, but with contiguous strides.
Shape(FuseArg, usize),
/// We don't have access to global metadata, they are passed as runtime values.
Runtime { pos: usize },
}
impl FuseArg {
/// Adds layout information.
///
/// It's going to impact how the input or output is read and written to.
pub fn add_layout_info(&mut self, layout: LayoutInfo) {
match self {
FuseArg::Input(_, _, old) => {
*old = layout;
}
FuseArg::Output(_, _, old) => {
*old = layout;
}
_ => {}
}
}
}
impl RegistryQuery<Self> for FuseArg {}
impl From<ElemType> for FuseType {
fn from(value: ElemType) -> Self {
match value {
ElemType::Float(kind) => match kind {
FloatKind::F16 => Self::F16,
FloatKind::BF16 => Self::BF16,
FloatKind::F32 => Self::F32,
FloatKind::Flex32 => Self::Flex32,
_ => panic!("Unsupported precision for fusion: {value}"),
},
ElemType::Int(kind) => match kind {
IntKind::I64 => Self::I64,
IntKind::I32 => Self::I32,
IntKind::I16 => Self::I16,
IntKind::I8 => Self::I8,
},
ElemType::UInt(kind) => match kind {
UIntKind::U64 => Self::U64,
UIntKind::U32 => Self::U32,
UIntKind::U16 => Self::U16,
UIntKind::U8 => Self::U8,
},
ElemType::Bool => Self::Bool,
}
}
}
impl From<StorageType> for FuseType {
fn from(value: StorageType) -> Self {
value.elem_type().into()
}
}
impl FuseType {
/// Converts the [fused element type](FuseType) into the [cubecl element type](ElemType).
pub fn into_elem(self) -> ElemType {
match self {
FuseType::F32 => ElemType::Float(FloatKind::F32),
FuseType::Flex32 => ElemType::Float(FloatKind::Flex32),
FuseType::F16 => ElemType::Float(FloatKind::F16),
FuseType::BF16 => ElemType::Float(FloatKind::BF16),
FuseType::I64 => ElemType::Int(IntKind::I64),
FuseType::I32 => ElemType::Int(IntKind::I32),
FuseType::I16 => ElemType::Int(IntKind::I16),
FuseType::I8 => ElemType::Int(IntKind::I8),
FuseType::U64 => ElemType::UInt(UIntKind::U64),
FuseType::U32 => ElemType::UInt(UIntKind::U32),
FuseType::U16 => ElemType::UInt(UIntKind::U16),
FuseType::U8 => ElemType::UInt(UIntKind::U8),
FuseType::Bool => ElemType::Bool,
FuseType::F64 => ElemType::Float(FloatKind::F64),
}
}
/// Convert the [fused element type](FuseType) into the [cubecl storage type](StorageType).
pub fn into_type(self) -> StorageType {
self.into_elem().into()
}
}
impl From<DType> for FuseType {
fn from(value: DType) -> Self {
match value {
DType::F32 => Self::F32,
DType::Flex32 => Self::Flex32,
DType::F16 => Self::F16,
DType::BF16 => Self::BF16,
DType::I64 => Self::I64,
DType::I32 => Self::I32,
DType::I16 => Self::I16,
DType::I8 => Self::I8,
DType::U64 => Self::U64,
DType::U32 => Self::U32,
DType::U16 => Self::U16,
DType::U8 => Self::U8,
DType::Bool => Self::Bool,
DType::F64 => Self::F64,
DType::QFloat(scheme) => match scheme.store {
QuantStore::Native => match scheme.value {
QuantValue::Q8F | QuantValue::Q8S => Self::I8,
QuantValue::E4M3 | QuantValue::E5M2 => {
unimplemented!("Unsupported precision for fusion")
}
QuantValue::Q4F
| QuantValue::Q4S
| QuantValue::Q2F
| QuantValue::Q2S
| QuantValue::E2M1 => {
panic!("Can't store native sub-byte values")
}
},
QuantStore::PackedU32(_) => Self::U32,
QuantStore::PackedNative(_) => match scheme.value {
QuantValue::E2M1 => unimplemented!("Unsupported precision for fusion"),
other => panic!("{other:?} doesn't support native packing"),
},
},
}
}
}
impl Display for FuseArg {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FuseArg::Input(pos, ..) => write!(f, "input({pos})"),
FuseArg::Output(pos, ..) => write!(f, "output({pos})"),
FuseArg::BlockLocal { pos, ty } => write!(f, "local({pos}, {ty:?})"),
FuseArg::MultiBlockLocal(mbp, ..) => write!(f, "{mbp}"),
FuseArg::MultiBlockGlobal(mbp, ..) => write!(f, "global_{mbp}"),
FuseArg::Scalar(pos, ..) => write!(f, "scalar({pos})"),
FuseArg::ScalarShape(pos) => write!(f, "scalar_shape({pos})"),
FuseArg::Literal(val, ..) => write!(f, "literal_{val}"),
FuseArg::InputReshaped { original, .. } => write!(f, "input_reshaped_{original}"),
FuseArg::InputSwapDims { original, .. } => write!(f, "input_swap_dims_{original}"),
}
}
}
impl Display for MultiBlockPos {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"block_local({}-{})",
self.block_pos, self.block_local_pos
)
}
}

View File

@@ -0,0 +1,927 @@
use super::{DYN_ELEM_ID, Q_PARAM_DYN_ELEM_ID, Q_STORE_DYN_ELEM_ID, io::*, ir::*};
use burn_std::quantization::{QuantScheme, QuantStore, QuantValue};
use cubecl::{
ir::{ElemType, FloatKind, StorageType, UIntKind},
prelude::*,
};
use cubek::quantization::{dequantize::dequantize_symmetric_packed_value_at, scheme::QuantMode};
#[cube]
/// Fuse element-wise operations at the given write position.
///
/// # Arguments
///
/// - `inputs`: Contains all readonly global kernel arguments.
/// - `outputs`: Contains all readwrite global kernel arguments.
/// - `locals`: Contains all local variables defined during kernel expansion.
/// - `write_pos`: The logical position the values are written to.
/// - `write_values`: The explicit values to write at the given position.
/// - `write_args`: The arguments associated to the `writes_values`.
/// - `config`: The current [fuse block configuration](FuseBlockConfig).
///
/// # Notes
///
/// The function will start by writing `write_values`.
pub fn fuse_on_write<E: CubePrimitive>(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
locals: &mut LocalArgs,
write_pos: usize,
write_values: Registry<FuseArg, Line<E>>,
#[comptime] write_args: Vec<FuseArg>,
#[comptime] config: &FuseBlockConfig,
) {
comment!("Fuse on write begin");
// Write the values given as arguments.
#[unroll]
for i in 0..write_args.len() {
let arg = comptime![write_args.get(i).unwrap().clone()];
let val = write_values.find(arg.clone());
write::<E>(inputs, outputs, locals, write_pos, val, arg, config);
}
fuse(inputs, outputs, locals, write_pos, config);
comment!("Fuse on write end");
}
#[cube]
/// Fuse element-wise operations at the given read position.
///
/// # Arguments
///
/// - `inputs`: Contains all readonly global kernel arguments.
/// - `outputs`: Contains all readwrite global kernel arguments.
/// - `locals`: Contains all local variables defined during kernel expansion.
/// - `read_pos`: The logical position the values are read from.
/// - `read_args`: The arguments associated to the `read_pos`.
/// - `config`: The current [fuse block configuration](FuseBlockConfig).
///
/// # Returns
///
/// - A sequence of values associated to the given `read_args`.
pub fn fuse_on_read<E: CubePrimitive>(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
locals: &mut LocalArgs,
read_pos: usize,
#[comptime] read_args: Sequence<FuseArg>,
#[comptime] config: &FuseBlockConfig,
) -> Sequence<Line<E>> {
comment!("Fuse on read begin");
fuse(inputs, outputs, locals, read_pos, config);
let mut output = Sequence::new();
#[unroll]
for i in 0..read_args.len() {
let arg = comptime![read_args.index(i).clone()];
let value = read::<E>(inputs, outputs, locals, read_pos, arg, config);
let value_line_size = value.line_size();
let output_line_size = config.width;
// We currently don't support broadcasting __across__ blocks.
if comptime!(value_line_size != output_line_size) {
let mut tmp = Line::<E>::empty(config.width);
comptime!(
assert_eq!(value_line_size, 1, "The input line_size must be 1 or the same as the config width.");
);
let val = value[0];
#[unroll]
for i in 0..config.width {
tmp[i] = val;
}
output.push(tmp);
} else {
output.push(value);
}
}
comment!("Fuse on read end");
output
}
#[cube]
/// Initializes [LocalArgs] given the input and output [arguments](GlobalArgs) with the [FuseBlockConfig].
///
/// # Notes
///
/// The goal is to resolve and cache the reference shape and strides, as it is used in many
/// different function during kernel expansion.
pub fn init_locals(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
#[comptime] config: &FuseBlockConfig,
) -> LocalArgs {
comment!("Init locals begin");
let mut ref_shape = Array::new(config.rank);
let mut ref_strides = Array::new(config.rank);
let locals = match config.ref_layout.clone() {
RefLayout::Concrete(arg) => match comptime![arg] {
FuseArg::Input(index, ..) => {
let layout = inputs.tensors.index(index);
#[unroll]
for i in 0..config.rank {
ref_shape[i] = layout.tensor.shape(i);
ref_strides[i] = layout.tensor.stride(i);
}
LocalArgs::new(
ref_shape.to_slice(),
ref_strides.to_slice(),
layout.tensor.line_size(),
)
}
FuseArg::Output(index, ..) => {
let layout = outputs.tensors.index(index);
#[unroll]
for i in 0..config.rank {
ref_shape[i] = layout.tensor.shape(i);
ref_strides[i] = layout.tensor.stride(i);
}
LocalArgs::new(
ref_shape.to_slice(),
ref_strides.to_slice(),
layout.tensor.line_size(),
)
}
_ => comptime![panic!("Invalid concrete ref layout.")],
},
RefLayout::Virtual(layout) => match layout {
VirtualLayout::SwapDims(original, dims) => {
let layout = match original.clone() {
FuseArg::Input(pos, ..) => inputs.tensors.index(pos),
FuseArg::Output(pos, ..) => outputs.tensors.index(pos),
_ => comptime![panic!("Unsupported")],
};
let mut stride_curr = 1;
#[unroll]
#[allow(clippy::clone_on_copy)]
for i in 0..config.rank {
let reverse = reverse_index(config.rank, i);
let swap = comptime![swap_dims_transform(reverse, dims)];
let shape = layout.tensor.shape(swap.clone());
ref_shape[reverse] = shape;
ref_strides[reverse] = stride_curr;
stride_curr *= ref_shape[comptime![reverse]];
}
LocalArgs::new(
ref_shape.to_slice(),
ref_strides.to_slice(),
layout.tensor.line_size(),
)
}
VirtualLayout::Reshaped {
reshape_pos,
line_size,
} => {
let mut stride_curr = 1;
let start = reshape_pos * config.rank;
#[unroll]
#[allow(clippy::clone_on_copy)]
for i in 0..config.rank {
let reverse = reverse_index(config.rank, i);
let arg = comptime![FuseArg::ScalarShape(start + reverse)];
let shape = read_scalar_shape(inputs, arg.clone());
ref_shape[comptime![reverse]] = shape;
ref_strides[comptime![reverse]] = stride_curr;
stride_curr *= ref_shape[comptime![reverse]];
}
LocalArgs::new(ref_shape.to_slice(), ref_strides.to_slice(), line_size)
}
VirtualLayout::Runtime { pos } => {
let start_shape = (pos * 2) * config.rank;
let start_strides = start_shape + config.rank;
#[unroll]
for i in 0..config.rank {
let shape_index = start_shape + i;
let strides_index = start_strides + i;
ref_shape[i] = *inputs.runtime_layouts.index(shape_index);
ref_strides[i] = *inputs.runtime_layouts.index(strides_index);
}
LocalArgs::new(ref_shape.to_slice(), ref_strides.to_slice(), config.width)
}
VirtualLayout::Shape(original, line_size) => {
let layout = match original.clone() {
FuseArg::Input(pos, ..) => inputs.tensors.index(pos),
FuseArg::Output(pos, ..) => outputs.tensors.index(pos),
_ => comptime![panic!("Unsupported")],
};
let mut stride_curr = 1;
#[unroll]
#[allow(clippy::clone_on_copy)]
for i in 0..config.rank {
let reverse = reverse_index(config.rank, i);
let shape = layout.tensor.shape(reverse);
ref_shape[comptime![reverse]] = shape;
ref_strides[comptime![reverse]] = stride_curr;
stride_curr *= ref_shape[comptime![reverse]];
}
LocalArgs::new(ref_shape.to_slice(), ref_strides.to_slice(), line_size)
}
},
};
comment!("Init locals end");
locals
}
#[cube]
/// Expands all [operations](FuseOp) registered in the [block config](FuseBlockConfig].
fn fuse(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
locals: &mut LocalArgs,
pos: usize,
#[comptime] config: &FuseBlockConfig,
) {
#[unroll]
for index in 0..config.ops.len() {
let op = config.ops[index].clone();
set_polyfill::<NumericExpand<DYN_ELEM_ID>>(op.cmp_type());
match op {
FuseOp::Add(op) => {
add::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::Div(op) => {
div::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::Sub(op) => {
sub::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::Mul(op) => {
mul::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::Powf(op) => {
powf::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::Erf(op) => {
erf::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::Sqrt(op) => {
sqrt::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::Abs(op) => {
abs::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::Log(op) => {
log::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::Log1p(op) => {
log1p::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::Recip(op) => {
recip::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::Assign(op) => {
assign::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::Exp(op) => {
exp::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::Cos(op) => {
cos::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::Sin(op) => {
sin::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::Tanh(op) => {
tanh::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::Equal(op) => {
equal::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::Greater(op) => {
greater::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::GreaterEqual(op) => greater_equal::<NumericExpand<DYN_ELEM_ID>>(
inputs, outputs, locals, pos, op, config,
),
FuseOp::Lower(op) => {
lower::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::LowerEqual(op) => {
lower_equal::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::ConditionalAssign {
cond,
lhs,
rhs,
out,
} => conditional_assign::<NumericExpand<DYN_ELEM_ID>>(
inputs, outputs, locals, pos, cond, lhs, rhs, out, config,
),
FuseOp::Gather {
input,
indices,
output,
dim,
} => gather::<NumericExpand<DYN_ELEM_ID>>(
inputs, outputs, locals, pos, dim, input, indices, output, config,
),
FuseOp::Select {
input,
indices,
output,
dim,
} => select_indices::<NumericExpand<DYN_ELEM_ID>>(
inputs, outputs, locals, pos, dim, input, indices, output, config,
),
FuseOp::Dequantize {
values,
params,
output,
scheme,
} => dequantize::<NumericExpand<DYN_ELEM_ID>>(
inputs,
outputs,
locals,
pos,
values,
params,
output,
scheme.scheme,
config,
),
FuseOp::Rem(op) => {
rem::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
}
FuseOp::Clamp {
input,
min,
max,
out,
} => clamp::<NumericExpand<DYN_ELEM_ID>>(
inputs, outputs, locals, pos, input, min, max, out, config,
),
}
}
}
macro_rules! binary_op {
($ident:ident, $op:tt) => {
#[cube]
fn $ident<C: Numeric>(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
locals: &mut LocalArgs,
write_pos: usize,
#[comptime] op: BinaryFuseArgs,
#[comptime] config: &FuseBlockConfig,
) {
let lhs = read::<C>(inputs, outputs, &locals, write_pos, op.lhs, config);
let rhs = read::<C>(inputs, outputs, &locals, write_pos, op.rhs, config);
let result = lhs $op rhs;
write::<C>(inputs, outputs, locals, write_pos, result, op.out, config);
}
};
}
macro_rules! binary_func {
($ident:ident, $func:expr, $c:tt) => {
#[cube]
fn $ident<C: $c>(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
locals: &mut LocalArgs,
write_pos: usize,
#[comptime] op: BinaryFuseArgs,
#[comptime] config: &FuseBlockConfig,
) {
let lhs = read::<C>(inputs, outputs, &locals, write_pos, op.lhs, config);
let rhs = read::<C>(inputs, outputs, &locals, write_pos, op.rhs, config);
let result = $func(lhs, rhs);
write::<C>(inputs, outputs, locals, write_pos, result, op.out, config);
}
};
}
macro_rules! comparison_op {
($ident:ident, $op:tt) => {
#[cube]
fn $ident<C: CubePrimitive + core::cmp::PartialOrd>(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
locals: &mut LocalArgs,
write_pos: usize,
#[comptime] op: BinaryFuseArgs,
#[comptime] config: &FuseBlockConfig,
) {
let lhs = read::<C>(inputs, outputs, &locals, write_pos, op.lhs, config);
let rhs = read::<C>(inputs, outputs, &locals, write_pos, op.rhs, config);
let result = Line::new(lhs $op rhs);
write::<bool>(inputs, outputs, locals, write_pos, result, op.out, config);
}
};
}
macro_rules! unary_func {
($ident:ident, $func:expr, $c:tt) => {
#[cube]
fn $ident<C: $c>(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
locals: &mut LocalArgs,
write_pos: usize,
#[comptime] op: UnaryFuseArgs,
#[comptime] config: &FuseBlockConfig,
) {
let input = read::<C>(inputs, outputs, &locals, write_pos, op.input, config);
let result = $func(input);
write::<C>(inputs, outputs, locals, write_pos, result, op.out, config);
}
};
}
#[cube]
fn assign<C: CubePrimitive>(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
locals: &mut LocalArgs,
write_pos: usize,
#[comptime] op: UnaryFuseArgs,
#[comptime] config: &FuseBlockConfig,
) {
let input = read::<C>(inputs, outputs, locals, write_pos, op.input, config);
write::<C>(inputs, outputs, locals, write_pos, input, op.out, config);
}
#[cube]
fn gather<C: Numeric>(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
locals: &mut LocalArgs,
write_pos: usize,
#[comptime] dim: usize,
#[comptime] input: FuseArg,
#[comptime] indices: FuseArg,
#[comptime] output: FuseArg,
#[comptime] config: &FuseBlockConfig,
) {
let line_size = locals.ref_line_size;
let pos_input = comptime! {
match input {
FuseArg::Input(pos, ..) => pos,
_ => panic!("Input tensor isn't an input"),
}
};
let pos_indices = comptime! {
match indices {
FuseArg::Input(pos, ..) => pos,
_ => panic!("Indices tensor isn't an input"),
}
};
let stride_input_dim = global_stride(inputs, dim, pos_input);
let mut index = 0;
let mut result = Line::empty(line_size);
if comptime![dim > 0] {
let index_before = global_offset(
inputs,
outputs,
locals,
write_pos,
input.clone(),
comptime![Some((0, dim))],
config,
);
index += index_before;
}
if comptime![dim + 1 < config.rank] {
let index_after = global_offset(
inputs,
outputs,
locals,
write_pos,
input,
comptime![Some((dim + 1, config.rank))],
config,
);
index += index_after;
}
let index_offset = global_offset(
inputs,
outputs,
locals,
write_pos,
indices,
comptime![Some((0, config.rank))],
config,
);
if comptime![dim == config.rank - 1] {
// Per-element indexing (along the dimension)
#[unroll]
for i in 0..line_size {
let offset = read_input::<u32>(
inputs,
locals,
pos_indices,
index_offset + i,
LayoutInfo::IsRef,
config,
None,
);
let input = read_input::<C>(
inputs,
locals,
pos_input,
index + (offset[0] as usize * stride_input_dim),
LayoutInfo::IsRef,
config,
None,
);
result[i] = input[0];
}
} else {
// Shared index for whole line
let stride_input_line = global_stride(inputs, config.rank - 1, pos_input);
let offset = read_input::<u32>(
inputs,
locals,
pos_indices,
index_offset,
LayoutInfo::IsRef,
config,
None,
);
index += offset[0] as usize * stride_input_dim;
#[unroll]
for i in 0..line_size {
let input = read_input::<C>(
inputs,
locals,
pos_input,
index + i * stride_input_line,
LayoutInfo::IsRef,
config,
None,
);
result[i] = input[0];
}
}
write::<C>(inputs, outputs, locals, write_pos, result, output, config);
}
#[cube]
fn select_indices<C: Numeric>(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
locals: &mut LocalArgs,
write_pos: usize,
#[comptime] dim: usize,
#[comptime] input: FuseArg,
#[comptime] indices: FuseArg,
#[comptime] output: FuseArg,
#[comptime] config: &FuseBlockConfig,
) {
let (line_size_ref, stride_dim_ref, shape_dim_ref) = (
locals.ref_line_size,
locals.ref_strides[dim],
locals.ref_shape[dim],
);
let pos_input = comptime! {
match input {
FuseArg::Input(pos, ..) => pos,
_ => panic!("Input tensor isn't an input"),
}
};
let pos_indices = match indices {
FuseArg::Input(pos, ..) => pos,
_ => panic!("Indices tensor isn't an input"),
};
let stride_input_dim = global_stride(inputs, dim, pos_input);
let mut index = 0;
let mut result = Line::empty(line_size_ref);
if comptime![dim != config.rank - 1] {
// In this scenario the select is actually broadcasted along the axis we're working on.
//
// Therefore the same indices are used to fetch multiple entries in the input tensor.
if comptime![dim > 0] {
let index_before = global_offset(
inputs,
outputs,
locals,
write_pos,
input.clone(),
comptime![Some((0, dim))],
config,
);
index += index_before;
}
if comptime![dim + 1 < config.rank] {
let index_after = global_offset(
inputs,
outputs,
locals,
write_pos,
input.clone(),
comptime![Some((dim + 1, config.rank))],
config,
);
index += index_after;
}
let stride_input_line = global_stride(inputs, comptime![config.rank - 1], pos_input);
let write_pos_input = write_pos * line_size_ref;
let coordinate_dim = write_pos_input / stride_dim_ref % shape_dim_ref;
let offset_dim = read_input::<u32>(
inputs,
locals,
pos_indices,
coordinate_dim,
LayoutInfo::IsRef,
config,
None,
);
index += offset_dim[0] as usize * stride_input_dim;
#[unroll]
for i in 0..line_size_ref {
let input = read_input::<C>(
inputs,
locals,
pos_input,
index + i * stride_input_line,
LayoutInfo::IsRef,
config,
None,
);
result[i] = input[0];
}
} else {
// In this scenario the select is actually performed on the last dimension we're working on.
//
// Therefore we need to fetch multiple indices that correspond to different entries in the
// input tensor.
if comptime![dim > 0] {
let index_before = global_offset(
inputs,
outputs,
locals,
write_pos,
input.clone(),
comptime![Some((0, dim))],
config,
);
index += index_before;
}
if comptime![dim + 1 < config.rank] {
let index_after = global_offset(
inputs,
outputs,
locals,
write_pos,
input,
comptime![Some((dim + 1, config.rank))],
config,
);
index += index_after;
}
let write_pos_indices = write_pos * line_size_ref;
#[unroll]
for i in 0..line_size_ref {
let coordinate_dim = (write_pos_indices + i) / stride_dim_ref % shape_dim_ref;
let offset_dim = read_input::<u32>(
inputs,
locals,
pos_indices,
coordinate_dim,
LayoutInfo::IsRef,
config,
None,
);
let input = read_input::<C>(
inputs,
locals,
pos_input,
index + (offset_dim[0] as usize * stride_input_dim),
LayoutInfo::IsRef,
config,
None,
);
result[i] = input[0];
}
}
write::<C>(inputs, outputs, locals, write_pos, result, output, config);
}
#[cube]
fn conditional_assign<C: CubePrimitive>(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
locals: &mut LocalArgs,
write_pos: usize,
#[comptime] cond: FuseArg,
#[comptime] lhs: FuseArg,
#[comptime] rhs: FuseArg,
#[comptime] out: FuseArg,
#[comptime] config: &FuseBlockConfig,
) {
let cond = read::<bool>(inputs, outputs, locals, write_pos, cond, config);
let lhs = read::<C>(inputs, outputs, locals, write_pos, lhs, config);
let rhs = read::<C>(inputs, outputs, locals, write_pos, rhs, config);
let result = select_many(cond, lhs, rhs);
write::<C>(inputs, outputs, locals, write_pos, result, out, config);
}
#[cube]
fn clamp<C: Numeric>(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
locals: &mut LocalArgs,
write_pos: usize,
#[comptime] input: FuseArg,
#[comptime] min: FuseArg,
#[comptime] max: FuseArg,
#[comptime] out: FuseArg,
#[comptime] config: &FuseBlockConfig,
) {
let input = read::<C>(inputs, outputs, locals, write_pos, input, config);
let min = read::<C>(inputs, outputs, locals, write_pos, min, config);
let max = read::<C>(inputs, outputs, locals, write_pos, max, config);
let result = cubecl::prelude::clamp(input, min, max);
write::<C>(inputs, outputs, locals, write_pos, result, out, config);
}
#[cube]
#[allow(clippy::explicit_counter_loop)]
fn dequantize<C: Float>(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
locals: &mut LocalArgs,
write_pos: usize,
#[comptime] input: FuseArg,
#[comptime] scales: FuseArg,
#[comptime] output: FuseArg,
#[comptime] scheme: QuantScheme,
#[comptime] config: &FuseBlockConfig,
) {
comptime!(assert_eq!(
scheme.mode,
QuantMode::Symmetric,
"Only symmetric quantization mode is supported."
));
set_polyfill::<NumericExpand<Q_STORE_DYN_ELEM_ID>>(comptime![match scheme.store {
QuantStore::Native => match scheme.value {
QuantValue::Q8F | QuantValue::Q8S => StorageType::Scalar(ElemType::UInt(UIntKind::U8)),
QuantValue::E4M3 => StorageType::Scalar(ElemType::Float(FloatKind::E4M3)),
QuantValue::E5M2 => StorageType::Scalar(ElemType::Float(FloatKind::E5M2)),
QuantValue::Q4F
| QuantValue::Q4S
| QuantValue::Q2F
| QuantValue::Q2S
| QuantValue::E2M1 => unreachable!("Can't store native sub-byte values"),
},
QuantStore::PackedU32(_) => ElemType::UInt(UIntKind::U32).into(),
QuantStore::PackedNative(_) => match scheme.value {
QuantValue::E2M1 => StorageType::Packed(ElemType::Float(FloatKind::E4M3), 2),
other => panic!("{other:?} doesn't support native packing"),
},
}]);
set_polyfill::<NumericExpand<Q_PARAM_DYN_ELEM_ID>>(comptime![match scheme.param {
cubecl::quant::scheme::QuantParam::F32 =>
StorageType::Scalar(ElemType::Float(FloatKind::F32)),
cubecl::quant::scheme::QuantParam::F16 =>
StorageType::Scalar(ElemType::Float(FloatKind::F16)),
cubecl::quant::scheme::QuantParam::BF16 =>
StorageType::Scalar(ElemType::Float(FloatKind::BF16)),
cubecl::quant::scheme::QuantParam::UE8M0 =>
StorageType::Scalar(ElemType::Float(FloatKind::UE8M0)),
cubecl::quant::scheme::QuantParam::UE4M3 =>
StorageType::Scalar(ElemType::Float(FloatKind::E4M3)),
}]);
let tensor_pos = comptime!(match input {
FuseArg::Input(pos, _, _) => pos,
_ => panic!("Not supported"),
});
let pos = comptime!(match scales {
FuseArg::Input(pos, ..) => pos,
_ => unreachable!(""),
});
let input = read_quantized::<NumericExpand<Q_STORE_DYN_ELEM_ID>>(
inputs, locals, write_pos, input, config, scheme,
);
let line_size = input.line_size();
let num_quants = scheme.num_quants();
let scales = input_as_scales_view::<NumericExpand<Q_PARAM_DYN_ELEM_ID>>(
inputs,
pos,
tensor_pos,
scheme.level,
config,
);
let result = dequantize_symmetric_packed_value_at::<
C,
ElemExpand<Q_PARAM_DYN_ELEM_ID>,
ElemExpand<Q_STORE_DYN_ELEM_ID>,
>(write_pos * num_quants, input, &scales, scheme);
let line_size_result = comptime!(num_quants * line_size);
let line = if comptime!(line_size == 1) {
result[0]
} else {
let mut line = Line::empty(line_size_result);
#[unroll]
for i in 0..line_size {
let value = result[i];
#[unroll]
for j in 0..num_quants {
let index = i * num_quants + j;
line[index] = value[j];
}
}
line
};
write::<C>(inputs, outputs, locals, write_pos, line, output, config);
}
binary_op!(add, +);
binary_op!(mul, *);
binary_op!(div, /);
binary_op!(sub, -);
comparison_op!(equal, ==);
comparison_op!(greater, >);
comparison_op!(greater_equal, >=);
comparison_op!(lower, <);
comparison_op!(lower_equal, <=);
binary_func!(powf, Line::<C>::powf, Float);
binary_func!(rem, Line::<C>::rem, Float);
unary_func!(exp, Line::<C>::exp, Float);
unary_func!(log, Line::<C>::ln, Float);
unary_func!(log1p, Line::<C>::log1p, Float);
unary_func!(sqrt, Line::<C>::sqrt, Float);
unary_func!(cos, Line::<C>::cos, Float);
unary_func!(sin, Line::<C>::sin, Float);
unary_func!(tanh, Line::<C>::tanh, Float);
unary_func!(erf, Line::<C>::erf, Float);
unary_func!(recip, Line::<C>::recip, Float);
unary_func!(abs, Line::<C>::abs, Numeric);

View File

@@ -0,0 +1,8 @@
pub(crate) mod io;
pub(crate) mod ir;
pub(crate) mod kernel;
pub(crate) mod tensor;
pub(crate) mod view;
mod base;
pub(crate) use base::*;

View File

@@ -0,0 +1,90 @@
use super::DYN_ELEM_ID;
use cubecl::{
ir::{ElemType, Type},
prelude::*,
};
use serde::{Deserialize, Serialize};
use std::hash::Hash;
/// Represents a global tensor with the given [element type](ElemType).
///
/// # Warning
///
/// The `tensor` field type [Line<NumericExpand<DYN_ELEM_ID>>] must be set using polyfill before
/// use.
#[derive(CubeType, Clone)]
pub struct GlobalTensor {
/// The global tensor type.
pub tensor: Tensor<Line<NumericExpand<DYN_ELEM_ID>>>,
/// The element type of the tensor.
#[cube(comptime)]
pub elem: ElemType,
/// Whether the current tensor is logically broadcasted.
#[cube(comptime)]
pub broadcasted: bool,
}
// Everything below is to implement [LaunchArg].
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash, Debug)]
pub struct GlobalTensorCompilationArg {
tensor: TensorCompilationArg,
elem: ElemType,
broadcasted: bool,
}
#[derive(new, Debug)]
pub struct GlobalTensorArg<'a, R: Runtime> {
pub tensor: <Tensor<Line<NumericExpand<DYN_ELEM_ID>>> as LaunchArg>::RuntimeArg<'a, R>,
pub elem: ElemType,
pub broadcasted: bool,
pub address_type: AddressType,
}
impl CompilationArg for GlobalTensorCompilationArg {}
impl LaunchArg for GlobalTensor {
type RuntimeArg<'a, R: Runtime> = GlobalTensorArg<'a, R>;
type CompilationArg = GlobalTensorCompilationArg;
fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
let tensor = <Tensor<Line<NumericExpand<DYN_ELEM_ID>>> as LaunchArg>::compilation_arg(
&runtime_arg.tensor,
);
GlobalTensorCompilationArg {
tensor,
elem: runtime_arg.elem,
broadcasted: runtime_arg.broadcasted,
}
}
fn expand(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> GlobalTensorExpand {
let tensor = builder.input_tensor(Type::scalar(arg.elem).line(arg.tensor.line_size));
GlobalTensorExpand {
tensor: tensor.into(),
elem: arg.elem,
broadcasted: arg.broadcasted,
}
}
fn expand_output(
arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
) -> GlobalTensorExpand {
let tensor = match arg.tensor.inplace {
Some(id) => builder.inplace_output(id),
None => builder.output_tensor(Type::scalar(arg.elem).line(arg.tensor.line_size)),
};
GlobalTensorExpand {
tensor: tensor.into(),
elem: arg.elem,
broadcasted: arg.broadcasted,
}
}
}
impl<R: Runtime> ArgSettings<R> for GlobalTensorArg<'_, R> {
fn register(&self, launcher: &mut KernelLauncher<R>) {
launcher.register_tensor(&self.tensor)
}
}

View File

@@ -0,0 +1,358 @@
use super::{
DYN_ELEM_ID,
io::{
Transform, global_buffer_len, global_line_size, input_as_slice, read_input,
read_input_window, ref_buffer_len, ref_len,
},
ir::{FuseArg, FuseBlockConfig, GlobalArgs, LayoutInfo, LocalArgs},
kernel::fuse_on_write,
};
use cubecl::{
CubeType,
io::read_masked,
ir::StorageType,
prelude::{barrier::BarrierExpand, *},
std::tensor::{
ViewOperations, ViewOperationsExpand, ViewOperationsMut, ViewOperationsMutExpand,
layout::Coords1d,
},
};
#[allow(dead_code, reason = "only used in expand")]
#[derive(CubeType)]
pub struct GlobalInput {
inputs: GlobalArgs,
locals: LocalArgs,
#[cube(comptime)]
pos: usize,
#[cube(comptime)]
ty: StorageType,
#[cube(comptime)]
layout: LayoutInfo,
#[cube(comptime)]
config: FuseBlockConfig,
#[cube(comptime)]
transform: Option<Transform>,
}
#[cube]
impl GlobalInput {
pub fn new(
inputs: &GlobalArgs,
locals: &LocalArgs,
#[comptime] arg: FuseArg,
#[comptime] config: FuseBlockConfig,
#[comptime] transform: Option<Transform>,
) -> GlobalInput {
let (pos, ty, layout) = comptime![match arg {
FuseArg::Input(pos, prec, layout) => (pos, prec.into_type(), layout),
_ => unreachable!("Must be concrete input"),
}];
GlobalInput {
inputs: inputs.clone(),
locals: locals.clone(),
pos,
ty,
layout,
config,
transform,
}
}
}
impl<E: CubePrimitive> ViewOperations<E, Coords1d> for GlobalInput {}
impl<E: CubePrimitive> ViewOperationsExpand<E, Coords1d> for GlobalInputExpand {
#[allow(clippy::too_many_arguments)]
fn __expand_read_method(
&self,
scope: &mut Scope,
pos: ExpandElementTyped<usize>,
) -> <E as CubeType>::ExpandType {
ViewOperationsExpand::<E, Coords1d>::__expand_read_unchecked_method(self, scope, pos)
}
#[allow(clippy::too_many_arguments)]
fn __expand_read_checked_method(
&self,
scope: &mut Scope,
pos: ExpandElementTyped<usize>,
) -> <E as CubeType>::ExpandType {
let zero = E::__expand_cast_from(scope, 0.into());
ViewOperationsExpand::<E, Coords1d>::__expand_read_masked_method(self, scope, pos, zero)
}
#[allow(clippy::too_many_arguments)]
fn __expand_read_masked_method(
&self,
scope: &mut Scope,
pos: ExpandElementTyped<usize>,
value: <E as CubeType>::ExpandType,
) -> <E as CubeType>::ExpandType {
let in_bounds = ViewOperationsExpand::<E, Coords1d>::__expand_is_in_bounds_method(
self,
scope,
pos.clone(),
);
scope.register_type::<NumericExpand<DYN_ELEM_ID>>(self.ty);
let slice = input_as_slice::expand(scope, self.inputs.clone(), self.pos);
read_masked::expand::<E>(scope, in_bounds, slice, pos, value)
}
#[allow(clippy::too_many_arguments)]
fn __expand_read_unchecked_method(
&self,
scope: &mut Scope,
pos: ExpandElementTyped<usize>,
) -> <E as CubeType>::ExpandType {
let value = read_input::expand::<E>(
scope,
self.inputs.clone(),
self.locals.clone(),
self.pos,
pos,
self.layout,
self.config.clone(),
self.transform.clone(),
);
E::__expand_cast_from(scope, value)
}
#[allow(clippy::too_many_arguments)]
fn __expand_to_linear_slice_method(
&self,
scope: &mut Scope,
pos: ExpandElementTyped<usize>,
end: ExpandElementTyped<usize>,
) -> SliceExpand<E, ReadOnly> {
scope.register_type::<NumericExpand<DYN_ELEM_ID>>(self.ty);
let end = add::expand(scope, end.clone(), 1.into());
read_input_window::expand(scope, self.inputs.clone(), self.pos, pos, end)
}
#[allow(clippy::too_many_arguments)]
fn __expand_tensor_map_load_method(
&self,
_scope: &mut Scope,
_barrier: BarrierExpand,
_shared_memory: SliceExpand<E, ReadWrite>,
_pos: ExpandElementTyped<usize>,
) {
panic!("Not a tensor map")
}
#[allow(clippy::too_many_arguments)]
fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
global_buffer_len::expand(scope, self.inputs.clone(), self.pos)
}
#[allow(clippy::too_many_arguments)]
fn __expand_is_in_bounds_method(
&self,
scope: &mut Scope,
pos: ExpandElementTyped<usize>,
) -> ExpandElementTyped<bool> {
let buffer_len = global_buffer_len::expand(scope, self.inputs.clone(), self.pos);
lt::expand(scope, pos, buffer_len)
}
}
impl Lined for GlobalInput {}
impl LinedExpand for GlobalInputExpand {
fn line_size(&self) -> LineSize {
let mut temp_scope = Scope::root(false);
global_line_size::expand(&mut temp_scope, self.inputs.clone(), self.pos)
}
}
#[allow(dead_code, reason = "only used in expand")]
#[derive(CubeType)]
pub struct FusedOutput {
inputs: GlobalArgs,
outputs: GlobalArgs,
locals: LocalArgs,
arg: FuseArg,
#[cube(comptime)]
config: FuseBlockConfig,
}
#[cube]
impl FusedOutput {
pub fn new(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
locals: &mut LocalArgs,
arg: FuseArg,
#[comptime] config: FuseBlockConfig,
) -> Self {
FusedOutput {
inputs: inputs.clone(),
outputs: outputs.clone(),
locals: locals.clone(),
arg,
config,
}
}
}
impl<E: CubePrimitive> ViewOperations<Line<E>, Coords1d> for FusedOutput {}
impl<E: CubePrimitive> ViewOperationsExpand<Line<E>, Coords1d> for FusedOutputExpand {
#[allow(clippy::too_many_arguments)]
fn __expand_read_method(
&self,
_scope: &mut Scope,
_pos: ExpandElementTyped<usize>,
) -> <Line<E> as CubeType>::ExpandType {
todo!()
}
#[allow(clippy::too_many_arguments)]
fn __expand_read_checked_method(
&self,
_scope: &mut Scope,
_pos: ExpandElementTyped<usize>,
) -> <Line<E> as CubeType>::ExpandType {
todo!()
}
#[allow(clippy::too_many_arguments)]
fn __expand_read_masked_method(
&self,
_scope: &mut Scope,
_pos: ExpandElementTyped<usize>,
_value: <Line<E> as CubeType>::ExpandType,
) -> <Line<E> as CubeType>::ExpandType {
todo!()
}
#[allow(clippy::too_many_arguments)]
fn __expand_read_unchecked_method(
&self,
_scope: &mut Scope,
_pos: ExpandElementTyped<usize>,
) -> <Line<E> as CubeType>::ExpandType {
todo!()
}
#[allow(clippy::too_many_arguments)]
fn __expand_to_linear_slice_method(
&self,
_scope: &mut Scope,
_pos: ExpandElementTyped<usize>,
_size: ExpandElementTyped<usize>,
) -> SliceExpand<Line<E>, ReadOnly> {
todo!()
}
#[allow(clippy::too_many_arguments)]
fn __expand_tensor_map_load_method(
&self,
_scope: &mut Scope,
_barrier: BarrierExpand,
_shared_memory: SliceExpand<Line<E>, ReadWrite>,
_pos: ExpandElementTyped<usize>,
) {
panic!("Not a tensor map")
}
#[allow(clippy::too_many_arguments)]
fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
ref_len::expand(
scope,
self.inputs.clone(),
self.outputs.clone(),
self.locals.clone(),
self.config.clone(),
)
}
#[allow(clippy::too_many_arguments)]
fn __expand_is_in_bounds_method(
&self,
scope: &mut Scope,
pos: ExpandElementTyped<usize>,
) -> ExpandElementTyped<bool> {
let buffer_len = ref_buffer_len::expand(
scope,
self.inputs.clone(),
self.outputs.clone(),
self.locals.clone(),
self.config.clone(),
);
lt::expand(scope, pos, buffer_len)
}
}
impl<E: CubePrimitive> ViewOperationsMut<Line<E>, Coords1d> for FusedOutput {}
impl<E: CubePrimitive> ViewOperationsMutExpand<Line<E>, Coords1d> for FusedOutputExpand {
#[allow(clippy::too_many_arguments)]
fn __expand_write_method(
&self,
scope: &mut Scope,
pos: ExpandElementTyped<usize>,
value: <Line<E> as CubeType>::ExpandType,
) {
let values = Registry::<FuseArg, Line<E>>::__expand_new(scope);
let mut args = comptime![Vec::<FuseArg>::new()];
values
.clone()
.__expand_insert_method(scope, comptime![self.arg.clone()], value);
comptime![args.push(self.arg.clone())];
fuse_on_write::expand(
scope,
self.inputs.clone(),
self.outputs.clone(),
self.locals.clone(),
pos,
values,
args,
self.config.clone(),
);
}
#[allow(clippy::too_many_arguments)]
fn __expand_write_checked_method(
&self,
scope: &mut Scope,
pos: ExpandElementTyped<usize>,
value: <Line<E> as CubeType>::ExpandType,
) {
let in_bounds = ViewOperationsExpand::<Line<E>, Coords1d>::__expand_is_in_bounds_method(
self,
scope,
pos.clone(),
);
if_expand(scope, in_bounds.into(), |scope| {
self.__expand_write_method(scope, pos, value);
})
}
#[allow(clippy::too_many_arguments)]
fn __expand_to_linear_slice_mut_method(
&self,
_scope: &mut Scope,
_pos: ExpandElementTyped<usize>,
_size: ExpandElementTyped<usize>,
) -> SliceExpand<Line<E>, ReadWrite> {
todo!("Not yet supported")
}
#[allow(clippy::too_many_arguments)]
fn __expand_tensor_map_store_method(
&self,
_scope: &mut Scope,
_shared_memory: SliceExpand<Line<E>, ReadOnly>,
_pos: ExpandElementTyped<usize>,
) {
panic!("Not a tensor map")
}
}
impl Lined for FusedOutput {}
impl LinedExpand for FusedOutputExpand {
fn line_size(&self) -> LineSize {
self.locals.ref_line_size
}
}

View File

@@ -0,0 +1,769 @@
use super::{
codegen::ir::{BinaryFuseArgs, FuseArg, FuseOp, FuseType, UnaryFuseArgs},
settings::FuseSettings,
trace::{FuseTrace, TraceFuser, block::QuantInput},
};
use crate::engine::codegen::ir::QuantSchemeFuse;
use burn_fusion::{FuserProperties, FuserStatus, OperationFuser};
use burn_ir::{
BaseOperationIr, BinaryOpIr, FloatOperationIr, NumericOperationIr, OperationIr, ScalarOpIr,
TensorIr, UnaryOpIr,
};
use burn_std::{DType, Shape};
use cubecl::ir::ElemType;
/// The base operation fuser that can be used to fuse [all supported fuse operations](FuseOp).
///
///
/// This fuser doesn't create a ready-to-execute kernel, but rather generates a
/// [trace](FuseTrace) that be used with a [runner](super::trace::TraceRunner).
///
/// Since this fuser supports fusing multiple blocks, you can fuse any compute-bound operations
/// with the combination of fuse-on-read and fuse-on-write strategy.
///
/// # Notes
///
/// It is responsible to translate [OperationIr] into [FuseOp] and it uses the [TraceFuser]
/// to actually fuse the [FuseOp] when possible.
#[derive(Debug, Clone)]
pub(crate) struct TraceOperationFuser {
fuser: TryTraceFuser,
pub(crate) settings: FuseSettings,
pub(crate) current_output_shape: Shape,
status: FuserStatus,
pub(crate) num_ops: usize,
pub(crate) num_views: usize,
pub(crate) max_bindings: u32,
}
impl TraceOperationFuser {
/// Checks if the [operation](OperationIr) can be fused with the current fuser.
pub(crate) fn can_fuse(&self, op: &OperationIr) -> bool {
let len_previous = self.len();
let mut fuser_cloned = self.clone();
fuser_cloned.fuse(op);
let len_after = fuser_cloned.len();
len_after > len_previous
}
}
impl OperationFuser<FuseTrace> for TraceOperationFuser {
fn fuse(&mut self, op: &OperationIr) {
if let FuserStatus::Closed = self.status {
return;
}
match op {
OperationIr::Drop(tensor) => {
if self.num_ops == 0 {
self.status = FuserStatus::Closed;
return;
}
self.fuser.fuser.fuse_dropped(tensor);
}
OperationIr::BaseFloat(ops) => {
if !self.fuse_base(ops) {
self.status = FuserStatus::Closed;
return;
}
}
OperationIr::BaseInt(ops) => {
if !self.fuse_base(ops) {
self.status = FuserStatus::Closed;
return;
}
}
OperationIr::Float(_dtype, ops) => {
if !self.fuse_float(ops) {
self.status = FuserStatus::Closed;
return;
}
}
OperationIr::NumericFloat(_dtype, ops) => {
if !self.fuse_numeric(ops) {
self.status = FuserStatus::Closed;
return;
}
}
OperationIr::NumericInt(_dtype, ops) => {
if !self.fuse_numeric(ops) {
self.status = FuserStatus::Closed;
return;
}
}
OperationIr::BaseBool(ops) => {
if !self.fuse_base(ops) {
self.status = FuserStatus::Closed;
return;
}
}
_ => {
self.status = FuserStatus::Closed;
return;
}
};
self.status = FuserStatus::Open;
self.num_ops += 1;
}
fn finish(&mut self) -> FuseTrace {
self.fuser.finish(self.current_output_shape.clone())
}
fn len(&self) -> usize {
self.num_ops
}
fn reset(&mut self) {
self.num_ops = 0;
self.num_views = 0;
self.status = FuserStatus::Open;
self.fuser = TryTraceFuser::new(
self.max_bindings,
self.fuser.fuser.bool_precision,
self.settings,
);
self.current_output_shape = Shape::new([]);
}
fn status(&self) -> FuserStatus {
self.status
}
fn properties(&self) -> FuserProperties {
let ready = self.num_ops > 0;
FuserProperties {
ready,
score: self.num_ops as u64,
}
}
fn clone_dyn(&self) -> Box<dyn OperationFuser<FuseTrace>> {
Box::new(self.clone())
}
}
impl TraceOperationFuser {
/// Creates a new fuser.
pub fn new(max_bindings: u32, bool_precision: FuseType, settings: FuseSettings) -> Self {
Self {
fuser: TryTraceFuser::new(max_bindings, bool_precision, settings),
settings,
num_ops: 0,
num_views: 0,
max_bindings,
current_output_shape: Shape::new([]),
status: FuserStatus::Open,
}
}
/// Closes the fuser.
pub fn close(&mut self) {
self.status = FuserStatus::Closed;
}
/// Declares an input tensor argument where the kernel is responsible to load.
///
/// # Returns
///
/// - The argument that maps to the tensor to be used during kernel expansion.
pub fn input_unhandled(&mut self, tensor: &TensorIr) -> FuseArg {
self.fuser.fuser.input_unhandled(tensor)
}
/// Declares an input quantized tensor argument where the kernel is responsible to load.
///
/// # Returns
///
/// None if it's not possible to fuse a quantized tensor. Otherwise:
///
/// - The argument that maps to the tensor values to be used during kernel expansion.
/// - The argument that maps to the tensor params to be used during kernel expansion.
pub fn input_quantized_unhandled(&mut self, tensor: &TensorIr) -> Option<(FuseArg, FuseArg)> {
self.fuser.fuser.input_quantized_unhandled(tensor)
}
/// Declares an output tensor argument where the kernel is responsible to write values.
///
/// # Notes
///
/// Normally you don't have to declare outputs explicitly before they are going to be
/// fused based on the operations [fused](Self::fuse).
///
/// # Returns
///
/// - The argument that maps to the tensor to be used during kernel expansion.
pub fn output_unhandled(&mut self, tensor: &TensorIr) -> FuseArg {
if self.current_output_shape.is_empty() {
self.current_output_shape = tensor.shape.clone();
} else if self.current_output_shape.iter().sum::<usize>() < tensor.shape.iter().sum() {
// The larguest shape win.
self.current_output_shape = tensor.shape.clone();
}
self.fuser.fuser.output_unhandled(tensor)
}
/// Closes the previous block and declares a new one.
///
/// # Arguments
///
/// - arguments: Tensors that are logical outputs of the current block and inputs of the following blocks.
/// - settings: [FuseSettings] to be used by the next block.
///
/// # Returns
///
/// None if it's impossible to create a next block with the given arguments. Otherwise, the
/// corresponding [arguments](Arg) to the given tensors are returned.
pub fn next_block<const N: usize>(
&mut self,
arguments: [&TensorIr; N],
settings: FuseSettings,
global: bool,
) -> [FuseArg; N] {
let block_pos = self.fuser.fuser.num_previous_blocks();
let current_output_shape =
core::mem::replace(&mut self.current_output_shape, Shape::new([]));
self.fuser.fuser.next_block(current_output_shape, settings);
self.settings = settings;
self.status = FuserStatus::Open;
arguments.map(|arg| self.fuser.fuser.block_local_input(arg, block_pos, global))
}
/// Tag the [tensor](TensorIr) as received from a previous block.
///
/// This will avoid reading the input again and instead use le local version when possible.
pub fn block_local_input(&mut self, tensor: &TensorIr, block_pos: usize, global: bool) {
self.fuser
.fuser
.block_local_input(tensor, block_pos, global);
}
fn fuse_base(&mut self, ops: &BaseOperationIr) -> bool {
match ops {
BaseOperationIr::Equal(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
FuseOp::Equal(BinaryFuseArgs { lhs, rhs, out })
}),
BaseOperationIr::EqualElem(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {
FuseOp::Equal(BinaryFuseArgs { lhs, rhs, out })
}),
BaseOperationIr::Cast(desc) => {
self.fuse_unary_op(&desc.input, &desc.out, |input, out| {
FuseOp::Assign(UnaryFuseArgs { input, out })
})
}
BaseOperationIr::SwapDims(desc) => {
if !self.output_is_compatible(&desc.out) {
return false;
}
if self.fuser.fuse(|fuser| {
fuser.input_swap_dims(&desc.input, &desc.out, (desc.dim1, desc.dim2))?;
Some(())
}) {
self.num_views += 1;
true
} else {
false
}
}
BaseOperationIr::Reshape(desc) => {
if desc.input.shape == desc.out.shape {
return self.fuse_unary_op(&desc.input, &desc.out, |input, out| {
FuseOp::Assign(UnaryFuseArgs { input, out })
});
}
if desc.input.shape.rank() > desc.out.shape.rank() {
// Not yet supported.
return false;
}
if !self.output_is_compatible(&desc.out) {
return false;
}
if self.fuser.fuse(|fuser| {
fuser.input_reshaped(&desc.input, &desc.out)?;
Some(())
}) {
self.num_views += 1;
true
} else {
false
}
}
BaseOperationIr::Ones(desc) => {
if !self.output_is_compatible(&desc.out) {
return false;
}
let elem: ElemType = desc.out.dtype.into();
let precision = elem.into();
let input = FuseArg::Literal(1, precision);
self.fuser.fuse(|fuser| {
let out = fuser.output(&desc.out)?;
fuser.fuse_operation(FuseOp::Assign(UnaryFuseArgs { input, out }));
Some(())
})
}
BaseOperationIr::Zeros(desc) => {
if !self.output_is_compatible(&desc.out) {
return false;
}
let elem: ElemType = desc.out.dtype.into();
let precision = elem.into();
let input = FuseArg::Literal(0, precision);
self.fuser.fuse(|fuser| {
let out = fuser.output(&desc.out)?;
fuser.fuse_operation(FuseOp::Assign(UnaryFuseArgs { input, out }));
Some(())
})
}
BaseOperationIr::Gather(desc) => {
if !self.output_is_compatible(&desc.out) {
return false;
}
self.fuser.fuse(|build| {
let input = build.input_indexed(&desc.tensor)?;
let indices = build.input_indexed(&desc.indices)?;
let output = build.output(&desc.out)?;
build.fuse_operation(FuseOp::Gather {
input,
indices,
output,
dim: desc.dim,
});
Some(())
})
}
BaseOperationIr::Select(desc) => {
if !self.output_is_compatible(&desc.out) {
return false;
}
self.fuser.fuse(|build| {
let input = build.input_indexed(&desc.tensor)?;
let indices = build.input_indexed(&desc.indices)?;
let output = build.output(&desc.out)?;
build.fuse_operation(FuseOp::Select {
input,
indices,
output,
dim: desc.dim,
});
Some(())
})
}
BaseOperationIr::MaskWhere(desc) => {
if !self.output_is_compatible(&desc.out) {
return false;
}
self.fuser.fuse(|build| {
let cond = build.input(&desc.mask)?;
let rhs = build.input(&desc.tensor)?;
let lhs = build.input(&desc.value)?;
let out = build.output(&desc.out)?;
build.fuse_operation(FuseOp::ConditionalAssign {
cond,
lhs,
rhs,
out,
});
Some(())
})
}
BaseOperationIr::MaskFill(desc) => {
if !self.output_is_compatible(&desc.out) {
return false;
}
self.fuser.fuse(|build| {
let cond = build.input(&desc.mask)?;
let lhs = build.scalar(&desc.value, desc.out.dtype);
let rhs = build.input(&desc.tensor)?;
let out = build.output(&desc.out)?;
build.fuse_operation(FuseOp::ConditionalAssign {
cond,
lhs,
rhs,
out,
});
Some(())
})
}
_ => false,
}
}
fn fuse_float(&mut self, ops: &FloatOperationIr) -> bool {
match ops {
FloatOperationIr::Exp(desc) => {
self.fuse_unary_ops(desc, |input, out| FuseOp::Exp(UnaryFuseArgs { input, out }))
}
FloatOperationIr::Log(desc) => {
self.fuse_unary_ops(desc, |input, out| FuseOp::Log(UnaryFuseArgs { input, out }))
}
FloatOperationIr::Log1p(desc) => self.fuse_unary_ops(desc, |input, out| {
FuseOp::Log1p(UnaryFuseArgs { input, out })
}),
FloatOperationIr::Cos(desc) => {
self.fuse_unary_ops(desc, |input, out| FuseOp::Cos(UnaryFuseArgs { input, out }))
}
FloatOperationIr::Sin(desc) => {
self.fuse_unary_ops(desc, |input, out| FuseOp::Sin(UnaryFuseArgs { input, out }))
}
FloatOperationIr::PowfScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {
FuseOp::Powf(BinaryFuseArgs { lhs, rhs, out })
}),
FloatOperationIr::Tanh(desc) => self.fuse_unary_ops(desc, |input, out| {
FuseOp::Tanh(UnaryFuseArgs { input, out })
}),
FloatOperationIr::Erf(desc) => {
self.fuse_unary_ops(desc, |input, out| FuseOp::Erf(UnaryFuseArgs { input, out }))
}
FloatOperationIr::Sqrt(desc) => self.fuse_unary_ops(desc, |input, out| {
FuseOp::Sqrt(UnaryFuseArgs { input, out })
}),
FloatOperationIr::Recip(desc) => self.fuse_unary_ops(desc, |input, out| {
FuseOp::Recip(UnaryFuseArgs { input, out })
}),
FloatOperationIr::Dequantize(desc) => {
if !self.output_is_compatible(&desc.out) {
return false;
}
self.fuser.fuse(|build| {
let qinput = build.input_quantized(&desc.input)?;
let out = build.output(&desc.out)?;
match qinput {
QuantInput::AlreadyDequantized { local } => {
build.fuse_operation(FuseOp::Assign(UnaryFuseArgs {
input: local,
out,
}));
}
QuantInput::Quantized { values, params } => {
build.fuse_operation(FuseOp::Dequantize {
values,
params,
output: out,
scheme: match desc.input.dtype {
DType::QFloat(scheme) => QuantSchemeFuse { scheme },
_ => unreachable!("Should be a quant tensor."),
},
});
}
}
Some(())
})
}
_ => false,
}
}
fn fuse_numeric(&mut self, op: &NumericOperationIr) -> bool {
match op {
NumericOperationIr::Add(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
FuseOp::Add(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::AddScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {
FuseOp::Add(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::Sub(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
FuseOp::Sub(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::SubScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {
FuseOp::Sub(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::Mul(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
FuseOp::Mul(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::MulScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {
FuseOp::Mul(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::Div(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
FuseOp::Div(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::DivScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {
FuseOp::Div(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::Abs(desc) => {
self.fuse_unary_ops(desc, |input, out| FuseOp::Abs(UnaryFuseArgs { input, out }))
}
NumericOperationIr::Lower(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
FuseOp::Lower(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::LowerElem(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {
FuseOp::Lower(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::Greater(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
FuseOp::Greater(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::GreaterElem(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {
FuseOp::Greater(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::LowerEqual(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
FuseOp::LowerEqual(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::LowerEqualElem(desc) => self
.fuse_scalar_ops(desc, |lhs, rhs, out| {
FuseOp::LowerEqual(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::GreaterEqual(desc) => self
.fuse_binary_ops(desc, |lhs, rhs, out| {
FuseOp::GreaterEqual(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::GreaterEqualElem(desc) => self
.fuse_scalar_ops(desc, |lhs, rhs, out| {
FuseOp::GreaterEqual(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::Full(desc) => {
if !self.output_is_compatible(&desc.out) {
return false;
}
self.fuser.fuse(|build| {
let input = build.scalar(&desc.value, desc.out.dtype);
let out = build.output(&desc.out)?;
build.fuse_operation(FuseOp::Assign(UnaryFuseArgs { input, out }));
Some(())
})
}
NumericOperationIr::Rem(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
FuseOp::Rem(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::RemScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {
FuseOp::Rem(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::Powf(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
FuseOp::Powf(BinaryFuseArgs { lhs, rhs, out })
}),
NumericOperationIr::Clamp(desc) => {
if !self.output_is_compatible(&desc.out) {
return false;
}
self.fuser.fuse(|build| {
let input = build.input(&desc.tensor)?;
let min = build.scalar(&desc.min, desc.out.dtype);
let max = build.scalar(&desc.max, desc.out.dtype);
let out = build.output(&desc.out)?;
build.fuse_operation(FuseOp::Clamp {
input,
min,
max,
out,
});
Some(())
})
}
_ => false,
}
}
fn fuse_binary_ops<Func>(&mut self, desc: &BinaryOpIr, func: Func) -> bool
where
Func: Fn(FuseArg, FuseArg, FuseArg) -> FuseOp,
{
if !self.output_is_compatible(&desc.out) {
return false;
}
self.fuser.fuse(|build| {
let lhs = build.input(&desc.lhs)?;
let rhs = build.input(&desc.rhs)?;
let out = build.output(&desc.out)?;
build.fuse_operation(func(lhs, rhs, out));
Some(())
})
}
fn fuse_unary_ops<Func>(&mut self, desc: &UnaryOpIr, func: Func) -> bool
where
Func: Fn(FuseArg, FuseArg) -> FuseOp,
{
self.fuse_unary_op(&desc.input, &desc.out, func)
}
fn fuse_unary_op<Func>(&mut self, input: &TensorIr, out: &TensorIr, func: Func) -> bool
where
Func: Fn(FuseArg, FuseArg) -> FuseOp,
{
if !self.output_is_compatible(out) {
return false;
}
self.fuser.fuse(|build| {
let input = build.input(input)?;
let out = build.output(out)?;
build.fuse_operation(func(input, out));
Some(())
})
}
fn fuse_scalar_ops<Func>(&mut self, desc: &ScalarOpIr, func: Func) -> bool
where
Func: Fn(FuseArg, FuseArg, FuseArg) -> FuseOp,
{
if !self.output_is_compatible(&desc.out) {
return false;
}
self.fuser.fuse(|build| {
let elem = desc.lhs.dtype;
let lhs = build.input(&desc.lhs)?;
let rhs = build.scalar(&desc.rhs, elem);
let out = build.output(&desc.out)?;
build.fuse_operation(func(lhs, rhs, out));
Some(())
})
}
fn output_is_compatible(&mut self, out: &TensorIr) -> bool {
if self.current_output_shape.is_empty() {
self.current_output_shape.clone_from(&out.shape);
return true;
}
let rank = self.current_output_shape.len();
// Rank should be equal.
if rank != out.shape.num_dims() {
return false;
}
let mut updated = self.current_output_shape.clone();
let mut should_update = false;
#[allow(clippy::needless_range_loop)]
for i in 0..rank {
let curr = self.current_output_shape[i];
let new = out.shape[i];
if curr == new {
continue;
}
// Broadcast not enabled.
if !self.settings.broadcast {
return false;
}
// Broadcasted on new dim.
if new == 0 {
continue;
}
// Broadcasted on curr dim - update reference output shape.
if curr == 0 && self.settings.output_shape_updates {
should_update = true;
updated[i] = new;
continue;
}
return false;
}
if should_update {
// For now forced to have exact shape.
if updated != out.shape {
return false;
}
self.current_output_shape.clone_from_slice(&out.shape);
}
true
}
}
#[derive(Debug, Clone)]
/// Builder wrapper to limit the number of bindings in generated kernels.
struct TryTraceFuser {
fuser: TraceFuser,
max_bindings: u32,
max_ops: u32,
added_ops: bool,
}
impl TryTraceFuser {
fn new(max_bindings: u32, bool_precision: FuseType, settings: FuseSettings) -> Self {
Self {
fuser: TraceFuser::new(bool_precision, settings),
max_bindings,
// A good default, avoid errors with for loops over only memory
// bound operations.
max_ops: 64,
added_ops: false,
}
}
fn fuse(&mut self, add_ops: impl FnOnce(&mut TraceFuser) -> Option<()>) -> bool {
if self.fuser.num_ops_fused() > self.max_ops {
return false;
}
// Always allow the first operation to be added.
if !self.added_ops {
self.added_ops = true;
if add_ops(&mut self.fuser).is_none() {
return false;
}
return true;
}
let mut cloned = self.fuser.clone();
if add_ops(&mut cloned).is_none() {
return false;
}
if cloned.estimate_bindings() > self.max_bindings {
return false;
}
self.fuser = cloned;
true
}
fn finish(&mut self, shape: Shape) -> FuseTrace {
self.fuser.finish(shape)
}
}

View File

@@ -0,0 +1,99 @@
use crate::{
CubeFusionHandle,
engine::{
launch::{
HandleInput, HandleOutput, LaunchPlan, executor::LaunchPlanExecutor,
input::InputPlanner, output::OutputPlanner, runner::TraceRunner,
vectorization::VectorizationPlanner,
},
trace::{FuseTrace, TraceError, TuneOutput},
},
};
use burn_fusion::stream::Context;
use cubecl::{CubeElement, Runtime, client::ComputeClient};
use std::marker::PhantomData;
/// The launcher is responsible to launch a fused kernel using the [TraceRunner] and a [FuseTrace].
pub struct FuseTraceLauncher<'a, R: Runtime, Runner: TraceRunner<R>> {
trace: &'a FuseTrace,
runner: &'a Runner,
_runtime: PhantomData<R>,
}
impl<'a, R: Runtime, Runner: TraceRunner<R>> FuseTraceLauncher<'a, R, Runner> {
/// Creates a new launcher.
pub fn new(trace: &'a FuseTrace, runner: &'a Runner) -> Self {
Self {
trace,
runner,
_runtime: PhantomData,
}
}
/// Launches the fuse kernel on the given device modifying the context.
pub fn launch<BT: CubeElement>(
&self,
client: &ComputeClient<R>,
device: &R::Device,
context: &mut Context<'_, CubeFusionHandle<R>>,
) -> Result<TuneOutput<R>, TraceError<Runner::Error>> {
let mut plan = LaunchPlan::new(&self.trace.blocks);
InputPlanner::new(&self.trace.resources, &self.trace.blocks).run(context, &mut plan);
OutputPlanner::new(&self.trace.resources, &self.trace.blocks)
.run::<BT>(client, device, context, &mut plan);
VectorizationPlanner::new(&self.trace.resources, &self.trace.blocks).run(
client,
self.runner,
context,
&mut plan,
);
match LaunchPlanExecutor::new(&self.trace.resources, &self.trace.blocks).execute::<_, BT>(
client,
self.runner,
context,
plan,
) {
Err(err) => {
self.rollback(context, err.handles_input, err.handles_output);
Err(err.error)
}
Ok(val) => Ok(val),
}
}
fn rollback(
&self,
context: &mut Context<'_, CubeFusionHandle<R>>,
handle_inputs: Vec<HandleInput<R>>,
handle_outputs: Vec<HandleOutput<R>>,
) {
for input in handle_inputs {
match input {
HandleInput::Normal(input) => {
context
.handles
.register_handle(input.global_ir.id, input.handle_rollback());
}
HandleInput::QuantValues(input) => {
context
.handles
.register_handle(input.global_ir.id, input.handle);
}
HandleInput::QuantParams(_) => {
// The scales are part of the quant data handle.
}
};
}
for output in handle_outputs {
if let HandleOutput::Owned {
global_id, handle, ..
} = output
{
context.handles.register_handle(global_id, handle);
}
}
}
}

View File

@@ -0,0 +1,292 @@
use super::{HandleInput, HandleOutput, LaunchPlan, ReferenceSelection};
use crate::engine::launch::runner::TraceRunner;
use crate::engine::trace::{FuseResources, TensorView, TraceError, TuneOutput, block::FuseBlock};
use crate::{
CubeFusionHandle, elem_dtype,
engine::{
codegen::ir::{
FuseBlockConfig, FuseOp, FuseType, GlobalArgsLaunch, RefLayout, VirtualLayout,
},
codegen::tensor::GlobalTensorArg,
},
};
use burn_fusion::stream::{Context, ScalarId};
use burn_ir::ScalarIr;
use burn_std::DType;
use cubecl::{
CubeElement, Runtime,
client::ComputeClient,
ir::AddressType,
prelude::{InputScalar, ScalarArg, TensorArg},
};
use std::marker::PhantomData;
/// Execute a [plan](LaunchPlan) using a [runner](TraceRunner) modifying the [context](Context).
pub struct LaunchPlanExecutor<'a, R: Runtime> {
resources: &'a FuseResources,
blocks: &'a Vec<FuseBlock>,
_r: PhantomData<R>,
}
#[derive(new, Debug)]
pub struct ExecutionError<R: Runtime, Runner: TraceRunner<R>> {
pub error: TraceError<Runner::Error>,
pub handles_input: Vec<HandleInput<R>>,
pub handles_output: Vec<HandleOutput<R>>,
}
impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> {
pub fn new(resources: &'a FuseResources, blocks: &'a Vec<FuseBlock>) -> Self {
Self {
resources,
blocks,
_r: PhantomData,
}
}
pub fn execute<Runner: TraceRunner<R>, BT: CubeElement>(
self,
client: &ComputeClient<R>,
runner: &Runner,
context: &mut Context<'_, CubeFusionHandle<R>>,
plan: LaunchPlan<'a, R>,
) -> Result<TuneOutput<R>, ExecutionError<R, Runner>> {
let mut num_writes = 0;
for b in plan.blocks.iter() {
for writes in b.writes.values() {
num_writes += writes.len();
}
}
#[cfg(feature = "autotune-checks")]
let mut tune_output = TuneOutput::Checked {
handles: std::collections::HashMap::new(),
};
#[cfg(not(feature = "autotune-checks"))]
let mut tune_output = TuneOutput::UnChecked(PhantomData);
if num_writes == 0 {
// Nothing to write, can skip execution.
return Ok(tune_output);
}
let mut inputs = GlobalArgsLaunch::default();
let mut outputs = GlobalArgsLaunch::default();
register_inputs(&plan.handle_inputs, &mut inputs);
register_scalars(
self.resources.scalars.iter(),
self.resources.views.iter(),
context,
&mut inputs,
);
register_outputs::<BT, R>(&plan.handle_outputs, &mut outputs, &mut tune_output);
for layout in plan.runtime_layouts {
for s in layout.shape.iter() {
inputs.runtime_layouts.push(ScalarArg::new(*s));
}
for s in layout.strides.iter() {
inputs.runtime_layouts.push(ScalarArg::new(*s));
}
}
let mut configs = Vec::with_capacity(plan.blocks.len());
for (block_plan, block) in plan.blocks.into_iter().zip(self.blocks) {
let reference = match block_plan.reference {
ReferenceSelection::Concrete { layout, .. } => RefLayout::Concrete(layout),
ReferenceSelection::VirtualShape { original, .. } => {
RefLayout::Virtual(VirtualLayout::Shape(original, block_plan.width))
}
ReferenceSelection::SwapDims { original, dims } => {
RefLayout::Virtual(VirtualLayout::SwapDims(original, dims))
}
ReferenceSelection::Reshaped { reshape_pos } => {
RefLayout::Virtual(VirtualLayout::Reshaped {
reshape_pos,
line_size: block_plan.width,
})
}
ReferenceSelection::Runtime { pos } => {
RefLayout::Virtual(VirtualLayout::Runtime { pos })
}
ReferenceSelection::Searching => {
return Err(ExecutionError::new(
TraceError::ReferenceNotFound,
plan.handle_inputs,
plan.handle_outputs,
));
}
};
let mut ops = Vec::<FuseOp>::new();
for read_ops in block_plan.reads.into_values() {
for op in read_ops {
ops.push(op);
}
}
for op in block.ops.iter() {
ops.push(op.clone());
}
for opsw in block_plan.writes.into_values() {
for op in opsw {
ops.push(op);
}
}
let config = FuseBlockConfig {
rank: plan.rank,
ref_layout: reference,
ops,
width: block_plan.width,
};
configs.push(config);
}
Runner::run(runner, client, inputs, outputs, &configs).map_err(|err| {
ExecutionError::new(
TraceError::RunnerError(err),
plan.handle_inputs,
plan.handle_outputs,
)
})?;
Ok(tune_output)
}
}
fn register_inputs<'h, R: Runtime>(
handle_inputs: &'h [HandleInput<R>],
inputs: &mut GlobalArgsLaunch<'h, R>,
) {
for hi in handle_inputs.iter() {
match hi {
HandleInput::Normal(hi) => {
let arg = hi.handle.as_tensor_arg(&hi.global_ir.shape, hi.line_size);
inputs.tensors.push(GlobalTensorArg::new(
arg,
hi.precision.into_elem(),
hi.broadcated,
hi.handle.required_address_type(),
));
}
HandleInput::QuantValues(hi) => {
let arg = hi.handle.as_tensor_arg(&hi.global_ir.shape, hi.line_size);
inputs.tensors.push(GlobalTensorArg::new(
arg,
hi.precision.into_elem(),
false,
hi.handle.required_address_type(),
));
}
HandleInput::QuantParams(hi) => {
let arg = hi.handle.as_tensor_arg(&hi.shape, 1);
inputs.tensors.push(GlobalTensorArg::new(
arg,
hi.precision.into_elem(),
false,
hi.handle.required_address_type(),
));
}
}
}
}
fn register_outputs<'s, BT: CubeElement, R: Runtime>(
handle_outputs: &'s [HandleOutput<R>],
outputs: &mut GlobalArgsLaunch<'s, R>,
#[allow(unused_variables)] tune_output: &mut TuneOutput<R>,
) {
for item in handle_outputs.iter() {
match item {
HandleOutput::Alias {
input_pos,
precision,
#[cfg(feature = "autotune-checks")]
debug_info,
} => {
outputs.tensors.push(GlobalTensorArg::new(
TensorArg::alias(*input_pos),
precision.into_elem(),
false,
AddressType::default(),
));
#[cfg(feature = "autotune-checks")]
if let TuneOutput::Checked { handles, .. } = tune_output {
handles.insert(
debug_info.relative_id,
(debug_info.global_shape.clone(), debug_info.handle.clone()),
);
}
}
HandleOutput::Owned {
precision,
handle,
global_shape,
vectorization: line_size,
#[cfg(feature = "autotune-checks")]
relative_id,
..
} => {
let arg = handle.as_tensor_arg(global_shape, *line_size);
let elem = match precision {
FuseType::Bool => match elem_dtype::<BT>() {
DType::U32 => FuseType::U32.into_elem(),
DType::U8 => FuseType::U8.into_elem(),
_ => todo!(),
},
_ => precision.into_elem(),
};
#[cfg(feature = "autotune-checks")]
if let TuneOutput::Checked { handles, .. } = tune_output {
handles.insert(*relative_id, (global_shape.clone(), handle.clone()));
}
outputs.tensors.push(GlobalTensorArg::new(
arg,
elem,
false,
handle.required_address_type(),
));
}
}
}
}
fn register_scalars<'h, R: Runtime>(
scalars: impl Iterator<Item = &'h (FuseType, u64)>,
views: impl DoubleEndedIterator<Item = &'h TensorView>,
context: &mut Context<'_, CubeFusionHandle<R>>,
inputs: &mut GlobalArgsLaunch<'h, R>,
) {
for (precision, id) in scalars {
let dtype = precision.into_type();
match context.scalars.get(&ScalarId { value: *id }) {
Some(scalar) => match scalar {
ScalarIr::Float(val) => inputs.scalars.push(InputScalar::new(*val, dtype)),
ScalarIr::Int(val) => inputs.scalars.push(InputScalar::new(*val, dtype)),
ScalarIr::UInt(val) => inputs.scalars.push(InputScalar::new(*val, dtype)),
ScalarIr::Bool(val) => inputs.scalars.push(InputScalar::new(*val as u8, dtype)),
},
None => panic!("Scalar ID not found"),
}
}
for relative in views {
if let TensorView::Reshape { reshaped, .. } = relative {
let global = context.tensors.get(reshaped).unwrap();
for shape in global.shape.iter() {
inputs.reshapes.push(ScalarArg::new(*shape));
}
}
}
}

View File

@@ -0,0 +1,247 @@
use super::{BlockPlan, HandleInput, InputReference};
use super::{LaunchPlan, NormalHandleInput, PotentialInplace};
use crate::CubeFusionHandle;
use crate::engine::launch::{QuantParamsHandleInput, QuantValuesHandleInput};
use crate::engine::trace::block::FuseBlock;
use crate::engine::trace::{FuseResources, RegisterTensor, TensorView};
use burn_fusion::stream::Context;
use burn_ir::{TensorIr, TensorStatus};
use burn_std::quantization::params_shape;
use cubecl::Runtime;
use std::marker::PhantomData;
/// Fetch and register [input handles](HandleInput). Also identifies potential inputs that
/// can be used inplace and/or as the [reference layout](super::super::ir::RefLayout).
pub struct InputPlanner<'a, R: Runtime> {
resources: &'a FuseResources,
blocks: &'a Vec<FuseBlock>,
_r: PhantomData<R>,
}
impl<'a, R: Runtime> InputPlanner<'a, R> {
pub fn new(resources: &'a FuseResources, blocks: &'a Vec<FuseBlock>) -> Self {
Self {
resources,
blocks,
_r: PhantomData,
}
}
pub fn run(self, context: &mut Context<'_, CubeFusionHandle<R>>, plan: &mut LaunchPlan<'a, R>) {
for (pos, input) in self.resources.inputs.iter().enumerate() {
match input {
RegisterTensor::Normal(tensor_relative, precision) => {
let mut tensor_global =
context.tensors.get(&tensor_relative.id).unwrap().clone();
let handle = context
.handles
.get_handle(&tensor_global.id, &TensorStatus::ReadOnly);
if let TensorStatus::ReadWrite = tensor_relative.status {
plan.cleared.push(tensor_global.id);
}
let mut new_strides = handle.strides.clone();
self.analyze(plan, pos, tensor_relative, &handle);
if tensor_global.shape.rank() < plan.rank {
let num_elem: usize = tensor_global.shape.iter().product();
for _ in 0..(plan.rank - tensor_global.shape.rank()) {
tensor_global.shape.insert(0, 1);
new_strides.insert(0, num_elem);
}
}
plan.handle_inputs
.push(HandleInput::Normal(NormalHandleInput::new(
tensor_global,
tensor_relative,
*precision,
handle,
new_strides,
)));
}
RegisterTensor::QuantValues(tensor_relative) => {
let tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone();
let handle = context
.handles
.get_handle(&tensor_global.id, &TensorStatus::ReadOnly);
let scheme = match tensor_relative.dtype {
burn_std::DType::QFloat(scheme) => scheme,
_ => unreachable!("Can't have quant data without QFloat"),
};
let params = handle.params(scheme).unwrap();
let precision = tensor_relative.dtype.into();
let precision_scales = params.dtype.into();
let global_shape = tensor_global.shape.clone();
let shape_params = params_shape(&global_shape, scheme.level);
plan.handle_inputs
.push(HandleInput::QuantValues(QuantValuesHandleInput {
relative_id: tensor_relative.id,
global_ir: tensor_global,
precision,
handle,
line_size: 1,
}));
plan.handle_inputs
.push(HandleInput::QuantParams(QuantParamsHandleInput {
precision: precision_scales,
handle: params,
shape: shape_params,
}));
}
RegisterTensor::QuantParams(_) => {
// It is registered at the same time as quant data.
// The order is important and the index in the vector as well, so that's why we
// have QuantParams.
}
}
}
}
fn analyze(
&self,
plan: &mut LaunchPlan<'a, R>,
pos: usize,
tensor_relative: &'a TensorIr,
handle: &CubeFusionHandle<R>,
) {
if !self
.resources
.inputs_unhandled
.contains(&tensor_relative.id)
{
let mut is_a_view = false;
// For each view we try to see if it's not possible to set it as a reference input.
for view in self.resources.views.iter() {
for (block_plan, block) in plan.blocks.iter_mut().zip(self.blocks) {
is_a_view = is_a_view
|| Self::analyze_view(pos, tensor_relative, block, block_plan, view);
}
}
if !is_a_view {
self.analyze_normal(plan, pos, tensor_relative, handle);
}
}
}
/// Analyzes if the given tensor can be used inplace in one of the block.
fn analyze_normal(
&self,
plan: &mut LaunchPlan<'a, R>,
pos: usize,
tensor_relative: &'a TensorIr,
handle: &CubeFusionHandle<R>,
) {
enum BlockInplaceSelection {
Notinit,
/// The block reads the input, and therefore can use it for inplace.
Selected(usize),
/// The same input is used in multiple blocks.
Unavailable,
}
let mut block_inplace_selection = BlockInplaceSelection::Notinit;
for (idx, block) in plan.blocks.iter().enumerate() {
if block.reads.contains_key(&tensor_relative.id) {
match block_inplace_selection {
BlockInplaceSelection::Notinit => {
block_inplace_selection = BlockInplaceSelection::Selected(idx);
}
BlockInplaceSelection::Selected(_) => {
block_inplace_selection = BlockInplaceSelection::Unavailable;
}
BlockInplaceSelection::Unavailable => {}
}
}
}
if let BlockInplaceSelection::Selected(idx) = block_inplace_selection {
if self.blocks[idx].shape_ref != tensor_relative.shape {
return;
}
let block_plan = &mut plan.blocks[idx];
if tensor_relative.status == TensorStatus::ReadWrite {
if self.blocks[idx].settings.inplace && handle.handle.can_mut() {
block_plan.potential_inplaces.push(PotentialInplace {
input_pos: pos,
tensor_relative,
strides: handle.strides.clone(),
});
}
// Inplace tensors are normally really good as the reference layout, since
// it's normally better to be based on writes rather than on reads.
block_plan.potential_reference_input =
Some(InputReference::Normal { input_pos: pos });
} else {
block_plan.potential_reference_input =
Some(InputReference::Normal { input_pos: pos });
}
}
}
/// Analyzes if the given tensor is also the view provided, and check if it can be used as the reference layout
/// for the given block.
fn analyze_view(
pos: usize,
tensor_relative: &'a TensorIr,
block: &FuseBlock,
block_plan: &mut BlockPlan<'a>,
view: &TensorView,
) -> bool {
match view {
TensorView::Reshape {
reshaped,
original,
reshape_pos,
shape_relative,
} => {
if original == &tensor_relative.id || reshaped == &tensor_relative.id {
if block_plan.potential_reference_input.is_none()
&& shape_relative == &block.shape_ref
{
block_plan.potential_reference_input = Some(InputReference::Reshaped {
reshape_pos: *reshape_pos,
});
}
return true;
}
}
TensorView::SwapDims {
swapped,
original,
dims,
..
} => {
if swapped == &tensor_relative.id {
return true;
}
if original == &tensor_relative.id {
let shape = tensor_relative
.shape
.clone()
.swapped(dims.0, dims.1)
.unwrap();
if block_plan.potential_reference_input.is_none() && shape == block.shape_ref {
block_plan.potential_reference_input = Some(InputReference::SwapDims {
original_pos: pos,
dims: *dims,
});
}
return true;
}
}
};
false
}
}

View File

@@ -0,0 +1,11 @@
pub(crate) mod executor;
pub(crate) mod input;
pub(crate) mod output;
pub(crate) mod runner;
pub(crate) mod vectorization;
pub(crate) mod plan;
pub use plan::*;
mod base;
pub use base::*;

View File

@@ -0,0 +1,696 @@
use super::{
super::codegen::ir::FuseType, BlockPlan, HandleOutput, InputReference, LaunchPlan,
NormalHandleInput, ReferenceSelection,
};
use crate::{
CubeFusionHandle, elem_dtype,
engine::{
codegen::ir::{FuseArg, FuseOp, LayoutInfo},
launch::HandleInput,
settings::RefLayoutSetting,
trace::{FuseResources, RegisterTensor, RuntimeLayout, TensorView, block::FuseBlock},
},
strides_dyn_rank,
};
use burn_fusion::stream::Context;
use burn_ir::{TensorId, TensorIr};
use burn_std::{DType, Shape};
use burn_std::{
Strides,
tensor::{ReshapeAction, contiguous_strides, is_contiguous, reshape_action},
};
use cubecl::{CubeElement, Runtime, client::ComputeClient, ir::StorageType};
/// Create or reuse handles for the outputs.
///
/// It is also responsible to select the reference tensor.
pub struct OutputPlanner<'a, R: Runtime> {
resources: &'a FuseResources,
outputs_sorted: Vec<OutputSorted<'a>>,
handles: Vec<Option<HandleOutput<R>>>,
globals: Vec<Option<TensorIr>>,
blocks: &'a Vec<FuseBlock>,
}
#[derive(Debug)]
struct OutputSorted<'a> {
pos_original: usize,
precision: FuseType,
tensor_relative: &'a TensorIr,
}
#[derive(Debug)]
enum OutputKind {
Normal,
Inplace {
/// The position in the potential inplace vector
input_pos: usize,
},
Transform(TensorView),
}
impl<'a, R: Runtime> OutputPlanner<'a, R> {
pub fn new(resources: &'a FuseResources, blocks: &'a Vec<FuseBlock>) -> Self {
let mut outputs_sorted: Vec<_> = resources
.outputs
.iter()
.enumerate()
.filter_map(|(pos, entry)| match entry {
RegisterTensor::Normal(ir, p) => Some((pos, ir, p)),
RegisterTensor::QuantValues(_) => None,
RegisterTensor::QuantParams(_) => None,
})
.map(|(pos, tensor, precision)| OutputSorted {
pos_original: pos,
precision: *precision,
tensor_relative: tensor,
})
.collect();
outputs_sorted.sort_by(|a, b| {
let a_val: usize = a.tensor_relative.shape.iter().sum();
let b_val: usize = b.tensor_relative.shape.iter().sum();
b_val.cmp(&a_val)
});
let mut handles = Vec::with_capacity(resources.outputs.len());
let mut globals = Vec::with_capacity(resources.outputs.len());
for _ in 0..resources.outputs.len() {
handles.push(None);
globals.push(None);
}
Self {
resources,
outputs_sorted,
handles,
globals,
blocks,
}
}
pub fn run<BT: CubeElement>(
mut self,
client: &ComputeClient<R>,
device: &R::Device,
context: &mut Context<'_, CubeFusionHandle<R>>,
plan: &mut LaunchPlan<'a, R>,
) {
// So that we can borrow self during the iteration.
let mut outputs = Vec::new();
core::mem::swap(&mut outputs, &mut self.outputs_sorted);
for output in outputs.into_iter() {
let tensor_global = context
.tensors
.get(&output.tensor_relative.id)
.unwrap()
.clone();
let strides = strides_dyn_rank(&tensor_global.shape);
let (kind, block_idx) = self.output_kind(plan, &tensor_global, &output, &strides);
match kind {
OutputKind::Inplace { input_pos } => {
self.inplace_output(context, plan, output, tensor_global, input_pos, block_idx);
}
OutputKind::Normal => {
self.normal_output::<BT>(
client,
device,
context,
plan,
output,
tensor_global,
strides,
block_idx,
);
}
OutputKind::Transform(TensorView::Reshape { original, .. }) => {
self.reshaped_output::<BT>(
client,
device,
context,
plan,
output,
tensor_global,
strides,
original,
block_idx,
);
}
OutputKind::Transform(TensorView::SwapDims { original, dims, .. }) => {
self.swapped_dims_output::<BT>(
client,
device,
context,
plan,
output,
tensor_global,
original,
dims,
block_idx,
);
}
}
}
for (handle, global) in self.handles.into_iter().zip(self.globals.into_iter()) {
plan.handle_outputs.push(handle.unwrap());
plan.global_outputs.push(global.unwrap());
}
for i in 0..plan.blocks.len() {
if !plan.blocks[i].reference.is_found() {
match self.blocks[i].settings.ref_layout {
RefLayoutSetting::SameAsBlock { block_pos } => {
plan.blocks[i].reference =
plan.blocks[block_pos as usize].reference.clone();
}
_ => {
let new_runtime = Self::select_reference_from_inputs(
&self.blocks[i],
&mut plan.blocks[i],
&plan.handle_inputs,
);
if let Some(shape) = new_runtime {
let pos = plan.runtime_layouts.len();
let mut shape_global = shape.clone();
for (i, s) in shape.iter().enumerate() {
shape_global[i] = *context.shapes_relative2global.get(s).unwrap();
}
let strides = strides_dyn_rank(&shape_global);
plan.blocks[i].reference = ReferenceSelection::Runtime { pos };
plan.runtime_layouts.push(RuntimeLayout {
shape: shape_global,
strides,
});
}
}
};
} else {
Self::add_layout_info_inputs(&mut plan.blocks[i], &plan.handle_inputs);
}
}
// Make sure dropped are correctly executed.
for id in self.resources.dropped.iter() {
if let Some(tensor_global) = context.tensors.get(id) {
context.handles.remove_handle(tensor_global.id);
}
}
for id in plan.cleared.drain(..) {
context.handles.remove_handle(id);
}
}
fn select_reference_from_inputs(
block: &FuseBlock,
block_plan: &mut BlockPlan<'_>,
handle_inputs: &[HandleInput<R>],
) -> Option<Shape> {
if let Some(input_ref) = block_plan.potential_reference_input.take() {
match input_ref {
InputReference::Normal { input_pos } => {
let reference = handle_inputs
.get(input_pos)
.unwrap()
.as_normal()
.expect("Quant can't be used as inplace");
let set_ref_as_concrete = |block: &mut BlockPlan<'_>| {
block.reference = ReferenceSelection::Concrete {
layout: FuseArg::Input(
input_pos,
reference.precision,
LayoutInfo::IsRef,
),
shape: reference.global_ir.shape.clone(),
strides: reference.handle.strides.clone(),
};
};
let set_ref_as_virtual = |block: &mut BlockPlan<'_>| {
block.reference = ReferenceSelection::VirtualShape {
original: FuseArg::Input(
input_pos,
reference.precision,
LayoutInfo::Unknown,
),
shape: reference.global_ir.shape.clone(),
strides: contiguous_strides(&reference.global_ir.shape),
};
};
match block.settings.ref_layout {
RefLayoutSetting::Any => set_ref_as_concrete(block_plan),
RefLayoutSetting::SameAsBlock { .. } => {
// Skip set ref.
}
RefLayoutSetting::OnlyContiguous => {
if is_contiguous(&reference.global_ir.shape, &reference.handle.strides)
{
set_ref_as_concrete(block_plan)
} else {
set_ref_as_virtual(block_plan)
}
}
}
Self::add_layout_info_inputs(block_plan, handle_inputs);
}
InputReference::SwapDims { original_pos, dims } => {
let reference = handle_inputs
.get(original_pos)
.unwrap()
.as_normal()
.expect("Quant can't be used in swap dims operation");
block_plan.reference = ReferenceSelection::SwapDims {
original: FuseArg::Input(
original_pos,
reference.precision,
LayoutInfo::Unknown,
),
dims,
};
}
InputReference::Reshaped { reshape_pos } => {
block_plan.reference = ReferenceSelection::Reshaped { reshape_pos };
}
};
None
} else {
Some(block.shape_ref.clone())
}
}
fn add_layout_info_inputs(block: &mut BlockPlan<'_>, handle_inputs: &[HandleInput<R>]) {
for hi in handle_inputs.iter().filter_map(|h| match h {
HandleInput::Normal(input) => Some(input),
_ => None,
}) {
let (strides, shape) = match &block.reference {
ReferenceSelection::Concrete { strides, shape, .. }
| ReferenceSelection::VirtualShape { strides, shape, .. } => (strides, shape),
_ => continue,
};
if strides == &hi.handle.strides
&& shape == &hi.global_ir.shape
&& let Some(ops) = block.reads.get_mut(&hi.relative_id)
{
for op in ops.iter_mut() {
if let FuseOp::Assign(op) = op {
op.input.add_layout_info(LayoutInfo::SameAsRef);
}
}
}
}
}
fn output_kind(
&self,
plan: &mut LaunchPlan<'a, R>,
tensor_global: &TensorIr,
output: &OutputSorted,
strides: &[usize],
) -> (OutputKind, usize) {
let mut block_idx = None;
for (i, block) in plan.blocks.iter().enumerate() {
if block.writes.contains_key(&output.tensor_relative.id) {
block_idx = Some(i);
break;
}
}
let block_idx = block_idx.unwrap();
if let Some(transform) = self.resources.views.iter().find(|v| match v {
TensorView::Reshape { reshaped, .. } => reshaped == &output.tensor_relative.id,
TensorView::SwapDims { swapped, .. } => swapped == &output.tensor_relative.id,
}) {
return (OutputKind::Transform(transform.clone()), block_idx);
}
let block = &plan.blocks[block_idx];
let kind = block
.potential_inplaces
.iter()
.enumerate()
.find(|(_pos, pi)| {
pi.tensor_relative.dtype == tensor_global.dtype
&& pi.tensor_relative.shape == output.tensor_relative.shape
&& &*pi.strides == strides
&& block.reference.compatible_strides_for_inplace(strides)
})
.map(|(pos, _)| OutputKind::Inplace { input_pos: pos })
.unwrap_or(OutputKind::Normal);
(kind, block_idx)
}
fn inplace_output(
&mut self,
context: &mut Context<'_, CubeFusionHandle<R>>,
plan: &mut LaunchPlan<'a, R>,
output: OutputSorted,
tensor_global: TensorIr,
input_index: usize,
block_idx: usize,
) {
let block = &mut plan.blocks[block_idx];
let potential_inplace = block.potential_inplaces.remove(input_index);
let handle_input = match plan.handle_inputs.get(potential_inplace.input_pos).unwrap() {
HandleInput::Normal(handle) => handle,
_ => {
unreachable!("Quant tensor handle can't be used inplace yet.")
}
};
if !block.reference.is_found()
&& !matches!(
self.blocks[block_idx].settings.ref_layout,
RefLayoutSetting::SameAsBlock { .. }
)
{
let index_input = self
.resources
.inputs
.get_index(potential_inplace.tensor_relative.id)
.unwrap();
block.reference = ReferenceSelection::Concrete {
layout: FuseArg::Input(index_input, output.precision, LayoutInfo::IsRef),
shape: tensor_global.shape.clone(),
strides: handle_input.handle.strides.clone(),
};
if let Some(ops) = block.reads.get_mut(&handle_input.relative_id) {
for op in ops.iter_mut() {
if let FuseOp::Assign(op) = op {
op.input.add_layout_info(LayoutInfo::IsRef);
break;
};
}
}
if let Some(ops) = block.writes.get_mut(&output.tensor_relative.id) {
for op in ops {
if let FuseOp::Assign(op) = op {
op.out.add_layout_info(LayoutInfo::IsRef);
break;
}
}
};
} else {
// Already validated, necessary for correctness.
if let Some(ops) = block.writes.get_mut(&output.tensor_relative.id) {
for op in ops {
if let FuseOp::Assign(op) = op {
op.out.add_layout_info(LayoutInfo::SameAsRef);
break;
}
}
};
}
context
.handles
.register_handle(tensor_global.id, handle_input.handle.clone());
self.handles[output.pos_original] = Some(HandleOutput::Alias {
input_pos: potential_inplace.input_pos,
precision: output.precision,
#[cfg(feature = "autotune-checks")]
debug_info: super::HandleOutputAliasDebugInfo {
relative_id: output.tensor_relative.id,
handle: handle_input.handle.clone(),
global_shape: tensor_global.shape.dims.clone(),
},
});
self.globals[output.pos_original] = Some(tensor_global);
}
#[allow(clippy::too_many_arguments)]
fn normal_output<BT: CubeElement>(
&mut self,
client: &ComputeClient<R>,
device: &R::Device,
context: &mut Context<'_, CubeFusionHandle<R>>,
plan: &mut LaunchPlan<'a, R>,
output: OutputSorted,
tensor_global: TensorIr,
strides: Strides,
block_idx: usize,
) {
let block = &mut plan.blocks[block_idx];
if !block.reference.is_found()
&& self.blocks[block_idx].shape_ref == output.tensor_relative.shape
&& !matches!(
self.blocks[block_idx].settings.ref_layout,
RefLayoutSetting::SameAsBlock { .. }
)
{
block.reference = ReferenceSelection::Concrete {
layout: FuseArg::Output(output.pos_original, output.precision, LayoutInfo::IsRef),
shape: tensor_global.shape.clone(),
strides: strides.clone(),
};
// Sometimes outputs that are manually handled don't have any write registered.
if let Some(ops) = block.writes.get_mut(&output.tensor_relative.id) {
for op in ops {
if let FuseOp::Assign(op) = op {
op.out.add_layout_info(LayoutInfo::IsRef);
break;
}
}
};
} else if let ReferenceSelection::Concrete {
shape: ref_shape,
strides: ref_strides,
..
} = &block.reference
&& ref_strides == &strides
&& ref_shape == &tensor_global.shape
&& let Some(ops) = block.writes.get_mut(&output.tensor_relative.id)
{
for op in ops {
if let FuseOp::Assign(op) = op {
op.out.add_layout_info(LayoutInfo::SameAsRef);
break;
}
}
};
// We encode bool tensors as `B`.
let dtype = match tensor_global.dtype {
DType::Bool => elem_dtype::<BT>(),
_ => tensor_global.dtype,
};
let size = tensor_global.shape.iter().product::<usize>() * StorageType::from(dtype).size();
let handle = CubeFusionHandle {
client: client.clone(),
handle: client.empty(size),
device: device.clone(),
strides,
dtype,
qparams: None,
};
plan.rank = usize::max(tensor_global.shape.rank(), plan.rank);
context
.handles
.register_handle(tensor_global.id, handle.clone());
self.handles[output.pos_original] = Some(HandleOutput::Owned {
precision: output.precision,
handle,
global_shape: tensor_global.shape.clone(),
global_id: tensor_global.id,
relative_id: output.tensor_relative.id,
vectorization: 1,
});
self.globals[output.pos_original] = Some(tensor_global);
}
#[allow(clippy::too_many_arguments)]
fn reshaped_output<BT: CubeElement>(
&mut self,
client: &ComputeClient<R>,
device: &R::Device,
context: &mut Context<'_, CubeFusionHandle<R>>,
plan: &mut LaunchPlan<'a, R>,
output: OutputSorted,
tensor_global: TensorIr,
strides: Strides,
original: TensorId,
block_idx: usize,
) {
let block = &mut plan.blocks[block_idx];
let (pos_input, original_handle) = Self::find_child_input(&plan.handle_inputs, original);
// We encode bool tensors as `B`.
let dtype = match tensor_global.dtype {
DType::Bool => elem_dtype::<BT>(),
_ => tensor_global.dtype,
};
let action = reshape_action(
&original_handle.global_ir.shape,
&original_handle.handle.strides,
&tensor_global.shape,
);
let update = match action {
ReshapeAction::UpdateStrides { strides } => Some(strides),
ReshapeAction::NoChange => Some(original_handle.handle.strides.clone()),
ReshapeAction::Recompute => None,
};
match update {
Some(strides) => {
// We modify the metadata instead.
remove_concrete_write(block, output.tensor_relative.id, output.pos_original);
let handle = CubeFusionHandle {
client: client.clone(),
handle: original_handle.handle.handle.clone(),
device: device.clone(),
strides,
dtype,
qparams: original_handle.handle.qparams.clone(),
};
context
.handles
.register_handle(tensor_global.id, handle.clone());
// IT will never be access, just a way to keep the original position working.
self.handles[output.pos_original] = Some(HandleOutput::Alias {
input_pos: pos_input,
precision: output.precision,
#[cfg(feature = "autotune-checks")]
debug_info: super::HandleOutputAliasDebugInfo {
relative_id: output.tensor_relative.id,
handle: handle.clone(),
global_shape: tensor_global.shape.dims.clone(),
},
});
self.globals[output.pos_original] = Some(tensor_global);
}
None => {
self.normal_output::<BT>(
client,
device,
context,
plan,
output,
tensor_global,
strides,
block_idx,
);
}
}
}
#[allow(clippy::too_many_arguments)]
fn swapped_dims_output<BT: CubeElement>(
&mut self,
client: &ComputeClient<R>,
device: &R::Device,
context: &mut Context<'_, CubeFusionHandle<R>>,
plan: &mut LaunchPlan<'a, R>,
output: OutputSorted,
tensor_global: TensorIr,
original: TensorId,
dims: (usize, usize),
block_idx: usize,
) {
let block = &mut plan.blocks[block_idx];
let (pos_input, original_handle) = Self::find_child_input(&plan.handle_inputs, original);
// We encode bool tensors as `B`.
let dtype = match tensor_global.dtype {
DType::Bool => elem_dtype::<BT>(),
_ => tensor_global.dtype,
};
// TODO: Check if we can also remove the read, if we have a dead partial graph.
//
// We modify the metadata instead.
remove_concrete_write(block, output.tensor_relative.id, output.pos_original);
let strides = original_handle.handle.strides.clone();
let mut handle = CubeFusionHandle {
client: client.clone(),
handle: original_handle.handle.handle.clone(),
device: device.clone(),
strides,
dtype,
qparams: original_handle.handle.qparams.clone(),
};
handle.strides.swap(dims.0, dims.1);
context
.handles
.register_handle(tensor_global.id, handle.clone());
// IT will never be access, just a way to keep the original position working.
self.handles[output.pos_original] = Some(HandleOutput::Alias {
input_pos: pos_input,
precision: output.precision,
#[cfg(feature = "autotune-checks")]
debug_info: super::HandleOutputAliasDebugInfo {
relative_id: output.tensor_relative.id,
handle: handle.clone(),
global_shape: tensor_global.shape.dims.clone(),
},
});
self.globals[output.pos_original] = Some(tensor_global);
}
fn find_child_input(
handle_inputs: &[HandleInput<R>],
original: TensorId,
) -> (usize, &NormalHandleInput<R>) {
handle_inputs
.iter()
.enumerate()
.find_map(|(pi, handle)| match handle {
HandleInput::Normal(handle) => match handle.relative_id == original {
true => Some((pi, handle)),
false => None,
},
_ => None, // Quant tensor can't be reshaped.
})
.unwrap()
}
}
fn remove_concrete_write(block: &mut BlockPlan, id: TensorId, output_pos: usize) {
let ops = block.writes.remove(&id);
if let Some(ops) = ops {
let mut keep = Vec::with_capacity(ops.len());
for op in ops {
if let FuseOp::Assign(args) = &op {
if let FuseArg::Output(pos, ..) = args.out {
if pos != output_pos {
keep.push(op);
}
} else {
keep.push(op);
}
}
}
block.writes.insert(id, keep);
}
}

View File

@@ -0,0 +1,273 @@
use crate::{
CubeFusionHandle,
engine::{
codegen::ir::{FuseArg, FuseOp, FuseType},
launch::vectorization::Vect,
trace::{RuntimeLayout, block::FuseBlock},
},
};
use burn_ir::{TensorId, TensorIr};
use burn_std::{Shape, Strides};
use cubecl::{Runtime, ir::LineSize};
use std::collections::BTreeMap;
/// The `LaunchPlan` is responsible for aggregating all runtime information required
/// to dispatch a fused kernel.
///
/// It maps abstract IR tensors to memory handles, manages vectorization
/// strategies, and tracks layout transformations.
#[derive(Debug)]
pub struct LaunchPlan<'a, R: Runtime> {
/// The IR representation of tensors that are results of the fusion.
pub global_outputs: Vec<TensorIr>,
/// Memory handles and metadata for all input tensors.
pub handle_inputs: Vec<HandleInput<R>>,
/// Memory handles and metadata for all output tensors, including aliased inputs.
pub handle_outputs: Vec<HandleOutput<R>>,
/// The rank across all tensors in the plan.
///
/// Smaller tensors are unsqueezed during launch.
pub rank: usize,
/// Detailed planning for each individual computation block within the fusion.
pub blocks: Vec<BlockPlan<'a>>,
/// Mapping of tensor IDs to their specific vectorization factors.
pub vectorizations: BTreeMap<TensorId, Vect>,
/// Tensors that can be cleared or deallocated after this plan executes.
pub cleared: Vec<TensorId>,
/// Metadata for shapes and strides passed from the host when they cannot be
/// inferred from input tensors (e.g., complex deep fusions).
pub runtime_layouts: Vec<RuntimeLayout>,
}
/// Information regarding the execution of a specific block of operations within a fusion.
#[derive(Debug)]
pub struct BlockPlan<'a> {
/// List of inputs that are candidates for in-place memory reuse within this block.
pub potential_inplaces: Vec<PotentialInplace<'a>>,
/// The input tensor chosen to define the iteration space, if any.
pub potential_reference_input: Option<InputReference>,
/// How the master layout is determined for this block.
pub reference: ReferenceSelection,
/// Mapping of tensor IDs to the read operations performed on them.
pub reads: BTreeMap<TensorId, Vec<FuseOp>>,
/// Mapping of tensor IDs to the write operations performed on them.
pub writes: BTreeMap<TensorId, Vec<FuseOp>>,
/// The width for the operations in this block.
pub width: LineSize,
}
/// Metadata for an input tensor being used as a reference for a block's layout.
#[derive(Debug)]
pub enum InputReference {
/// Standard input at the specified position.
Normal { input_pos: usize },
/// Input that has an axis swapped.
SwapDims {
original_pos: usize,
dims: (usize, usize),
},
/// Input that has been reshaped.
Reshaped { reshape_pos: usize },
}
/// Strategies for selecting the reference layout of a fused block.
///
/// The reference layout determines how global indices are mapped to tensor coordinates.
#[derive(Clone, Debug)]
pub enum ReferenceSelection {
/// The engine is still calculating the optimal reference.
Searching,
/// Layout from a normal tensor.
Concrete {
layout: FuseArg,
shape: Shape,
strides: Strides,
},
/// Layout from a swapped dim tensor.
SwapDims {
original: FuseArg,
dims: (usize, usize),
},
/// Layout from a reshaped tensor.
Reshaped { reshape_pos: usize },
/// Layout that has the shape of an input, but not its strides.
VirtualShape {
original: FuseArg,
shape: Shape,
strides: Strides,
},
/// The layout is provided dynamically by the host at runtime.
Runtime { pos: usize },
}
impl<R: Runtime> LaunchPlan<'_, R> {
/// Creates a new `LaunchPlan` from a slice of fusion blocks.
///
/// Initializes blocks with default "Searching" references and calculates
/// the initial max rank.
pub fn new(fuse_blocks: &[FuseBlock]) -> Self {
let mut rank = 0;
let mut blocks = Vec::with_capacity(fuse_blocks.len());
for b in fuse_blocks.iter() {
rank = usize::max(b.shape_ref.len(), rank);
let block = BlockPlan {
reference: ReferenceSelection::Searching,
reads: b.reads.clone(),
writes: b.writes.clone(),
width: 0,
potential_inplaces: Vec::new(),
potential_reference_input: None,
};
blocks.push(block);
}
LaunchPlan {
global_outputs: Vec::new(),
handle_inputs: Vec::new(),
handle_outputs: Vec::new(),
rank,
blocks,
vectorizations: Default::default(),
cleared: Default::default(),
runtime_layouts: Default::default(),
}
}
}
/// Debugging information for aliased handles when `autotune-checks` is enabled.
#[cfg(feature = "autotune-checks")]
#[derive(Debug)]
pub struct HandleOutputAliasDebugInfo<R: Runtime> {
pub handle: CubeFusionHandle<R>,
pub relative_id: TensorId,
pub global_shape: Shape,
}
/// Represents the output of a fused kernel execution.
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
pub enum HandleOutput<R: Runtime> {
/// An output that reuses the memory of an input tensor (In-place).
Alias {
/// Index of the input handle being aliased.
input_pos: usize,
/// Data type precision.
precision: FuseType,
#[cfg(feature = "autotune-checks")]
debug_info: HandleOutputAliasDebugInfo<R>,
},
/// An output that requires a newly allocated memory buffer.
Owned {
global_id: TensorId,
relative_id: TensorId,
precision: FuseType,
handle: CubeFusionHandle<R>,
global_shape: Shape,
vectorization: LineSize,
},
}
/// A standard input handle with associated layout and vectorization metadata.
#[derive(Debug)]
pub struct NormalHandleInput<R: Runtime> {
pub relative_id: TensorId,
pub global_ir: TensorIr,
pub precision: FuseType,
pub handle: CubeFusionHandle<R>,
pub line_size: LineSize,
pub broadcated: bool,
/// Stores the original strides of the handle for restoration during plan rollback.
pub orig_strides: Strides,
}
/// An input handle containing values for a quantized tensor.
#[derive(Debug)]
pub struct QuantValuesHandleInput<R: Runtime> {
pub relative_id: TensorId,
pub global_ir: TensorIr,
pub precision: FuseType,
pub handle: CubeFusionHandle<R>,
pub line_size: LineSize,
}
/// An input handle containing parameters (scales/offsets) for quantization.
#[derive(Debug)]
pub struct QuantParamsHandleInput<R: Runtime> {
pub precision: FuseType,
pub handle: CubeFusionHandle<R>,
pub shape: Shape,
}
/// Different types of inputs that can be passed to a fused kernel.
#[derive(Debug)]
pub enum HandleInput<R: Runtime> {
Normal(NormalHandleInput<R>),
QuantValues(QuantValuesHandleInput<R>),
QuantParams(QuantParamsHandleInput<R>),
}
impl<R: Runtime> HandleInput<R> {
/// Returns a reference to the inner `NormalHandleInput` if the variant matches.
pub fn as_normal(&self) -> Option<&NormalHandleInput<R>> {
match self {
HandleInput::Normal(normal) => Some(normal),
_ => None,
}
}
}
impl<R: Runtime> NormalHandleInput<R> {
/// Creates a new `NormalHandleInput` tracking original strides.
pub fn new(
tensor_global: TensorIr,
tensor_relative: &TensorIr,
precision: FuseType,
mut handle: CubeFusionHandle<R>,
mut strides: Strides,
) -> Self {
// Swap current handle strides with provided strides to track the original state for rollback.
core::mem::swap(&mut handle.strides, &mut strides);
Self {
precision,
handle,
relative_id: tensor_relative.id,
global_ir: tensor_global,
line_size: 1,
broadcated: false,
orig_strides: strides,
}
}
/// Restores the handle's original strides and returns the handle.
///
/// Used when a plan is invalidated or needs to be rolled back.
pub fn handle_rollback(mut self) -> CubeFusionHandle<R> {
core::mem::swap(&mut self.handle.strides, &mut self.orig_strides);
self.handle
}
}
/// A candidate for in-place optimization.
#[derive(Debug)]
pub struct PotentialInplace<'a> {
/// Position of the input handle in the `handle_inputs` vector.
pub input_pos: usize,
/// Reference to the IR of the relative tensor.
pub tensor_relative: &'a TensorIr,
/// Current strides of the potential in-place candidate.
pub strides: Strides,
}
impl ReferenceSelection {
pub fn is_found(&self) -> bool {
!matches!(self, Self::Searching)
}
pub fn compatible_strides_for_inplace(&self, strides_inplace: &[usize]) -> bool {
match self {
ReferenceSelection::Concrete { strides, .. } => &**strides == strides_inplace,
_ => false,
}
}
}

View File

@@ -0,0 +1,96 @@
use super::super::codegen::ir::{FuseBlockConfig, GlobalArgsLaunch};
use crate::{
CubeFusionHandle,
engine::launch::{
LaunchPlan,
vectorization::{Vect, vectorization_default},
},
};
use burn_fusion::stream::Context;
use burn_ir::{TensorId, TensorIr};
use cubecl::prelude::*;
use std::collections::{BTreeMap, HashMap};
/// A trace runner is responsible for determining the vectorization factor as well as launching
/// a kernel based on global [inputs](GlobalArgsLaunch) and [outputs](GlobalArgsLaunch)
/// with provided [fuse block configs](FuseBlockConfig).
pub trait TraceRunner<R: Runtime>: Vectorization<R> {
/// The error that might happen while running the trace.
type Error;
/// Run the trace with the given inputs and outputs.
///
/// There is one [fuse config](FuseBlockConfig) for each [block](super::block::FuseBlock) registered
/// in the [optimization builder](burn_fusion::OptimizationBuilder).
fn run<'a>(
&'a self,
client: &'a ComputeClient<R>,
inputs: GlobalArgsLaunch<'a, R>,
outputs: GlobalArgsLaunch<'a, R>,
configs: &'a [FuseBlockConfig],
) -> Result<(), Self::Error>;
}
pub enum VectorizationHandle<'a, R: Runtime> {
NormalInput(&'a CubeFusionHandle<R>, &'a TensorIr),
QuantValues(&'a CubeFusionHandle<R>, &'a TensorIr),
QuantParams,
}
impl<'a, R: Runtime> VectorizationHandle<'a, R> {
/// Returns if the current vectorization handle is from the given tensor id.
pub fn is_from_tensor(&self, id: TensorId) -> bool {
match self {
VectorizationHandle::NormalInput(_, tensor_ir) => tensor_ir.id == id,
VectorizationHandle::QuantValues(_, tensor_ir) => tensor_ir.id == id,
VectorizationHandle::QuantParams => false,
}
}
}
#[derive(Default)]
pub struct VectorizationAxis {
axis: HashMap<TensorId, usize>,
}
impl VectorizationAxis {
pub fn get<F: FnOnce() -> usize>(&self, id: TensorId, default: F) -> usize {
self.axis.get(&id).copied().unwrap_or_else(default)
}
pub fn insert(&mut self, id: TensorId, axis: usize) {
self.axis.insert(id, axis);
}
}
pub trait Vectorization<R: Runtime> {
/// Returns the vectorization options.
fn axis(&self, _plan: &LaunchPlan<'_, R>) -> VectorizationAxis {
VectorizationAxis::default()
}
/// The vectorization factor for all inputs and outputs.
#[allow(clippy::too_many_arguments)]
fn vectorization<'a>(
&self,
_context: &Context<'_, CubeFusionHandle<R>>,
vectorizations: &mut BTreeMap<TensorId, Vect>,
inputs: impl Iterator<Item = VectorizationHandle<'a, R>>,
outputs: impl Iterator<Item = &'a TensorIr>,
reshaped: impl Iterator<Item = (&'a TensorIr, &'a TensorIr, bool)>,
swapped: impl Iterator<Item = (&'a TensorIr, &'a TensorIr, bool, &'a (usize, usize))>,
line_sizes: &[LineSize],
max: LineSize,
axis: VectorizationAxis,
) {
vectorization_default(
vectorizations,
inputs,
outputs,
reshaped,
swapped,
line_sizes,
&Default::default(),
max,
&axis,
)
}
}

View File

@@ -0,0 +1,439 @@
use crate::{
CubeFusionHandle,
engine::launch::runner::{VectorizationAxis, VectorizationHandle},
};
use burn_fusion::stream::Context;
use burn_ir::{TensorId, TensorIr};
use cubecl::{Runtime, ir::LineSize};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
#[derive(Debug, Clone, Copy)]
pub enum Vect {
Broadcasted,
Aligned(LineSize),
}
impl Vect {
pub fn line_size(&self) -> LineSize {
match self {
Vect::Broadcasted => 1,
Vect::Aligned(val) => *val,
}
}
pub fn is_broadcast(&self) -> bool {
matches!(self, Vect::Broadcasted)
}
}
#[derive(Default, Clone, Serialize, Deserialize, Debug)]
pub struct LineSizeOverrides {
state: Option<BTreeMap<TensorId, Vec<LineSize>>>,
default: Option<Vec<LineSize>>,
}
#[allow(unused)]
impl LineSizeOverrides {
pub fn overrides(&mut self, tensor_id: &TensorId, line_sizes: Vec<LineSize>) {
let map = match &mut self.state {
Some(val) => val,
None => {
self.state = Some(BTreeMap::new());
self.state.as_mut().unwrap()
}
};
map.insert(*tensor_id, line_sizes);
}
pub fn overrides_default(&mut self, line_sizes: Vec<LineSize>) {
self.default = Some(line_sizes);
}
pub fn mapping<R: Runtime>(&self, context: &Context<'_, CubeFusionHandle<R>>) -> Self {
match &self.state {
Some(state) => {
let mut state_new = BTreeMap::new();
for (k, v) in state.iter() {
let global = context.tensors.get(k).unwrap();
state_new.insert(global.id, v.clone());
}
Self {
state: Some(state_new),
default: self.default.clone(),
}
}
None => Self {
state: None,
default: self.default.clone(),
},
}
}
pub fn tensor(&self, tensor_id: &TensorId) -> Option<&Vec<LineSize>> {
let map = match &self.state {
Some(val) => val,
None => match &self.default {
Some(val) => return Some(val),
None => return None,
},
};
match map.get(tensor_id) {
Some(val) => Some(val),
None => match &self.default {
Some(val) => Some(val),
None => None,
},
}
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn vectorization_default<'a, R: Runtime>(
vectorizations: &mut BTreeMap<TensorId, Vect>,
inputs: impl Iterator<Item = VectorizationHandle<'a, R>>,
outputs: impl Iterator<Item = &'a TensorIr>,
reshaped: impl Iterator<Item = (&'a TensorIr, &'a TensorIr, bool)>,
swapped: impl Iterator<Item = (&'a TensorIr, &'a TensorIr, bool, &'a (usize, usize))>,
line_sizes: &[LineSize],
overrides: &LineSizeOverrides,
max: LineSize,
axis: &VectorizationAxis,
) {
let swapped: Vec<_> = swapped.collect();
for input in inputs {
if let Some((s, o, mr, dims)) = swapped
.iter()
.find(|(_s, o, _mr, _dims)| input.is_from_tensor(o.id))
{
let (handle, id) = match input {
VectorizationHandle::NormalInput(handle, tensor_ir) => (handle, &tensor_ir.id),
VectorizationHandle::QuantValues(..) => panic!("Can't be swapped"),
VectorizationHandle::QuantParams => panic!("Can't be swapped"),
};
let val = vectorization_swapped(
handle,
s,
o,
*mr,
dims,
max,
axis,
line_sizes,
overrides.tensor(id),
);
multi_reads_vectorization_update(vectorizations, o.id, val);
} else {
match input {
VectorizationHandle::NormalInput(handle, tensor_ir) => {
let val = vectorization_input(
handle,
tensor_ir,
axis,
line_sizes,
overrides.tensor(&tensor_ir.id),
);
vectorizations.insert(tensor_ir.id, val);
}
VectorizationHandle::QuantValues(handle, tensor_ir) => {
let val = vectorization_input(
handle,
tensor_ir,
axis,
line_sizes,
overrides.tensor(&tensor_ir.id),
);
let num_quants = match tensor_ir.dtype {
burn_std::DType::QFloat(quant_scheme) => quant_scheme.num_quants(),
_ => panic!(""),
};
let val = match val {
Vect::Broadcasted => Vect::Aligned(1),
Vect::Aligned(val) => Vect::Aligned(val.div_ceil(num_quants)),
};
vectorizations.insert(tensor_ir.id, val);
}
VectorizationHandle::QuantParams => {
// Doesn't have vectorization for now.
}
};
}
}
for (reshaped, original, multi_reads) in reshaped {
let val = vectorization_reshape(
reshaped,
original,
multi_reads,
axis,
line_sizes,
max,
overrides.tensor(&original.id),
);
multi_reads_vectorization_update(vectorizations, original.id, val);
}
for tensor in outputs {
let val = vectorization_output(tensor, axis, line_sizes, max, overrides.tensor(&tensor.id));
vectorizations.insert(tensor.id, val);
}
}
fn multi_reads_vectorization_update(
vectorizations: &mut BTreeMap<TensorId, Vect>,
original: TensorId,
vect: Vect,
) {
if let Some(ori_vect) = vectorizations.get(&original).cloned() {
match ori_vect {
Vect::Broadcasted => {
// keep the original as is.
}
Vect::Aligned(ori) => match vect {
Vect::Broadcasted => {
vectorizations.insert(original, Vect::Aligned(1));
}
Vect::Aligned(new) => {
let val = if new != ori { 1 } else { new };
vectorizations.insert(original, Vect::Aligned(val));
}
},
};
} else {
vectorizations.insert(original, vect);
}
}
// The default version uses the last dimension as vectorization axis and assumes a
// perpendicular contiguous line.
fn vectorization_input<R: Runtime>(
handle: &CubeFusionHandle<R>,
desc: &TensorIr,
axis: &VectorizationAxis,
line_sizes: &[LineSize],
overrides: Option<&Vec<LineSize>>,
) -> Vect {
let axis = axis.get(desc.id, || handle.strides.len() - 1);
let shape_axis = desc.shape[axis];
if shape_axis == 1 {
return Vect::Broadcasted;
}
// Last dimension strides should be 1, otherwise vecX won't be contiguous.
if handle.strides[axis] != 1 {
return Vect::Aligned(1);
}
let inner = |s: LineSize| {
// The last dimension should be a multiple of the vector size or broadcated.
if shape_axis.is_multiple_of(s) {
return Some(Vect::Aligned(s));
}
None
};
match overrides {
Some(vals) => {
for s in vals {
if let Some(val) = inner(*s) {
return val;
}
}
}
None => {
for s in line_sizes {
if let Some(val) = inner(*s) {
return val;
}
}
}
}
Vect::Aligned(1)
}
fn vectorization_output(
desc: &TensorIr,
axis: &VectorizationAxis,
line_sizes: &[LineSize],
max: LineSize,
overrides: Option<&Vec<LineSize>>,
) -> Vect {
let axis = axis.get(desc.id, || desc.shape.rank() - 1);
let inner = |s: LineSize| {
// The dimension should be a multiple of the vector size.
if desc.shape[axis].is_multiple_of(s) && s <= max {
return Some(Vect::Aligned(s));
}
None
};
match overrides {
Some(val) => {
for s in val {
if let Some(val) = inner(*s) {
return val;
}
}
}
None => {
for s in line_sizes {
if let Some(val) = inner(*s) {
return val;
}
}
}
}
Vect::Aligned(1)
}
fn vectorization_reshape(
reshaped: &TensorIr,
original: &TensorIr,
multi_reads: bool,
axis: &VectorizationAxis,
line_sizes: &[LineSize],
max: LineSize,
overrides: Option<&Vec<LineSize>>,
) -> Vect {
let axis = axis.get(reshaped.id, || reshaped.shape.rank() - 1);
let reshape_shape_axis = reshaped.shape[axis];
if !multi_reads && reshape_shape_axis == 1 {
return Vect::Broadcasted;
}
// If the axis is not the last dim, didn't think of it, return Aligned(1) to be sure.
if axis != reshaped.shape.rank() - 1 {
return Vect::Aligned(1);
}
let original_shape_axis = original.shape[original.shape.rank() - 1];
if original_shape_axis != reshape_shape_axis {
return Vect::Aligned(1);
}
let inner = |s: LineSize| {
if !multi_reads {
// The last dimension should be a multiple of the vector size or broadcated.
if reshape_shape_axis.is_multiple_of(s) && s <= max {
Some(Vect::Aligned(s))
} else {
None
}
} else {
// Since the original tensor must share the same vectorization factor as the
// reshaped tensor, they must have compatible shapes when both are access
// independently.
if reshape_shape_axis.is_multiple_of(s)
&& original_shape_axis.is_multiple_of(s)
&& s <= max
{
Some(Vect::Aligned(s))
} else {
None
}
}
};
match overrides {
Some(val) => {
for i in val {
if let Some(vect) = inner(*i) {
return vect;
}
}
}
None => {
for s in line_sizes {
if let Some(vect) = inner(*s) {
return vect;
}
}
}
}
Vect::Aligned(1)
}
#[allow(clippy::too_many_arguments)]
fn vectorization_swapped<R: Runtime>(
handle: &CubeFusionHandle<R>,
swapped: &TensorIr,
original: &TensorIr,
multi_reads: bool,
dims: &(usize, usize),
max: LineSize,
axis: &VectorizationAxis,
line_sizes: &[LineSize],
overrides: Option<&Vec<LineSize>>,
) -> Vect {
let axis = axis.get(swapped.id, || swapped.shape.rank() - 1);
let swapped_axis = swapped.shape[axis];
let shape_axis = original.shape[axis];
let axis_index = axis;
let dim_index = if dims.0 == axis_index {
dims.1
} else if dims.1 == axis_index {
dims.0
} else {
axis_index
};
// Last dimension strides should be 1, otherwise vecX won't be contiguous.
if multi_reads {
if handle.strides[axis_index] != 1 {
return Vect::Aligned(1);
}
if handle.strides[dim_index] != 1 {
return Vect::Aligned(1);
}
} else if handle.strides[dim_index] != 1 {
return Vect::Aligned(1);
}
if !multi_reads && swapped_axis == 1 {
return Vect::Broadcasted;
}
let inner = |s: LineSize| {
// The last dimension should be a multiple of the vector size or broadcated.
if multi_reads {
if swapped_axis.is_multiple_of(s) && s <= max {
return Some(Vect::Aligned(s));
}
} else if swapped_axis.is_multiple_of(s) && shape_axis.is_multiple_of(s) && s <= max {
return Some(Vect::Aligned(s));
}
None
};
match overrides {
Some(val) => {
for s in val {
if let Some(val) = inner(*s) {
return val;
}
}
}
None => {
for s in line_sizes {
if let Some(val) = inner(*s) {
return val;
}
}
}
}
Vect::Aligned(1)
}

View File

@@ -0,0 +1,5 @@
mod base;
mod planner;
pub use base::*;
pub use planner::*;

View File

@@ -0,0 +1,438 @@
use super::{
super::{BlockPlan, HandleOutput, LaunchPlan},
Vect,
};
use crate::{
CubeFusionHandle,
engine::{
launch::{
HandleInput,
runner::{Vectorization, VectorizationHandle},
},
settings::VectorizationSetting,
trace::{FuseResources, TensorView, block::FuseBlock},
},
};
use burn_fusion::stream::Context;
use burn_ir::TensorId;
use cubecl::{
Runtime,
client::ComputeClient,
ir::{ElemType, StorageType, UIntKind},
};
use cubecl::{
ir::LineSize,
quant::scheme::{QuantScheme, QuantStore, QuantValue},
};
use std::marker::PhantomData;
/// Select the best vectorization factor for each tensor handle.
pub struct VectorizationPlanner<'a, R: Runtime> {
resources: &'a FuseResources,
blocks: &'a Vec<FuseBlock>,
_r: PhantomData<R>,
}
impl<'a, R: Runtime> VectorizationPlanner<'a, R> {
pub fn new(resources: &'a FuseResources, blocks: &'a Vec<FuseBlock>) -> Self {
Self {
resources,
blocks,
_r: PhantomData,
}
}
pub fn run<Runner: Vectorization<R>>(
self,
client: &ComputeClient<R>,
runner: &Runner,
context: &Context<'_, CubeFusionHandle<R>>,
plan: &mut LaunchPlan<'a, R>,
) {
let has_multiple_read = |tensor: &TensorId| {
let mut read_count = 0;
for block in plan.blocks.iter() {
read_count += block.reads.get(tensor).map(|a| a.len()).unwrap_or(0);
}
read_count > 1
};
let tensors_reshaped = self.resources.views.iter().filter_map(|view| match view {
TensorView::Reshape {
reshaped, original, ..
} => Some((
context.tensors.get(reshaped).unwrap(),
context.tensors.get(original).unwrap(),
has_multiple_read(original),
)),
TensorView::SwapDims { .. } => None,
});
let tensors_swapped = self.resources.views.iter().filter_map(|view| match view {
TensorView::SwapDims {
swapped,
original,
dims,
..
} => Some((
context.tensors.get(swapped).unwrap(),
context.tensors.get(original).unwrap(),
has_multiple_read(original),
dims,
)),
TensorView::Reshape { .. } => None,
});
let mut ref_elem = (ElemType::UInt(UIntKind::U64).into(), 8);
let mut quants_line_sizes: Option<Vec<LineSize>> = None;
for input in plan.handle_inputs.iter() {
let elem: StorageType = match input {
HandleInput::Normal(h) => h.global_ir.dtype.into(),
HandleInput::QuantValues(handle) => match handle.global_ir.dtype {
burn_std::DType::QFloat(scheme) => {
line_sizes_quants(client, &mut quants_line_sizes, scheme);
continue;
}
_ => panic!("Unable to retrieve the scheme for quantized values."),
},
HandleInput::QuantParams(..) => continue,
};
let elem_size = elem.size();
if ref_elem.1 >= elem_size {
ref_elem = (elem, elem_size);
}
}
for r in plan.global_outputs.iter() {
let elem: StorageType = r.dtype.into();
let elem_size = elem.size();
if ref_elem.1 >= elem_size {
ref_elem = (elem, elem_size);
}
}
let filtered = plan
.handle_inputs
.iter()
.map(|item| {
item.as_normal()
// Filter out indexed resources.
.map(|item| !self.resources.indexed.contains_key(&item.relative_id))
.unwrap_or(true)
})
.collect::<Vec<_>>();
let line_sizes = match quants_line_sizes {
// Quantization normally triggers higher vectorization than anything else, no need to
// compare to ref elem.
Some(line_sizes) => line_sizes,
None => client
.io_optimized_line_sizes(ref_elem.0.size())
.collect::<Vec<_>>(),
};
let vectorization_axis = runner.axis(plan);
runner.vectorization(
context,
&mut plan.vectorizations,
plan.handle_inputs
.iter()
.enumerate()
.filter_map(|(i, item)| {
if filtered[i] {
Some(match item {
HandleInput::Normal(h) => {
VectorizationHandle::NormalInput(&h.handle, &h.global_ir)
}
HandleInput::QuantValues(h) => {
VectorizationHandle::QuantValues(&h.handle, &h.global_ir)
}
HandleInput::QuantParams(_) => VectorizationHandle::QuantParams,
})
} else {
None
}
}),
plan.global_outputs.iter(),
tensors_reshaped,
tensors_swapped,
&line_sizes,
u8::MAX as usize,
vectorization_axis,
);
for tensor in self.resources.indexed.keys() {
let global = context.tensors.get(tensor).unwrap();
plan.vectorizations.insert(global.id, Vect::Aligned(1));
}
let mut block_vectorization = Vec::with_capacity(self.blocks.len());
for _ in 0..self.blocks.len() {
block_vectorization.push(Vec::new());
}
for (input_pos, handle) in plan.handle_inputs.iter_mut().enumerate() {
let (global_ir, relative_id) = match handle {
HandleInput::Normal(h) => (&h.global_ir, &h.relative_id),
HandleInput::QuantValues(h) => (&h.global_ir, &h.relative_id),
HandleInput::QuantParams(_) => continue,
};
let (vect, br) = match plan.vectorizations.get(&global_ir.id) {
Some(v) => (v.line_size(), v.is_broadcast()),
None => panic!("No vectorization factor found for {:?}", global_ir.id),
};
for (block_pos, block_plan) in plan.blocks.iter().enumerate() {
if block_plan.reads.contains_key(relative_id) {
block_vectorization[block_pos].push(BlockVectorization {
action: VectorizationAction::Input(input_pos),
potential: vect,
broadcasted: br,
});
}
}
}
for (output_pos, handle) in plan.handle_outputs.iter().enumerate() {
if let HandleOutput::Owned {
global_id,
relative_id,
..
} = handle
{
for (block_pos, block_plan) in plan.blocks.iter().enumerate() {
if block_plan.writes.contains_key(relative_id) {
let vectorization = plan.vectorizations.get(global_id).unwrap().line_size();
block_vectorization[block_pos].push(BlockVectorization {
action: VectorizationAction::Output(output_pos),
potential: vectorization,
broadcasted: false,
});
}
}
}
}
let mut previous_widths = Vec::with_capacity(block_vectorization.len());
// Unhandled inputs might not get included in any fused blocks for now.
//
// So we ensure they are vectorized by setting their vectorization before we set the
// vectorizations in blocks.
//
// Unhandled Outputs are correctly vectorized, so this is only necessary for inputs.
for input in self.resources.inputs_unhandled.iter() {
let pos = self
.resources
.inputs
.get_index(*input)
.unwrap_or_else(|| self.resources.inputs.get_index_quant(*input).unwrap());
let input_global = context.tensors.get(input).unwrap();
match plan.vectorizations.get(&input_global.id).unwrap() {
Vect::Aligned(vect) => {
let handle = &mut plan.handle_inputs[pos];
match handle {
HandleInput::Normal(handle) => {
handle.line_size = *vect;
}
HandleInput::QuantValues(handle) => {
handle.line_size = *vect;
}
HandleInput::QuantParams(_) => {}
}
}
Vect::Broadcasted => {}
}
}
for ((tmp, block_plan), block) in block_vectorization
.into_iter()
.zip(plan.blocks.iter_mut())
.zip(self.blocks)
{
match block.settings.vectorization {
VectorizationSetting::Activated => {
apply_vectorization_block(
tmp,
&mut plan.handle_inputs,
&mut plan.handle_outputs,
block_plan,
u8::MAX as usize,
);
}
VectorizationSetting::SmallerOrEqualThanPreviousBlock { block_pos } => {
apply_vectorization_block(
tmp,
&mut plan.handle_inputs,
&mut plan.handle_outputs,
block_plan,
previous_widths[block_pos],
);
if block_plan.width == 0 {
block_plan.width = previous_widths[block_pos];
}
}
VectorizationSetting::EqualThanPreviousBlock { block_pos } => {
apply_vectorization_block(
tmp,
&mut plan.handle_inputs,
&mut plan.handle_outputs,
block_plan,
previous_widths[block_pos],
);
// Enforces the width.
block_plan.width = previous_widths[block_pos];
}
VectorizationSetting::Deactivated => {
apply_vectorization_block(
tmp,
&mut plan.handle_inputs,
&mut plan.handle_outputs,
block_plan,
1,
);
block_plan.width = 1;
}
}
// When only virtual inputs/outputs are present for a block, we need to set a width.
if block_plan.width == 0 {
if let Some(w) = previous_widths.last() {
block_plan.width = *w;
} else {
block_plan.width = 1;
}
}
previous_widths.push(block_plan.width);
}
}
}
#[derive(Debug)]
enum VectorizationAction {
Input(usize),
Output(usize),
}
#[derive(Debug)]
struct BlockVectorization {
action: VectorizationAction,
potential: LineSize,
broadcasted: bool,
}
fn apply_vectorization_block<R: Runtime>(
block_vectorization: Vec<BlockVectorization>,
inputs: &mut [HandleInput<R>],
outputs: &mut [HandleOutput<R>],
block_plan: &mut BlockPlan,
max: LineSize,
) {
for item in block_vectorization {
match item.action {
VectorizationAction::Input(pos) => {
let (vect, br) = if item.potential <= max {
(item.potential, item.broadcasted)
} else {
(1, false)
};
match &mut inputs[pos] {
HandleInput::Normal(input) => {
input.line_size = vect;
input.broadcated = br;
}
HandleInput::QuantValues(input) => {
input.line_size = vect;
}
HandleInput::QuantParams(_) => {
// Not vectorized
}
}
if block_plan.width < vect {
block_plan.width = vect;
}
}
VectorizationAction::Output(pos) => {
if let HandleOutput::Owned { vectorization, .. } = &mut outputs[pos] {
let vect = if item.potential <= max {
item.potential
} else {
1
};
*vectorization = vect;
if block_plan.width < vect {
block_plan.width = vect;
}
}
}
}
}
}
fn line_sizes_quants<R: Runtime>(
client: &ComputeClient<R>,
quants_line_sizes: &mut Option<Vec<LineSize>>,
scheme: QuantScheme,
) {
match scheme.store {
QuantStore::Native => match scheme.value {
// Type sizes are the same so just treat fp8/fp4x2 as i8
QuantValue::Q8F
| QuantValue::Q8S
| QuantValue::E4M3
| QuantValue::E5M2
| QuantValue::E2M1 => {
let line_sizes = client
.io_optimized_line_sizes(size_of::<i8>())
.collect::<Vec<_>>();
match &quants_line_sizes {
Some(sizes) => {
if sizes[0] < line_sizes[0] {
*quants_line_sizes = Some(line_sizes);
}
}
None => {
*quants_line_sizes = Some(line_sizes);
}
}
}
QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
unreachable!("Can't store native sub-byte values")
}
},
QuantStore::PackedU32(_) => {
let mut line_sizes = client
.io_optimized_line_sizes(size_of::<u32>())
.collect::<Vec<_>>();
for val in line_sizes.iter_mut() {
*val *= scheme.num_quants();
}
match &quants_line_sizes {
Some(sizes) => {
if sizes[0] < line_sizes[0] {
let mut min = *line_sizes.last().unwrap();
while min > 1 {
min /= 2;
line_sizes.push(min);
}
*quants_line_sizes = Some(line_sizes);
}
}
None => {
*quants_line_sizes = Some(line_sizes);
}
}
}
QuantStore::PackedNative(_) => {
panic!("Not yet supported")
}
};
}

View File

@@ -0,0 +1,6 @@
pub(crate) mod codegen;
pub(crate) mod fuser;
pub(crate) mod launch;
pub(crate) mod settings;
pub mod trace;

View File

@@ -0,0 +1,59 @@
use serde::{Deserialize, Serialize};
/// Controls which operations can be fused.
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub struct FuseSettings {
/// Enables broadcasting of shapes.
pub broadcast: bool,
/// Enables output shape updates.
///
/// When broadcast is enabled, the output shape can become bigger after a fusion,
/// therefore an update is needed.
pub output_shape_updates: bool,
/// Enables the reuse of input buffers.
pub inplace: bool,
/// Whether vectorization is enabled.
pub vectorization: VectorizationSetting,
/// How [reference layout](super::ir::RefLayout) selection is done.
pub ref_layout: RefLayoutSetting,
}
impl Default for FuseSettings {
fn default() -> Self {
Self {
broadcast: true,
output_shape_updates: true,
inplace: true,
vectorization: VectorizationSetting::Activated,
ref_layout: RefLayoutSetting::Any,
}
}
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
/// How vectorization is handled during fusion.
pub enum VectorizationSetting {
/// The biggest line_size possible will be used.
Activated,
/// Equivalent to using line_size of one.
Deactivated,
/// This is a good setting when a block processes values calculated from a previous block.
SmallerOrEqualThanPreviousBlock { block_pos: usize },
/// This is a good setting when a block processes values calculated from a previous block.
EqualThanPreviousBlock { block_pos: usize },
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
/// Influence how the [reference layout](super::ir::RefLayout) selection is done.
pub enum RefLayoutSetting {
/// Any reference layout is allowed.
Any,
/// Only contiguous reference layout is allowed.
///
/// Note that forcing a contiguous reference layout might reduce the opportunity of inplace
/// fusion.
OnlyContiguous,
SameAsBlock {
block_pos: u32,
},
}

View File

@@ -0,0 +1,377 @@
use crate::engine::{
codegen::ir::{FuseArg, FuseType},
trace::block::FuseBlock,
};
use burn_ir::{TensorId, TensorIr};
use burn_std::{Shape, Strides};
use cubecl::prelude::*;
use serde::{Deserialize, Serialize};
use std::{
collections::{BTreeMap, HashSet},
marker::PhantomData,
};
#[cfg(feature = "autotune-checks")]
use crate::CubeFusionHandle;
#[cfg(feature = "autotune-checks")]
use burn_backend::TensorData;
#[cfg(feature = "autotune-checks")]
use std::collections::HashMap;
#[derive(Clone, Serialize, Deserialize, Debug)]
/// A trace contains all [blocks](FuseBlock) and the [resources](FuseResources) used by the
/// kernel.
pub struct FuseTrace {
pub blocks: Vec<FuseBlock>,
pub resources: FuseResources,
}
impl core::fmt::Display for FuseTrace {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "FuseTrace")?;
for b in self.blocks.iter() {
writeln!(f, " - Block shape={:?}", b.shape_ref)?;
for (tensor, ops) in b.reads.iter() {
for op in ops.iter() {
writeln!(f, " - {op} <== {tensor}")?;
}
}
for op in b.ops.iter() {
writeln!(f, " - {op}")?;
}
for (tensor, ops) in b.writes.iter() {
for op in ops.iter() {
writeln!(f, " - {op} <== {tensor}")?;
}
}
}
Ok(())
}
}
pub enum TuneOutput<R: Runtime> {
UnChecked(PhantomData<R>),
#[cfg(feature = "autotune-checks")]
Checked {
handles: HashMap<TensorId, (Vec<usize>, CubeFusionHandle<R>)>,
},
}
impl<R: Runtime> TuneOutput<R> {
#[allow(unused_variables)]
pub fn merge(self, other: Self) -> Self {
let mut result = self;
match &mut result {
TuneOutput::UnChecked(..) => {}
#[cfg(feature = "autotune-checks")]
TuneOutput::Checked { handles } => match other {
TuneOutput::UnChecked(..) => {}
TuneOutput::Checked { handles: o } => {
for (k, v) in o.into_iter() {
handles.insert(k, v);
}
}
},
}
result
}
}
impl<R: Runtime> cubecl::tune::AutotuneOutput for TuneOutput<R> {
#[cfg(feature = "autotune-checks")]
fn check_equivalence(&self, other: Self) {
use burn_backend::Tolerance;
use burn_std::DType;
if let (
TuneOutput::Checked {
handles: handles_ref,
},
TuneOutput::Checked { handles },
) = (self, &other)
{
let mut num_checked = 0;
let mut num_handles = 0;
for (id, (shape, handle)) in handles_ref.iter() {
num_handles += 1;
if let Some((shape_other, other)) = handles.get(id) {
use burn_std::is_contiguous;
use cubecl::std::tensor::into_contiguous_ref;
let current_handle = if !is_contiguous(&shape, &handle.strides) {
into_contiguous_ref::<R>(
&handle.client,
&handle.as_handle_ref(&shape),
handle.dtype.into(),
)
.unwrap()
.handle
} else {
handle.handle.clone()
};
let other_handle = if !is_contiguous(&shape, &other.strides) {
into_contiguous_ref::<R>(
&other.client,
&other.as_handle_ref(&shape),
other.dtype.into(),
)
.unwrap()
.handle
} else {
other.handle.clone()
};
let data_ref = handle.client.read_one(current_handle);
let data_other = other.client.read_one(other_handle);
let data_ref = TensorData::from_bytes(data_ref, shape.clone(), handle.dtype);
let data_other =
TensorData::from_bytes(data_other, shape_other.clone(), handle.dtype);
match handle.dtype {
DType::F64 => {
data_ref.assert_approx_eq::<f64>(&data_other, Tolerance::permissive())
}
DType::F32 => {
data_ref.assert_approx_eq::<f32>(&data_other, Tolerance::permissive())
}
DType::F16 => data_ref
.assert_approx_eq::<half::f16>(&data_other, Tolerance::permissive()),
DType::BF16 => data_ref
.assert_approx_eq::<half::bf16>(&data_other, Tolerance::permissive()),
_ => data_ref.assert_eq(&data_other, true),
}
num_checked += 1;
} else {
// Debug info for the tests.
println!("No tensor found for {id:?}=>{shape:?}");
}
}
// At least one check is needed per output when there is an output.
//
// Some optimizations might write more outputs than needed, so it might be fined if
// the number of handles is different, but at least one is required.
//
// An optimization might not create outputs if its dead code detection is triggered,
// therefore avoiding useless computation.
if num_handles > 0 {
assert!(num_checked >= 1);
}
}
}
}
#[derive(Clone, Serialize, Deserialize, Debug, Default)]
/// Declare all resources used by the kernel, and potentially multiple [blocks](FuseBlock).
///
/// # Notes
///
/// Each block can't contain their own resources, since they are shared between blocks. The
/// vectorization factor of one input tensor must be the same for all blocks.
pub struct FuseResources {
pub outputs: RegisteredTensors,
pub inputs: RegisteredTensors,
pub scalars: Vec<(FuseType, u64)>,
// TODO: Making put a map of global registers.
pub views: Vec<TensorView>,
pub indexed: BTreeMap<TensorId, FuseArg>,
pub inputs_unhandled: Vec<TensorId>,
pub outputs_unhandled: Vec<FuseArg>,
pub num_reshaped: usize,
/// Necessary to remove some entries from the context.
pub dropped: HashSet<TensorId>,
/// We know during fusion that we have to have those buffers has global.
/// The pos here can be interpreted as GLOBAL pos where the output pos are locals.
pub buffers: RegisteredTensors,
/// Global registers available everywhere.
///
/// TODO: Not all registers should be globals.
pub registers: BTreeMap<TensorId, FuseArg>,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct RuntimeLayout {
pub shape: Shape,
pub strides: Strides,
}
impl Default for RuntimeLayout {
fn default() -> Self {
Self {
shape: Shape::new([]),
strides: Strides::new(&[]),
}
}
}
#[derive(Debug)]
pub enum TraceError<Err> {
ReferenceNotFound,
RunnerError(Err),
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub enum TensorView {
Reshape {
reshaped: TensorId,
original: TensorId,
reshape_pos: usize,
shape_relative: Shape,
},
SwapDims {
swapped: TensorId,
original: TensorId,
dims: (usize, usize),
},
}
#[derive(Default, Clone, Serialize, Deserialize, Debug)]
pub struct RegisteredTensors {
tensors: Vec<RegisterTensor>,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub enum RegisterTensor {
Normal(TensorIr, FuseType),
QuantValues(TensorIr),
QuantParams(TensorId),
}
impl RegisterTensor {
pub fn as_normal_tensor(&self) -> Option<(&TensorIr, &FuseType)> {
match self {
RegisterTensor::Normal(tensor_ir, precision) => Some((tensor_ir, precision)),
RegisterTensor::QuantValues(_) => None,
RegisterTensor::QuantParams(_) => None,
}
}
}
impl RegisteredTensors {
/// Iterate over all the registered tensors.
pub fn iter(&self) -> impl Iterator<Item = &RegisterTensor> {
self.tensors.iter()
}
/// Consumes and iterate over all the registered tensors.
pub fn into_iter(self) -> impl Iterator<Item = RegisterTensor> {
self.tensors.into_iter()
}
/// Returns the number of tensors registered.
pub fn len(&self) -> usize {
self.tensors.len()
}
/// Retrieve the [tensor id](TensorId) at the given index.
pub fn get_id(&self, index: usize) -> Option<TensorId> {
self.tensors.get(index).map(|entry| match entry {
RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id,
RegisterTensor::QuantValues(tensor_ir) => tensor_ir.id,
RegisterTensor::QuantParams(tensor_id) => *tensor_id,
})
}
/// Doesn't return quantized tensor.
pub fn get_index(&self, tensor_id: TensorId) -> Option<usize> {
self.tensors
.iter()
.enumerate()
.find(|(_pos, entry)| match entry {
RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id == tensor_id,
RegisterTensor::QuantValues(_) => false,
RegisterTensor::QuantParams(_) => false,
})
.map(|(pos, _)| pos)
}
/// Get the index of a quantized tensor.
pub fn get_index_quant(&self, tensor_id: TensorId) -> Option<usize> {
self.tensors
.iter()
.enumerate()
.find(|(_pos, entry)| match entry {
RegisterTensor::Normal(..) => false,
RegisterTensor::QuantValues(tensor_ir) => tensor_ir.id == tensor_id,
RegisterTensor::QuantParams(_) => false,
})
.map(|(pos, _)| pos)
}
/// Doesn't return quantized tensor.
pub fn get(&self, tensor_id: TensorId) -> Option<(&TensorIr, &FuseType)> {
self.tensors
.iter()
.find(|entry| match entry {
RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id == tensor_id,
RegisterTensor::QuantValues(_) => false,
RegisterTensor::QuantParams(_) => false,
})
.and_then(|entry| match entry {
RegisterTensor::Normal(tensor_ir, fuse_precision) => {
Some((tensor_ir, fuse_precision))
}
RegisterTensor::QuantValues(_) => None,
RegisterTensor::QuantParams(_) => None,
})
}
/// Insert a quantized tensor.
///
/// It will return the positions for both the value tensor and param tensor.
pub fn insert_quant(&mut self, tensor: TensorIr) -> (usize, usize) {
if let Some(old) = self.tensors.iter().enumerate().find(|(_, val)| match &val {
RegisterTensor::QuantValues(tensor_ir) => tensor_ir == &tensor,
_ => false,
}) {
let values = old.0;
let params = values + 1;
return (values, params);
}
let params = RegisterTensor::QuantParams(tensor.id);
let values = RegisterTensor::QuantValues(tensor);
let pos_values = self.len();
self.tensors.push(values);
let pos_params = self.len();
self.tensors.push(params);
(pos_values, pos_params)
}
/// Insert a normal tensor with the given [precision](FusePrecision) in the current block.
pub fn insert(&mut self, precision: FuseType, tensor: TensorIr) -> usize {
if let Some(old) = self.tensors.iter().enumerate().find(|(_, val)| match &val {
RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id == tensor.id,
_ => false,
}) {
return old.0;
}
let value = RegisterTensor::Normal(tensor, precision);
let pos = self.len();
self.tensors.push(value);
pos
}
/// Update the already registered tensor with the given [tensor ir](TensorIr).
///
/// # Notes
///
/// This function only works with normal tensors, not quantized tensors.
pub fn update(&mut self, tensor: &TensorIr) {
if let Some(entry) = self.tensors.iter_mut().find(|entry| match entry {
RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id == tensor.id,
_ => false,
}) && let RegisterTensor::Normal(tensor_ir, _) = entry
{
tensor_ir.status = tensor.status
}
}
}

View File

@@ -0,0 +1,555 @@
use super::{FuseResources, RegisteredTensors, TensorView};
use crate::engine::{
codegen::ir::{FuseArg, FuseOp, FuseType, LayoutInfo, MultiBlockPos, UnaryFuseArgs},
settings::FuseSettings,
};
use burn_ir::{TensorId, TensorIr, TensorStatus};
use burn_std::{DType, Shape, quantization::QuantParam};
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, btree_map::Entry};
#[derive(Clone, Serialize, Deserialize, Debug)]
/// A block containing all [operations](FuseOp) as well as reads and writes for each tensor along
/// with the [fusion settings](FuseSettings).
pub struct FuseBlock {
/// Contains the [fusion settings](FuseSettings) associated to the current block.
pub settings: FuseSettings,
/// Contains all the [operations](FuseOp) registered in the current block.
pub ops: Vec<FuseOp>,
/// The reference shape of the current block.
pub shape_ref: Shape,
/// Contains all tensor inputs of the current block except for manually handled tensors.
///
/// # Notes
///
/// Some reads might not have read operations registered, such as dequantization, but it's
/// important to be registered here for vectorization. Input tensors that are not
/// registered here must be vectorized manually.
pub reads: BTreeMap<TensorId, Vec<FuseOp>>,
/// Contains all tensor outputs of the current block except for manually handled tensors.
/// We can have multiple writes when the same variable is reused after in another block.
pub writes: BTreeMap<TensorId, Vec<FuseOp>>,
}
#[derive(Clone, Debug)]
/// It is responsible to build a [trace](FuseBlock).
pub struct FuseBlockBuilder {
pub settings: FuseSettings,
locals: LocalVariablePool,
pub ops: Vec<FuseOp>,
reads: BTreeMap<TensorId, Vec<FuseOp>>,
// Only for global registers.
writes: BTreeMap<TensorId, Vec<FuseOp>>,
bool_precision: FuseType,
// Output declared in this block alone.
outputs: RegisteredTensors,
pub outputs_unhandled: Vec<FuseArg>,
pub local_inputs: BTreeMap<TensorId, FuseArg>,
/// The reference shape used by this block.
pub shape_ref: Shape,
}
#[derive(Debug)]
/// How a quantized input can be read.
pub enum QuantInput {
/// If already dequantized, we cache the dequantization and returns the local variable
/// corresponding to the float value.
AlreadyDequantized { local: FuseArg },
/// Otherwise we return the information necessary to dequantize the tensor.
Quantized { values: FuseArg, params: FuseArg },
}
impl FuseBlockBuilder {
pub fn new(bool_precision: FuseType, settings: FuseSettings) -> Self {
Self {
bool_precision,
settings,
locals: Default::default(),
ops: Default::default(),
reads: Default::default(),
writes: Default::default(),
outputs: Default::default(),
outputs_unhandled: Default::default(),
local_inputs: Default::default(),
shape_ref: Shape::new([]),
}
}
/// Register an output tensor.
pub fn output(&mut self, tensor: &TensorIr, resources: &mut FuseResources) -> Option<FuseArg> {
if resources.indexed.contains_key(&tensor.id) {
return None;
}
if matches!(tensor.dtype, DType::QFloat(..)) {
return None;
}
let precision = tensor.dtype.into();
// Bool tensors are encoded as bool_precision.
let precision_output = match precision {
FuseType::Bool => self.bool_precision,
_ => precision,
};
let out = match self.locals.get(precision, tensor.id) {
Some(local) => local,
None => {
let out = self.locals.create(precision, tensor.id);
self.outputs.insert(precision_output, tensor.clone());
resources.outputs.insert(precision_output, tensor.clone());
out
}
};
Some(out)
}
/// Register an input tensor.
pub fn multi_block_variable(
&mut self,
block_pos: usize,
tensor: &TensorIr,
global: bool,
) -> Option<FuseArg> {
let precision = tensor.dtype.into();
if let Some(val) = self.local_inputs.get(&tensor.id) {
return Some(val.clone());
}
let val = match self.locals.get(precision, tensor.id) {
Some(val) => val,
None => {
return None;
}
};
let arg = if global {
FuseArg::MultiBlockGlobal(
MultiBlockPos {
block_pos,
block_local_pos: self.writes.len(),
},
val.precision(),
)
} else {
FuseArg::MultiBlockLocal(
MultiBlockPos {
block_pos,
block_local_pos: self.writes.len(),
},
val.precision(),
)
};
let ops = match self.writes.get_mut(&tensor.id) {
Some(ops) => ops,
None => {
self.writes.insert(tensor.id, Vec::new());
self.writes.get_mut(&tensor.id).unwrap()
}
};
ops.push(FuseOp::Assign(UnaryFuseArgs {
input: val,
out: arg.clone(),
}));
Some(arg)
}
/// Register an input tensor.
pub fn input(&mut self, tensor: &TensorIr, resources: &mut FuseResources) -> Option<FuseArg> {
if resources.indexed.contains_key(&tensor.id) {
return None;
}
if matches!(tensor.dtype, DType::QFloat(..)) {
return None;
}
let precision = tensor.dtype.into();
// Bool tensors are encoded as bool_precision.
let precision_input = match precision {
FuseType::Bool => self.bool_precision,
_ => precision,
};
if let Some(val) = self.local_inputs.get(&tensor.id) {
return Some(val.clone());
}
let arg = match self.locals.get(precision, tensor.id) {
Some(local) => {
resources.inputs.update(tensor);
local
}
None => {
let input = if resources.outputs.get_index(tensor.id).is_some() {
if let Some(val) = resources.registers.get(&tensor.id) {
return Some(val.clone());
};
let pos = resources.buffers.insert(precision, tensor.clone());
FuseArg::Output(pos, precision_input, LayoutInfo::Unknown)
} else {
let pos = resources.inputs.insert(precision_input, tensor.clone());
FuseArg::Input(pos, precision_input, LayoutInfo::Unknown)
};
let out = self.locals.create(precision, tensor.id);
let reads = if let Entry::Vacant(e) = self.reads.entry(tensor.id) {
e.insert(Vec::with_capacity(1));
self.reads.get_mut(&tensor.id).unwrap()
} else {
self.reads.get_mut(&tensor.id).unwrap()
};
reads.push(FuseOp::Assign(UnaryFuseArgs {
input,
out: out.clone(),
}));
out
}
};
Some(arg)
}
/// Register an input quantized tensor.
pub fn input_quant(
&mut self,
tensor: &TensorIr,
resources: &mut FuseResources,
) -> Option<QuantInput> {
if resources.indexed.contains_key(&tensor.id) {
return None;
}
let precision = tensor.dtype.into();
let precision_scales = match tensor.dtype {
DType::QFloat(scheme) => match scheme.param {
QuantParam::F32 => FuseType::F32,
QuantParam::F16 => FuseType::F16,
QuantParam::BF16 => FuseType::BF16,
QuantParam::UE8M0 | QuantParam::UE4M3 => {
unimplemented!("Unsupported fuse precision");
}
},
_ => return None,
};
let arg = match self.locals.get(precision, tensor.id) {
Some(local) => {
resources.inputs.update(tensor);
QuantInput::AlreadyDequantized { local }
}
None => {
let (new_input, q_index) = resources.inputs.insert_quant(tensor.clone());
let input = FuseArg::Input(new_input, precision, LayoutInfo::Unknown);
let scales = FuseArg::Input(q_index, precision_scales, LayoutInfo::Unknown);
// Important to flag that there is a read, even if no operation is registered.
if let Entry::Vacant(e) = self.reads.entry(tensor.id) {
e.insert(Vec::new());
};
QuantInput::Quantized {
values: input,
params: scales,
}
}
};
Some(arg)
}
/// Register an input with swapped dims.
pub fn input_swap_dims(
&mut self,
tensor: &TensorIr,
output: &TensorIr,
dims: (usize, usize),
resources: &mut FuseResources,
) -> Option<FuseArg> {
if matches!(tensor.dtype, DType::QFloat(..)) {
return None;
}
let precision = tensor.dtype.into();
// Bool tensors are encoded as bool_precision.
let precision_input = match precision {
FuseType::Bool => self.bool_precision,
_ => precision,
};
let input_index = match self.locals.get(precision, tensor.id) {
Some(_) => {
// Can't fused an already fused input.
if resources.outputs.get(tensor.id).is_some() {
return None;
}
match resources.inputs.get_index(tensor.id) {
Some(index) => {
resources.inputs.update(tensor);
index
}
None => {
return None;
}
}
}
None => resources.inputs.insert(precision_input, tensor.clone()),
};
let out = self.output(output, resources)?;
let original = FuseArg::Input(input_index, precision_input, LayoutInfo::Unknown);
let broadcasted = output.shape[output.shape.rank() - 1] == 0;
resources.views.push(TensorView::SwapDims {
swapped: output.id,
original: tensor.id,
dims,
});
let input = FuseArg::InputSwapDims {
original: Box::new(original),
dims,
broadcasted,
};
let reads = if let Entry::Vacant(e) = self.reads.entry(tensor.id) {
e.insert(Vec::with_capacity(1));
self.reads.get_mut(&tensor.id).unwrap()
} else {
self.reads.get_mut(&tensor.id).unwrap()
};
reads.push(FuseOp::Assign(UnaryFuseArgs {
input,
out: out.clone(),
}));
Some(out)
}
/// Register an input that is reshaped.
pub fn input_reshaped(
&mut self,
tensor: &TensorIr,
output: &TensorIr,
resources: &mut FuseResources,
) -> Option<FuseArg> {
if matches!(tensor.dtype, DType::QFloat(..)) {
return None;
}
let precision = tensor.dtype.into();
// Bool tensors are encoded as bool_precision.
let precision_input = match precision {
FuseType::Bool => self.bool_precision,
_ => precision,
};
let input_index = match self.locals.get(precision, tensor.id) {
Some(_) => {
// Can't fused an already fused input.
if resources.outputs.get(tensor.id).is_some() {
return None;
}
match resources.inputs.get_index(tensor.id) {
Some(index) => {
resources.inputs.update(tensor);
index
}
None => {
return None;
}
}
}
None => resources.inputs.insert(precision_input, tensor.clone()),
};
let out = self.output(output, resources)?;
let original = FuseArg::Input(input_index, precision_input, LayoutInfo::Unknown);
let mut shape = Vec::new();
let index = resources.num_reshaped;
resources.num_reshaped += 1;
let rank = output.shape.rank();
for i in 0..output.shape.rank() {
let id = index * rank + i;
shape.push(FuseArg::ScalarShape(id));
}
resources.views.push(TensorView::Reshape {
reshaped: output.id,
original: tensor.id,
reshape_pos: index,
shape_relative: output.shape.clone(),
});
let input = FuseArg::InputReshaped {
original: Box::new(original),
shape,
broadcasted: output.shape[rank - 1] == 0,
};
let reads = if let Entry::Vacant(e) = self.reads.entry(tensor.id) {
e.insert(Vec::with_capacity(1));
self.reads.get_mut(&tensor.id).unwrap()
} else {
self.reads.get_mut(&tensor.id).unwrap()
};
reads.push(FuseOp::Assign(UnaryFuseArgs {
input,
out: out.clone(),
}));
Some(out)
}
/// Build into a fuse block.
pub fn build(
&self,
resources: &FuseResources,
outputs: &mut RegisteredTensors,
buffers: &mut Vec<TensorId>,
) -> FuseBlock {
let ops = self.ops.clone();
let reads = self.reads.clone();
let tensor_writes = self.tensor_writes(resources, buffers);
let mut writes = self.writes.clone();
for (tensor, precision) in tensor_writes
.iter()
.filter_map(|entry| entry.as_normal_tensor())
{
if let Some(local) = self.locals.get_any_precision(tensor.id) {
let out_index = outputs.insert(*precision, tensor.clone());
let ops = match writes.get_mut(&tensor.id) {
Some(ops) => ops,
None => {
writes.insert(tensor.id, Vec::new());
writes.get_mut(&tensor.id).unwrap()
}
};
ops.push(FuseOp::Assign(UnaryFuseArgs {
input: local,
out: FuseArg::Output(out_index, *precision, LayoutInfo::Unknown),
}));
}
}
FuseBlock {
settings: self.settings,
ops,
shape_ref: self.shape_ref.clone(),
reads,
writes,
}
}
/// Return the tensor that needs to be written to.
///
/// # Notes
///
/// The buffers vector passed as input is only to track the intermediary buffer writes needed
/// during execution.
pub fn tensor_writes(
&self,
resources: &FuseResources,
buffers: &mut Vec<TensorId>,
) -> RegisteredTensors {
let mut result = RegisteredTensors::default();
// All tensors where their latest representation is not read write should be written to since they
// are going to be used after the fused kernel by other operations.
for output in self.outputs.iter() {
if let Some((tensor, _precision)) = output.as_normal_tensor() {
// We get the latest representation from the resources, not just this block.
if let Some((tensor, precision)) = resources.outputs.get(tensor.id) {
if !matches!(tensor.status, TensorStatus::ReadWrite) {
result.insert(*precision, tensor.clone());
} else if resources.buffers.get(tensor.id).is_some()
&& !buffers.contains(&tensor.id)
{
result.insert(*precision, tensor.clone());
// We make sure we don't write multiple time in the same buffer, only the
// earliest possible.
buffers.push(tensor.id);
}
}
}
}
result
}
}
#[derive(Default, Clone, Debug)]
pub struct LocalVariablePool {
values: BTreeMap<FuseType, BTreeMap<TensorId, usize>>,
}
impl LocalVariablePool {
fn get(&self, precision: FuseType, tensor_id: TensorId) -> Option<FuseArg> {
if let Some(indexes) = self.values.get(&precision)
&& let Some(index) = indexes.get(&tensor_id)
{
return Some(FuseArg::BlockLocal {
pos: *index,
ty: precision,
});
}
None
}
fn get_any_precision(&self, tensor_id: TensorId) -> Option<FuseArg> {
for (precision, indexes) in self.values.iter() {
if let Some(index) = indexes.get(&tensor_id) {
return Some(FuseArg::BlockLocal {
pos: *index,
ty: *precision,
});
}
}
None
}
fn create(&mut self, precision: FuseType, tensor_id: TensorId) -> FuseArg {
if let Some(indexes) = self.values.get_mut(&precision) {
let new_index = indexes.len();
indexes.insert(tensor_id, new_index);
return FuseArg::BlockLocal {
pos: new_index,
ty: precision,
};
}
let new_index = 0;
self.values
.insert(precision, BTreeMap::from_iter([(tensor_id, new_index)]));
FuseArg::BlockLocal {
pos: new_index,
ty: precision,
}
}
}

View File

@@ -0,0 +1,330 @@
use super::{
super::{
codegen::ir::{FuseArg, FuseOp, FuseType, LayoutInfo},
settings::FuseSettings,
},
FuseResources,
block::FuseBlockBuilder,
};
use super::{FuseTrace, RegisteredTensors};
use crate::engine::trace::block::QuantInput;
use burn_fusion::stream::ScalarId;
use burn_ir::{ScalarIr, TensorIr};
use burn_std::{DType, Shape};
use cubecl::quant::scheme::QuantParam;
#[derive(Clone, Debug)]
/// It is responsible to create a [trace](FuseTrace) composed of multiple [blocks](super::block::FuseBlock).
///
/// It mostly handles the [resources](KernelResources) needed by the generated fused kernel, and
/// delegates most of the work to the [block builder](FuseBlockBuilder).
pub struct TraceFuser {
settings: FuseSettings,
pub bool_precision: FuseType,
// The tensors returned by the block that don't need to be written to global memory.
block_current: FuseBlockBuilder,
blocks_previous: Vec<FuseBlockBuilder>,
resources: FuseResources,
}
impl TraceFuser {
/// Create a new trace builder with the given bool precision and [fuse settings](FuseSettings).
pub fn new(bool_precision: FuseType, settings: FuseSettings) -> Self {
Self {
settings,
bool_precision,
block_current: FuseBlockBuilder::new(bool_precision, settings),
blocks_previous: Default::default(),
resources: Default::default(),
}
}
/// Get the number of blocks that are closed.
pub fn num_previous_blocks(&self) -> usize {
self.blocks_previous.len()
}
/// Tag a tensor as dropped.
pub fn fuse_dropped(&mut self, tensor: &TensorIr) {
self.resources.outputs.update(tensor);
self.resources.inputs.update(tensor);
self.resources.dropped.insert(tensor.id);
}
/// Register an operation.
pub fn fuse_operation(&mut self, op: FuseOp) {
self.block_current.ops.push(op);
}
/// The number of operations fused.
pub fn num_ops_fused(&self) -> u32 {
let mut num_ops_fused = 0;
for block in self.blocks_previous.iter() {
num_ops_fused += block.ops.len();
}
num_ops_fused += self.block_current.ops.len();
num_ops_fused as u32
}
/// Close the current block with the given reference shape and creates a new one with new [fusion settings](FuseSettings).
pub fn next_block(&mut self, shape_ref: Shape, settings: FuseSettings) {
let mut block_new = FuseBlockBuilder::new(self.bool_precision, settings);
core::mem::swap(&mut self.block_current, &mut block_new);
block_new.shape_ref = shape_ref;
self.blocks_previous.push(block_new);
self.settings = settings;
}
// Estimate how many bindings are in use right now. This can return more than the actual number
// but should never return less.
pub fn estimate_bindings(&self) -> u32 {
let mut buffers = Vec::new();
let mut estimation = 1; // Metadata takes one.
// We assume we are not going to write multiple times in the same output buffer.
for b in self.blocks_previous.iter() {
estimation += b.tensor_writes(&self.resources, &mut buffers).len() as u32;
}
estimation += self
.block_current
.tensor_writes(&self.resources, &mut buffers)
.len() as u32;
estimation += self.resources.inputs.len() as u32;
// One buffer per scalar type for now.
estimation += self.resources.scalars.len() as u32;
estimation
}
/// Tag the [tensor](TensorIr) as received from a previous block.
///
/// This will avoid reading the input again and instead use le local version when possible.
pub fn block_local_input(
&mut self,
tensor: &TensorIr,
block_pos: usize,
global: bool,
) -> FuseArg {
let block = &mut self.blocks_previous[block_pos];
let src_arg = match block.multi_block_variable(block_pos, tensor, global) {
Some(val) => val,
None => {
// We try to read the input if not present.
block.input(tensor, &mut self.resources);
block
.multi_block_variable(block_pos, tensor, global)
.unwrap()
}
};
self.resources.outputs.update(tensor);
if global {
self.resources.registers.insert(tensor.id, src_arg.clone());
}
self.block_current
.local_inputs
.insert(tensor.id, src_arg.clone());
src_arg
}
/// Register an output tensor that won't be automatically synced into global memory.
///
/// It is therefore the responsibility of the operation to write the result to given tensor.
pub fn output_unhandled(&mut self, tensor: &TensorIr) -> FuseArg {
let arg = self
.output(tensor)
.expect("Can't add a new output that is already used in an index operation");
self.resources.outputs_unhandled.push(arg.clone());
self.block_current.outputs_unhandled.push(arg.clone());
arg
}
/// Register an input tensor that won't be automatically read into a local variable.
///
/// It is therefore the responsibility of the operation to read the given tensor.
pub fn input_unhandled(&mut self, tensor: &TensorIr) -> FuseArg {
if self.resources.indexed.contains_key(&tensor.id) {
panic!("Can't add a new input that is already used in an index operation");
}
self.resources.outputs.update(tensor);
let precision = tensor.dtype.into();
// Bool tensors are encoded as bool_precision.
let precision_input = match precision {
FuseType::Bool => self.bool_precision,
_ => precision,
};
let new_input = self
.resources
.inputs
.insert(precision_input, tensor.clone());
let arg = FuseArg::Input(new_input, precision_input, LayoutInfo::Unknown);
self.resources.inputs_unhandled.push(tensor.id);
arg
}
/// Register an input tensor.
pub fn input_quantized_unhandled(&mut self, tensor: &TensorIr) -> Option<(FuseArg, FuseArg)> {
if self.resources.indexed.contains_key(&tensor.id) {
panic!("Can't add a new input that is already used in an index operation");
}
self.resources.outputs.update(tensor);
let precision = tensor.dtype.into();
let precision_scales = match tensor.dtype {
DType::QFloat(scheme) => match scheme.param {
QuantParam::F32 => FuseType::F32,
QuantParam::F16 => FuseType::F16,
QuantParam::BF16 => FuseType::BF16,
QuantParam::UE8M0 | QuantParam::UE4M3 => {
unimplemented!("Unsupported fuse precision");
}
},
_ => return None,
};
let (new_input, q_index) = self.resources.inputs.insert_quant(tensor.clone());
let input = FuseArg::Input(new_input, precision, LayoutInfo::Unknown);
let scales = FuseArg::Input(q_index, precision_scales, LayoutInfo::Unknown);
self.resources.inputs_unhandled.push(tensor.id);
Some((input, scales))
}
/// Register an input tensor.
pub fn input(&mut self, tensor: &TensorIr) -> Option<FuseArg> {
if matches!(tensor.dtype, DType::QFloat(_)) {
return None;
}
self.resources.outputs.update(tensor);
self.block_current.input(tensor, &mut self.resources)
}
/// Register an input tensor.
pub fn input_quantized(&mut self, tensor: &TensorIr) -> Option<QuantInput> {
self.resources.outputs.update(tensor);
self.block_current.input_quant(tensor, &mut self.resources)
}
/// Register an output tensor.
pub fn output(&mut self, tensor: &TensorIr) -> Option<FuseArg> {
if matches!(tensor.dtype, DType::QFloat(_)) {
return None;
}
self.block_current.output(tensor, &mut self.resources)
}
/// Register an input that will be accessed using custom indexing with no vectorization.
pub fn input_indexed(&mut self, tensor: &TensorIr) -> Option<FuseArg> {
if matches!(tensor.dtype, DType::QFloat(_)) {
return None;
}
if let Some(val) = self.resources.indexed.get(&tensor.id) {
self.resources.outputs.update(tensor);
return Some(val.clone());
};
if self.resources.inputs.get(tensor.id).is_some() {
return None;
}
if self.resources.outputs.get(tensor.id).is_some() {
return None;
}
let input = self.input_unhandled(tensor);
self.resources.indexed.insert(tensor.id, input.clone());
Some(input)
}
/// Register an input with swapped dims.
pub fn input_swap_dims(
&mut self,
tensor: &TensorIr,
output: &TensorIr,
dims: (usize, usize),
) -> Option<FuseArg> {
if matches!(tensor.dtype, DType::QFloat(_)) {
return None;
}
self.resources.outputs.update(tensor);
self.block_current
.input_swap_dims(tensor, output, dims, &mut self.resources)
}
/// Register an input that is reshaped.
pub fn input_reshaped(&mut self, tensor: &TensorIr, output: &TensorIr) -> Option<FuseArg> {
if matches!(tensor.dtype, DType::QFloat(_)) {
return None;
}
self.resources.outputs.update(tensor);
self.block_current
.input_reshaped(tensor, output, &mut self.resources)
}
/// Register a scalar value.
pub fn scalar(&mut self, elem: &ScalarIr, dtype: DType) -> FuseArg {
let precision = dtype.into();
let id = if let ScalarIr::UInt(value) = elem {
ScalarId { value: *value }
} else {
unreachable!() // should always be u64
};
// Bool scalars are encoded as bool_precision.
let precision = match precision {
FuseType::Bool => self.bool_precision,
_ => precision,
};
let new_index = self.resources.scalars.len();
self.resources.scalars.push((precision, id.value));
FuseArg::Scalar(new_index, precision)
}
/// Finish fusing and returns the created trace.
pub fn finish(&mut self, shape_ref: Shape) -> FuseTrace {
let mut resources = self.resources.clone();
let mut outputs = RegisteredTensors::default();
let mut buffers = Vec::new();
for tensor in resources.buffers.iter() {
let (tensor, ty) = tensor.as_normal_tensor().unwrap();
outputs.insert(*ty, tensor.clone());
}
let mut blocks = Vec::new();
let mut register_block = |block: &FuseBlockBuilder| {
let block = block.build(&self.resources, &mut outputs, &mut buffers);
blocks.push(block);
};
for block in self.blocks_previous.iter() {
register_block(block);
}
self.block_current.shape_ref = shape_ref;
register_block(&self.block_current);
// We update the output tensors registered to be the ones that are written to in global
// memory.
resources.outputs = outputs;
FuseTrace { blocks, resources }
}
}

View File

@@ -0,0 +1,7 @@
pub(crate) mod block;
mod base;
mod fuser;
pub use base::*;
pub use fuser::*;

View File

@@ -0,0 +1,11 @@
#[macro_use]
extern crate derive_new;
pub mod optim;
mod base;
pub(crate) mod engine;
pub(crate) mod tune;
pub use base::*;

View File

@@ -0,0 +1,63 @@
use crate::optim::{
elemwise::{ElemwiseOptimization, ElemwiseOptimizationState},
matmul::{MatmulOptimization, MatmulOptimizationState},
reduce::{ReduceOptimization, ReduceOptimizationState},
reduce_broadcasted::{ReduceBroadcastedOptimization, ReduceBroadcastedOptimizationState},
};
use cubecl::Runtime;
use serde::{Deserialize, Serialize};
/// Fusion optimization type for cubecl.
///
/// More optimization variants should be added here.
#[allow(clippy::large_enum_variant)]
pub enum CubeOptimization<R: Runtime> {
ElementWise(ElemwiseOptimization<R>),
Matmul(MatmulOptimization<R>),
Reduce(ReduceOptimization<R>),
ReduceBroadcasted(ReduceBroadcastedOptimization<R>),
}
impl<R: Runtime> core::fmt::Debug for CubeOptimization<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let value = self.to_opt_state();
f.write_fmt(format_args!("{value:?}"))
}
}
impl<R: Runtime> CubeOptimization<R> {
/// Serializes the current optimization to its state.
pub fn to_opt_state(&self) -> CubeOptimizationState {
match self {
Self::ElementWise(value) => CubeOptimizationState::ElementWise(value.to_state()),
Self::Matmul(value) => CubeOptimizationState::Matmul(value.to_state()),
Self::Reduce(value) => CubeOptimizationState::Reduce(value.to_state()),
Self::ReduceBroadcasted(value) => {
CubeOptimizationState::ReduceBroadcasted(value.to_state())
}
}
}
}
impl<R: Runtime> burn_fusion::NumOperations for CubeOptimization<R> {
fn len(&self) -> usize {
match self {
Self::ElementWise(op) => op.num_ops_fused(),
Self::Matmul(op) => op.num_ops_fused(),
Self::Reduce(op) => op.num_ops_fused(),
Self::ReduceBroadcasted(op) => op.num_ops_fused(),
}
}
}
/// Fusion optimization state type for cubecl.
///
/// More optimization variants should be added here.
#[allow(clippy::large_enum_variant)]
#[derive(Serialize, Deserialize, Debug)]
pub enum CubeOptimizationState {
ElementWise(ElemwiseOptimizationState),
Matmul(MatmulOptimizationState),
Reduce(ReduceOptimizationState),
ReduceBroadcasted(ReduceBroadcastedOptimizationState),
}

View File

@@ -0,0 +1,87 @@
use super::optimization::ElemwiseOptimization;
use crate::{
engine::{
codegen::ir::FuseType,
fuser::TraceOperationFuser,
settings::{FuseSettings, RefLayoutSetting, VectorizationSetting},
},
optim::CubeOptimization,
};
use burn_fusion::OperationFuser;
use burn_std::Shape;
use cubecl::Runtime;
/// Fuses element wise operations.
pub struct ElementWiseFuser<R: Runtime> {
fuser: TraceOperationFuser,
device: R::Device,
}
impl<R: Runtime> Clone for ElementWiseFuser<R> {
fn clone(&self) -> Self {
Self {
fuser: self.fuser.clone(),
device: self.device.clone(),
}
}
}
impl<R: Runtime> ElementWiseFuser<R> {
pub fn shape_id(&self) -> Shape {
self.fuser.current_output_shape.clone()
}
pub fn new(device: R::Device, bool_precision: FuseType) -> Self {
let client = R::client(&device);
let props = client.properties();
let max_bindings = props.hardware.max_bindings;
Self {
fuser: TraceOperationFuser::new(
max_bindings,
bool_precision,
FuseSettings {
broadcast: true,
output_shape_updates: true,
inplace: true,
vectorization: VectorizationSetting::Activated,
ref_layout: RefLayoutSetting::Any,
},
),
device,
}
}
}
impl<R: Runtime> OperationFuser<CubeOptimization<R>> for ElementWiseFuser<R> {
fn fuse(&mut self, operation: &burn_ir::OperationIr) {
self.fuser.fuse(operation);
}
fn finish(&mut self) -> CubeOptimization<R> {
let client = R::client(&self.device);
let trace = self.fuser.finish();
let elementwise = ElemwiseOptimization::new(trace, client, self.device.clone(), self.len());
CubeOptimization::ElementWise(elementwise)
}
fn reset(&mut self) {
self.fuser.reset()
}
fn status(&self) -> burn_fusion::FuserStatus {
self.fuser.status()
}
fn properties(&self) -> burn_fusion::FuserProperties {
self.fuser.properties()
}
fn len(&self) -> usize {
self.fuser.len()
}
fn clone_dyn(&self) -> Box<dyn OperationFuser<CubeOptimization<R>>> {
Box::new(self.clone())
}
}

View File

@@ -0,0 +1,5 @@
mod fuser;
mod optimization;
pub use fuser::*;
pub use optimization::*;

View File

@@ -0,0 +1,140 @@
use crate::{
CubeFusionHandle,
engine::{
codegen::{
io::ref_len,
ir::{
FuseArg, FuseBlockConfig, GlobalArgs, GlobalArgsLaunch, RefLayout,
multi_block_variables_init,
},
kernel::{fuse_on_write, init_locals},
},
launch::{
FuseTraceLauncher,
runner::{TraceRunner, Vectorization},
},
trace::FuseTrace,
},
};
use burn_fusion::stream::Context;
use cubecl::{CubeDim, calculate_cube_count_elemwise, client::ComputeClient, prelude::*};
use serde::{Deserialize, Serialize};
#[derive(new)]
/// Fuse element wise operations into a single kernel.
pub struct ElemwiseOptimization<R: Runtime> {
pub(crate) trace: FuseTrace,
client: ComputeClient<R>,
device: R::Device,
len: usize,
}
#[derive(Serialize, Deserialize, Debug)]
/// State for the [elemwise optimization](ElemwiseOptimization).
pub struct ElemwiseOptimizationState {
trace: FuseTrace,
len: usize,
}
impl<R: Runtime> ElemwiseOptimization<R> {
/// Execute the optimization.
pub fn execute<BT: CubeElement>(&self, context: &mut Context<'_, CubeFusionHandle<R>>) {
let launcher = FuseTraceLauncher::new(&self.trace, &ElemwiseRunner);
match launcher.launch::<BT>(&self.client, &self.device, context) {
Ok(_) => (),
Err(err) => {
panic!("{err:?} - {:?}", self.trace);
}
}
}
/// Number of element wise operations fused.
pub fn num_ops_fused(&self) -> usize {
self.len
}
/// Create an optimization from its [state](ElemwiseOptimizationState).
pub fn from_state(device: &R::Device, state: ElemwiseOptimizationState) -> Self {
Self {
trace: state.trace,
len: state.len,
client: R::client(device),
device: device.clone(),
}
}
/// Convert the optimization to its [state](ElemwiseOptimizationState).
pub fn to_state(&self) -> ElemwiseOptimizationState {
ElemwiseOptimizationState {
trace: self.trace.clone(),
len: self.len,
}
}
}
pub struct ElemwiseRunner;
impl<R: Runtime> Vectorization<R> for ElemwiseRunner {}
impl<R: Runtime> TraceRunner<R> for ElemwiseRunner {
type Error = LaunchError; // No error possible
fn run<'a>(
&'a self,
client: &'a ComputeClient<R>,
inputs: GlobalArgsLaunch<'a, R>,
outputs: GlobalArgsLaunch<'a, R>,
configs: &[FuseBlockConfig],
) -> Result<(), Self::Error> {
let config = &configs[0];
let shape = match &config.ref_layout {
RefLayout::Concrete(arg) => match arg {
FuseArg::Input(..) => inputs.shape_ref(&config.ref_layout, config.rank),
FuseArg::Output(..) => outputs.shape_ref(&config.ref_layout, config.rank),
_ => panic!("Invalid concreate ref layout"),
},
RefLayout::Virtual(_) => inputs.shape_ref(&config.ref_layout, config.rank),
};
let working_units = shape.iter().product::<usize>() / config.width;
let cube_dim = CubeDim::new(client, working_units);
let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim);
let address_type = inputs
.required_address_type()
.max(outputs.required_address_type());
unsafe {
elemwise_fuse::launch_unchecked(
client,
cube_count,
cube_dim,
address_type,
inputs,
outputs,
config.clone(),
)?;
};
Ok(())
}
}
#[cube(launch_unchecked, address_type = "dynamic")]
fn elemwise_fuse(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
#[comptime] config: &FuseBlockConfig,
) {
// We write no values for this fusion.
let values = Registry::<FuseArg, Line<f32>>::new();
let args = comptime![Vec::<FuseArg>::new()];
let pos = ABSOLUTE_POS;
multi_block_variables_init(config, &mut outputs.variables);
let mut locals = init_locals(inputs, outputs, config);
let length = ref_len(inputs, outputs, &locals, config);
if pos < length {
fuse_on_write::<f32>(inputs, outputs, &mut locals, pos, values, args, config)
}
}

View File

@@ -0,0 +1,548 @@
use crate::engine::codegen::{
io::ref_line_size,
ir::{FuseArg, FuseBlockConfig, FuseType, GlobalArgs, LocalArgs, multi_block_variables_init},
kernel::init_locals,
view::{FusedOutput, GlobalInput, GlobalInputExpand},
};
use cubecl::{
intrinsic,
prelude::*,
quant::scheme::{QuantLevel, QuantScheme},
std::{
FastDivmod,
quant::{
RunWithQuantType,
view::{QuantizedView, run_with_quant_type},
},
tensor::{
View, ViewExpand,
layout::{Coords1d, Coords2d, VirtualLayout},
},
},
};
use cubek::matmul::{
components::global::memory::{
BatchLayout, BlockScaledLayout, GlobalLayout, GlobalLayoutConfig, GlobalLayoutExpand,
GlobalScaleLayout, GlobalScaleLayoutExpand, NoopLayout,
},
definition::MatrixLayout,
launch::{BatchedCoords, MatmulArgs},
};
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
#[derive(Clone)]
pub struct FusedMatmulArgs;
#[derive(CubeLaunch, CubeType)]
pub struct FusedMatmulInput {
global: GlobalArgs,
#[cube(comptime)]
config: FuseBlockConfig,
#[cube(comptime)]
a: MatmulArg,
#[cube(comptime)]
b: MatmulArg,
#[cube(comptime)]
c: Option<MatmulArg>,
#[cube(comptime)]
out: FuseArg,
}
#[cube]
impl MatmulArgs for FusedMatmulArgs {
type Output<EO: Numeric> = GlobalArgs;
type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = FusedMatmulInput;
type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = FusedMatmulState;
type Config = ();
fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
inputs: &Self::Input<Lhs, Rhs, EO>,
outputs: &mut Self::Output<EO>,
_config: (),
#[comptime] lhs_layout_config: GlobalLayoutConfig,
#[comptime] rhs_layout_config: GlobalLayoutConfig,
#[comptime] out_layout_config: GlobalLayoutConfig,
) -> Self::State<Lhs, Rhs, EO> {
multi_block_variables_init(&inputs.config, &mut outputs.variables);
let mut locals = init_locals(&inputs.global, outputs, &inputs.config);
let rank = comptime![inputs.config.rank];
let mut batch_shape = Sequence::new();
let mut batch_strides_out = Sequence::new();
#[unroll]
for i in 0..rank - 2 {
batch_shape.push(FastDivmod::new_Fallback(locals.ref_shape[i] as u32));
batch_strides_out.push(locals.ref_strides[i]);
}
let batch_lhs = input_batch_layout(
&inputs.global,
&batch_shape,
comptime![inputs.a.clone()],
comptime![inputs.config.clone()],
);
let batch_rhs = input_batch_layout(
&inputs.global,
&batch_shape,
comptime![inputs.b.clone()],
comptime![inputs.config.clone()],
);
let batch_acc = match comptime![inputs.c.clone()] {
Some(c) => Option::Some(input_batch_layout(
&inputs.global,
&batch_shape,
comptime![c],
comptime![inputs.config.clone()],
)),
None => Option::new_None(),
};
let batch_out = BatchLayout::new(batch_strides_out, batch_shape.clone());
FusedMatmulState::new(
inputs,
outputs,
&mut locals,
batch_lhs,
batch_rhs,
batch_acc,
VirtualLayout::new::<BatchLayout>(batch_out),
batch_shape,
&inputs.config,
lhs_layout_config,
rhs_layout_config,
out_layout_config,
)
}
fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
state: &Self::State<Lhs, Rhs, EO>,
) -> View<Line<Lhs>, BatchedCoords> {
global_view(
&state.inputs,
&state.locals,
&state.batch_shape,
comptime![state.a.clone()],
comptime![state.config.clone()],
state.lhs_layout_config,
)
}
fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
state: &Self::State<Lhs, Rhs, EO>,
batch: usize,
) -> usize {
state.a_batch.to_source_pos(batch)
}
fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
state: &Self::State<Lhs, Rhs, EO>,
) -> View<Line<Rhs>, BatchedCoords> {
global_view(
&state.inputs,
&state.locals,
&state.batch_shape,
comptime![state.b.clone()],
comptime![state.config.clone()],
comptime![state.rhs_layout_config],
)
}
fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
state: &Self::State<Lhs, Rhs, EO>,
batch: usize,
) -> usize {
state.b_batch.to_source_pos(batch)
}
fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
state: &Self::State<Lhs, Rhs, EO>,
) -> Option<View<Line<EO>, BatchedCoords>> {
match comptime![state.c.clone()] {
Some(c) => {
let view = global_view(
&state.inputs,
&state.locals,
&state.batch_shape,
c,
comptime![state.config.clone()],
comptime![state.out_layout_config],
);
Option::Some(view)
}
None => Option::new_None(),
}
}
fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
state: &Self::State<Lhs, Rhs, EO>,
batch: usize,
) -> usize {
match state.c_batch {
Some(c_batch) => c_batch.to_source_pos(batch),
None => batch,
}
}
fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
state: &mut Self::State<Lhs, Rhs, EO>,
) -> View<Line<EO>, BatchedCoords, ReadWrite> {
let rank = comptime![state.config.rank];
let shape_row = state.locals.ref_shape[rank - 2] as u32;
let shape_col = state.locals.ref_shape[rank - 1] as u32;
let stride_row = state.locals.ref_strides[rank - 2];
let stride_col = state.locals.ref_strides[rank - 1];
let layout = GlobalLayout::new(
VirtualLayout::new::<NoopLayout>(NoopLayout::new()),
shape_row,
shape_col,
stride_row,
stride_col,
ref_line_size(&state.locals),
1u32,
state.out_layout_config,
);
let mut buffer = FusedOutput::new(
&state.inputs,
&mut state.outputs,
&mut state.locals,
comptime![state.out.clone()],
comptime![state.config.clone()],
);
View::new_mut::<FusedOutput, Coords1d>(&mut buffer, layout)
}
fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
state: &Self::State<Lhs, Rhs, EO>,
batch: usize,
) -> usize {
state.out_batch.to_source_pos(batch)
}
fn runtime_config<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(_state: &Self::State<Lhs, Rhs, EO>) {
}
}
#[cube]
fn global_view<E: Numeric>(
inputs: &GlobalArgs,
locals: &LocalArgs,
batch_shape: &Sequence<FastDivmod<u32>>,
#[comptime] arg: MatmulArg,
#[comptime] config: FuseBlockConfig,
#[comptime] layout_config: GlobalLayoutConfig,
) -> View<Line<E>, BatchedCoords> {
let rank = comptime![config.rank];
let data = comptime![arg.data().clone()];
let data_tensor = match comptime![data.clone()] {
FuseArg::Input(pos, ..) => inputs.tensors.index(pos),
_ => panic!("Input must be concrete"),
};
let mut shape_row = data_tensor.tensor.shape(rank - 2) as u32;
let mut shape_col = data_tensor.tensor.shape(rank - 1) as u32;
let mut packing = comptime![1];
if arg.scheme().is_some() {
let scheme = arg.scheme().unwrap();
let num_quants = scheme.num_quants() as u32;
comptime![packing = num_quants];
match comptime![layout_config.matrix_layout] {
MatrixLayout::RowMajor => shape_col *= num_quants,
MatrixLayout::ColMajor => shape_row *= num_quants,
};
}
let shape = (shape_row, shape_col);
// Noop for normal inputs because batch offset is cached, quantized uses logical batches
let batch_layout = match comptime![arg.clone()] {
MatmulArg::Normal(_) => VirtualLayout::new::<NoopLayout>(NoopLayout::new()),
MatmulArg::Quantized { data, .. } => {
let data_arg = comptime![MatmulArg::Normal(data)];
input_batch_layout(inputs, batch_shape, data_arg, comptime![config.clone()])
}
};
let data_layout = global_layout(
inputs,
shape,
batch_layout,
arg.data().clone(),
config.clone(),
data_tensor.tensor.line_size(),
layout_config,
packing,
);
let data_buf = GlobalInput::new(inputs, locals, data, comptime![config.clone()], None);
match comptime![arg.clone()] {
MatmulArg::Normal(_) => View::new::<GlobalInput, Coords1d>(&data_buf, data_layout),
MatmulArg::Quantized { scales, scheme, .. } => {
let scales_layout = match comptime![scheme.level] {
QuantLevel::Tensor => GlobalScaleLayout::new_PerTensor(shape),
QuantLevel::Block(block_size) => {
let block_size = comptime![block_size.as_dim::<2>()];
let scales_arg = comptime![MatmulArg::Normal(scales.clone())];
let batch_layout = input_batch_layout(
inputs,
batch_shape,
scales_arg,
comptime![config.clone()],
);
let scales_layout = global_layout(
inputs,
shape,
batch_layout,
comptime![scales.clone()],
comptime![config.clone()],
1usize,
layout_config,
1u32,
);
GlobalScaleLayout::new_BlockScaled(BlockScaledLayout::new(
shape,
scales_layout,
comptime![(block_size[0] as u32, block_size[1] as u32)],
))
}
};
let scales_buf = GlobalInput::new(inputs, locals, scales, config, None);
create_quant_view_dynamic(data_buf, data_layout, scales_buf, scales_layout, scheme)
}
}
}
#[cube]
fn input_batch_layout(
inputs: &GlobalArgs,
batch_shape: &Sequence<FastDivmod<u32>>,
#[comptime] arg: MatmulArg,
#[comptime] config: FuseBlockConfig,
) -> VirtualLayout<usize, usize> {
let rank = comptime![config.rank];
match comptime![arg.clone()] {
MatmulArg::Normal(arg) => {
let data_tensor = match comptime![arg.clone()] {
FuseArg::Input(pos, ..) => inputs.tensors.index(pos),
_ => panic!("Input must be concrete"),
};
let mut batch_strides = Sequence::new();
#[unroll]
for i in 0..rank - 2 {
let shape = data_tensor.tensor.shape(i);
let stride = select(shape == 1, 0, data_tensor.tensor.stride(i));
batch_strides.push(stride);
}
VirtualLayout::new::<BatchLayout>(BatchLayout::new(batch_strides, batch_shape.clone()))
}
MatmulArg::Quantized { .. } => VirtualLayout::new::<NoopLayout>(NoopLayout::new()),
}
}
#[cube]
fn global_layout(
inputs: &GlobalArgs,
shape: Coords2d,
batch_layout: VirtualLayout<usize, usize>,
#[comptime] arg: FuseArg,
#[comptime] config: FuseBlockConfig,
#[comptime] line_size: LineSize,
#[comptime] layout_config: GlobalLayoutConfig,
#[comptime] packing: u32,
) -> GlobalLayout {
let rank = comptime![config.rank];
let data_tensor = match comptime![arg.clone()] {
FuseArg::Input(pos, ..) => inputs.tensors.index(pos),
_ => panic!("Input must be concrete"),
};
let (shape_row, shape_col) = shape;
let stride_row = data_tensor.tensor.stride(rank - 2);
let stride_col = data_tensor.tensor.stride(rank - 1);
GlobalLayout::new(
batch_layout,
shape_row,
shape_col,
stride_row,
stride_col,
line_size,
packing,
layout_config,
)
}
struct CreateQuantView<'a, E: Numeric> {
scope: &'a mut Scope,
data_buf: GlobalInputExpand,
data_layout: GlobalLayoutExpand,
scales_buf: GlobalInputExpand,
scales_layout: GlobalScaleLayoutExpand,
scheme: QuantScheme,
_ty: PhantomData<E>,
}
impl<'a, E: Numeric> RunWithQuantType for CreateQuantView<'a, E> {
type Output = ViewExpand<Line<E>, BatchedCoords>;
fn execute<Q: CubePrimitive, S: CubePrimitive>(self) -> Self::Output {
create_quant_view::expand::<E, Q, S>(
self.scope,
self.data_buf,
self.data_layout,
self.scales_buf,
self.scales_layout,
self.scheme,
)
}
}
#[cube]
#[allow(unused)]
fn create_quant_view_dynamic<E: Numeric>(
data_buf: GlobalInput,
data_layout: GlobalLayout,
scales_buf: GlobalInput,
scales_layout: GlobalScaleLayout,
#[comptime] scheme: QuantScheme,
) -> View<Line<E>, BatchedCoords> {
intrinsic!(|scope| {
let func = CreateQuantView {
scope,
data_buf,
data_layout,
scales_buf,
scales_layout,
scheme,
_ty: PhantomData,
};
run_with_quant_type(func, scheme)
})
}
#[cube]
fn create_quant_view<E: Numeric, Q: CubePrimitive, S: CubePrimitive>(
data_buf: GlobalInput,
data_layout: GlobalLayout,
scales_buf: GlobalInput,
scales_layout: GlobalScaleLayout,
#[comptime] scheme: QuantScheme,
) -> View<Line<E>, BatchedCoords> {
let data_view: View<Line<Q>, BatchedCoords> =
View::new::<GlobalInput, Coords1d>(&data_buf, data_layout);
let scales_view: View<S, BatchedCoords> =
View::new::<GlobalInput, Coords1d>(&scales_buf, scales_layout);
QuantizedView::new(data_view, scales_view, scheme).view()
}
#[derive(CubeType)]
pub struct FusedMatmulState {
inputs: GlobalArgs,
outputs: GlobalArgs,
locals: LocalArgs,
a_batch: VirtualLayout<Coords1d, Coords1d>,
b_batch: VirtualLayout<Coords1d, Coords1d>,
c_batch: Option<VirtualLayout<Coords1d, Coords1d>>,
out_batch: VirtualLayout<Coords1d, Coords1d>,
#[cube(comptime)]
config: FuseBlockConfig,
#[cube(comptime)]
a: MatmulArg,
#[cube(comptime)]
b: MatmulArg,
#[cube(comptime)]
c: Option<MatmulArg>,
#[cube(comptime)]
out: FuseArg,
#[cube(comptime)]
lhs_layout_config: GlobalLayoutConfig,
#[cube(comptime)]
rhs_layout_config: GlobalLayoutConfig,
#[cube(comptime)]
out_layout_config: GlobalLayoutConfig,
batch_shape: Sequence<FastDivmod<u32>>,
}
#[cube]
impl FusedMatmulState {
#[allow(clippy::too_many_arguments)]
pub fn new(
inputs: &FusedMatmulInput,
outputs: &mut GlobalArgs,
locals: &mut LocalArgs,
a_batch: VirtualLayout<usize, usize>,
b_batch: VirtualLayout<usize, usize>,
c_batch: Option<VirtualLayout<usize, usize>>,
out_batch: VirtualLayout<usize, usize>,
batch_shape: Sequence<FastDivmod<u32>>,
#[comptime] config: &FuseBlockConfig,
#[comptime] lhs_layout_config: GlobalLayoutConfig,
#[comptime] rhs_layout_config: GlobalLayoutConfig,
#[comptime] out_layout_config: GlobalLayoutConfig,
) -> FusedMatmulState {
FusedMatmulState {
inputs: inputs.global.clone(),
outputs: outputs.clone(),
config: comptime![config.clone()],
locals: locals.clone(),
a_batch,
b_batch,
c_batch,
out_batch,
a: comptime![inputs.a.clone()],
b: comptime![inputs.b.clone()],
c: comptime![inputs.c.clone()],
out: comptime![inputs.out.clone()],
lhs_layout_config,
rhs_layout_config,
out_layout_config,
batch_shape,
}
}
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
/// Argument to a matmul operation.
pub enum MatmulArg {
Normal(FuseArg),
Quantized {
data: FuseArg,
scales: FuseArg,
precision: FuseType,
scheme: QuantScheme,
},
}
impl MatmulArg {
pub fn data(&self) -> &FuseArg {
match self {
MatmulArg::Normal(arg) => arg,
MatmulArg::Quantized { data, .. } => data,
}
}
pub fn scheme(&self) -> Option<&QuantScheme> {
match self {
MatmulArg::Normal(_) => None,
MatmulArg::Quantized { scheme, .. } => Some(scheme),
}
}
pub fn precision(&self) -> FuseType {
match self {
MatmulArg::Normal(arg) => arg.precision(),
MatmulArg::Quantized { precision, .. } => *precision,
}
}
}

View File

@@ -0,0 +1,160 @@
use super::optimization::{FusedMatmul, MatmulOptimization};
use crate::{
engine::{codegen::ir::FuseType, fuser::TraceOperationFuser, settings::FuseSettings},
optim::CubeOptimization,
optim::matmul::args::MatmulArg,
};
use burn_fusion::{FuserStatus, OperationFuser};
use burn_ir::{FloatOperationIr, OperationIr};
use burn_std::DType;
use cubecl::Runtime;
/// Fused element wise operations that are normally memory bound.
pub struct MatmulFuser<R: Runtime> {
fuser: TraceOperationFuser,
fuser_fallback: TraceOperationFuser,
device: R::Device,
matmul: Option<FusedMatmul>,
}
impl<R: Runtime> Clone for MatmulFuser<R> {
fn clone(&self) -> Self {
Self {
fuser: self.fuser.clone(),
fuser_fallback: self.fuser_fallback.clone(),
device: self.device.clone(),
matmul: self.matmul.clone(),
}
}
}
impl<R: Runtime> MatmulFuser<R> {
pub fn new(device: R::Device, bool_precision: FuseType) -> Self {
let client = R::client(&device);
let props = client.properties();
let max_bindings = props.hardware.max_bindings;
let settings_matmul = FuseSettings {
output_shape_updates: false,
..Default::default()
};
let settings_fallback = FuseSettings::default();
Self {
fuser: TraceOperationFuser::new(max_bindings, bool_precision, settings_matmul),
fuser_fallback: TraceOperationFuser::new(
max_bindings,
bool_precision,
settings_fallback,
),
device,
matmul: None,
}
}
}
impl<R: Runtime> OperationFuser<CubeOptimization<R>> for MatmulFuser<R> {
fn fuse(&mut self, operation: &OperationIr) {
if let FuserStatus::Closed = self.fuser.status() {
return;
}
if self.matmul.is_none() {
if let OperationIr::Float(_, FloatOperationIr::Matmul(op)) = operation {
// Precision shouldn't be hardcoded but I don't know how to get float precision of the backend
let lhs = match op.lhs.dtype {
DType::QFloat(scheme) => {
let (data, scales) = self.fuser.input_quantized_unhandled(&op.lhs).unwrap();
MatmulArg::Quantized {
data,
scales,
precision: op.out.dtype.into(),
scheme,
}
}
_ => MatmulArg::Normal(self.fuser.input_unhandled(&op.lhs)),
};
let rhs = match op.rhs.dtype {
DType::QFloat(scheme) => {
let (data, scales) = self.fuser.input_quantized_unhandled(&op.rhs).unwrap();
MatmulArg::Quantized {
data,
scales,
precision: op.out.dtype.into(),
scheme,
}
}
_ => MatmulArg::Normal(self.fuser.input_unhandled(&op.rhs)),
};
let out = self.fuser.output_unhandled(&op.out);
self.matmul = Some(FusedMatmul::new(
lhs,
rhs,
out,
op.clone().into(),
Default::default(),
));
} else {
self.fuser.close();
self.fuser_fallback.close();
}
} else {
let can_register =
self.fuser.can_fuse(operation) && self.fuser_fallback.can_fuse(operation);
match can_register {
true => {
self.fuser.fuse(operation);
self.fuser_fallback.fuse(operation);
}
false => {
self.fuser.close();
self.fuser_fallback.close();
}
};
}
}
fn finish(&mut self) -> CubeOptimization<R> {
let client = R::client(&self.device);
let trace = self.fuser.finish();
let trace_fallback = self.fuser_fallback.finish();
let matmul = MatmulOptimization::new(
trace,
trace_fallback,
client,
self.device.clone(),
self.len(),
self.matmul.as_ref().unwrap().clone(),
);
CubeOptimization::Matmul(matmul)
}
fn reset(&mut self) {
self.fuser.reset();
self.fuser_fallback.reset();
self.matmul = None;
}
fn status(&self) -> burn_fusion::FuserStatus {
self.fuser.status()
}
fn properties(&self) -> burn_fusion::FuserProperties {
let mut properties = self.fuser.properties();
properties.score += 1;
properties
}
fn len(&self) -> usize {
// Matmul operation isn't registered in the fuser
self.fuser.len() + 1
}
fn clone_dyn(&self) -> Box<dyn OperationFuser<CubeOptimization<R>>> {
Box::new(self.clone())
}
}

View File

@@ -0,0 +1,8 @@
mod fuser;
mod optimization;
pub(crate) mod args;
pub(crate) mod tune;
pub use fuser::*;
pub use optimization::*;

View File

@@ -0,0 +1,649 @@
use super::args::FusedMatmulInputLaunch;
#[cfg(feature = "autotune")]
use super::tune::fused_matmul_autotune;
use crate::{
CubeFusionHandle, FallbackOperation,
engine::{
codegen::ir::{FuseArg, FuseBlockConfig, FuseType, GlobalArgsLaunch, RefLayout},
launch::{
FuseTraceLauncher, HandleInput, LaunchPlan,
runner::{TraceRunner, Vectorization, VectorizationAxis},
},
trace::{FuseTrace, TraceError, TuneOutput},
},
optim::{
elemwise::ElemwiseRunner,
matmul::args::{FusedMatmulArgs, MatmulArg},
},
};
use burn_fusion::stream::Context;
use burn_ir::BinaryOpIr;
use cubecl::{
client::ComputeClient,
prelude::*,
std::tensor::{MatrixBatchLayout, matrix_batch_layout},
};
use cubek::matmul::{
components::tile::{cmma::CmmaMatmul, mma::MmaMatmul},
definition::{
MatmulElems, MatmulGlobalElems, MatmulLineSizes, MatmulProblem, MatmulSetupError,
MatrixLayout,
},
launch::launch_kernel_virtual,
routines::{
BlueprintStrategy, Routine,
double_buffering::{CyclicDoubleBufferingAlgorithm, DoubleBufferingArgs},
double_unit::DoubleUnitAlgorithm,
ordered_double_buffering::{OrderedDoubleBufferingAlgorithm, OrderedSelectionArgs},
simple::{SimpleAlgorithm, SimpleArgs},
simple_unit::SimpleUnitAlgorithm,
vecmat::{DoubleVecMatAlgorithm, SimpleVecMatAlgorithm},
},
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
/// Fuse matmul operation followed by elemwise operations into a single kernel.
pub struct MatmulOptimization<R: Runtime> {
pub(crate) info: Arc<MatmulOptimizationInfo<R>>,
}
pub struct MatmulOptimizationTuneArg<R: Runtime> {
pub(crate) info: Arc<MatmulOptimizationInfo<R>>,
pub(crate) fallback: Box<dyn FallbackOperation<R>>,
}
pub(crate) struct MatmulOptimizationInfo<R: Runtime> {
trace: FuseTrace,
trace_fallback: FuseTrace,
pub(crate) client: ComputeClient<R>,
pub(crate) device: R::Device,
pub(crate) len: usize,
pub(crate) matmul: FusedMatmul,
}
#[derive(Serialize, Deserialize, Debug)]
/// State for the [matrix optimization](MatmulOptimizationState).
pub struct MatmulOptimizationState {
trace: FuseTrace,
trace_fallback: FuseTrace,
matmul: FusedMatmul,
len: usize,
}
impl<R: Runtime> MatmulOptimizationInfo<R> {
/// Returns the number of output buffers added by fusion.
pub fn num_output_buffers(&self) -> usize {
self.trace_fallback.resources.outputs.len()
}
/// Number of operations fused.
pub fn num_ops_fused(&self) -> usize {
self.len
}
}
impl<R: Runtime> MatmulOptimizationTuneArg<R> {
pub(crate) fn execute_fused<BT: CubeElement>(
&self,
context: &mut Context<'_, CubeFusionHandle<R>>,
selector: FusedMatmulSelector,
) -> Result<TuneOutput<R>, TraceError<FusedMatmulError>> {
let launch = FusedMatmulLaunch::new(&self.info.matmul, selector);
let launcher = FuseTraceLauncher::new(&self.info.trace, &launch);
launcher.launch::<BT>(&self.info.client, &self.info.device, context)
}
pub fn execute_fallback<BT: CubeElement>(
&self,
context: &mut Context<'_, CubeFusionHandle<R>>,
) -> TuneOutput<R> {
self.fallback.run(context);
#[cfg(feature = "autotune-checks")]
let mut output = TuneOutput::Checked {
handles: Default::default(),
};
#[cfg(not(feature = "autotune-checks"))]
let output = TuneOutput::UnChecked(core::marker::PhantomData);
#[cfg(feature = "autotune-checks")]
if let TuneOutput::Checked { handles } = &mut output {
let out_desc = context.tensors.get(&self.info.matmul.op.out.id).unwrap();
let handle_out = context
.handles
.get_handle(&out_desc.id, &burn_ir::TensorStatus::ReadOnly);
handles.insert(
self.info.matmul.op.out.id,
(out_desc.shape.dims.clone(), handle_out.clone()),
);
}
let launcher = FuseTraceLauncher::new(&self.info.trace_fallback, &ElemwiseRunner);
let output_write = launcher
.launch::<BT>(&self.info.client, &self.info.device, context)
.unwrap();
output.merge(output_write)
}
}
impl<R: Runtime> MatmulOptimization<R> {
pub fn new(
trace: FuseTrace,
trace_fallback: FuseTrace,
client: ComputeClient<R>,
device: R::Device,
len: usize,
matmul: FusedMatmul,
) -> Self {
let info = MatmulOptimizationInfo {
trace,
trace_fallback,
client,
device,
len,
matmul,
};
Self {
info: Arc::new(info),
}
}
/// Execute the optimization.
pub fn execute<BT: CubeElement>(
&mut self,
context: &mut Context<'_, CubeFusionHandle<R>>,
fallback: impl FnOnce(usize) -> Box<dyn FallbackOperation<R>>,
) {
// The index of the fallback matmul is always 0.
let fallback = fallback(0);
let arg = MatmulOptimizationTuneArg {
info: self.info.clone(),
fallback,
};
#[cfg(feature = "autotune")]
fused_matmul_autotune::<R, BT>(arg, context);
#[cfg(not(feature = "autotune"))]
if arg
.execute_fused::<BT>(context, FusedMatmulSelector::default())
.is_err()
{
arg.execute_fallback::<BT>(context);
}
}
/// Number of operations fused.
pub fn num_ops_fused(&self) -> usize {
self.info.num_ops_fused()
}
/// Create an optimization from its [state](MatmulOptimizationState).
pub fn from_state(device: &R::Device, state: MatmulOptimizationState) -> Self {
let info = MatmulOptimizationInfo {
trace: state.trace,
trace_fallback: state.trace_fallback,
len: state.len,
client: R::client(device),
device: device.clone(),
matmul: state.matmul.clone(),
};
Self {
info: Arc::new(info),
}
}
/// Convert the optimization to its [state](MatmulOptimizationState).
pub fn to_state(&self) -> MatmulOptimizationState {
MatmulOptimizationState {
trace: self.info.trace.clone(),
trace_fallback: self.info.trace_fallback.clone(),
matmul: self.info.matmul.clone(),
len: self.info.len,
}
}
}
#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
pub enum FusedMatmulSelector {
Simple {
multi_rows: bool,
tile_matmul: AcceleratedTileKind,
},
DoubleBuffering {
specialized: bool,
tile_matmul: AcceleratedTileKind,
},
OrderedDoubleBuffering {
tile_matmul: AcceleratedTileKind,
},
SimpleVecMat,
DoubleVecMat,
SimpleUnit,
DoubleUnit,
}
impl FusedMatmulSelector {
/// Not efficient, but only called once when initializing the tunables.
pub fn name(&self) -> String {
let name = match self {
FusedMatmulSelector::Simple {
multi_rows,
tile_matmul,
} => match multi_rows {
false => format!("simple_{tile_matmul:?}"),
true => format!("simple_multirows_{tile_matmul:?}"),
},
FusedMatmulSelector::DoubleBuffering {
specialized,
tile_matmul,
} => match specialized {
false => format!("double_buffering_{tile_matmul:?}"),
true => format!("double_buffering_specialized_{tile_matmul:?}"),
},
FusedMatmulSelector::OrderedDoubleBuffering { tile_matmul } => {
format!("double_buffering_ordered_{tile_matmul:?}").to_lowercase()
}
FusedMatmulSelector::SimpleVecMat => "simple_vec_mat".into(),
FusedMatmulSelector::DoubleVecMat => "double_buffering_vec_mat".into(),
FusedMatmulSelector::SimpleUnit => "simple_unit".into(),
FusedMatmulSelector::DoubleUnit => "double_buffering_unit".into(),
};
format!("fused_{name}")
}
}
impl Default for FusedMatmulSelector {
fn default() -> Self {
FusedMatmulSelector::Simple {
multi_rows: false,
tile_matmul: AcceleratedTileKind::Cmma,
}
}
}
#[derive(new, Clone, Serialize, Deserialize, Debug)]
pub struct FusedMatmul {
pub(crate) lhs: MatmulArg,
pub(crate) rhs: MatmulArg,
out: FuseArg,
pub(crate) op: BinaryOpIr,
pub(crate) selector: FusedMatmulSelector,
}
#[derive(new)]
pub struct FusedMatmulLaunch<'a> {
pub(crate) matmul: &'a FusedMatmul,
pub(crate) selector: FusedMatmulSelector,
}
#[derive(Debug)]
pub enum FusedMatmulError {
LaunchError(MatmulSetupError),
InvalidInput(&'static str),
}
impl From<MatmulSetupError> for FusedMatmulError {
fn from(value: MatmulSetupError) -> Self {
Self::LaunchError(value)
}
}
impl<'a, R: Runtime> Vectorization<R> for FusedMatmulLaunch<'a> {
fn axis(&self, plan: &LaunchPlan<'_, R>) -> VectorizationAxis {
let lhs_id = self.matmul.op.lhs.id;
let rhs_id = self.matmul.op.rhs.id;
let mut tensor_lhs = None;
let mut tensor_rhs = None;
for input in plan.handle_inputs.iter() {
match input {
HandleInput::Normal(input) => {
if input.relative_id == lhs_id {
tensor_lhs = Some((input.global_ir.id, &input.handle.strides));
}
if input.relative_id == rhs_id {
tensor_rhs = Some((input.global_ir.id, &input.handle.strides));
}
}
HandleInput::QuantValues(input) => {
if input.relative_id == lhs_id {
tensor_lhs = Some((input.global_ir.id, &input.handle.strides));
}
if input.relative_id == rhs_id {
tensor_rhs = Some((input.global_ir.id, &input.handle.strides));
}
}
HandleInput::QuantParams(_) => {}
}
}
let (lhs_id_global, lhs_strides) = tensor_lhs.unwrap();
let (rhs_id_global, rhs_strides) = tensor_rhs.unwrap();
let mut axis = VectorizationAxis::default();
if let MatrixBatchLayout::MildlyPermuted { transposed, .. } =
matrix_batch_layout(lhs_strides, self.matmul.lhs.scheme())
&& transposed
{
axis.insert(lhs_id_global, lhs_strides.len() - 2);
}
if let MatrixBatchLayout::MildlyPermuted { transposed, .. } =
matrix_batch_layout(rhs_strides, self.matmul.rhs.scheme())
&& transposed
{
axis.insert(rhs_id_global, rhs_strides.len() - 2);
}
axis
}
}
impl<R: Runtime> TraceRunner<R> for FusedMatmulLaunch<'_> {
type Error = FusedMatmulError;
fn run<'a>(
&'a self,
client: &'a ComputeClient<R>,
inputs: GlobalArgsLaunch<'a, R>,
outputs: GlobalArgsLaunch<'a, R>,
configs: &'a [FuseBlockConfig],
) -> Result<(), FusedMatmulError> {
let global_elems = MatmulGlobalElems {
lhs: self.matmul.lhs.precision().into_type(),
rhs: self.matmul.rhs.precision().into_type(),
out: self.matmul.out.precision().into_type(),
};
let dtypes = MatmulElems::from_globals(&global_elems);
self.matmul_fused(client, inputs, outputs, &configs[0], dtypes)
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
/// Which tile matmul to use for accelerated algorithms
pub enum AcceleratedTileKind {
#[default]
Cmma,
Mma,
}
macro_rules! with_tile_kind {
($kind: expr, $T: ident, $launch: expr) => {
match $kind {
AcceleratedTileKind::Cmma => {
type $T = CmmaMatmul;
($launch)()
}
AcceleratedTileKind::Mma => {
type $T = MmaMatmul;
($launch)()
}
}
};
}
impl FusedMatmulLaunch<'_> {
fn matmul_fused<'a, R: Runtime>(
&'a self,
client: &'a ComputeClient<R>,
inputs: GlobalArgsLaunch<'a, R>,
outputs: GlobalArgsLaunch<'a, R>,
config: &'a FuseBlockConfig,
dtypes: MatmulElems,
) -> Result<(), FusedMatmulError> {
let lhs_shape = inputs.shape(self.matmul.lhs.data());
let rhs_shape = inputs.shape(self.matmul.rhs.data());
let out_shape = outputs.shape_ref(&config.ref_layout, config.rank);
let lhs_strides = inputs.strides(self.matmul.lhs.data());
let lhs_scheme = self.matmul.lhs.scheme();
let rhs_strides = inputs.strides(self.matmul.rhs.data());
let rhs_scheme = self.matmul.rhs.scheme();
if matrix_batch_layout(&lhs_strides, lhs_scheme) == MatrixBatchLayout::HighlyPermuted {
return Err(FusedMatmulError::InvalidInput(
"Lhs needs to be contiguous, but can't when fusing.",
));
}
if matrix_batch_layout(&rhs_strides, rhs_scheme) == MatrixBatchLayout::HighlyPermuted {
return Err(FusedMatmulError::InvalidInput(
"Rhs needs to be contiguous, but can't when fusing.",
));
}
let mut line_sizes = MatmulLineSizes {
lhs: inputs.line_size(self.matmul.lhs.data()),
rhs: inputs.line_size(self.matmul.rhs.data()),
out: match &config.ref_layout {
RefLayout::Concrete(arg) => match arg {
FuseArg::Input(..) => inputs.line_size(arg),
FuseArg::Output(..) => outputs.line_size(arg),
_ => panic!("Invalid ref layout"),
},
RefLayout::Virtual(_) => 1,
},
};
let address_type = inputs
.required_address_type()
.max(outputs.required_address_type());
if line_sizes.out == 1 && (line_sizes.lhs > 1 || line_sizes.rhs > 1) {
return Err(FusedMatmulError::InvalidInput(
"Output line size of 1 removes the gain from fusion",
));
}
if let MatmulArg::Quantized { scheme, .. } = self.matmul.lhs {
line_sizes.lhs *= scheme.num_quants();
}
if let MatmulArg::Quantized { scheme, .. } = self.matmul.rhs {
line_sizes.rhs *= scheme.num_quants();
}
let out_strides = MatrixLayout::RowMajor.to_strides(&out_shape);
let problem = MatmulProblem::from_shapes_and_strides(
lhs_shape,
rhs_shape,
out_shape,
lhs_strides,
rhs_strides,
out_strides,
dtypes.as_global_elems(),
address_type,
self.matmul.lhs.scheme(),
self.matmul.rhs.scheme(),
);
match self.selector {
FusedMatmulSelector::Simple {
multi_rows,
tile_matmul,
} => with_tile_kind!(tile_matmul, Accelerated, || match launch_inner_fix_dtype::<
R,
SimpleAlgorithm<Accelerated>,
>(
client,
FusedMatmulInputLaunch::new(
inputs,
config.clone(),
self.matmul.lhs.clone(),
self.matmul.rhs.clone(),
None,
self.matmul.out.clone(),
),
outputs,
problem,
line_sizes,
&BlueprintStrategy::Inferred(SimpleArgs { multi_rows }),
) {
Ok(_) => Ok(()),
Err(err) => Err(FusedMatmulError::LaunchError(err)),
}),
FusedMatmulSelector::DoubleBuffering {
specialized,
tile_matmul,
} => with_tile_kind!(tile_matmul, Accelerated, || match launch_inner_fix_dtype::<
R,
CyclicDoubleBufferingAlgorithm<Accelerated>,
>(
client,
FusedMatmulInputLaunch::new(
inputs,
config.clone(),
self.matmul.lhs.clone(),
self.matmul.rhs.clone(),
None,
self.matmul.out.clone(),
),
outputs,
problem,
line_sizes,
&BlueprintStrategy::Inferred(DoubleBufferingArgs { specialized }),
) {
Ok(_) => Ok(()),
Err(err) => Err(FusedMatmulError::LaunchError(err)),
}),
FusedMatmulSelector::OrderedDoubleBuffering { tile_matmul } => {
let row_count = match self.matmul.lhs.precision() {
FuseType::F16 | FuseType::BF16 => 8,
_ => 4,
};
with_tile_kind!(tile_matmul, Accelerated, || match launch_inner_fix_dtype::<
R,
OrderedDoubleBufferingAlgorithm<Accelerated>,
>(
client,
FusedMatmulInputLaunch::new(
inputs,
config.clone(),
self.matmul.lhs.clone(),
self.matmul.rhs.clone(),
None,
self.matmul.out.clone(),
),
outputs,
problem,
line_sizes,
&BlueprintStrategy::Inferred(OrderedSelectionArgs {
row_count: Some(row_count),
rows_per_plane: Some(2),
partition_k: Some(2),
}),
) {
Ok(_) => Ok(()),
Err(err) => Err(FusedMatmulError::LaunchError(err)),
})
}
FusedMatmulSelector::SimpleUnit => {
match launch_inner_fix_dtype::<R, SimpleUnitAlgorithm>(
client,
FusedMatmulInputLaunch::new(
inputs,
config.clone(),
self.matmul.lhs.clone(),
self.matmul.rhs.clone(),
None,
self.matmul.out.clone(),
),
outputs,
problem,
line_sizes,
&Default::default(),
) {
Ok(_) => Ok(()),
Err(err) => Err(FusedMatmulError::LaunchError(err)),
}
}
FusedMatmulSelector::DoubleUnit => {
match launch_inner_fix_dtype::<R, DoubleUnitAlgorithm>(
client,
FusedMatmulInputLaunch::new(
inputs,
config.clone(),
self.matmul.lhs.clone(),
self.matmul.rhs.clone(),
None,
self.matmul.out.clone(),
),
outputs,
problem,
line_sizes,
&Default::default(),
) {
Ok(_) => Ok(()),
Err(err) => Err(FusedMatmulError::LaunchError(err)),
}
}
FusedMatmulSelector::SimpleVecMat => {
match launch_inner_fix_dtype::<R, SimpleVecMatAlgorithm>(
client,
FusedMatmulInputLaunch::new(
inputs,
config.clone(),
self.matmul.lhs.clone(),
self.matmul.rhs.clone(),
None,
self.matmul.out.clone(),
),
outputs,
problem,
line_sizes,
&Default::default(),
) {
Ok(_) => Ok(()),
Err(err) => Err(FusedMatmulError::LaunchError(err)),
}
}
FusedMatmulSelector::DoubleVecMat => {
match launch_inner_fix_dtype::<R, DoubleVecMatAlgorithm>(
client,
FusedMatmulInputLaunch::new(
inputs,
config.clone(),
self.matmul.lhs.clone(),
self.matmul.rhs.clone(),
None,
self.matmul.out.clone(),
),
outputs,
problem,
line_sizes,
&Default::default(),
) {
Ok(_) => Ok(()),
Err(err) => Err(FusedMatmulError::LaunchError(err)),
}
}
}
}
}
fn launch_inner_fix_dtype<'a, R: Runtime, A: Routine<()>>(
client: &ComputeClient<R>,
input: FusedMatmulInputLaunch<'a, R>,
output: GlobalArgsLaunch<'a, R>,
problem: MatmulProblem,
line_sizes: MatmulLineSizes,
blueprint_strategy: &BlueprintStrategy<(), A>,
) -> Result<(), MatmulSetupError> {
launch_kernel_virtual::<FusedMatmulArgs, R, A>(
client,
input,
output,
(),
problem,
line_sizes,
blueprint_strategy,
)
}

View File

@@ -0,0 +1,269 @@
use super::optimization::MatmulOptimizationTuneArg;
use crate::{
CubeFusionHandle,
engine::trace::TuneOutput,
optim::matmul::{AcceleratedTileKind, FusedMatmulSelector},
tune::{TuneContext, TuneInput},
};
use burn_fusion::stream::Context;
use cubecl::{
AutotuneKey, CubeElement, CubeTuneId, Runtime,
tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner},
};
use cubek::matmul::{
definition::MatmulKind,
launch::{MatmulAutotuneKey, MatmulGlobalScale, should_tune_double_buffering},
};
use serde::{Deserialize, Serialize};
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
pub struct FusedMatmulAutotuneKey {
matmul_key: MatmulAutotuneKey,
#[autotune(anchor)]
num_out_buffers: usize,
#[autotune(anchor)]
num_ops: usize,
}
/// Executes autotune on matmul operations
pub fn fused_matmul_autotune<R: Runtime, BT: CubeElement>(
optimization: MatmulOptimizationTuneArg<R>,
context: &mut Context<CubeFusionHandle<R>>,
) {
static TUNER: LocalTuner<FusedMatmulAutotuneKey, CubeTuneId> = local_tuner!();
let tunables = TUNER.init(|| {
const PRIORITY_MAX: i8 = 3;
const PRIORITY_HIGH: i8 = 2;
const PRIORITY_MEDIUM: i8 = 1;
const PRIORITY_MIN: i8 = 0;
let cmma = TuneGroup::<FusedMatmulAutotuneKey>::new("cmma", |key| {
if matches!(
key.matmul_key.analysis.kind,
MatmulKind::General
// Those variants are just because the unit alternatives aren't very good yet.
| MatmulKind::VecMat | MatmulKind::MatVec
) {
PRIORITY_MAX
} else {
PRIORITY_MEDIUM
}
});
let mma = TuneGroup::<FusedMatmulAutotuneKey>::new("mma", |key| {
if matches!(
key.matmul_key.analysis.kind,
// General is usually bad, but I think shapes like 16x8196 would be classed as
// general and are very good with MMA
// Should highly degenerated matrices that aren't VecMat have their own class?
MatmulKind::General | MatmulKind::VecMat | MatmulKind::MatVec
) {
PRIORITY_MAX
} else {
PRIORITY_MEDIUM
}
});
let odd = TuneGroup::<FusedMatmulAutotuneKey>::new("odd", |key| {
if key.matmul_key.definition.lhs_pow2_factor == 0
|| key.matmul_key.definition.rhs_pow2_factor == 0
{
PRIORITY_MAX
} else {
PRIORITY_MIN
}
});
let unit = TuneGroup::<FusedMatmulAutotuneKey>::new("unit", |key| {
if !matches!(key.matmul_key.analysis.kind, MatmulKind::General)
|| matches!(
key.matmul_key.analysis.scale_global,
MatmulGlobalScale::Small
)
{
PRIORITY_MAX
} else {
PRIORITY_MIN
}
});
fn double_buffering_priority(key: &FusedMatmulAutotuneKey, max: i8, min: i8) -> i8 {
if should_tune_double_buffering(key.num_out_buffers > 1, &key.matmul_key) {
max
} else {
min
}
}
let mut set = TunableSet::new(create_key::<R>, input_gen::<R>).with(Tunable::new(
"fused_matmul_fallback",
tune_fallback::<R, BT>,
)); // First one should always work.
// Unit matmuls
for (selector, double_buf) in [
(FusedMatmulSelector::SimpleUnit, false),
(FusedMatmulSelector::DoubleUnit, true),
(FusedMatmulSelector::SimpleVecMat, false),
(FusedMatmulSelector::DoubleVecMat, true),
] {
set = set.with(
Tunable::new(selector.name(), move |input| {
tune_fused::<R, BT>(input, selector)
})
.group(&unit, move |key| match double_buf {
true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH),
false => PRIORITY_MAX,
}),
);
}
// Accelerated matmuls
for (tile_matmul, group) in [
(AcceleratedTileKind::Cmma, &cmma),
(AcceleratedTileKind::Mma, &mma),
] {
for (selector, double_buf, extra_group) in [
(
FusedMatmulSelector::Simple {
multi_rows: false,
tile_matmul,
},
false,
None,
),
(
FusedMatmulSelector::Simple {
multi_rows: true,
tile_matmul,
},
false,
None,
),
(
FusedMatmulSelector::OrderedDoubleBuffering { tile_matmul },
true,
None,
),
(
FusedMatmulSelector::DoubleBuffering {
specialized: false,
tile_matmul,
},
true,
None,
),
(
FusedMatmulSelector::DoubleBuffering {
specialized: true,
tile_matmul,
},
true,
Some(&odd),
),
] {
let mut tunable = Tunable::new(selector.name(), move |input| {
tune_fused::<R, BT>(input, selector)
})
.group(group, move |key| match double_buf {
true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH),
false => PRIORITY_MAX,
});
if let Some(group) = extra_group {
tunable = tunable.group(group, |_| PRIORITY_MAX);
}
set = set.with(tunable);
}
}
set
});
TUNER.execute(
&CubeTuneId::new(&optimization.info.client, &optimization.info.device),
&optimization.info.client.clone(),
tunables,
TuneInput::new(context, optimization),
);
}
pub(crate) fn create_key<R: Runtime>(
input: &TuneInput<R, MatmulOptimizationTuneArg<R>>,
) -> FusedMatmulAutotuneKey {
let opt = input.optimization();
let context = match input.context() {
TuneContext::Original(context) => context,
TuneContext::Fork(_) => panic!("Not supported when generating key"),
};
let lhs = context.tensors.get(&opt.info.matmul.op.lhs.id).unwrap();
let rhs = context.tensors.get(&opt.info.matmul.op.rhs.id).unwrap();
let out = context.tensors.get(&opt.info.matmul.op.out.id).unwrap();
let lhs_strides = context
.handles
.get_handle(&lhs.id, &burn_ir::TensorStatus::ReadOnly)
.strides;
let rhs_strides = context
.handles
.get_handle(&rhs.id, &burn_ir::TensorStatus::ReadOnly)
.strides;
let key = MatmulAutotuneKey::generate(
&opt.info.client,
&lhs.shape,
&rhs.shape,
&lhs_strides,
&rhs_strides,
lhs.dtype.into(),
rhs.dtype.into(),
out.dtype.into(),
opt.info.matmul.lhs.scheme(),
opt.info.matmul.rhs.scheme(),
);
FusedMatmulAutotuneKey::new(key, opt.info.num_output_buffers(), opt.info.num_ops_fused())
}
fn input_gen<R: Runtime>(
_key: &FusedMatmulAutotuneKey,
input: &TuneInput<R, MatmulOptimizationTuneArg<R>>,
) -> TuneInput<R, MatmulOptimizationTuneArg<R>> {
input.clone()
}
fn tune_fused<R: Runtime, BT: CubeElement>(
input: TuneInput<R, MatmulOptimizationTuneArg<R>>,
selector: FusedMatmulSelector,
) -> Result<TuneOutput<R>, String> {
let optimization = input.optimization();
let context = input.context();
match context {
TuneContext::Original(context) => match optimization.execute_fused::<BT>(context, selector)
{
Ok(out) => Ok(out),
Err(_) => {
return tune_fallback::<R, BT>(input);
}
},
TuneContext::Fork(mut context_owned) => {
optimization.execute_fused::<BT>(&mut context_owned.as_context(), selector)
}
}
.map_err(|e| format!("{e:?}"))
}
fn tune_fallback<R: Runtime, BT: CubeElement>(
input: TuneInput<R, MatmulOptimizationTuneArg<R>>,
) -> Result<TuneOutput<R>, String> {
let optimization = input.optimization();
let context = input.context();
Ok(match context {
TuneContext::Original(context) => optimization.execute_fallback::<BT>(context),
TuneContext::Fork(mut context_owned) => {
optimization.execute_fallback::<BT>(&mut context_owned.as_context())
}
})
}

View File

@@ -0,0 +1,8 @@
pub mod elemwise;
pub mod matmul;
pub mod reduce;
pub mod reduce_broadcasted;
mod base;
pub use base::*;

View File

@@ -0,0 +1,208 @@
use crate::engine::codegen::{
io::{ref_buffer_len, ref_len, ref_line_size, ref_shape, ref_stride},
ir::{FuseArg, FuseBlockConfig, GlobalArgs, GlobalArgsExpand, LocalArgs, LocalArgsExpand},
kernel::{fuse_on_read, fuse_on_write, init_locals},
};
use cubecl::prelude::*;
use cubek::reduce::components::args::{ReduceArgs, ReduceDType};
#[derive(Clone)]
pub struct FusedReduceArgs;
#[derive(CubeType, CubeLaunch)]
pub struct FusedReduceInput {
pub global: GlobalArgs,
#[cube(comptime)]
pub config: FuseBlockConfig,
#[cube(comptime)]
pub arg: FuseArg,
}
#[derive(CubeType, CubeLaunch)]
pub struct FusedReduceOutput {
pub global: GlobalArgs,
#[cube(comptime)]
pub config: FuseBlockConfig,
#[cube(comptime)]
pub arg: FuseArg,
}
pub struct FusedReduceState {
inputs: *const GlobalArgs,
outputs: *mut GlobalArgs,
locals_on_read: *mut LocalArgs,
locals_on_write: *mut LocalArgs,
config_on_read: FuseBlockConfig,
config_on_write: FuseBlockConfig,
// TODO: Should be a list when multiple blocks are there.
input: FuseArg,
out: FuseArg,
}
#[derive(Clone)]
pub struct FusedReduceStateExpand {
inputs: GlobalArgsExpand,
outputs: GlobalArgsExpand,
locals_on_read: LocalArgsExpand,
locals_on_write: LocalArgsExpand,
config_on_read: FuseBlockConfig,
config_on_write: FuseBlockConfig,
input: FuseArg,
out: FuseArg,
}
#[cube]
impl ReduceArgs for FusedReduceArgs {
type Input<E: Numeric> = FusedReduceInput;
type Output<E: Numeric> = FusedReduceOutput;
type State<P: ReduceDType> = FusedReduceState;
fn init_state<P: ReduceDType>(
input: &Self::Input<P::In>,
output: &mut Self::Output<P::Out>,
) -> Self::State<P> {
let mut locals_read = init_locals(&input.global, &mut output.global, &input.config);
let mut locals_write = init_locals(&input.global, &mut output.global, &output.config);
// TODO Add stuff from previous blocks to the local of each block.
FusedReduceState::new(input, output, &mut locals_read, &mut locals_write)
}
fn read_input<P: ReduceDType>(state: &Self::State<P>, index: usize) -> Line<P::In> {
let value = fuse_on_read::<P::In>(
unsafe { &(*state.inputs) },
unsafe { &mut (*state.outputs) },
unsafe { &mut (*state.locals_on_read) },
index,
comptime! {
let mut sequence = Sequence::new();
// TODO: Register local arguments from previous blocks.
sequence.push(state.input.clone());
sequence
},
&state.config_on_read,
)[0];
value
}
fn read_output<P: ReduceDType>(_state: &Self::State<P>, _index: usize) -> Line<P::Out> {
Line::empty(1usize)
}
fn write_output<P: ReduceDType>(state: &mut Self::State<P>, index: usize, value: Line<P::Out>) {
let mut values = Registry::<FuseArg, Line<P::Out>>::new();
let mut args = comptime![Vec::<FuseArg>::new()];
values.insert(comptime![state.out.clone()], value);
comptime![args.push(state.out.clone())];
fuse_on_write(
unsafe { &(*state.inputs) },
unsafe { &mut (*state.outputs) },
unsafe { &mut (*state.locals_on_write) },
index,
values,
args,
&state.config_on_write,
);
}
fn len_input<P: ReduceDType>(state: &Self::State<P>) -> usize {
ref_len(
unsafe { &(*state.inputs) },
unsafe { &(*state.outputs) },
unsafe { &(*state.locals_on_read) },
&state.config_on_read,
)
}
fn len_output<P: ReduceDType>(state: &Self::State<P>) -> usize {
ref_len(
unsafe { &(*state.inputs) },
unsafe { &(*state.outputs) },
unsafe { &(*state.locals_on_write) },
&state.config_on_write,
)
}
fn buffer_len_input<P: ReduceDType>(state: &Self::State<P>) -> usize {
ref_buffer_len(
unsafe { &(*state.inputs) },
unsafe { &(*state.outputs) },
unsafe { &(*state.locals_on_read) },
&state.config_on_read,
)
}
fn buffer_len_output<P: ReduceDType>(state: &Self::State<P>) -> usize {
ref_buffer_len(
unsafe { &(*state.inputs) },
unsafe { &(*state.outputs) },
unsafe { &(*state.locals_on_write) },
&state.config_on_write,
)
}
fn rank_input<P: ReduceDType>(state: &Self::State<P>) -> usize {
state.config_on_read.rank.runtime()
}
fn rank_output<P: ReduceDType>(state: &Self::State<P>) -> usize {
state.config_on_write.rank.runtime()
}
fn shape_input<P: ReduceDType>(state: &Self::State<P>, dim: usize) -> usize {
ref_shape(unsafe { &(*state.locals_on_read) }, dim)
}
fn shape_output<P: ReduceDType>(state: &Self::State<P>, dim: usize) -> usize {
ref_shape(unsafe { &(*state.locals_on_write) }, dim)
}
fn stride_input<P: ReduceDType>(state: &Self::State<P>, dim: usize) -> usize {
ref_stride(unsafe { &(*state.locals_on_read) }, dim)
}
fn stride_output<P: ReduceDType>(state: &Self::State<P>, dim: usize) -> usize {
ref_stride(unsafe { &(*state.locals_on_write) }, dim)
}
fn line_size_input<P: ReduceDType>(state: &Self::State<P>) -> comptime_type!(LineSize) {
ref_line_size(unsafe { &(*state.locals_on_read) })
}
fn line_size_output<P: ReduceDType>(state: &Self::State<P>) -> comptime_type!(LineSize) {
ref_line_size(unsafe { &(*state.locals_on_write) })
}
}
#[cube]
impl FusedReduceState {
pub fn new(
inputs: &FusedReduceInput,
outputs: &mut FusedReduceOutput,
locals_on_read: &mut LocalArgs,
locals_on_write: &mut LocalArgs,
) -> FusedReduceState {
FusedReduceState {
inputs: &inputs.global,
outputs: &mut outputs.global,
locals_on_read,
locals_on_write,
config_on_read: comptime![inputs.config.clone()],
config_on_write: comptime![outputs.config.clone()],
input: comptime![inputs.arg.clone()],
out: comptime![outputs.arg.clone()],
}
}
}
impl CubeType for FusedReduceState {
type ExpandType = FusedReduceStateExpand;
}
impl IntoMut for FusedReduceStateExpand {
fn into_mut(self, _context: &mut Scope) -> Self {
self
}
}
impl CubeDebug for FusedReduceStateExpand {}

View File

@@ -0,0 +1,328 @@
use super::{
ReduceSettings,
optimization::{FusedReduce, ReduceInstruction, ReduceOptimization},
};
use crate::{
engine::{
codegen::ir::FuseType,
fuser::TraceOperationFuser,
settings::{FuseSettings, RefLayoutSetting, VectorizationSetting},
},
optim::CubeOptimization,
};
use burn_fusion::{FuserStatus, OperationFuser};
use burn_ir::{NumericOperationIr, OperationIr, ReduceDimOpIr};
use burn_std::Shape;
use cubecl::Runtime;
/// Fuses element wise operations around a reduce operation.
pub struct ReduceFuser<R: Runtime> {
pub(crate) fuser: TraceOperationFuser,
pub(crate) fuser_read_fallback: TraceOperationFuser,
fuser_write_fallback: TraceOperationFuser,
settings_write: FuseSettings,
pub(crate) device: R::Device,
pub(crate) reduce: Option<FusedReduce>,
settings: ReduceSettings,
}
impl<R: Runtime> Clone for ReduceFuser<R> {
fn clone(&self) -> Self {
Self {
fuser: self.fuser.clone(),
fuser_read_fallback: self.fuser_read_fallback.clone(),
fuser_write_fallback: self.fuser_write_fallback.clone(),
settings_write: self.settings_write,
device: self.device.clone(),
reduce: self.reduce.clone(),
settings: self.settings,
}
}
}
#[derive(Debug)]
pub enum ReduceFuserInfo {
FusedReduce { shape_input_id: Shape, axis: usize },
FusedElemwise { shape_id: Shape },
}
impl<R: Runtime> ReduceFuser<R> {
pub fn new(device: R::Device, bool_precision: FuseType, settings: ReduceSettings) -> Self {
let client = R::client(&device);
let props = client.properties();
let max_bindings = props.hardware.max_bindings;
let settings_read = FuseSettings {
// Inplace would work, but not when we have a concrete output to write too.
inplace: true,
ref_layout: RefLayoutSetting::OnlyContiguous,
broadcast: false,
output_shape_updates: true,
vectorization: VectorizationSetting::Activated,
};
let settings_write = FuseSettings {
inplace: false,
output_shape_updates: false,
vectorization: VectorizationSetting::SmallerOrEqualThanPreviousBlock { block_pos: 0 },
broadcast: false,
ref_layout: RefLayoutSetting::OnlyContiguous,
};
let settings_fallback = FuseSettings::default();
Self {
fuser: TraceOperationFuser::new(max_bindings, bool_precision, settings_read),
fuser_read_fallback: TraceOperationFuser::new(
max_bindings,
bool_precision,
settings_fallback,
),
fuser_write_fallback: TraceOperationFuser::new(
max_bindings,
bool_precision,
settings_fallback,
),
settings_write,
device,
reduce: None,
settings,
}
}
pub fn reduce_info(&self) -> ReduceFuserInfo {
match &self.reduce {
Some(reduce) => {
let shape_input_id = reduce.op.input.shape.clone();
let axis = reduce.axis;
ReduceFuserInfo::FusedReduce {
shape_input_id,
axis,
}
}
None => {
let shape_id = self.fuser_read_fallback.current_output_shape.clone();
ReduceFuserInfo::FusedElemwise { shape_id }
}
}
}
fn on_reduce(&mut self, op: &ReduceDimOpIr, inst: ReduceInstruction) {
// TODO: Fix: we need to have fuse-on-read with an identity block.
//
// if self.fuser.num_ops == 0 && false {
// self.fuser.current_output_shape = op.input.shape.dims.clone();
// } else if self.fuser.current_output_shape != op.input.shape.dims {
if self.fuser.current_output_shape != op.input.shape {
self.fuser.close();
self.fuser_read_fallback.close();
return;
}
let [input] = self
.fuser
.next_block([&op.input], self.settings_write, false);
let output = self.fuser.output_unhandled(&op.out);
let axis = op.axis;
let fuse_on_write_activated = match self.settings {
ReduceSettings::Always => true,
// We only activate fuse-on-write when the reduction isn't on the last dimension, otherwise
// vectorization is impossible. Only [LineMode::Perpendicular] supports vectorization.
//
// We could still fuse some output operations, but it would probably lead to worse performance.
ReduceSettings::OnlyParallel => axis != op.input.shape.rank() - 1,
ReduceSettings::Never => false,
};
if !fuse_on_write_activated {
self.fuser.close();
}
let acc = match inst {
ReduceInstruction::Mean | ReduceInstruction::Prod | ReduceInstruction::Sum => {
match input.precision() {
FuseType::F16 | FuseType::BF16 => FuseType::F32,
FuseType::I16 | FuseType::I8 => FuseType::I32,
FuseType::U16 | FuseType::U8 => FuseType::U32,
_ => input.precision(),
}
}
_ => input.precision(),
};
self.reduce = Some(FusedReduce {
input,
output,
acc,
axis,
op: op.clone(),
use_planes: false,
shared: false,
inst,
});
self.fuser_read_fallback.close();
}
fn on_elemwise_read(&mut self, operation: &OperationIr) {
let can_register =
self.fuser.can_fuse(operation) && self.fuser_read_fallback.can_fuse(operation);
match can_register {
true => {
self.fuser.fuse(operation);
self.fuser_read_fallback.fuse(operation);
}
false => {
self.fuser.close();
self.fuser_read_fallback.close();
}
};
}
fn on_elemwise_write(&mut self, operation: &OperationIr) {
let can_register =
self.fuser.can_fuse(operation) && self.fuser_write_fallback.can_fuse(operation);
match can_register {
true => {
self.fuser.fuse(operation);
self.fuser_write_fallback.fuse(operation);
}
false => {
self.fuser.close();
self.fuser_write_fallback.close();
}
};
}
}
impl<R: Runtime> OperationFuser<CubeOptimization<R>> for ReduceFuser<R> {
fn fuse(&mut self, operation: &OperationIr) {
if let FuserStatus::Closed = self.fuser.status() {
return;
}
if self.reduce.is_none() {
if let OperationIr::NumericFloat(_, op) = operation {
match op {
NumericOperationIr::SumDim(op) => {
self.on_reduce(op, ReduceInstruction::Sum);
}
NumericOperationIr::MeanDim(op) => {
self.on_reduce(op, ReduceInstruction::Mean);
}
NumericOperationIr::ProdDim(op) => {
self.on_reduce(op, ReduceInstruction::Prod);
}
NumericOperationIr::ArgMax(op) => {
self.on_reduce(op, ReduceInstruction::ArgMax);
}
NumericOperationIr::ArgMin(op) => {
self.on_reduce(op, ReduceInstruction::ArgMin);
}
NumericOperationIr::MinDim(op) => {
self.on_reduce(op, ReduceInstruction::Min);
}
NumericOperationIr::MaxDim(op) => {
self.on_reduce(op, ReduceInstruction::Max);
}
NumericOperationIr::MaxAbsDim(op) => {
self.on_reduce(op, ReduceInstruction::MaxAbs);
}
_ => {
self.on_elemwise_read(operation);
}
};
} else if let OperationIr::NumericInt(_, op) = operation {
match op {
NumericOperationIr::SumDim(op) => {
self.on_reduce(op, ReduceInstruction::Sum);
}
NumericOperationIr::MeanDim(op) => {
self.on_reduce(op, ReduceInstruction::Mean);
}
NumericOperationIr::ProdDim(op) => {
self.on_reduce(op, ReduceInstruction::Prod);
}
NumericOperationIr::ArgMax(op) => {
self.on_reduce(op, ReduceInstruction::ArgMax);
}
NumericOperationIr::ArgMin(op) => {
self.on_reduce(op, ReduceInstruction::ArgMin);
}
NumericOperationIr::MinDim(op) => {
self.on_reduce(op, ReduceInstruction::Min);
}
NumericOperationIr::MaxDim(op) => {
self.on_reduce(op, ReduceInstruction::Max);
}
NumericOperationIr::MaxAbsDim(op) => {
self.on_reduce(op, ReduceInstruction::MaxAbs);
}
_ => {
self.on_elemwise_read(operation);
}
};
} else {
self.on_elemwise_read(operation);
}
} else {
self.on_elemwise_write(operation);
}
}
fn finish(&mut self) -> CubeOptimization<R> {
let client = R::client(&self.device);
let trace = self.fuser.finish();
let trace_read_fallback = self.fuser_read_fallback.finish();
let trace_write_fallback = self.fuser_write_fallback.finish();
let fuse_reduce = self.reduce.as_ref().unwrap();
let reduce = ReduceOptimization::new(
trace,
trace_read_fallback,
trace_write_fallback,
client,
self.device.clone(),
self.len(),
self.fuser_read_fallback.len(),
fuse_reduce.clone(),
self.settings,
);
CubeOptimization::Reduce(reduce)
}
fn reset(&mut self) {
self.fuser.reset();
self.fuser_read_fallback.reset();
self.fuser_write_fallback.reset();
self.reduce = None;
}
fn status(&self) -> burn_fusion::FuserStatus {
self.fuser.status()
}
fn properties(&self) -> burn_fusion::FuserProperties {
let mut properties = self.fuser.properties();
if self.reduce.is_some() {
properties.ready = true;
properties.score += 1;
} else {
properties.ready = false;
};
properties
}
fn len(&self) -> usize {
self.fuser.len() + if self.reduce.is_some() { 1 } else { 0 }
}
fn clone_dyn(&self) -> Box<dyn OperationFuser<CubeOptimization<R>>> {
Box::new(self.clone())
}
}

View File

@@ -0,0 +1,8 @@
mod fuser;
mod optimization;
pub(crate) mod args;
pub(crate) mod tune;
pub use fuser::*;
pub use optimization::*;

View File

@@ -0,0 +1,492 @@
use super::args::{
FusedReduceInput, FusedReduceInputLaunch, FusedReduceOutput, FusedReduceOutputLaunch,
};
#[cfg(feature = "autotune")]
use super::tune::fused_reduce_autotune;
use crate::{
CubeFusionHandle, FallbackOperation,
engine::{
codegen::ir::{
FuseArg, FuseBlockConfig, FuseType, GlobalArgsLaunch, RefLayout,
multi_block_variables_init,
},
launch::{
FuseTraceLauncher,
runner::{TraceRunner, Vectorization},
},
trace::{FuseTrace, TraceError, TuneOutput},
},
optim::{elemwise::ElemwiseRunner, reduce::args::FusedReduceArgs},
};
use burn_fusion::stream::Context;
use burn_ir::ReduceDimOpIr;
use burn_std::DType;
use cubecl::{Runtime, client::ComputeClient, ir::StorageType, prelude::*};
use cubek::reduce::{
LineMode, ReduceDtypes, ReduceError,
components::instructions::ReduceOperationConfig,
init_tensors,
launch::{RoutineStrategy, reduce_kernel_virtual},
routines::{
ReduceBlueprint, ReduceLaunchSettings, ReduceLineSettings, ReduceProblem, Routine,
cube::CubeRoutine, plane::PlaneRoutine, unit::UnitRoutine,
},
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[cfg(not(feature = "autotune"))]
use cubek::reduce::routines::{BlueprintStrategy, unit::UnitStrategy};
pub struct ReduceOptimization<R: Runtime> {
pub(crate) info: Arc<ReduceOptimizationInfo<R>>,
}
pub(crate) struct ReduceOptimizationInfo<R: Runtime> {
pub(crate) trace: FuseTrace,
trace_read_fallback: FuseTrace,
trace_write_fallback: FuseTrace,
pub(crate) client: ComputeClient<R>,
pub(crate) device: R::Device,
pub(crate) len: usize,
pub(crate) len_read: usize,
pub(crate) reduce: FusedReduce,
settings: ReduceSettings,
}
impl<R: Runtime> ReduceOptimizationInfo<R> {
pub fn from_state(device: &R::Device, state: ReduceOptimizationState) -> Self {
let client = R::client(device);
Self {
trace: state.trace,
trace_read_fallback: state.trace_read_fallback,
trace_write_fallback: state.trace_write_fallback,
client,
device: device.clone(),
len: state.len,
len_read: state.len_read,
reduce: state.reduce,
settings: state.settings,
}
}
pub fn to_state(&self) -> ReduceOptimizationState {
ReduceOptimizationState {
trace: self.trace.clone(),
trace_read_fallback: self.trace_read_fallback.clone(),
trace_write_fallback: self.trace_write_fallback.clone(),
len: self.len,
len_read: self.len_read,
reduce: self.reduce.clone(),
settings: self.settings,
}
}
}
#[derive(Serialize, Deserialize, Copy, Clone)]
pub enum ReduceSettings {
Always,
/// We only activate fuse-on-write when the reduction isn't on the last dimension, otherwise
/// vectorization is impossible. Only [LineMode::Perpendicular] supports vectorization.
///
/// We could still fuse some output operations, but it would probably lead to worse performance.
OnlyParallel,
Never,
}
pub(crate) struct ReduceOptimizationTuneArg<R: Runtime> {
pub(crate) info: Arc<ReduceOptimizationInfo<R>>,
pub(crate) fallback: Arc<Box<dyn FallbackOperation<R>>>,
}
impl<R: Runtime> Clone for ReduceOptimizationTuneArg<R> {
fn clone(&self) -> Self {
Self {
info: self.info.clone(),
fallback: self.fallback.clone(),
}
}
}
#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
pub enum ReduceInstruction {
ArgMax,
ArgMin,
Mean,
Prod,
Sum,
Max,
Min,
MaxAbs,
}
pub trait ReduceFallbackFn<R: Runtime>: Send + Sync {
fn run(&self, context: &mut Context<'_, CubeFusionHandle<R>>);
}
#[derive(Serialize, Deserialize)]
pub struct ReduceOptimizationState {
pub(crate) trace: FuseTrace,
pub(crate) trace_read_fallback: FuseTrace,
pub(crate) trace_write_fallback: FuseTrace,
pub(crate) reduce: FusedReduce,
pub(crate) len: usize,
pub(crate) len_read: usize,
pub(crate) settings: ReduceSettings,
}
impl core::fmt::Debug for ReduceOptimizationState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!(
"{{ len_read: {}, len_total: {} }}",
self.len_read, self.len
))
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FusedReduce {
pub(crate) input: FuseArg,
pub(crate) output: FuseArg,
pub(crate) acc: FuseType,
pub(crate) axis: usize,
pub(crate) op: ReduceDimOpIr,
pub(crate) use_planes: bool,
pub(crate) shared: bool,
pub(crate) inst: ReduceInstruction,
}
#[derive(new)]
pub struct FusedReduceLaunch<'a> {
reduce: &'a FusedReduce,
strategy: RoutineStrategy,
}
#[derive(Debug)]
pub enum FusedReduceError {
Reduce(ReduceError),
InvalidSelection(Box<&'static str>),
InvalidInput,
}
impl From<ReduceError> for FusedReduceError {
fn from(value: ReduceError) -> Self {
Self::Reduce(value)
}
}
impl<R: Runtime> ReduceOptimizationTuneArg<R> {
pub fn execute_fused<BT: CubeElement>(
&self,
context: &mut Context<'_, CubeFusionHandle<R>>,
strategy: RoutineStrategy,
) -> Result<TuneOutput<R>, TraceError<FusedReduceError>> {
let launch = FusedReduceLaunch::new(&self.info.reduce, strategy);
let launcher = FuseTraceLauncher::new(&self.info.trace, &launch);
launcher.launch::<BT>(&self.info.client, &self.info.device, context)
}
pub fn execute_fallback<BT: CubeElement>(
&self,
context: &mut Context<'_, CubeFusionHandle<R>>,
) -> TuneOutput<R> {
let launcher = FuseTraceLauncher::new(&self.info.trace_read_fallback, &ElemwiseRunner);
#[allow(unused_mut)] // It is used when `autotune-checks` is activated.
let mut output_read = launcher
.launch::<BT>(&self.info.client, &self.info.device, context)
.unwrap();
self.fallback.run(context);
#[cfg(feature = "autotune-checks")]
if let TuneOutput::Checked { handles } = &mut output_read {
let out_desc = context.tensors.get(&self.info.reduce.op.out.id).unwrap();
let handle_out = context
.handles
.get_handle(&out_desc.id, &burn_ir::TensorStatus::ReadOnly);
handles.insert(
self.info.reduce.op.out.id,
(out_desc.shape.dims.clone(), handle_out.clone()),
);
}
let launcher = FuseTraceLauncher::new(&self.info.trace_write_fallback, &ElemwiseRunner);
let output_write = launcher
.launch::<BT>(&self.info.client, &self.info.device, context)
.unwrap();
output_read.merge(output_write)
}
}
#[allow(clippy::too_many_arguments)]
impl<R: Runtime> ReduceOptimization<R> {
pub fn new(
trace: FuseTrace,
trace_read_fallback: FuseTrace,
trace_write_fallback: FuseTrace,
client: ComputeClient<R>,
device: R::Device,
len: usize,
len_read: usize,
reduce: FusedReduce,
settings: ReduceSettings,
) -> Self {
let info = ReduceOptimizationInfo {
trace,
trace_read_fallback,
trace_write_fallback,
client,
device,
len,
len_read,
reduce,
settings,
};
Self {
info: Arc::new(info),
}
}
/// Execute the optimization.
pub fn execute<BT: CubeElement>(
&mut self,
context: &mut Context<'_, CubeFusionHandle<R>>,
fallback: impl FnOnce(usize) -> Box<dyn FallbackOperation<R>>,
) {
// The index of the fallback reduce is the number of ops fused as read.
let fallback = fallback(self.info.len_read);
let arg = ReduceOptimizationTuneArg {
info: self.info.clone(),
fallback: Arc::new(fallback),
};
#[cfg(feature = "autotune")]
fused_reduce_autotune::<R, BT>(arg, context);
#[cfg(not(feature = "autotune"))]
if arg
.execute_fused::<BT>(
context,
RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
)
.is_err()
{
arg.execute_fallback::<BT>(context);
}
}
pub fn num_output_buffers(&self) -> usize {
self.info.trace_read_fallback.resources.outputs.len()
}
pub fn to_state(&self) -> ReduceOptimizationState {
ReduceOptimizationState {
trace: self.info.trace.clone(),
trace_read_fallback: self.info.trace_read_fallback.clone(),
trace_write_fallback: self.info.trace_write_fallback.clone(),
reduce: self.info.reduce.clone(),
len: self.info.len,
len_read: self.info.len_read,
settings: self.info.settings,
}
}
pub fn from_state(device: &R::Device, state: ReduceOptimizationState) -> Self {
let client = R::client(device);
let info = ReduceOptimizationInfo {
trace: state.trace,
trace_read_fallback: state.trace_read_fallback,
trace_write_fallback: state.trace_write_fallback,
reduce: state.reduce,
len: state.len,
len_read: state.len_read,
client,
device: device.clone(),
settings: state.settings,
};
Self {
info: Arc::new(info),
}
}
/// Returns the number of output buffers added by fusion.
pub fn num_ops_fused(&self) -> usize {
self.info.len
}
}
// TODO: Implement better vectorization here.
impl<R: Runtime> Vectorization<R> for FusedReduceLaunch<'_> {}
impl<R: Runtime> TraceRunner<R> for FusedReduceLaunch<'_> {
type Error = FusedReduceError;
fn run<'a>(
&'a self,
client: &'a ComputeClient<R>,
inputs: GlobalArgsLaunch<'a, R>,
outputs: GlobalArgsLaunch<'a, R>,
configs: &'a [FuseBlockConfig],
) -> Result<(), FusedReduceError> {
let [config_read, config_write] = [&configs[0], &configs[1]];
let shape = match &config_read.ref_layout {
RefLayout::Concrete(FuseArg::Output(..)) => {
outputs.shape_ref(&config_read.ref_layout, config_read.rank)
}
_ => inputs.shape_ref(&config_read.ref_layout, config_read.rank),
};
let reduce_count: usize = shape
.iter()
.enumerate()
.map(|(i, s)| if i == self.reduce.axis { 1 } else { *s })
.product();
let line_mode = match self.reduce.axis == config_read.rank - 1 {
true => LineMode::Parallel,
false => LineMode::Perpendicular,
};
let address_type = inputs
.required_address_type()
.max(outputs.required_address_type());
let settings = ReduceLineSettings {
line_mode,
line_size_input: config_read.width,
line_size_output: config_write.width,
};
let problem = ReduceProblem {
vector_size: shape[self.reduce.axis],
vector_count: reduce_count,
axis: self.reduce.axis,
dtypes: ReduceDtypes {
input: self.reduce.op.input.dtype.into(),
output: self.reduce.op.out.dtype.into(),
accumulation: self.reduce.acc.into_elem().into(),
},
address_type,
};
let (blueprint, settings) = match self.strategy.clone() {
RoutineStrategy::Unit(strategy) => {
let routine = UnitRoutine;
routine.prepare(client, problem, settings, strategy)?
}
RoutineStrategy::Plane(strategy) => {
let routine = PlaneRoutine;
routine.prepare(client, problem, settings, strategy)?
}
RoutineStrategy::Cube(strategy) => {
let routine = CubeRoutine;
routine.prepare(client, problem, settings, strategy)?
}
};
let kwargs = ReduceKwArgs {
client,
inputs,
outputs,
axis: self.reduce.axis,
config_fuse_read: config_read.clone(),
config_fuse_write: config_write.clone(),
input: self.reduce.input.clone(),
output: self.reduce.output.clone(),
blueprint,
settings,
};
let result = launch_reduce_mixed_precision(
kwargs,
self.reduce.inst,
self.reduce.op.input.dtype,
self.reduce.op.out.dtype,
DType::from(self.reduce.acc.into_elem()),
);
match result {
Ok(_) => Ok(()),
Err(err) => Err(FusedReduceError::Reduce(ReduceError::Launch(err))),
}
}
}
struct ReduceKwArgs<'a, 'b, Run: Runtime> {
client: &'b ComputeClient<Run>,
inputs: GlobalArgsLaunch<'a, Run>,
outputs: GlobalArgsLaunch<'a, Run>,
axis: usize,
blueprint: ReduceBlueprint,
settings: ReduceLaunchSettings,
config_fuse_read: FuseBlockConfig,
config_fuse_write: FuseBlockConfig,
input: FuseArg,
output: FuseArg,
}
fn launch_reduce_mixed_precision<Run: Runtime>(
kwargs: ReduceKwArgs<'_, '_, Run>,
instruction: ReduceInstruction,
dtype_input: DType,
dtype_output: DType,
dtype_acc: DType,
) -> Result<(), LaunchError> {
let config = match instruction {
ReduceInstruction::ArgMax => ReduceOperationConfig::ArgMax,
ReduceInstruction::ArgMin => ReduceOperationConfig::ArgMin,
ReduceInstruction::Prod => ReduceOperationConfig::Prod,
ReduceInstruction::Mean => ReduceOperationConfig::Mean,
ReduceInstruction::Sum => ReduceOperationConfig::Sum,
ReduceInstruction::Max => ReduceOperationConfig::Max,
ReduceInstruction::Min => ReduceOperationConfig::Min,
ReduceInstruction::MaxAbs => ReduceOperationConfig::MaxAbs,
};
launch_reduce::<Run>(kwargs, config, dtype_input, dtype_output, dtype_acc)
}
fn launch_reduce<Run: Runtime>(
kwargs: ReduceKwArgs<'_, '_, Run>,
inst: ReduceOperationConfig,
dtype_input: DType,
dtype_output: DType,
dtype_acc: DType,
) -> Result<(), LaunchError> {
unsafe {
reduce_kernel_fused::launch_unchecked::<Run>(
kwargs.client,
kwargs.settings.cube_count,
kwargs.settings.cube_dim,
kwargs.settings.address_type,
FusedReduceInputLaunch::new(kwargs.inputs, kwargs.config_fuse_read, kwargs.input),
FusedReduceOutputLaunch::new(kwargs.outputs, kwargs.config_fuse_write, kwargs.output),
ScalarArg::new(kwargs.axis),
kwargs.blueprint,
inst,
dtype_input.into(),
dtype_output.into(),
dtype_acc.into(),
)
}
}
#[cube(launch_unchecked, address_type = "dynamic")]
pub fn reduce_kernel_fused<In: Numeric, Out: Numeric, Acc: Numeric>(
input: &FusedReduceInput,
output: &mut FusedReduceOutput,
axis_reduce: usize,
#[comptime] blueprint: ReduceBlueprint,
#[comptime] config: ReduceOperationConfig,
#[define(In)] _input_dtype: StorageType,
#[define(Out)] _output_dtype: StorageType,
#[define(Acc)] _acc_dtype: StorageType,
) {
multi_block_variables_init(&input.config, &mut output.global.variables);
multi_block_variables_init(&output.config, &mut output.global.variables);
let (input, mut output) = init_tensors::<FusedReduceArgs, In, Out>(input, output);
reduce_kernel_virtual::<In, Out, Acc>(&input, &mut output, axis_reduce, blueprint, config);
}

View File

@@ -0,0 +1,196 @@
use super::optimization::ReduceOptimizationTuneArg;
use crate::{
CubeFusionHandle,
engine::trace::TuneOutput,
tune::{TuneContext, TuneInput},
};
use burn_fusion::stream::Context;
use cubecl::{
AutotuneKey, CubeElement, CubeTuneId, Runtime,
tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner},
};
use cubek::reduce::{
launch::{RoutineStrategy, tune_key::ReduceAutotuneKey},
routines::{BlueprintStrategy, cube::CubeStrategy, plane::PlaneStrategy, unit::UnitStrategy},
};
use serde::{Deserialize, Serialize};
/// Autotune key for standard fused reduction operations.
///
/// Records metadata about the fusion graph (IO and ops) alongside
/// the core reduction parameters to ensure stable kernel selection.
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
pub struct FusedReduceAutotuneKey {
reduce_key: ReduceAutotuneKey,
#[autotune(anchor)]
fuse_num_reads: usize,
#[autotune(anchor)]
fuse_num_writes: usize,
#[autotune(anchor)]
fuse_num_ops: usize,
}
/// Executes autotuning for fused reduction operations.
///
/// This tuner evaluates different hardware-specific strategies (Plane, Cube, Unit)
/// and assigns priorities based on the `vector_count` of the reduction.
pub fn fused_reduce_autotune<R: Runtime, BT: CubeElement>(
arg: ReduceOptimizationTuneArg<R>,
context: &mut Context<CubeFusionHandle<R>>,
) {
static TUNER: LocalTuner<FusedReduceAutotuneKey, CubeTuneId> = local_tuner!();
let tunables = TUNER.init(|| {
const PRIORITY_MAX: i8 = 2;
const PRIORITY_MIN: i8 = 1;
let mut set = TunableSet::new(create_key::<R>, input_gen::<R>);
let group = TuneGroup::<FusedReduceAutotuneKey>::new("fused_reduce", |_key| PRIORITY_MAX);
// Fallback implementation for robustness.
set = set.with(Tunable::new(
"fused_reduce_fallback",
tune_fallback::<R, BT>,
));
// Define properties to categorize hardware strategies.
enum ReduceProps {
GreatWithLowReduceCount,
GreatWithHighReduceCount,
Balanced,
}
let strategies = [
(
"fused_unit",
RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
ReduceProps::GreatWithHighReduceCount,
),
(
"fused_plane",
RoutineStrategy::Plane(BlueprintStrategy::Inferred(PlaneStrategy {
independent: true,
})),
ReduceProps::Balanced,
),
(
"fused_cube",
RoutineStrategy::Cube(BlueprintStrategy::Inferred(CubeStrategy {
// Two steps reduction doesn't work with fuse-on-write, we can't activate plane
// when using the cube algo.
use_planes: false,
})),
ReduceProps::GreatWithLowReduceCount,
),
];
for (name, strategy, props) in strategies {
let tunable = Tunable::new(name, move |input| tune_reduce::<R, BT>(input, &strategy))
.group(&group, move |key| match props {
ReduceProps::GreatWithLowReduceCount => {
if key.reduce_key.vector_count < 128 {
PRIORITY_MAX
} else {
PRIORITY_MIN
}
}
ReduceProps::GreatWithHighReduceCount => {
if key.reduce_key.vector_count > 64 {
PRIORITY_MAX
} else {
PRIORITY_MIN
}
}
ReduceProps::Balanced => PRIORITY_MAX,
});
set = set.with(tunable);
}
set
});
TUNER.execute(
&CubeTuneId::new(&arg.info.client, &arg.info.device),
&arg.info.client.clone(),
tunables,
TuneInput::new(context, arg),
);
}
/// Creates the autotune key by extracting tensor metadata and fusion block statistics.
pub(crate) fn create_key<R: Runtime>(
input: &TuneInput<R, ReduceOptimizationTuneArg<R>>,
) -> FusedReduceAutotuneKey {
let opt = input.optimization();
let context = match input.context() {
TuneContext::Original(context) => context,
TuneContext::Fork(_) => panic!("Forked context not supported for key generation"),
};
let input_tensor = context.tensors.get(&opt.info.reduce.op.input.id).unwrap();
let out_tensor = context.tensors.get(&opt.info.reduce.op.out.id).unwrap();
let acc = opt.info.reduce.acc.into_elem();
let key = ReduceAutotuneKey::generate(
input_tensor.dtype.into(),
out_tensor.dtype.into(),
acc,
&input_tensor.shape,
opt.info.reduce.axis == input_tensor.shape.rank() - 1,
opt.info.reduce.axis,
);
// Assume the fusion contains at least a read and a write block.
let read_block = &opt.info.trace.blocks[0];
let write_block = &opt.info.trace.blocks[1];
FusedReduceAutotuneKey::new(
key,
read_block.reads.len() + write_block.reads.len(),
read_block.writes.len() + write_block.writes.len(),
read_block.ops.len() + write_block.ops.len(),
)
}
/// Identity generator for tuning inputs.
fn input_gen<R: Runtime>(
_key: &FusedReduceAutotuneKey,
input: &TuneInput<R, ReduceOptimizationTuneArg<R>>,
) -> TuneInput<R, ReduceOptimizationTuneArg<R>> {
input.clone()
}
/// Executes a fused reduction optimization.
fn tune_reduce<R: Runtime, BT: CubeElement>(
input: TuneInput<R, ReduceOptimizationTuneArg<R>>,
strategy: &RoutineStrategy,
) -> Result<TuneOutput<R>, String> {
let optimization = input.optimization();
match input.context() {
TuneContext::Original(context) => {
optimization.execute_fused::<BT>(context, strategy.clone())
}
TuneContext::Fork(mut context_owned) => {
optimization.execute_fused::<BT>(&mut context_owned.as_context(), strategy.clone())
}
}
.map_err(|e| format!("{e:?}"))
}
/// Executes the fallback path for a reduction optimization.
fn tune_fallback<R: Runtime, BT: CubeElement>(
input: TuneInput<R, ReduceOptimizationTuneArg<R>>,
) -> Result<TuneOutput<R>, String> {
let optimization = input.optimization();
match input.context() {
TuneContext::Original(context) => optimization.execute_fallback::<BT>(context),
TuneContext::Fork(mut context_owned) => {
optimization.execute_fallback::<BT>(&mut context_owned.as_context())
}
};
Ok(TuneOutput::UnChecked(std::marker::PhantomData))
}

View File

@@ -0,0 +1,375 @@
use crate::{
engine::codegen::ir::FuseType,
optim::{
CubeOptimization,
reduce::{ReduceFuser, ReduceFuserInfo, ReduceSettings},
reduce_broadcasted::{
ReduceBroadcastedOptimization, ReduceBroadcastedOptimizationInfo,
fuser::{
block::{ReduceBlockFuser, ReduceBlockFusionAnalysis, ReduceBroadcastedStatus},
full::ReduceBroadcastedFullFuser,
full_analyzer::FullFuserAnalyzer,
},
},
},
};
use burn_fusion::{FuserProperties, FuserStatus, OperationFuser};
use burn_ir::OperationIr;
use cubecl::Runtime;
use std::sync::Arc;
/// Fuses element wise operations around a reduce operation.
pub struct ReduceBroadcastedFuser<R: Runtime> {
blocks: Vec<ReduceBlockFuser<R>>,
fuser_default: ReduceFuser<R>,
num_ops: usize,
state: ReduceBroadcastedStatus,
max_bindings: u32,
bool_precision: FuseType,
}
impl<R: Runtime> Clone for ReduceBroadcastedFuser<R> {
fn clone(&self) -> Self {
Self {
blocks: self.blocks.clone(),
fuser_default: self.fuser_default.clone(),
num_ops: self.num_ops,
state: self.state.clone(),
max_bindings: self.max_bindings,
bool_precision: self.bool_precision,
}
}
}
impl<R: Runtime> ReduceBroadcastedFuser<R> {
pub fn new(device: R::Device, bool_precision: FuseType) -> Self {
let fuser = ReduceFuser::new(device, bool_precision, ReduceSettings::Always);
let max_bindings = fuser.fuser.max_bindings;
let block = ReduceBlockFuser::new(fuser.clone());
Self {
blocks: vec![block],
fuser_default: fuser,
num_ops: 0,
state: ReduceBroadcastedStatus::Starting,
max_bindings,
bool_precision,
}
}
}
impl<R: Runtime> OperationFuser<CubeOptimization<R>> for ReduceBroadcastedFuser<R> {
fn fuse(&mut self, operation: &OperationIr) {
if matches!(
&self.state,
ReduceBroadcastedStatus::Closed | ReduceBroadcastedStatus::Abort
) {
return;
}
let block = self.blocks.last_mut().unwrap();
let analyze = block.analyze(operation, &self.state, &self.fuser_default);
let info = match analyze {
ReduceBlockFusionAnalysis::Accept => {
block.fuse(operation);
self.num_ops += 1;
block.fuser.reduce_info()
}
ReduceBlockFusionAnalysis::Refuse => {
self.state = ReduceBroadcastedStatus::Closed;
return;
}
ReduceBlockFusionAnalysis::NewBlockRequired => {
let info = block.fuser.reduce_info();
let mut block = ReduceBlockFuser::new(self.fuser_default.clone());
block.fuse(operation);
self.num_ops += 1;
self.blocks.push(block);
info
}
};
match info {
ReduceFuserInfo::FusedReduce {
shape_input_id,
axis,
} => {
// Only support last axis for now.
if axis != shape_input_id.len() - 1 {
self.state = ReduceBroadcastedStatus::Abort;
} else {
self.state = ReduceBroadcastedStatus::Init {
shape_id: shape_input_id,
axis,
};
}
}
ReduceFuserInfo::FusedElemwise { .. } => {}
}
}
fn finish(&mut self) -> CubeOptimization<R> {
let analyzer = FullFuserAnalyzer::new(&self.blocks);
let mut full =
ReduceBroadcastedFullFuser::new(self.max_bindings, self.bool_precision, analyzer);
let mut num_ops = 0;
let fallbacks = self
.blocks
.iter_mut()
.map(|block| block.finish(&mut num_ops, &mut full))
.collect::<Vec<_>>();
let broadcasted = Arc::new(full.finish());
let info = Arc::new(ReduceBroadcastedOptimizationInfo {
fallbacks,
broadcasted,
});
CubeOptimization::ReduceBroadcasted(ReduceBroadcastedOptimization { info, num_ops })
}
fn reset(&mut self) {
let block = ReduceBlockFuser::new(self.fuser_default.clone());
self.blocks = vec![block];
self.num_ops = 0;
self.state = ReduceBroadcastedStatus::Starting;
}
fn status(&self) -> FuserStatus {
match self.state {
ReduceBroadcastedStatus::Closed | ReduceBroadcastedStatus::Abort => {
return FuserStatus::Closed;
}
_ => {}
};
let fuser = self.blocks.last().unwrap();
fuser.fuser.status()
}
fn properties(&self) -> FuserProperties {
let ready = match self.state {
ReduceBroadcastedStatus::Starting | ReduceBroadcastedStatus::Abort => false,
ReduceBroadcastedStatus::Closed => {
if self.blocks.len() == 1 {
!self.blocks[0].is_elemwise()
} else {
true
}
}
_ => true,
};
let mut props = FuserProperties { score: 0, ready };
for block in self.blocks.iter() {
let p = block.properties();
props.score += p.score;
props.ready = p.ready && props.ready;
}
props
}
fn len(&self) -> usize {
self.num_ops
}
fn clone_dyn(&self) -> Box<dyn OperationFuser<CubeOptimization<R>>> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tests {
use burn_ir::{
BaseOperationIr, BinaryOpIr, CreationOpIr, ReduceDimOpIr, TensorId, TensorIr, TensorStatus,
};
use burn_std::{DType, Shape};
use super::*;
type Run = cubecl::TestRuntime;
#[test]
fn reduce_broadcast_workflow_1() {
let device: <Run as Runtime>::Device = Default::default();
let mut fuser = ReduceBroadcastedFuser::<Run>::new(device, FuseType::I32);
let (tensor1_out, tensor1) = tensor(0, &[1, 2], TensorStatus::ReadWrite);
let (tensor2_out, tensor2) = tensor(1, &[1, 0], TensorStatus::ReadWrite);
fuser.fuse(&OperationIr::BaseFloat(BaseOperationIr::Ones(
CreationOpIr { out: tensor1_out },
)));
fuser.fuse(&OperationIr::NumericFloat(
DType::F32,
burn_ir::NumericOperationIr::SumDim(ReduceDimOpIr {
input: tensor1,
out: tensor2_out,
axis: 1,
}),
));
let status = fuser.status();
assert_eq!(2, fuser.len());
assert_eq!(status, FuserStatus::Open);
assert_eq!(
fuser.properties(),
FuserProperties {
score: 2,
ready: true
}
);
// An existing tensor
let (_tensor3_out, tensor3) = tensor(2, &[1, 0], TensorStatus::ReadWrite);
// A new tensor
let (tensor4_out, tensor4) = tensor(3, &[1, 0], TensorStatus::ReadWrite);
fuser.fuse(&OperationIr::NumericFloat(
DType::F32,
burn_ir::NumericOperationIr::Add(BinaryOpIr {
lhs: tensor2,
rhs: tensor3,
out: tensor4_out,
}),
));
let status = fuser.status();
assert_eq!(3, fuser.len());
assert_eq!(status, FuserStatus::Open);
assert_eq!(
fuser.properties(),
FuserProperties {
score: 3,
ready: true
}
);
// An existing tensor
let (_tensor5_out, tensor5) = tensor(4, &[1, 2], TensorStatus::ReadWrite);
// A new tensor
let (tensor6_out, tensor6) = tensor(5, &[1, 2], TensorStatus::ReadWrite);
fuser.fuse(&OperationIr::NumericFloat(
DType::F32,
burn_ir::NumericOperationIr::Add(BinaryOpIr {
lhs: tensor4,
rhs: tensor5,
out: tensor6_out,
}),
));
let status = fuser.status();
assert_eq!(4, fuser.len());
assert_eq!(status, FuserStatus::Open);
assert_eq!(
fuser.properties(),
FuserProperties {
score: 4,
ready: true
}
);
let (tensor7_out, _tensor7) = tensor(6, &[1, 0], TensorStatus::ReadWrite);
fuser.fuse(&OperationIr::NumericFloat(
DType::F32,
burn_ir::NumericOperationIr::SumDim(ReduceDimOpIr {
input: tensor6,
out: tensor7_out,
axis: 1,
}),
));
assert_eq!(5, fuser.len());
assert_eq!(status, FuserStatus::Open);
assert_eq!(
fuser.properties(),
FuserProperties {
score: 5,
ready: true
}
);
let _optimization = fuser.finish();
}
#[test]
fn reduce_broadcast_workflow_2() {
let device: <Run as Runtime>::Device = Default::default();
let mut fuser = ReduceBroadcastedFuser::<Run>::new(device, FuseType::I32);
let (tensor1_out, tensor1) = tensor(0, &[1, 2], TensorStatus::ReadWrite);
// An existing tensor
let (_tensor2_out, mut tensor2) = tensor(2, &[1, 2], TensorStatus::ReadOnly);
let (tensor3_out, tensor3) = tensor(3, &[1, 2], TensorStatus::ReadWrite);
// First reduce output
let (tensor4_out, tensor4) = tensor(1, &[1, 0], TensorStatus::ReadWrite);
fuser.fuse(&OperationIr::BaseFloat(BaseOperationIr::Ones(
CreationOpIr { out: tensor1_out },
)));
fuser.fuse(&OperationIr::NumericFloat(
DType::F32,
burn_ir::NumericOperationIr::Add(BinaryOpIr {
lhs: tensor1,
rhs: tensor2.clone(),
out: tensor3_out,
}),
));
fuser.fuse(&OperationIr::NumericFloat(
DType::F32,
burn_ir::NumericOperationIr::SumDim(ReduceDimOpIr {
input: tensor3,
out: tensor4_out,
axis: 1,
}),
));
let status = fuser.status();
assert_eq!(3, fuser.len());
assert_eq!(status, FuserStatus::Open);
assert_eq!(
fuser.properties(),
FuserProperties {
score: 3,
ready: true
}
);
// A new tensor
let (tensor5_out, _tensor5) = tensor(5, &[1, 2], TensorStatus::ReadWrite);
// Last time we use tensor2.
tensor2.status = TensorStatus::ReadWrite;
fuser.fuse(&OperationIr::NumericFloat(
DType::F32,
burn_ir::NumericOperationIr::Add(BinaryOpIr {
lhs: tensor4,
rhs: tensor2,
out: tensor5_out,
}),
));
let status = fuser.status();
assert_eq!(4, fuser.len());
assert_eq!(status, FuserStatus::Open);
assert_eq!(
fuser.properties(),
FuserProperties {
score: 4,
ready: true
}
);
let _optimization = fuser.finish();
}
fn tensor(id: u64, shape: &[usize], status: TensorStatus) -> (TensorIr, TensorIr) {
let tensor = TensorIr {
id: TensorId::new(id),
shape: Shape::from(shape),
status: TensorStatus::NotInit,
dtype: DType::F32,
};
let mut tensor_init = tensor.clone();
tensor_init.status = status;
(tensor, tensor_init)
}
}

View File

@@ -0,0 +1,214 @@
use crate::optim::{
CubeOptimization,
elemwise::ElemwiseOptimization,
reduce::{FusedReduce, ReduceFuser, ReduceFuserInfo},
reduce_broadcasted::{ReduceBlockOptimInfo, fuser::full::ReduceBroadcastedFullFuser},
};
use burn_fusion::{FuserProperties, OperationFuser};
use burn_ir::OperationIr;
use burn_std::Shape;
use cubecl::Runtime;
use std::sync::Arc;
/// Responsible for fusing a single reduce block or elementwise block.
///
/// When the block kind is reduce, it supports fuse-on-read and fuse-on-write fusion.
/// Broadcasting isn't supported; another block should handle it instead.
pub struct ReduceBlockFuser<R: Runtime> {
/// We use [ReduceFuser] for both elementwise and reduce blocks, keeping only the
/// fuse-on-read trace if the block is tagged as elementwise.
///
/// # Notes
///
/// A single elementwise block can only exist at the end of a full [ReduceBlockFuser],
/// otherwise the optimization will be included in the reduce fusion block.
pub fuser: ReduceFuser<R>,
pub(crate) ops: Vec<OperationIr>,
pub(crate) kind: ReduceBlockKind,
}
/// The current state of the fusion process.
#[derive(Debug, Clone)]
pub enum ReduceBroadcastedStatus {
/// Fusion is starting; no reduction has been fused yet.
Starting,
/// Fusion is initialized with at least one reduce operation.
///
/// # Notes
///
/// Subsequent reduce operations must be compatible with the previous reduction to fuse.
Init { shape_id: Shape, axis: usize },
/// No more operations can be fused.
Closed,
/// Invalid axis.
Abort,
}
/// The [ReduceBlockFuser] capacity to accept an [OperationIr].
#[derive(Clone, Copy, Debug)]
pub enum ReduceBlockFusionAnalysis {
/// The operation can be fused; call [ReduceBlockFuser::fuse()].
Accept,
/// The operation cannot be fused; the optimization should close.
Refuse,
/// The operation can be fused, but requires a new block.
NewBlockRequired,
}
impl<R: Runtime> ReduceBlockFuser<R> {
/// Creates a new block.
pub fn new(fuser: ReduceFuser<R>) -> Self {
Self {
fuser: fuser.clone(),
ops: Vec::new(),
kind: ReduceBlockKind::Elemwise,
}
}
/// Returns true if this is an elementwise fuser.
pub fn is_elemwise(&self) -> bool {
matches!(self.kind, ReduceBlockKind::Elemwise)
}
/// Analyzes if fusion is possible within this block.
pub fn analyze(
&self,
op: &OperationIr,
status: &ReduceBroadcastedStatus,
default_node: &ReduceFuser<R>,
) -> ReduceBlockFusionAnalysis {
let mut fuser_try = self.fuser.clone();
let before = fuser_try.len();
fuser_try.fuse(op);
let after = fuser_try.len();
if after > before {
return ReduceBlockFusionAnalysis::Accept;
}
// Can't create a new block if the previous one was not a reduction.
if self.fuser.reduce.is_none() {
return ReduceBlockFusionAnalysis::Refuse;
}
let mut fuser_try = default_node.clone();
let before = fuser_try.len();
fuser_try.fuse(op);
let after = fuser_try.len();
if after > before {
let info = fuser_try.reduce_info();
return match (info, status) {
(
ReduceFuserInfo::FusedReduce {
shape_input_id,
axis,
},
ReduceBroadcastedStatus::Init {
shape_id,
axis: axis_init,
},
) => {
if shape_id == &shape_input_id && axis_init == &axis {
ReduceBlockFusionAnalysis::NewBlockRequired
} else {
ReduceBlockFusionAnalysis::Refuse
}
}
(
ReduceFuserInfo::FusedElemwise { shape_id },
ReduceBroadcastedStatus::Init {
shape_id: shape_init,
..
},
) => {
if &shape_id == shape_init {
ReduceBlockFusionAnalysis::NewBlockRequired
} else {
ReduceBlockFusionAnalysis::Refuse
}
}
_ => ReduceBlockFusionAnalysis::Refuse,
};
}
ReduceBlockFusionAnalysis::Refuse
}
/// Fuses an operation within this block.
///
/// # Warning
///
/// Ensure [Self::analyze()] is called before this function to confirm the operation is accepted.
pub fn fuse(&mut self, op: &OperationIr) {
self.fuser.fuse(op);
self.ops.push(op.clone());
// Update the kind if a reduction is introduced to an elementwise block.
if let (Some(reduce), ReduceBlockKind::Elemwise) = (&self.fuser.reduce, &self.kind) {
self.kind = ReduceBlockKind::Reduce {
ops_index: self.ops.len() - 1,
reduce: Box::new(reduce.clone()),
};
}
}
/// Computes the fuser properties.
pub fn properties(&self) -> FuserProperties {
let mut properties = self.fuser.properties();
if let ReduceBlockKind::Elemwise = &self.kind {
// Elementwise traces are always ready to run.
properties.ready = true;
}
properties
}
pub fn finish(
&mut self,
num_ops: &mut usize,
full: &mut ReduceBroadcastedFullFuser,
) -> ReduceBlockOptimInfo<R> {
full.register(self);
match &self.kind {
ReduceBlockKind::Elemwise => {
let len = self.fuser.fuser_read_fallback.len();
let device = self.fuser.device.clone();
*num_ops += len;
let trace = self.fuser.fuser_read_fallback.finish();
let client = R::client(&device);
let elementwise = ElemwiseOptimization::new(trace, client, device, len);
ReduceBlockOptimInfo::Elemwise(Arc::new(elementwise))
}
ReduceBlockKind::Reduce { .. } => {
*num_ops += self.fuser.len();
let optim = self.fuser.finish();
let info = match optim {
CubeOptimization::Reduce(optim) => optim.info,
_ => unreachable!("Expected Reduce optimization"),
};
ReduceBlockOptimInfo::Reduce(info)
}
}
}
}
#[derive(Clone, Debug)]
pub enum ReduceBlockKind {
Elemwise,
Reduce {
ops_index: usize,
reduce: Box<FusedReduce>,
},
}
impl<R: Runtime> Clone for ReduceBlockFuser<R> {
fn clone(&self) -> Self {
Self {
fuser: self.fuser.clone(),
ops: self.ops.clone(),
kind: self.kind.clone(),
}
}
}

View File

@@ -0,0 +1,163 @@
use crate::{
engine::{
codegen::ir::FuseType,
fuser::TraceOperationFuser,
settings::{FuseSettings, RefLayoutSetting, VectorizationSetting},
},
optim::{
reduce::{FusedReduce, ReduceInstruction},
reduce_broadcasted::{
ReduceBroadcastedInfo,
fuser::{
block::{ReduceBlockFuser, ReduceBlockKind},
full_analyzer::FullFuserAnalyzer,
},
launch::ReduceBroadcastedFuseBlock,
},
},
};
use burn_fusion::OperationFuser;
use cubecl::Runtime;
use cubek::reduce::components::instructions::ReduceOperationConfig;
/// Responsible for fusing a single trace for all operations involved in this optimization.
pub struct ReduceBroadcastedFullFuser {
pub(crate) fuser: TraceOperationFuser,
analyzer: FullFuserAnalyzer,
blocks: Vec<ReduceBlockKind>,
settings_read: FuseSettings,
settings_write: FuseSettings,
}
impl ReduceBroadcastedFullFuser {
/// Creates a new fuser with the given settings.
pub fn new(max_bindings: u32, bool_precision: FuseType, analyzer: FullFuserAnalyzer) -> Self {
let settings_read = FuseSettings {
output_shape_updates: true,
broadcast: true,
inplace: false,
ref_layout: RefLayoutSetting::OnlyContiguous,
vectorization: VectorizationSetting::Activated,
};
let settings_write = FuseSettings {
output_shape_updates: false,
inplace: false,
broadcast: false,
ref_layout: RefLayoutSetting::OnlyContiguous,
// Deactivated for now, but would be cool to support vectorization of the output.
vectorization: VectorizationSetting::Deactivated,
};
let fuser = TraceOperationFuser::new(max_bindings, bool_precision, settings_read);
Self {
fuser,
blocks: Vec::new(),
settings_write,
settings_read,
analyzer,
}
}
/// Finishes fusing all blocks.
pub fn finish(mut self) -> ReduceBroadcastedInfo {
let mut reduce_axis = 0;
let mut blocks = Vec::new();
for block in self.blocks.iter() {
match block {
ReduceBlockKind::Elemwise => {}
ReduceBlockKind::Reduce { reduce, .. } => {
let config = match reduce.inst {
ReduceInstruction::ArgMax => ReduceOperationConfig::ArgMax,
ReduceInstruction::ArgMin => ReduceOperationConfig::ArgMin,
ReduceInstruction::Prod => ReduceOperationConfig::Prod,
ReduceInstruction::Mean => ReduceOperationConfig::Mean,
ReduceInstruction::Sum => ReduceOperationConfig::Sum,
ReduceInstruction::Max => ReduceOperationConfig::Max,
ReduceInstruction::Min => ReduceOperationConfig::Min,
ReduceInstruction::MaxAbs => ReduceOperationConfig::MaxAbs,
};
let block = ReduceBroadcastedFuseBlock {
op: config,
input: reduce.input.clone(),
output: reduce.output.clone(),
};
reduce_axis = reduce.axis;
blocks.push(block);
}
}
}
let trace = self.fuser.finish();
ReduceBroadcastedInfo {
blocks,
trace,
reduce_axis,
}
}
/// Registers a [ReduceBlockFuser] to build the trace.
pub fn register<R: Runtime>(&mut self, block: &ReduceBlockFuser<R>) {
// Helper to close previous blocks if necessary
if !self.fuser.is_empty() {
let mut settings = self.settings_read;
settings.vectorization = VectorizationSetting::EqualThanPreviousBlock { block_pos: 0 };
settings.ref_layout = RefLayoutSetting::SameAsBlock { block_pos: 0 };
self.fuser.next_block([], settings, false);
let analysis = self.analyzer.retrieve_next();
for (tensor, block_pos) in analysis.inputs {
self.fuser.block_local_input(&tensor, block_pos, false);
}
}
match &block.kind {
ReduceBlockKind::Elemwise => {
for op in &block.ops {
self.fuser.fuse(op);
}
self.blocks.push(ReduceBlockKind::Elemwise);
}
ReduceBlockKind::Reduce { ops_index, reduce } => {
for op in &block.ops[0..*ops_index] {
self.fuser.fuse(op);
}
let [input] = self
.fuser
.next_block([&reduce.op.input], self.settings_write, false);
let output = self.fuser.output_unhandled(&reduce.op.out);
let analysis = self.analyzer.retrieve_next();
// Can be broadcasted so the generated buffer can be global.
for (tensor, block_pos) in analysis.inputs {
self.fuser.block_local_input(&tensor, block_pos, false);
}
let fused_reduce = FusedReduce {
input,
output,
acc: reduce.acc,
axis: reduce.axis,
op: reduce.op.clone(),
use_planes: reduce.use_planes,
shared: reduce.shared,
inst: reduce.inst,
};
self.blocks.push(ReduceBlockKind::Reduce {
ops_index: *ops_index,
reduce: Box::new(fused_reduce),
});
for op in &block.ops[*ops_index + 1..block.ops.len()] {
self.fuser.fuse(op);
}
}
}
}
}

View File

@@ -0,0 +1,159 @@
use super::block::ReduceBlockKind;
use crate::optim::reduce_broadcasted::fuser::block::ReduceBlockFuser;
use burn_ir::{TensorId, TensorIr};
use cubecl::Runtime;
use std::collections::BTreeMap;
#[derive(Debug)]
pub struct FullFuserAnalyzer {
// We need to know the block id of which we can reuse the read local input.
analyses: Vec<Vec<(TensorIr, usize)>>,
}
impl FullFuserAnalyzer {
pub fn new<R: Runtime>(blocks: &[ReduceBlockFuser<R>]) -> Self {
let mut state = AnalysisState::default();
for block in blocks.iter() {
for (pos, op) in block.ops.iter().enumerate() {
let potential_from_previous_blocks = op.inputs();
let potential_to_next_blocks = op.outputs();
match &block.kind {
ReduceBlockKind::Elemwise => {
state.register(
potential_from_previous_blocks,
potential_to_next_blocks,
BlockKind::Full,
);
}
ReduceBlockKind::Reduce { ops_index, .. } => {
if pos < *ops_index {
state.register(
potential_from_previous_blocks,
potential_to_next_blocks,
BlockKind::Full,
);
} else if pos > *ops_index {
state.register(
potential_from_previous_blocks,
potential_to_next_blocks,
BlockKind::Single,
);
} else {
state.next_block();
}
}
}
}
state.next_block();
}
// First one is never called.
state.analyses.remove(0);
Self {
analyses: state.analyses,
}
}
pub fn retrieve_next(&mut self) -> FullFuserAnalysis {
let inputs = self.analyses.remove(0);
FullFuserAnalysis { inputs }
}
}
#[derive(Debug)]
pub struct FullFuserAnalysis {
/// The tensor received from a previous block.
pub inputs: Vec<(TensorIr, usize)>,
}
#[derive(Default)]
struct AnalysisState {
/// That pool contains tensors that are available in the fuse-on-write part of a reduce, not
/// broadcasted.
available_from_previous_single: BTreeMap<TensorId, usize>,
/// That pool contains tensors that are available in the fuse-on-read of a reduce and the
/// element-wise broadcasted part
available_from_previous_full: BTreeMap<TensorId, usize>,
block_data: Vec<(TensorIr, usize)>,
analyses: Vec<Vec<(TensorIr, usize)>>,
current_full: Vec<TensorIr>,
current_single: Vec<TensorIr>,
}
enum BlockKind {
Full,
Single,
}
impl AnalysisState {
fn next_block(&mut self) {
let block_pos = self.analyses.len();
let data = core::mem::take(&mut self.block_data);
self.analyses.push(data);
// Makes the current tensor reads available for the next block.
for p in self.current_single.drain(..) {
// We need to keep the earliest block position.
self.available_from_previous_single
.entry(p.id)
.or_insert(block_pos);
}
for p in self.current_full.drain(..) {
// We need to keep the earliest block position.
self.available_from_previous_full
.entry(p.id)
.or_insert(block_pos);
}
}
fn register<'a>(
&mut self,
potential_from_previous_blocks: impl Iterator<Item = &'a TensorIr>,
potential_to_next_blocks: impl Iterator<Item = &'a TensorIr>,
kind: BlockKind,
) {
match kind {
BlockKind::Full => {
for potential in potential_from_previous_blocks {
// We can't since it's not in the same scope.
//
// TODO: Find a way to merge multiple reduce loops.
//
// if let Some(block_pos) = self.available_from_previous_full.get(&potential.id) {
// self.block_data.push((potential.clone(), *block_pos));
// }
// We can since it's a broadcast.
if let Some(block_pos) = self.available_from_previous_single.get(&potential.id)
{
self.block_data.push((potential.clone(), *block_pos));
}
// Can reuse the read.
self.current_full.push(potential.clone());
}
for p in potential_to_next_blocks {
self.current_full.push(p.clone());
}
}
BlockKind::Single => {
for potential in potential_from_previous_blocks {
if let Some(block_pos) = self.available_from_previous_single.get(&potential.id)
{
self.block_data.push((potential.clone(), *block_pos));
}
// Can reuse the read.
self.current_single.push(potential.clone());
}
for p in potential_to_next_blocks {
self.current_single.push(p.clone());
}
}
}
}
}

View File

@@ -0,0 +1,6 @@
mod base;
mod block;
mod full;
mod full_analyzer;
pub use base::*;

View File

@@ -0,0 +1,139 @@
use crate::{
engine::{
codegen::ir::{FuseArg, FuseBlockConfig, GlobalArgsLaunch, RefLayout},
launch::runner::{TraceRunner, Vectorization},
},
optim::reduce_broadcasted::unit::{
ElemwiseFuseBlockLaunch, ReduceFuseBlockLaunch, reduce_kernel_broadcasted,
},
};
use cubecl::{
Runtime,
ir::{ElemType, FloatKind, StorageType},
prelude::*,
server::LaunchError,
};
use cubek::reduce::{
LineMode, ReduceDtypes,
components::instructions::ReduceOperationConfig,
launch::RoutineStrategy,
routines::{
BlueprintStrategy, GlobalReduceBlueprint, ReduceLineSettings, ReduceProblem, Routine,
unit::{UnitRoutine, UnitStrategy},
},
};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ReduceBroadcastedFuseBlock {
pub(crate) op: ReduceOperationConfig,
pub(crate) input: FuseArg,
pub(crate) output: FuseArg,
}
#[derive(new)]
pub struct FusedReduceBroadcastedLaunch<'a> {
blocks: &'a Vec<ReduceBroadcastedFuseBlock>,
reduce_axis: usize,
// TODO: Support multiple strategies.
_strategy: RoutineStrategy,
}
impl<R: Runtime> Vectorization<R> for FusedReduceBroadcastedLaunch<'_> {}
impl<R: Runtime> TraceRunner<R> for FusedReduceBroadcastedLaunch<'_> {
type Error = LaunchError;
fn run<'a>(
&'a self,
client: &'a ComputeClient<R>,
inputs: GlobalArgsLaunch<'a, R>,
outputs: GlobalArgsLaunch<'a, R>,
configs: &'a [FuseBlockConfig],
) -> Result<(), Self::Error> {
let routine = UnitRoutine;
let first_config = &configs[0];
let shape = match &first_config.ref_layout {
RefLayout::Concrete(FuseArg::Output(..)) => {
outputs.shape_ref(&first_config.ref_layout, first_config.rank)
}
_ => inputs.shape_ref(&first_config.ref_layout, first_config.rank),
};
let vector_size = shape[self.reduce_axis];
let vector_count = shape.iter().product::<usize>() / vector_size;
let address_type = inputs
.required_address_type()
.max(outputs.required_address_type());
let (blueprint, settings) = routine
.prepare::<R>(
client,
ReduceProblem {
vector_size,
vector_count,
axis: self.reduce_axis,
dtypes: ReduceDtypes {
input: StorageType::Scalar(ElemType::Float(FloatKind::F32)),
output: StorageType::Scalar(ElemType::Float(FloatKind::F32)),
accumulation: StorageType::Scalar(ElemType::Float(FloatKind::F32)),
},
address_type,
},
ReduceLineSettings {
line_mode: LineMode::Parallel,
line_size_input: first_config.width,
line_size_output: 1,
},
BlueprintStrategy::Inferred(UnitStrategy),
)
.unwrap();
assert_eq!(blueprint.line_mode, LineMode::Parallel);
let mut blocks = SequenceArg::new();
let mut index = 0;
for block in self.blocks {
let arg = ReduceFuseBlockLaunch::new(
block.op,
configs[index].clone(),
configs[index + 1].clone(),
block.input.clone(),
block.output.clone(),
match blueprint.global {
GlobalReduceBlueprint::Unit(bpt) => bpt,
_ => panic!(),
},
);
index += 2;
blocks.push(arg);
}
let block_end = match configs.len() > index {
true => OptionArgs::Some(ElemwiseFuseBlockLaunch::new(
configs.last().cloned().unwrap(),
)),
false => OptionArgs::None,
};
// TODO: Ensure parallel is selected.
unsafe {
reduce_kernel_broadcasted::launch_unchecked::<R>(
client,
settings.cube_count,
settings.cube_dim,
settings.address_type,
inputs,
outputs,
ScalarArg::new(self.reduce_axis),
blocks,
block_end,
)?;
}
Ok(())
}
}

View File

@@ -0,0 +1,9 @@
mod fuser;
mod optimization;
pub(crate) mod launch;
pub(crate) mod tune;
pub(crate) mod unit;
pub use fuser::*;
pub use optimization::*;

View File

@@ -0,0 +1,222 @@
#[cfg(feature = "autotune")]
use crate::optim::reduce::tune::fused_reduce_autotune;
use crate::{
CubeFusionHandle, FallbackOperation,
engine::{
launch::FuseTraceLauncher,
trace::{FuseTrace, TraceError, TuneOutput},
},
optim::{
elemwise::{ElemwiseOptimization, ElemwiseOptimizationState},
reduce::{ReduceOptimizationInfo, ReduceOptimizationState, ReduceOptimizationTuneArg},
reduce_broadcasted::{
launch::{FusedReduceBroadcastedLaunch, ReduceBroadcastedFuseBlock},
tune::fused_broadcasted_reduce_autotune,
},
},
};
use burn_fusion::stream::Context;
use cubecl::{Runtime, prelude::*};
use cubek::reduce::launch::RoutineStrategy;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
pub struct ReduceBroadcastedOptimization<R: Runtime> {
pub(crate) info: Arc<ReduceBroadcastedOptimizationInfo<R>>,
pub(crate) num_ops: usize,
}
pub(crate) struct ReduceBroadcastedOptimizationInfo<R: Runtime> {
pub(crate) fallbacks: Vec<ReduceBlockOptimInfo<R>>,
pub(crate) broadcasted: Arc<ReduceBroadcastedInfo>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub(crate) struct ReduceBroadcastedInfo {
pub(crate) blocks: Vec<ReduceBroadcastedFuseBlock>,
pub(crate) trace: FuseTrace,
pub(crate) reduce_axis: usize,
}
pub(crate) enum ReduceBlockOptimInfo<R: Runtime> {
Reduce(Arc<ReduceOptimizationInfo<R>>),
Elemwise(Arc<ElemwiseOptimization<R>>),
}
impl<R: Runtime> ReduceBlockOptimInfo<R> {
pub fn from_state(device: &R::Device, state: ReduceBlockState) -> Self {
match state {
ReduceBlockState::Reduce(state) => {
Self::Reduce(Arc::new(ReduceOptimizationInfo::from_state(device, state)))
}
ReduceBlockState::Elemwise(state) => {
Self::Elemwise(Arc::new(ElemwiseOptimization::from_state(device, state)))
}
}
}
pub fn to_state(&self) -> ReduceBlockState {
match self {
Self::Reduce(info) => ReduceBlockState::Reduce(info.to_state()),
Self::Elemwise(info) => ReduceBlockState::Elemwise(info.to_state()),
}
}
}
pub(crate) struct ReduceBroadcastedOptimizationTuneArg<R: Runtime> {
pub(crate) fallbacks: Vec<ReduceBlockOptimArg<R>>,
pub(crate) broadcasted: Arc<ReduceBroadcastedInfo>,
pub(crate) client: ComputeClient<R>,
pub(crate) device: R::Device,
}
pub(crate) enum ReduceBlockOptimArg<R: Runtime> {
Reduce(ReduceOptimizationTuneArg<R>),
Elemwise(Arc<ElemwiseOptimization<R>>),
}
impl<R: Runtime> ReduceBlockOptimArg<R> {
pub fn execute_fallback<BT: CubeElement>(
&self,
context: &mut Context<'_, CubeFusionHandle<R>>,
) -> Option<TuneOutput<R>> {
match self {
ReduceBlockOptimArg::Reduce(reduce) => {
#[cfg(feature = "autotune")]
{
fused_reduce_autotune::<R, BT>(reduce.clone(), context);
None
}
#[cfg(not(feature = "autotune"))]
Some(reduce.execute_fallback::<BT>(context))
}
ReduceBlockOptimArg::Elemwise(elem) => {
elem.execute::<BT>(context);
None
}
}
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ReduceBroadcastedOptimizationState {
fallbacks: Vec<ReduceBlockState>,
broadcasted: ReduceBroadcastedInfo,
num_ops: usize,
}
#[derive(Serialize, Deserialize, Debug)]
#[allow(clippy::large_enum_variant)] // Only for serialization.
pub enum ReduceBlockState {
Reduce(ReduceOptimizationState),
Elemwise(ElemwiseOptimizationState),
}
impl<R: Runtime> ReduceBroadcastedOptimizationTuneArg<R> {
pub fn execute_fused<BT: CubeElement>(
&self,
context: &mut Context<'_, CubeFusionHandle<R>>,
strategy: RoutineStrategy,
) -> Result<TuneOutput<R>, TraceError<String>> {
let launch = FusedReduceBroadcastedLaunch::new(
&self.broadcasted.blocks,
self.broadcasted.reduce_axis,
strategy,
);
let launcher = FuseTraceLauncher::new(&self.broadcasted.trace, &launch);
launcher
.launch::<BT>(&self.client, &self.device, context)
.map_err(|err| TraceError::RunnerError(format!("{:?}", err)))
}
pub fn execute_fallback<BT: CubeElement>(
&self,
context: &mut Context<'_, CubeFusionHandle<R>>,
) {
for fallback in self.fallbacks.iter() {
fallback.execute_fallback::<BT>(context);
}
}
}
#[allow(clippy::too_many_arguments)]
impl<R: Runtime> ReduceBroadcastedOptimization<R> {
/// Execute the optimization.
pub fn execute<BT: CubeElement>(
&mut self,
context: &mut Context<'_, CubeFusionHandle<R>>,
fallback: impl Fn(usize) -> Box<dyn FallbackOperation<R>>,
) {
let mut current_index = 0;
let mut client = None;
let mut device = None;
let fallbacks = self
.info
.fallbacks
.iter()
.map(|info| {
match info {
ReduceBlockOptimInfo::Reduce(info) => {
// The index of the fallback reduce is the number of ops fused as read.
let fallback = fallback(current_index + info.len_read);
client = Some(info.client.clone());
device = Some(info.device.clone());
let arg = ReduceOptimizationTuneArg {
info: info.clone(),
fallback: Arc::new(fallback),
};
current_index += info.len;
ReduceBlockOptimArg::Reduce(arg)
}
ReduceBlockOptimInfo::Elemwise(op) => ReduceBlockOptimArg::Elemwise(op.clone()),
}
})
.collect();
let arg = ReduceBroadcastedOptimizationTuneArg {
fallbacks,
client: client.unwrap(),
device: device.unwrap(),
broadcasted: self.info.broadcasted.clone(),
};
#[cfg(feature = "autotune")]
fused_broadcasted_reduce_autotune::<R, BT>(arg, context);
#[cfg(not(feature = "autotune"))]
arg.execute_fallback::<BT>(context);
}
pub fn to_state(&self) -> ReduceBroadcastedOptimizationState {
ReduceBroadcastedOptimizationState {
fallbacks: self
.info
.fallbacks
.iter()
.map(|info| info.to_state())
.collect(),
broadcasted: self.info.broadcasted.as_ref().clone(),
num_ops: self.num_ops,
}
}
pub fn from_state(device: &R::Device, state: ReduceBroadcastedOptimizationState) -> Self {
Self {
info: Arc::new(ReduceBroadcastedOptimizationInfo {
fallbacks: state
.fallbacks
.into_iter()
.map(|state| ReduceBlockOptimInfo::from_state(device, state))
.collect(),
broadcasted: Arc::new(state.broadcasted),
}),
num_ops: state.num_ops,
}
}
/// Returns the number of output buffers added by fusion.
pub fn num_ops_fused(&self) -> usize {
self.num_ops
}
}

View File

@@ -0,0 +1,180 @@
use super::optimization::ReduceBroadcastedOptimizationTuneArg;
use crate::{
CubeFusionHandle,
engine::trace::TuneOutput,
optim::{reduce::ReduceOptimizationInfo, reduce_broadcasted::ReduceBlockOptimArg},
tune::{TuneContext, TuneInput},
};
use burn_fusion::stream::Context;
use cubecl::{
AutotuneKey, CubeElement, CubeTuneId, Runtime,
tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner},
};
use cubek::reduce::{
launch::{RoutineStrategy, tune_key::ReduceAutotuneKey},
routines::{BlueprintStrategy, unit::UnitStrategy},
};
use serde::{Deserialize, Serialize};
/// Autotune key for fused broadcasted reduction operations.
///
/// Captures the characteristics of the fusion (reads, writes, ops) to ensure
/// the best kernel is selected for specific fused graph shapes.
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
pub struct FusedBroadcastedReduceAutotuneKey {
reduce_key: ReduceAutotuneKey,
#[autotune(anchor)]
fuse_num_reads: usize,
#[autotune(anchor)]
fuse_num_writes: usize,
#[autotune(anchor)]
fuse_num_ops: usize,
fuse_num_blocks: usize,
}
/// Executes the autotuning process for fused reduction operations.
///
/// This function initializes a local tuner and attempts multiple strategies
/// (fallback vs. unit strategy) to find the most efficient execution path.
pub fn fused_broadcasted_reduce_autotune<R: Runtime, BT: CubeElement>(
arg: ReduceBroadcastedOptimizationTuneArg<R>,
context: &mut Context<CubeFusionHandle<R>>,
) {
static TUNER: LocalTuner<FusedBroadcastedReduceAutotuneKey, CubeTuneId> = local_tuner!();
let tunables = TUNER.init(|| {
const PRIORITY_MAX: i8 = 2;
let mut set = TunableSet::new(create_key::<R>, input_gen::<R>);
let group = TuneGroup::<FusedBroadcastedReduceAutotuneKey>::new(
"fused_reduce_broadcasted",
|_key| PRIORITY_MAX,
);
// Standard fallback implementation - guaranteed to work.
set = set.with(Tunable::new(
"fused_reduce_broadcasted_fallback",
tune_fallback::<R, BT>,
));
// Specialized unit strategy for fused reductions.
set = set.with(
Tunable::new("fused_reduce_broadcasted_unit", move |input| {
tune_reduce::<R, BT>(
input,
&RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
)
})
.group(&group, |_| PRIORITY_MAX),
);
set
});
TUNER.execute(
&CubeTuneId::new(&arg.client, &arg.device),
&arg.client.clone(),
tunables,
TuneInput::new(context, arg),
);
}
/// Generates the autotune key based on the current optimization context and trace blocks.
pub(crate) fn create_key<R: Runtime>(
input: &TuneInput<R, ReduceBroadcastedOptimizationTuneArg<R>>,
) -> FusedBroadcastedReduceAutotuneKey {
let opt = input.optimization();
let context = match input.context() {
TuneContext::Original(context) => context,
TuneContext::Fork(_) => unreachable!("Forked context not supported for key generation"),
};
// The fusion must start with a reduction block to be valid here.
let info = match &opt.fallbacks[0] {
ReduceBlockOptimArg::Reduce(reduce) => &reduce.info,
ReduceBlockOptimArg::Elemwise(_) => {
unreachable!("Fusion must start with a reduction block")
}
};
let key = generate_reduce_autotune_key(info, context);
// Sum up complexity metrics across all blocks in the fused trace.
let (mut num_reads, mut num_writes, mut num_ops) = (0, 0, 0);
for block in opt.broadcasted.trace.blocks.iter() {
num_reads += block.reads.len();
num_writes += block.writes.len();
num_ops += block.ops.len();
}
FusedBroadcastedReduceAutotuneKey::new(
key,
num_reads,
num_writes,
num_ops,
info.trace.blocks.len(),
)
}
/// Helper to generate the base reduction key (shapes, types, axes).
fn generate_reduce_autotune_key<R: Runtime>(
info: &ReduceOptimizationInfo<R>,
context: &Context<CubeFusionHandle<R>>,
) -> ReduceAutotuneKey {
let input = context.tensors.get(&info.reduce.op.input.id).unwrap();
let out = context.tensors.get(&info.reduce.op.out.id).unwrap();
let acc = info.reduce.acc.into_elem();
ReduceAutotuneKey::generate(
input.dtype.into(),
out.dtype.into(),
acc,
&input.shape,
info.reduce.axis == input.shape.rank() - 1, // Is it the last dimension?
info.reduce.axis,
)
}
/// Simple input generator that clones the input for the tuner.
fn input_gen<R: Runtime>(
_key: &FusedBroadcastedReduceAutotuneKey,
input: &TuneInput<R, ReduceBroadcastedOptimizationTuneArg<R>>,
) -> TuneInput<R, ReduceBroadcastedOptimizationTuneArg<R>> {
input.clone()
}
/// Executes a fused reduction using a specific routine strategy.
fn tune_reduce<R: Runtime, BT: CubeElement>(
input: TuneInput<R, ReduceBroadcastedOptimizationTuneArg<R>>,
strategy: &RoutineStrategy,
) -> Result<TuneOutput<R>, String> {
let optimization = input.optimization();
match input.context() {
TuneContext::Original(context) => {
optimization.execute_fused::<BT>(context, strategy.clone())
}
TuneContext::Fork(mut context_owned) => {
optimization.execute_fused::<BT>(&mut context_owned.as_context(), strategy.clone())
}
}
.map_err(|e| format!("{e:?}"))
}
/// Executes the fallback implementation for the reduction.
fn tune_fallback<R: Runtime, BT: CubeElement>(
input: TuneInput<R, ReduceBroadcastedOptimizationTuneArg<R>>,
) -> Result<TuneOutput<R>, String> {
let optimization = input.optimization();
match input.context() {
TuneContext::Original(context) => optimization.execute_fallback::<BT>(context),
TuneContext::Fork(mut context_owned) => {
optimization.execute_fallback::<BT>(&mut context_owned.as_context())
}
};
// Fallback is often used as a baseline, returning unchecked output.
Ok(TuneOutput::UnChecked(std::marker::PhantomData))
}

View File

@@ -0,0 +1,202 @@
use crate::{
engine::codegen::{
ir::{FuseArg, FuseBlockConfig, FuseType, GlobalArgs, multi_block_variables_init},
kernel::{fuse_on_write, init_locals},
},
optim::reduce::args::{FusedReduceArgs, FusedReduceInput, FusedReduceOutput},
};
use cubecl::{Runtime, prelude::*, std::tensor::r#virtual::VirtualTensor};
use cubek::reduce::{
LineMode, ReduceInstruction, ReducePrecision,
components::{
global::unit::GlobalFullUnitReduce,
instructions::{ReduceOperation, ReduceOperationConfig},
},
init_tensors,
routines::UnitReduceBlueprint,
};
/// A configuration block for a reduction operation within a fused kernel.
///
/// This struct holds all the compile-time information needed to perform a
/// reduction, including the operation type (Sum, Max, etc.) and the layout
/// configuration for both input and output.
#[derive(CubeType, CubeLaunch, Clone)]
pub struct ReduceFuseBlock {
#[cube(comptime)]
op: ReduceOperationConfig,
#[cube(comptime)]
config_input: FuseBlockConfig,
#[cube(comptime)]
config_output: FuseBlockConfig,
#[cube(comptime)]
input: FuseArg,
#[cube(comptime)]
output: FuseArg,
#[cube(comptime)]
blueprint: UnitReduceBlueprint,
}
/// A configuration block for an elementwise operation that follows a reduction.
#[derive(CubeType, CubeLaunch, Clone)]
pub struct ElemwiseFuseBlock {
#[cube(comptime)]
config: FuseBlockConfig,
}
/// The entry point for a broadcasted reduction kernel.
///
/// This kernel initializes local variables for multiple reduction blocks and then
/// executes the reduction sequence.
///
/// # Arguments
///
/// * `inputs` - Global arguments containing input tensor handles.
/// * `outputs` - Global arguments containing output tensor handles.
/// * `reduce_axis` - The dimension along which the reduction is performed.
/// * `blocks` - A sequence of reduction operations to execute.
/// * `block_end` - An optional elementwise block to execute after reductions are complete.
#[cube(launch_unchecked, address_type = "dynamic")]
pub fn reduce_kernel_broadcasted(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
reduce_axis: usize,
blocks: Sequence<ReduceFuseBlock>,
block_end: Option<ElemwiseFuseBlock>,
) {
#[unroll]
for i in 0..blocks.len() {
let block = blocks.index(i);
multi_block_variables_init(&block.config_input, &mut outputs.variables);
multi_block_variables_init(&block.config_output, &mut outputs.variables);
}
reduce_many(inputs, outputs, reduce_axis, blocks, block_end);
}
const REDUCE_INPUT: u8 = 0;
const REDUCE_ACC: u8 = 1;
const REDUCE_OUT: u8 = 2;
type In = NumericExpand<REDUCE_INPUT>;
type Acc = NumericExpand<REDUCE_ACC>;
type Out = NumericExpand<REDUCE_OUT>;
/// Configures the precision polyfills for the reduction based on the block's `FuseType`.
#[cube]
fn set_polyfill_block(block: &ReduceFuseBlock) {
let input_precision = comptime!(block.input.precision());
let output_precision = comptime!(block.output.precision());
let acc_precision = comptime!(match input_precision {
FuseType::F64 => FuseType::F64,
FuseType::F32 => FuseType::F32,
FuseType::Flex32 => FuseType::F32,
FuseType::F16 => FuseType::F32,
FuseType::BF16 => FuseType::F32,
FuseType::I64 => FuseType::I64,
FuseType::I32 => FuseType::I32,
FuseType::I16 => FuseType::I32,
FuseType::I8 => FuseType::I32,
FuseType::U64 => FuseType::U64,
FuseType::U32 => FuseType::U32,
FuseType::U16 => FuseType::U32,
FuseType::U8 => FuseType::U32,
FuseType::Bool => FuseType::I32,
});
set_polyfill::<In>(comptime!(input_precision.into_type()));
set_polyfill::<Out>(comptime!(output_precision.into_type()));
set_polyfill::<Acc>(comptime!(acc_precision.into_type()));
}
/// Internal logic for executing a sequence of reduction blocks followed by an optional
/// trailing elementwise block.
#[cube]
#[allow(clippy::clone_on_copy)]
fn reduce_many(
inputs: &GlobalArgs,
outputs: &mut GlobalArgs,
reduce_axis: usize,
blocks: Sequence<ReduceFuseBlock>,
block_end: Option<ElemwiseFuseBlock>,
) {
let mut axis_size = 0;
#[unroll]
for i in 0..blocks.len() {
let block = blocks.index(i);
let input = FusedReduceInput {
global: inputs.clone(),
config: comptime!(block.config_input.clone()),
arg: comptime!(block.input.clone()),
};
let global = outputs.clone();
let config = comptime!(block.config_output.clone());
let arg = comptime!(block.output.clone());
let mut output = FusedReduceOutput {
global,
config,
arg,
};
set_polyfill_block(block);
let (input, mut output) = init_tensors::<FusedReduceArgs, In, Out>(&input, &mut output);
axis_size = reduce_step::<(In, Acc), Out, ReduceOperation>(
&input,
&mut output,
reduce_axis,
block.op,
comptime!(block.blueprint.clone()),
);
}
if let Some(block) = block_end {
let global_index = ABSOLUTE_POS;
let width = comptime!(block.config.width as u32);
let num_iter = axis_size / usize::cast_from(width);
for i in 0..num_iter {
// Register block local inputs.
let values = Registry::<FuseArg, Line<f32>>::new();
let args = comptime![Vec::<FuseArg>::new()];
let index = global_index * num_iter + i;
let mut locals = init_locals(inputs, outputs, &block.config);
fuse_on_write::<f32>(
inputs,
outputs,
&mut locals,
index,
values,
args,
&block.config.clone(),
)
}
}
}
#[cube]
/// Executes a single reduction step using a specified instruction and blueprint.
///
/// Returns the size of the axis that was reduced.
fn reduce_step<P: ReducePrecision, Out: Numeric, I: ReduceInstruction<P>>(
input: &VirtualTensor<P::EI>,
output: &mut VirtualTensor<Out, ReadWrite>,
reduce_axis: usize,
#[comptime] config: I::Config,
#[comptime] blueprint: UnitReduceBlueprint,
) -> usize {
let inst = I::from_config(config);
let axis_size = input.shape(reduce_axis);
GlobalFullUnitReduce::execute::<P, Out, I>(
input,
output,
reduce_axis,
&inst,
LineMode::Parallel,
comptime!(blueprint),
);
axis_size
}

View File

@@ -0,0 +1,107 @@
use crate::CubeFusionHandle;
use burn_fusion::stream::{Context, ContextOwned};
use cubecl::Runtime;
use std::sync::Arc;
/// Fusion context used when tuning kernels.
///
/// Either the original context is returned or a fork of the original.
/// The fork is only given when performing autotuning, and not when actually performing the
/// operation.
pub enum TuneContext<'a, R: Runtime> {
Original(&'a mut Context<'a, CubeFusionHandle<R>>),
Fork(Box<ContextOwned<CubeFusionHandle<R>>>),
}
/// Fusion input wrapper containing the context and the optimization.
///
/// # Safety
///
/// This should only be used with the [tuner](cubecl::tune::LocalTuner), since safety assumptions
/// are made based on its behavior.
pub struct TuneInput<R: Runtime, O> {
context: UnsafeTuneContext<R>,
optimization: Arc<O>,
}
/// Unsafe wrapper around the context.
///
/// # Safety
///
/// The wrapper removes the context lifetime.
///
/// For it to be correct, the context must not be used after the invocation of the
/// [cubecl::tune::LocalTuner::execute] function. This is the case, since autotune functions are
/// tuned using a cloned version of the input; therefore, a fork of the context will be used to find
/// the best kernel to use, which can be async.
enum UnsafeTuneContext<R: Runtime> {
Original(*mut Context<'static, CubeFusionHandle<R>>),
Fork(Box<ContextOwned<CubeFusionHandle<R>>>),
}
unsafe impl<R: Runtime> Send for UnsafeTuneContext<R> {}
unsafe impl<R: Runtime, O> Send for TuneInput<R, O> {}
impl<R: Runtime, O> TuneInput<R, O> {
/// Create a new autotune input from the [context](Context) and an optimization.
pub fn new(context: &mut Context<CubeFusionHandle<R>>, optimization: O) -> Self {
let context = UnsafeTuneContext::new(context);
Self {
context,
optimization: Arc::new(optimization),
}
}
/// Retrieve the [autotune context](TuneContext) for the current input.
pub fn context(&self) -> TuneContext<'static, R> {
self.context.get()
}
/// Retrieve the optimization for the current input.
pub fn optimization(&self) -> &O {
&self.optimization
}
}
impl<R: Runtime> UnsafeTuneContext<R> {
fn new(context: &mut Context<'_, CubeFusionHandle<R>>) -> Self {
let ptr = core::ptr::from_mut(context);
// It is necessary for the lifetime.
#[allow(clippy::unnecessary_cast)]
Self::Original(ptr as *mut Context<'static, _>)
}
fn get(&self) -> TuneContext<'static, R> {
match self {
UnsafeTuneContext::Original(ptr) => {
TuneContext::Original(unsafe { ptr.as_mut().unwrap() })
}
UnsafeTuneContext::Fork(context) => TuneContext::Fork(Box::new(context.fork())),
}
}
}
impl<R: Runtime, O> Clone for TuneInput<R, O> {
fn clone(&self) -> Self {
Self {
context: self.context.clone(),
optimization: self.optimization.clone(),
}
}
}
impl<R: Runtime> Clone for UnsafeTuneContext<R> {
fn clone(&self) -> Self {
let context = match self {
UnsafeTuneContext::Original(ptr) => {
let context: &mut Context<'static, CubeFusionHandle<R>> =
unsafe { ptr.as_mut().unwrap() };
context.fork()
}
UnsafeTuneContext::Fork(context) => context.fork(),
};
UnsafeTuneContext::Fork(Box::new(context))
}
}