feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
@@ -0,0 +1,165 @@
|
||||
use burn_fusion::stream::Context;
|
||||
use burn_std::{DType, Strides, quantization::QParamTensor, strides};
|
||||
use cubecl::{
|
||||
CubeElement, Runtime,
|
||||
client::ComputeClient,
|
||||
ir::{AddressType, ElemType},
|
||||
prelude::{TensorArg, TensorHandleRef},
|
||||
};
|
||||
use cubecl::{
|
||||
ir::LineSize,
|
||||
quant::scheme::{QuantParam, QuantScheme},
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Defines a fallback operation when fusion isn't possible.
|
||||
pub trait FallbackOperation<R: Runtime>: Send + Sync {
|
||||
/// Executes the fallback procedure.
|
||||
fn run(&self, context: &mut Context<'_, CubeFusionHandle<R>>);
|
||||
}
|
||||
|
||||
/// Runtime parameters for quantization. Can be used to construct a scales handle from the base
|
||||
/// tensor handle.
|
||||
pub type QParams = burn_std::quantization::QParams<QParamTensor>;
|
||||
|
||||
/// Handle to be used when fusing operations.
|
||||
pub struct CubeFusionHandle<R: Runtime> {
|
||||
/// Compute client for jit.
|
||||
pub client: ComputeClient<R>,
|
||||
/// The buffer where the data are stored.
|
||||
pub handle: cubecl::server::Handle,
|
||||
/// The device of the current tensor.
|
||||
pub device: R::Device,
|
||||
/// The element type of the tensor.
|
||||
pub dtype: DType,
|
||||
/// The strides of the tensor.
|
||||
pub strides: Strides,
|
||||
/// Quantization runtime parameters, if applicable
|
||||
pub qparams: Option<QParams>,
|
||||
}
|
||||
|
||||
impl<R: Runtime> core::fmt::Debug for CubeFusionHandle<R> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_fmt(format_args!(
|
||||
"CubeFusionHandle {{ device: {:?}, runtime: {}}}",
|
||||
self.device,
|
||||
R::name(&self.client),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> Clone for CubeFusionHandle<R> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
client: self.client.clone(),
|
||||
handle: self.handle.clone(),
|
||||
device: self.device.clone(),
|
||||
strides: self.strides.clone(),
|
||||
dtype: self.dtype,
|
||||
qparams: self.qparams.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<R: Runtime> Send for CubeFusionHandle<R> {}
|
||||
unsafe impl<R: Runtime> Sync for CubeFusionHandle<R> {}
|
||||
|
||||
impl<R: Runtime> CubeFusionHandle<R> {
|
||||
/// Return the reference to a tensor handle.
|
||||
pub fn as_handle_ref<'a>(&'a self, shape: &'a [usize]) -> TensorHandleRef<'a, R> {
|
||||
TensorHandleRef {
|
||||
handle: &self.handle,
|
||||
strides: &self.strides,
|
||||
shape,
|
||||
runtime: PhantomData,
|
||||
elem_size: self.dtype.size(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn required_address_type(&self) -> AddressType {
|
||||
match self.dtype {
|
||||
DType::QFloat(scheme) => {
|
||||
let len = self.handle.size() as usize * 8 / scheme.size_bits_value();
|
||||
AddressType::from_len(len)
|
||||
}
|
||||
_ => AddressType::from_len(self.handle.size() as usize / self.dtype.size()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the reference to a tensor argument.
|
||||
pub fn as_tensor_arg<'a>(
|
||||
&'a self,
|
||||
shape: &'a [usize],
|
||||
line_size: LineSize,
|
||||
) -> TensorArg<'a, R> {
|
||||
let handle: TensorHandleRef<'a, R> = self.as_handle_ref(shape);
|
||||
|
||||
unsafe {
|
||||
TensorArg::from_raw_parts_and_size(
|
||||
handle.handle,
|
||||
handle.strides,
|
||||
handle.shape,
|
||||
line_size,
|
||||
self.dtype.size(),
|
||||
)
|
||||
}
|
||||
}
|
||||
/// Construct a separate tensor for the quantization scales, if present
|
||||
pub fn params(&self, scheme: QuantScheme) -> Option<Self> {
|
||||
let qparams = self.qparams.as_ref()?;
|
||||
let mut handle = self.handle.clone();
|
||||
handle.offset_start = Some(qparams.scales.offset_start as u64);
|
||||
handle.offset_end = Some(qparams.scales.offset_end as u64);
|
||||
|
||||
Some(Self {
|
||||
client: self.client.clone(),
|
||||
handle,
|
||||
device: self.device.clone(),
|
||||
dtype: match scheme.param {
|
||||
QuantParam::F32 => DType::F32,
|
||||
QuantParam::F16 => DType::F16,
|
||||
QuantParam::BF16 => DType::BF16,
|
||||
QuantParam::UE8M0 | QuantParam::UE4M3 => unimplemented!("Not yet supported"),
|
||||
},
|
||||
strides: qparams.scales.metadata.strides().clone(),
|
||||
qparams: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn strides_dyn_rank(shape: &[usize]) -> Strides {
|
||||
let mut strides = strides![0; shape.len()];
|
||||
|
||||
let mut current = 1;
|
||||
shape.iter().enumerate().rev().for_each(|(index, val)| {
|
||||
strides[index] = current;
|
||||
current *= val;
|
||||
});
|
||||
|
||||
strides
|
||||
}
|
||||
|
||||
pub(crate) fn elem_dtype<E: CubeElement>() -> DType {
|
||||
match E::cube_type().elem_type() {
|
||||
ElemType::Float(kind) => match kind {
|
||||
cubecl::ir::FloatKind::F64 => DType::F64,
|
||||
cubecl::ir::FloatKind::F16 => DType::F16,
|
||||
cubecl::ir::FloatKind::BF16 => DType::BF16,
|
||||
cubecl::ir::FloatKind::F32 => DType::F32,
|
||||
_ => todo!(),
|
||||
},
|
||||
ElemType::Int(kind) => match kind {
|
||||
cubecl::ir::IntKind::I64 => DType::I64,
|
||||
cubecl::ir::IntKind::I32 => DType::I32,
|
||||
cubecl::ir::IntKind::I16 => DType::I16,
|
||||
cubecl::ir::IntKind::I8 => DType::I8,
|
||||
},
|
||||
ElemType::UInt(kind) => match kind {
|
||||
cubecl::ir::UIntKind::U64 => DType::U64,
|
||||
cubecl::ir::UIntKind::U32 => DType::U32,
|
||||
cubecl::ir::UIntKind::U16 => DType::U16,
|
||||
cubecl::ir::UIntKind::U8 => DType::U8,
|
||||
},
|
||||
ElemType::Bool => DType::Bool,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
/// The element type ID to be used for dynamic element type while expanding a fused kernel.
|
||||
pub(crate) const DYN_ELEM_ID: u8 = u8::MAX;
|
||||
/// The element type ID to be used for the quantization store element type while expanding a fused kernel.
|
||||
pub(crate) const Q_STORE_DYN_ELEM_ID: u8 = u8::MAX - 1;
|
||||
/// The element type ID to be used for the quantization param element type while expanding a fused kernel.
|
||||
pub(crate) const Q_PARAM_DYN_ELEM_ID: u8 = u8::MAX - 2;
|
||||
@@ -0,0 +1,792 @@
|
||||
//! This module declares input-output primitives to read and write values during kernel expansion.
|
||||
use super::{DYN_ELEM_ID, ir::*, tensor::GlobalTensor};
|
||||
use burn_std::quantization::QuantScheme;
|
||||
use cubecl::quant::scheme::QuantLevel;
|
||||
use cubecl::{
|
||||
intrinsic,
|
||||
ir::{ExpandElement, Variable},
|
||||
prelude::*,
|
||||
std::{FastDivmod, tensor::View},
|
||||
};
|
||||
use cubek::quantization::layout::{BlockScaledLayout, PerTensorLayout, ScalesLayout};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Define how a tensor might be transformed at runtime.
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub enum Transform {
|
||||
/// A reshape operation has been registered on a tensor.
|
||||
///
|
||||
/// This enum entry contains a sequence of [arguments](FuseArg) that points to global scalars representing the
|
||||
/// new shape for the current tensor.
|
||||
Reshape(Vec<FuseArg>),
|
||||
/// Two axes have been swapped on a tensor.
|
||||
///
|
||||
/// The enum entry contains those two axes.
|
||||
SwapDims(usize, usize),
|
||||
}
|
||||
|
||||
/// Reads the value from the [arg](FuseArg) and cast it to the generic cube primitive.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The [global arguments](GlobalArgs) for both inputs and outputs as well as the
|
||||
/// [local arguments](LocalArgs) need to be passed to this function.
|
||||
///
|
||||
/// This is because the [argument](FuseArg) might point to a global input, output or local variable
|
||||
/// created during kernel expansion.
|
||||
#[cube]
|
||||
pub fn read<C: CubePrimitive>(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &GlobalArgs,
|
||||
locals: &LocalArgs,
|
||||
ref_pos: usize,
|
||||
#[comptime] arg: FuseArg,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) -> Line<C> {
|
||||
match arg {
|
||||
FuseArg::Input(pos, _precision, layout) => {
|
||||
let global = inputs.tensors.index(pos);
|
||||
let line_size = global.tensor.line_size();
|
||||
|
||||
if comptime![!global.broadcasted && line_size != config.width] {
|
||||
read_input_aligned(inputs, locals, pos, ref_pos, layout, config, None)
|
||||
} else {
|
||||
read_input(inputs, locals, pos, ref_pos, layout, config, None)
|
||||
}
|
||||
}
|
||||
FuseArg::MultiBlockLocal(key, _) | FuseArg::MultiBlockGlobal(key, _) => {
|
||||
Line::cast_from(outputs.variables.read(key))
|
||||
}
|
||||
FuseArg::Output(pos, _precision, layout) => {
|
||||
read_output(inputs, outputs, locals, pos, ref_pos, layout, config)
|
||||
}
|
||||
FuseArg::BlockLocal { pos, ty } => match comptime![ty] {
|
||||
FuseType::F64 => Line::cast_from(locals.l_f64.find(pos)),
|
||||
FuseType::F32 | FuseType::Flex32 => Line::cast_from(locals.l_f32.find(pos)),
|
||||
FuseType::F16 => Line::cast_from(locals.l_f16.find(pos)),
|
||||
FuseType::BF16 => Line::cast_from(locals.l_bf16.find(pos)),
|
||||
FuseType::U64 => Line::cast_from(locals.l_u64.find(pos)),
|
||||
FuseType::U32 => Line::cast_from(locals.l_u32.find(pos)),
|
||||
FuseType::U16 => Line::cast_from(locals.l_u16.find(pos)),
|
||||
FuseType::U8 => Line::cast_from(locals.l_u8.find(pos)),
|
||||
FuseType::I64 => Line::cast_from(locals.l_i64.find(pos)),
|
||||
FuseType::I32 => Line::cast_from(locals.l_i32.find(pos)),
|
||||
FuseType::I16 => Line::cast_from(locals.l_i16.find(pos)),
|
||||
FuseType::I8 => Line::cast_from(locals.l_i8.find(pos)),
|
||||
FuseType::Bool => Line::cast_from(locals.l_bool.find(pos)),
|
||||
},
|
||||
FuseArg::Scalar(..) => {
|
||||
let scalar = read_scalar::<C>(inputs, arg);
|
||||
Line::new(scalar)
|
||||
}
|
||||
FuseArg::ScalarShape(_) => {
|
||||
let scalar = read_scalar_shape(inputs, arg);
|
||||
Line::cast_from(scalar)
|
||||
}
|
||||
FuseArg::Literal(val, _precision) => Line::new(from_const_int::<C>(val)),
|
||||
FuseArg::InputReshaped {
|
||||
original,
|
||||
shape,
|
||||
broadcasted,
|
||||
} => match comptime![original.as_ref().clone()] {
|
||||
FuseArg::Input(pos, _precision, layout) => {
|
||||
let global = inputs.tensors.index(pos);
|
||||
let line_size = global.tensor.line_size();
|
||||
|
||||
if comptime![!broadcasted && line_size != config.width] {
|
||||
read_input_aligned(
|
||||
inputs,
|
||||
locals,
|
||||
pos,
|
||||
ref_pos,
|
||||
layout,
|
||||
config,
|
||||
comptime![Some(Transform::Reshape(shape))],
|
||||
)
|
||||
} else {
|
||||
read_input(
|
||||
inputs,
|
||||
locals,
|
||||
pos,
|
||||
ref_pos,
|
||||
layout,
|
||||
config,
|
||||
comptime![Some(Transform::Reshape(shape))],
|
||||
)
|
||||
}
|
||||
}
|
||||
_ => comptime![panic!("Only input can be reshaped")],
|
||||
},
|
||||
FuseArg::InputSwapDims {
|
||||
original,
|
||||
dims,
|
||||
broadcasted,
|
||||
} => match comptime![original.as_ref().clone()] {
|
||||
FuseArg::Input(pos, _precision, layout) => {
|
||||
let global = inputs.tensors.index(pos);
|
||||
let line_size = global.tensor.line_size();
|
||||
|
||||
if comptime![!broadcasted && line_size != config.width] {
|
||||
read_input_aligned(
|
||||
inputs,
|
||||
locals,
|
||||
pos,
|
||||
ref_pos,
|
||||
layout,
|
||||
config,
|
||||
comptime![Some(Transform::SwapDims(dims.0, dims.1))],
|
||||
)
|
||||
} else {
|
||||
read_input(
|
||||
inputs,
|
||||
locals,
|
||||
pos,
|
||||
ref_pos,
|
||||
layout,
|
||||
config,
|
||||
comptime![Some(Transform::SwapDims(dims.0, dims.1))],
|
||||
)
|
||||
}
|
||||
}
|
||||
_ => comptime![panic!("Only input can be swapped dims")],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the offset for the current global tensor with a quantized layout.
|
||||
///
|
||||
/// The offset can be used to fetch the correct data from the quantized tensor as if it was in a
|
||||
/// linear contiguous format.
|
||||
#[cube]
|
||||
fn index_offset_with_quant_layout(
|
||||
tensor: &GlobalTensor,
|
||||
locals: &LocalArgs,
|
||||
index: usize,
|
||||
#[comptime] rank: usize,
|
||||
#[comptime] scheme: QuantScheme,
|
||||
) -> usize {
|
||||
let (start, end) = (0, rank - 1);
|
||||
let num_quants = scheme.num_quants();
|
||||
|
||||
let offset_ref = index * locals.ref_line_size;
|
||||
let mut offset = 0;
|
||||
|
||||
#[unroll]
|
||||
for i in start..end {
|
||||
let ogwl = offset_ref / locals.ref_strides[i];
|
||||
offset += ogwl % tensor.tensor.shape(i) * tensor.tensor.stride(i);
|
||||
}
|
||||
|
||||
// Handle packed representation in last dim
|
||||
let ogwl = offset_ref / locals.ref_strides[end];
|
||||
let shape_last = tensor.tensor.shape(end).div_ceil(num_quants);
|
||||
let stride_last = tensor.tensor.stride(end);
|
||||
offset += (ogwl.div_ceil(num_quants)) % shape_last * stride_last;
|
||||
|
||||
offset / tensor.tensor.line_size()
|
||||
}
|
||||
|
||||
/// Reads a global quantized tensor at the given position.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The values returned in the [Line] are not dequantized.
|
||||
#[cube]
|
||||
pub fn read_quantized<C: CubePrimitive>(
|
||||
inputs: &GlobalArgs,
|
||||
locals: &LocalArgs,
|
||||
ref_pos: usize,
|
||||
#[comptime] arg: FuseArg,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
#[comptime] scheme: QuantScheme,
|
||||
) -> Line<C> {
|
||||
match arg {
|
||||
FuseArg::Input(pos, _precision, _layout) => {
|
||||
let global = inputs.tensors.index(pos);
|
||||
|
||||
let offset =
|
||||
index_offset_with_quant_layout(global, locals, ref_pos, config.rank, scheme);
|
||||
let val = global.tensor[offset];
|
||||
Line::cast_from(val)
|
||||
}
|
||||
_ => panic!("Not supported"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads a global scalar.
|
||||
#[cube]
|
||||
pub fn read_scalar<C: CubePrimitive>(inputs: &GlobalArgs, #[comptime] arg: FuseArg) -> C {
|
||||
match arg {
|
||||
FuseArg::Scalar(pos, _precision) => {
|
||||
let scalar = inputs.scalars.index(pos);
|
||||
scalar.get::<C>()
|
||||
}
|
||||
_ => comptime![panic!("Not a scalar")],
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads a global scalar that is used as a reshape position.
|
||||
#[cube]
|
||||
pub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: FuseArg) -> usize {
|
||||
match arg {
|
||||
FuseArg::ScalarShape(pos) => inputs.reshapes[pos],
|
||||
_ => comptime![panic!("Not a scalar shape")],
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads an input tensor.
|
||||
#[cube]
|
||||
pub fn read_input<C: CubePrimitive>(
|
||||
inputs: &GlobalArgs,
|
||||
locals: &LocalArgs,
|
||||
#[comptime] pos: usize,
|
||||
ref_pos: usize,
|
||||
#[comptime] layout: LayoutInfo,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
#[comptime] transform: Option<Transform>,
|
||||
) -> Line<C> {
|
||||
let tensor = inputs.tensors.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, locals, tensor, ref_pos, None, config, transform),
|
||||
};
|
||||
Line::cast_from(tensor.tensor[offset])
|
||||
}
|
||||
|
||||
/// Returns a slice of data in the asked precision of the input tensor at the given position.
|
||||
#[cube]
|
||||
pub fn read_input_window<C: CubePrimitive>(
|
||||
inputs: &GlobalArgs,
|
||||
#[comptime] pos: usize,
|
||||
start: usize,
|
||||
end: usize,
|
||||
) -> Slice<C> {
|
||||
let tensor = inputs.tensors.index(pos);
|
||||
let slice = tensor.tensor.slice(start, end);
|
||||
slice.downcast()
|
||||
}
|
||||
|
||||
/// Returns the input as a slice.
|
||||
#[cube]
|
||||
pub fn input_as_slice<C: CubePrimitive>(inputs: &GlobalArgs, #[comptime] pos: usize) -> Slice<C> {
|
||||
let tensor = inputs.tensors.index(pos);
|
||||
let slice = tensor.tensor.to_slice();
|
||||
slice.downcast()
|
||||
}
|
||||
|
||||
/// Returns the input tensor as a quantized scale view.
|
||||
#[cube]
|
||||
pub fn input_as_scales_view<C: CubePrimitive>(
|
||||
inputs: &GlobalArgs,
|
||||
#[comptime] pos: usize,
|
||||
#[comptime] tensor_pos: usize,
|
||||
#[comptime] level: QuantLevel,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) -> View<C, usize> {
|
||||
set_polyfill_typed::<C, NumericExpand<DYN_ELEM_ID>>();
|
||||
let tensor = inputs.tensors.index(tensor_pos);
|
||||
let scales = inputs.tensors.index(pos);
|
||||
let tensor_len = tensor.tensor.len();
|
||||
let rank = config.rank;
|
||||
let layout = match level {
|
||||
QuantLevel::Tensor => ScalesLayout::new_PerTensor(PerTensorLayout::new(tensor_len)),
|
||||
QuantLevel::Block(block_size) => {
|
||||
let block_size = comptime![block_size.to_dim_vec(rank)];
|
||||
let mut tensor_shape = Sequence::new();
|
||||
let mut scales_strides = Sequence::new();
|
||||
#[unroll]
|
||||
for i in 0..rank {
|
||||
tensor_shape.push(FastDivmod::new_Fallback(tensor.tensor.shape(i)));
|
||||
scales_strides.push(scales.tensor.stride(i));
|
||||
}
|
||||
let line_size = scales.tensor.line_size();
|
||||
let layout = BlockScaledLayout::new(
|
||||
tensor_shape,
|
||||
tensor_len,
|
||||
scales_strides,
|
||||
block_size,
|
||||
line_size,
|
||||
);
|
||||
ScalesLayout::new_BlockScaled(layout)
|
||||
}
|
||||
};
|
||||
View::new::<Slice<C>, usize>(&scales.tensor.to_slice().downcast(), layout)
|
||||
}
|
||||
|
||||
/// Reads the input tensor aligned.
|
||||
#[cube]
|
||||
pub fn read_input_aligned<C: CubePrimitive>(
|
||||
inputs: &GlobalArgs,
|
||||
locals: &LocalArgs,
|
||||
#[comptime] pos: usize,
|
||||
ref_pos: usize,
|
||||
#[comptime] layout: LayoutInfo,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
#[comptime] transform: Option<Transform>,
|
||||
) -> Line<C> {
|
||||
let mut result: Line<C> = Line::<C>::empty(config.width);
|
||||
let tensor = inputs.tensors.index(pos);
|
||||
|
||||
match transform.clone() {
|
||||
Some(Transform::Reshape(shape)) => {
|
||||
// Very brute force, not really efficient, but not easy to optimize and not a very
|
||||
// frequent workflow.
|
||||
let ref_pos = ref_pos * config.width;
|
||||
#[unroll]
|
||||
for i in 0..config.width {
|
||||
let index = reshaped_index(
|
||||
inputs,
|
||||
locals,
|
||||
ref_pos + i,
|
||||
config.rank,
|
||||
comptime![shape.clone()],
|
||||
);
|
||||
let index = reshaped_index_to_original_index(&tensor.tensor, index, config.rank);
|
||||
result[i] = C::cast_from(tensor.tensor[index][0])
|
||||
}
|
||||
}
|
||||
Some(Transform::SwapDims(dim1, dim2)) => {
|
||||
let offset =
|
||||
get_offset_aligned(inputs, locals, tensor, ref_pos, layout, config, transform);
|
||||
let i = comptime![swap_dims_transform(config.rank - 1, (dim1, dim2))];
|
||||
let stride = tensor.tensor.stride(i);
|
||||
|
||||
#[unroll]
|
||||
for i in 0..config.width {
|
||||
let index = offset + i * stride;
|
||||
result[i] = C::cast_from(tensor.tensor[index][0])
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let offset =
|
||||
get_offset_aligned(inputs, locals, tensor, ref_pos, layout, config, transform);
|
||||
let stride = tensor.tensor.stride(config.rank - 1);
|
||||
#[unroll]
|
||||
for i in 0..config.width {
|
||||
let index = offset + i * stride;
|
||||
result[i] = C::cast_from(tensor.tensor[index][0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Computes the offset of the given [GlobalTensor] at on the reference position with a linear
|
||||
/// layout.
|
||||
#[cube]
|
||||
pub fn get_offset_aligned(
|
||||
inputs: &GlobalArgs,
|
||||
locals: &LocalArgs,
|
||||
tensor: &GlobalTensor,
|
||||
ref_pos: usize,
|
||||
#[comptime] layout: LayoutInfo,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
#[comptime] transform: Option<Transform>,
|
||||
) -> usize {
|
||||
match layout {
|
||||
LayoutInfo::SameAsRef | LayoutInfo::IsRef => {
|
||||
(ref_pos * locals.ref_line_size) / tensor.tensor.line_size()
|
||||
}
|
||||
LayoutInfo::Unknown => get_offset(
|
||||
inputs,
|
||||
locals,
|
||||
tensor,
|
||||
ref_pos,
|
||||
None,
|
||||
config,
|
||||
comptime!(transform.clone()),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads an output tensor.
|
||||
#[cube]
|
||||
pub fn read_output<C: CubePrimitive>(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &GlobalArgs,
|
||||
locals: &LocalArgs,
|
||||
#[comptime] pos: usize,
|
||||
ref_pos: usize,
|
||||
#[comptime] layout: LayoutInfo,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) -> Line<C> {
|
||||
let tensor = outputs.tensors.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => get_offset(inputs, locals, tensor, ref_pos, None, config, None),
|
||||
};
|
||||
Line::cast_from(tensor.tensor[offset])
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Write the given value at the [arg](Arg) position.
|
||||
pub fn write<C: CubePrimitive>(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
locals: &mut LocalArgs,
|
||||
ref_pos: usize,
|
||||
value: Line<C>,
|
||||
#[comptime] arg: FuseArg,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) {
|
||||
match arg {
|
||||
FuseArg::Output(pos, precision, layout) => {
|
||||
let tensor = outputs.tensors.index(pos);
|
||||
let offset = match layout {
|
||||
LayoutInfo::SameAsRef => ref_pos,
|
||||
LayoutInfo::IsRef => ref_pos,
|
||||
LayoutInfo::Unknown => {
|
||||
get_offset(inputs, locals, tensor, ref_pos, None, config, None)
|
||||
}
|
||||
};
|
||||
let tensor = outputs.tensors.index_mut(pos);
|
||||
set_polyfill::<NumericExpand<DYN_ELEM_ID>>(comptime![precision.into_type()]);
|
||||
|
||||
tensor.tensor[offset] = Line::cast_from(value);
|
||||
}
|
||||
FuseArg::BlockLocal { .. } => write_scalar::<C>(locals, value, arg),
|
||||
FuseArg::MultiBlockLocal(key, _) | FuseArg::MultiBlockGlobal(key, _) => {
|
||||
outputs.variables.write(key, Line::cast_from(value))
|
||||
}
|
||||
_ => comptime![panic!("Can't write into inputs and scalars")],
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Write the given value at the [arg](Arg) position.
|
||||
pub fn write_scalar<C: CubePrimitive>(
|
||||
locals: &mut LocalArgs,
|
||||
value: Line<C>,
|
||||
#[comptime] arg: FuseArg,
|
||||
) {
|
||||
match arg {
|
||||
FuseArg::BlockLocal { pos, ty } => match comptime![ty] {
|
||||
FuseType::F64 => locals.l_f64.insert(pos, Line::cast_from(value)),
|
||||
FuseType::F32 | FuseType::Flex32 => locals.l_f32.insert(pos, Line::cast_from(value)),
|
||||
FuseType::F16 => locals.l_f16.insert(pos, Line::cast_from(value)),
|
||||
FuseType::BF16 => locals.l_bf16.insert(pos, Line::cast_from(value)),
|
||||
FuseType::U64 => locals.l_u64.insert(pos, Line::cast_from(value)),
|
||||
FuseType::U32 => locals.l_u32.insert(pos, Line::cast_from(value)),
|
||||
FuseType::U16 => locals.l_u16.insert(pos, Line::cast_from(value)),
|
||||
FuseType::U8 => locals.l_u8.insert(pos, Line::cast_from(value)),
|
||||
FuseType::I64 => locals.l_i64.insert(pos, Line::cast_from(value)),
|
||||
FuseType::I32 => locals.l_i32.insert(pos, Line::cast_from(value)),
|
||||
FuseType::I16 => locals.l_i16.insert(pos, Line::cast_from(value)),
|
||||
FuseType::I8 => locals.l_i8.insert(pos, Line::cast_from(value)),
|
||||
FuseType::Bool => locals.l_bool.insert(pos, Line::cast_from(value)),
|
||||
},
|
||||
_ => comptime![panic!("Can't write into something else than scalars")],
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub(crate) fn global_offset(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &GlobalArgs,
|
||||
locals: &LocalArgs,
|
||||
index: usize,
|
||||
#[comptime] arg: FuseArg,
|
||||
#[comptime] range: Option<(usize, usize)>,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) -> usize {
|
||||
match arg {
|
||||
FuseArg::Input(pos, _precision, _layout) => {
|
||||
let tensor = inputs.tensors.index(pos);
|
||||
get_offset(inputs, locals, tensor, index, range, config, None)
|
||||
}
|
||||
FuseArg::Output(pos, _precision, _layout) => {
|
||||
let tensor = outputs.tensors.index(pos);
|
||||
get_offset(inputs, locals, tensor, index, range, config, None)
|
||||
}
|
||||
_ => panic!("Only input and output tensors have global offset."),
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn get_offset(
|
||||
inputs: &GlobalArgs,
|
||||
locals: &LocalArgs,
|
||||
tensor: &GlobalTensor,
|
||||
ref_pos: usize,
|
||||
#[comptime] range: Option<(usize, usize)>,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
#[comptime] transform: Option<Transform>,
|
||||
) -> usize {
|
||||
index_offset_with_layout(
|
||||
inputs,
|
||||
tensor,
|
||||
locals,
|
||||
ref_pos,
|
||||
range,
|
||||
config.rank,
|
||||
transform,
|
||||
)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Gets the line size for a global tensor.
|
||||
pub fn global_line_size(global: &GlobalArgs, #[comptime] pos: usize) -> comptime_type!(LineSize) {
|
||||
let tensor = global.tensors.index(pos);
|
||||
tensor.tensor.line_size()
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Gets the rank for a global tensor.
|
||||
pub fn global_rank(global: &GlobalArgs, #[comptime] pos: usize) -> usize {
|
||||
let tensor = global.tensors.index(pos);
|
||||
tensor.tensor.rank()
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Gets the length for a global tensor.
|
||||
pub fn global_len(global: &GlobalArgs, #[comptime] pos: usize) -> usize {
|
||||
let tensor = global.tensors.index(pos);
|
||||
tensor.tensor.len()
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Gets the buffer length for a global tensor.
|
||||
pub fn global_buffer_len(global: &GlobalArgs, #[comptime] pos: usize) -> usize {
|
||||
let tensor = global.tensors.index(pos);
|
||||
tensor.tensor.buffer_len()
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Gets the reference tensor length.
|
||||
pub fn ref_len(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &GlobalArgs,
|
||||
locals: &LocalArgs,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) -> usize {
|
||||
match config.ref_layout.clone() {
|
||||
RefLayout::Concrete(arg) => match comptime![arg] {
|
||||
FuseArg::Input(index, _, _) => global_len(inputs, index),
|
||||
FuseArg::Output(index, _, _) => global_len(outputs, index),
|
||||
_ => panic!("Invalid concrete ref layout."),
|
||||
},
|
||||
RefLayout::Virtual(..) => num_elements(locals, config),
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Gets the reference buffer tensor length.
|
||||
pub fn ref_buffer_len(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &GlobalArgs,
|
||||
locals: &LocalArgs,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) -> usize {
|
||||
match config.ref_layout.clone() {
|
||||
RefLayout::Concrete(arg) => match comptime![arg] {
|
||||
FuseArg::Input(index, _, _) => global_buffer_len(inputs, index),
|
||||
FuseArg::Output(index, _, _) => global_buffer_len(outputs, index),
|
||||
_ => panic!("Invalid concrete ref layout."),
|
||||
},
|
||||
RefLayout::Virtual(VirtualLayout::SwapDims(arg, ..)) => match arg {
|
||||
FuseArg::Input(index, _, _) => global_buffer_len(inputs, index),
|
||||
FuseArg::Output(index, _, _) => global_buffer_len(outputs, index),
|
||||
_ => panic!("Invalid concrete ref layout."),
|
||||
},
|
||||
RefLayout::Virtual(VirtualLayout::Reshaped { .. }) => num_elements(locals, config),
|
||||
RefLayout::Virtual(VirtualLayout::Shape(..)) => num_elements(locals, config),
|
||||
RefLayout::Virtual(VirtualLayout::Runtime { .. }) => num_elements(locals, config),
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Gets the reference number of elements.
|
||||
pub fn num_elements(locals: &LocalArgs, #[comptime] config: &FuseBlockConfig) -> usize {
|
||||
let mut length = 1;
|
||||
|
||||
for i in 0..config.rank {
|
||||
length *= locals.ref_shape[i];
|
||||
}
|
||||
|
||||
length
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Gets the reference axis shape.
|
||||
pub fn ref_shape(locals: &LocalArgs, axis: usize) -> usize {
|
||||
locals.ref_shape[axis]
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Gets the reference axis stride.
|
||||
pub fn ref_stride(locals: &LocalArgs, axis: usize) -> usize {
|
||||
locals.ref_strides[axis]
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Gets the reference line size.
|
||||
pub fn ref_line_size(locals: &LocalArgs) -> comptime_type!(LineSize) {
|
||||
comptime![locals.ref_line_size]
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Gets the given tensor axis shape.
|
||||
pub fn global_shape(global: &GlobalArgs, axis: usize, #[comptime] pos: usize) -> usize {
|
||||
let tensor = global.tensors.index(pos);
|
||||
tensor.tensor.shape(axis)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Gets the given tensor axis stride.
|
||||
pub fn global_stride(global: &GlobalArgs, dim: usize, #[comptime] pos: usize) -> usize {
|
||||
let tensor = global.tensors.index(pos);
|
||||
tensor.tensor.stride(dim)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn index_offset_with_layout(
|
||||
inputs: &GlobalArgs,
|
||||
tensor: &GlobalTensor,
|
||||
locals: &LocalArgs,
|
||||
index: usize,
|
||||
#[comptime] range: Option<(usize, usize)>,
|
||||
#[comptime] rank: usize,
|
||||
#[comptime] transform: Option<Transform>,
|
||||
) -> usize {
|
||||
match comptime![transform.clone()] {
|
||||
Some(Transform::Reshape(shape)) => {
|
||||
comptime![assert!(
|
||||
range.is_none(),
|
||||
"Can't get a range on a reshaped tensor."
|
||||
)];
|
||||
|
||||
let index = index * locals.ref_line_size;
|
||||
let index = reshaped_index(inputs, locals, index, rank, shape);
|
||||
reshaped_index_to_original_index(&tensor.tensor, index, rank)
|
||||
}
|
||||
Some(Transform::SwapDims(dim1, dim2)) => {
|
||||
let (start, end) = comptime! {match range {
|
||||
Some(range) => range,
|
||||
None => (0, rank),
|
||||
}};
|
||||
|
||||
let offset_ref = index * locals.ref_line_size;
|
||||
let mut offset = 0;
|
||||
|
||||
#[unroll]
|
||||
for i in start..end {
|
||||
let index = comptime![swap_dims_transform(i, (dim1, dim2))];
|
||||
let ogwl = offset_ref / locals.ref_strides[i];
|
||||
offset += ogwl % tensor.tensor.shape(index) * tensor.tensor.stride(index);
|
||||
}
|
||||
|
||||
offset / tensor.tensor.line_size()
|
||||
}
|
||||
None => {
|
||||
let (start, end) = comptime! {match range {
|
||||
Some(range) => range,
|
||||
None => (0, rank),
|
||||
}};
|
||||
|
||||
let offset_ref = index * locals.ref_line_size;
|
||||
let mut offset = 0;
|
||||
|
||||
#[unroll]
|
||||
for i in start..end {
|
||||
let ogwl = offset_ref / locals.ref_strides[i];
|
||||
offset += ogwl % tensor.tensor.shape(i) * tensor.tensor.stride(i);
|
||||
}
|
||||
|
||||
offset / tensor.tensor.line_size()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn swap_dims_transform(i: usize, dims: (usize, usize)) -> usize {
|
||||
if i == dims.0 {
|
||||
dims.1
|
||||
} else if i == dims.1 {
|
||||
dims.0
|
||||
} else {
|
||||
i
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
#[allow(clippy::clone_on_copy)]
|
||||
/// The index the input tensor would be at if it was contiguous.
|
||||
fn reshaped_index(
|
||||
inputs: &GlobalArgs,
|
||||
locals: &LocalArgs,
|
||||
index: usize,
|
||||
#[comptime] rank: usize,
|
||||
#[comptime] shape: Vec<FuseArg>,
|
||||
) -> usize {
|
||||
let mut offset = 0;
|
||||
let mut stride_curr = 1;
|
||||
|
||||
#[unroll]
|
||||
for r in 0..rank {
|
||||
let i = reverse_index(rank, r).comptime();
|
||||
let arg = shape[i].clone();
|
||||
let shape_i = read_scalar_shape(inputs, arg);
|
||||
let ogwl = index / locals.ref_strides[i];
|
||||
|
||||
offset += ogwl % shape_i * stride_curr;
|
||||
|
||||
stride_curr *= shape_i;
|
||||
}
|
||||
|
||||
offset
|
||||
}
|
||||
|
||||
#[allow(unreachable_code)]
|
||||
#[cube]
|
||||
#[allow(clippy::clone_on_copy)]
|
||||
fn reshaped_index_to_original_index<C: CubePrimitive>(
|
||||
original: &Tensor<Line<C>>,
|
||||
index_reshaped: usize,
|
||||
#[comptime] rank: usize,
|
||||
) -> usize {
|
||||
let mut remaining = index_reshaped;
|
||||
let mut offset = 0;
|
||||
|
||||
#[unroll]
|
||||
for r in 0..rank {
|
||||
let i = reverse_index(rank, r);
|
||||
let shape = original.shape(i);
|
||||
let stride = original.stride(i);
|
||||
|
||||
let coordinate = remaining % shape;
|
||||
|
||||
remaining /= shape;
|
||||
offset += coordinate * stride;
|
||||
}
|
||||
|
||||
offset / original.line_size()
|
||||
}
|
||||
|
||||
#[cube]
|
||||
#[allow(unused_variables)]
|
||||
pub(crate) fn reverse_index(
|
||||
#[comptime] rank: usize,
|
||||
#[comptime] iter: usize,
|
||||
) -> comptime_type!(usize) {
|
||||
rank - iter - 1
|
||||
}
|
||||
|
||||
/// Generic way to construct any [`CubePrimitive`] from an int. Used for fusion.
|
||||
#[allow(unused_variables)]
|
||||
#[cube]
|
||||
fn from_const_int<C: CubePrimitive>(#[comptime] value: usize) -> C {
|
||||
intrinsic!(|scope| {
|
||||
ExpandElement::Plain(Variable::constant(value.into(), C::as_type(scope))).into()
|
||||
})
|
||||
}
|
||||
|
||||
#[cube]
|
||||
#[allow(clippy::extra_unused_type_parameters)]
|
||||
fn set_polyfill_typed<C: CubePrimitive, Dyn: CubePrimitive>() {
|
||||
intrinsic!(|scope| {
|
||||
let elem_type = C::as_type(scope);
|
||||
set_polyfill::expand::<Dyn>(scope, elem_type);
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,917 @@
|
||||
use super::tensor::GlobalTensor;
|
||||
use crate::engine::codegen::DYN_ELEM_ID;
|
||||
use burn_std::{
|
||||
DType, Shape, Strides, bf16, f16,
|
||||
quantization::{QuantScheme, QuantStore, QuantValue},
|
||||
strides,
|
||||
};
|
||||
use core::fmt::Display;
|
||||
use cubecl::{
|
||||
ir::{ElemType, FloatKind, IntKind, StorageType, UIntKind},
|
||||
prelude::*,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
/// Argument to a [fuse operation](FuseOp).
|
||||
pub enum FuseArg {
|
||||
/// A readonly input tensor.
|
||||
Input(usize, FuseType, LayoutInfo),
|
||||
/// A readwrite output tensor.
|
||||
Output(usize, FuseType, LayoutInfo),
|
||||
/// A temporary local variable within a single [block](FuseBlockConfig).
|
||||
BlockLocal {
|
||||
/// The position of the current variable relative to all local variables within a single block.
|
||||
pos: usize,
|
||||
/// The type of the current variable.
|
||||
ty: FuseType,
|
||||
},
|
||||
/// A variable shared between multiple [block](FuseBlockConfig) that must have a compatible
|
||||
/// scope.
|
||||
MultiBlockLocal(MultiBlockPos, FuseType),
|
||||
/// A variable shared between multiple [blocks](FuseBlockConfig) within a global accessible
|
||||
/// scope.
|
||||
MultiBlockGlobal(MultiBlockPos, FuseType),
|
||||
/// A global scalar.
|
||||
Scalar(usize, FuseType),
|
||||
/// A global scalar used in a reshape operation.
|
||||
///
|
||||
/// This is not a scalar defined by a user for computation, but a scalar defined as part of
|
||||
/// a reshape operation.
|
||||
ScalarShape(usize),
|
||||
/// Only constant that can be encoded into an u32 can be used as literal.
|
||||
Literal(usize, FuseType),
|
||||
/// A readonly input tensor that is reshaped.
|
||||
InputReshaped {
|
||||
original: Box<FuseArg>,
|
||||
shape: Vec<FuseArg>,
|
||||
broadcasted: bool,
|
||||
},
|
||||
/// A readonly input tensor with swapped dimensions.
|
||||
InputSwapDims {
|
||||
original: Box<FuseArg>,
|
||||
dims: (usize, usize),
|
||||
broadcasted: bool,
|
||||
},
|
||||
}
|
||||
|
||||
/// Metadata of a variable shared between blocks.
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub struct MultiBlockPos {
|
||||
/// The block position in all blocks included in a fused trace.
|
||||
pub block_pos: usize,
|
||||
/// The [FuseArg::BlockLocal] position in the block where the variable is first initialized.
|
||||
pub block_local_pos: usize,
|
||||
}
|
||||
|
||||
#[derive(
|
||||
CubeType, Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord,
|
||||
)]
|
||||
/// Layout information.
|
||||
pub enum LayoutInfo {
|
||||
/// The layout if the same as the reference.
|
||||
SameAsRef,
|
||||
/// The reference layout.
|
||||
IsRef,
|
||||
/// The layout if unknown.
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl FuseArg {
|
||||
pub fn precision(&self) -> FuseType {
|
||||
*match self {
|
||||
FuseArg::Input(_, p, _) => p,
|
||||
FuseArg::BlockLocal { ty, .. } => ty,
|
||||
FuseArg::MultiBlockLocal(_, p) => p,
|
||||
FuseArg::MultiBlockGlobal(_, p) => p,
|
||||
FuseArg::Output(_, p, _) => p,
|
||||
FuseArg::Scalar(_, p) => p,
|
||||
FuseArg::Literal(_, p) => p,
|
||||
FuseArg::ScalarShape(_) => return FuseType::U32,
|
||||
FuseArg::InputReshaped { original, .. } => return original.precision(),
|
||||
FuseArg::InputSwapDims { original, .. } => return original.precision(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CubeType for FuseArg {
|
||||
type ExpandType = Self;
|
||||
}
|
||||
|
||||
impl IntoMut for FuseArg {
|
||||
fn into_mut(self, _context: &mut Scope) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoRuntime for FuseArg {
|
||||
fn __expand_runtime_method(self, _context: &mut Scope) -> Self::ExpandType {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl CubeDebug for FuseArg {}
|
||||
|
||||
#[derive(CubeType, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
/// Operations that can be executed and fused automatically using a fuse-on-read and/or
|
||||
/// fuse-on-write strategy.
|
||||
pub enum FuseOp {
|
||||
Add(BinaryFuseArgs),
|
||||
Sub(BinaryFuseArgs),
|
||||
Mul(BinaryFuseArgs),
|
||||
Div(BinaryFuseArgs),
|
||||
Powf(BinaryFuseArgs),
|
||||
Abs(UnaryFuseArgs),
|
||||
Exp(UnaryFuseArgs),
|
||||
Log(UnaryFuseArgs),
|
||||
Log1p(UnaryFuseArgs),
|
||||
Cos(UnaryFuseArgs),
|
||||
Sin(UnaryFuseArgs),
|
||||
Tanh(UnaryFuseArgs),
|
||||
Erf(UnaryFuseArgs),
|
||||
Sqrt(UnaryFuseArgs),
|
||||
Recip(UnaryFuseArgs),
|
||||
Assign(UnaryFuseArgs),
|
||||
Equal(BinaryFuseArgs),
|
||||
Lower(BinaryFuseArgs),
|
||||
Greater(BinaryFuseArgs),
|
||||
LowerEqual(BinaryFuseArgs),
|
||||
Rem(BinaryFuseArgs),
|
||||
GreaterEqual(BinaryFuseArgs),
|
||||
Clamp {
|
||||
input: FuseArg,
|
||||
min: FuseArg,
|
||||
max: FuseArg,
|
||||
out: FuseArg,
|
||||
},
|
||||
ConditionalAssign {
|
||||
cond: FuseArg,
|
||||
lhs: FuseArg,
|
||||
rhs: FuseArg,
|
||||
out: FuseArg,
|
||||
},
|
||||
Gather {
|
||||
input: FuseArg,
|
||||
indices: FuseArg,
|
||||
output: FuseArg,
|
||||
dim: usize,
|
||||
},
|
||||
Select {
|
||||
input: FuseArg,
|
||||
indices: FuseArg,
|
||||
output: FuseArg,
|
||||
dim: usize,
|
||||
},
|
||||
Dequantize {
|
||||
values: FuseArg,
|
||||
params: FuseArg,
|
||||
output: FuseArg,
|
||||
scheme: QuantSchemeFuse,
|
||||
},
|
||||
}
|
||||
|
||||
impl Display for FuseOp {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
FuseOp::Add(args) => write!(f, "{} = {} + {}", args.out, args.lhs, args.rhs),
|
||||
FuseOp::Sub(args) => write!(f, "{} = {} - {}", args.out, args.lhs, args.rhs),
|
||||
FuseOp::Mul(args) => write!(f, "{} = {} * {}", args.out, args.lhs, args.rhs),
|
||||
FuseOp::Div(args) => write!(f, "{} = {} / {}", args.out, args.lhs, args.rhs),
|
||||
FuseOp::Powf(args) => write!(f, "{} = powf({}, {})", args.out, args.lhs, args.rhs),
|
||||
FuseOp::Abs(args) => write!(f, "{} = abs({})", args.out, args.input),
|
||||
FuseOp::Exp(args) => write!(f, "{} = exp({})", args.out, args.input),
|
||||
FuseOp::Log(args) => write!(f, "{} = log({})", args.out, args.input),
|
||||
FuseOp::Log1p(args) => write!(f, "{} = log1p({})", args.out, args.input),
|
||||
FuseOp::Cos(args) => write!(f, "{} = cos({})", args.out, args.input),
|
||||
FuseOp::Sin(args) => write!(f, "{} = sin({})", args.out, args.input),
|
||||
FuseOp::Tanh(args) => write!(f, "{} = tanh({})", args.out, args.input),
|
||||
FuseOp::Erf(args) => write!(f, "{} = erf({})", args.out, args.input),
|
||||
FuseOp::Sqrt(args) => write!(f, "{} = sqrt({})", args.out, args.input),
|
||||
FuseOp::Recip(args) => write!(f, "{} = recip({})", args.out, args.input),
|
||||
FuseOp::Assign(args) => write!(f, "{} = {}", args.out, args.input),
|
||||
FuseOp::Equal(args) => write!(f, "{} = {} == {}", args.out, args.lhs, args.rhs),
|
||||
FuseOp::Lower(args) => write!(f, "{} = {} < {}", args.out, args.lhs, args.rhs),
|
||||
FuseOp::Greater(args) => write!(f, "{} = {} > {}", args.out, args.lhs, args.rhs),
|
||||
FuseOp::LowerEqual(args) => write!(f, "{} = {} <= {}", args.out, args.lhs, args.rhs),
|
||||
FuseOp::Rem(args) => write!(f, "{} = {} % {}", args.out, args.lhs, args.rhs),
|
||||
FuseOp::GreaterEqual(args) => write!(f, "{} = {} >= {}", args.out, args.lhs, args.rhs),
|
||||
FuseOp::Clamp {
|
||||
input,
|
||||
min,
|
||||
max,
|
||||
out,
|
||||
} => write!(f, "{} = clamp({}, min={}, max={})", out, input, min, max),
|
||||
FuseOp::ConditionalAssign {
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
} => write!(
|
||||
f,
|
||||
"{} = select(cond={}, lhs={}, rhs={})",
|
||||
out, cond, lhs, rhs
|
||||
),
|
||||
FuseOp::Gather {
|
||||
input,
|
||||
indices,
|
||||
output,
|
||||
dim,
|
||||
} => write!(
|
||||
f,
|
||||
"{} = gather(input={}, indices={}, dim={})",
|
||||
output, input, indices, dim
|
||||
),
|
||||
FuseOp::Select {
|
||||
input,
|
||||
indices,
|
||||
output,
|
||||
dim,
|
||||
} => write!(
|
||||
f,
|
||||
"{} = select(input={}, indices={}, dim={})",
|
||||
output, input, indices, dim
|
||||
),
|
||||
FuseOp::Dequantize {
|
||||
values,
|
||||
params,
|
||||
output,
|
||||
scheme: _,
|
||||
} => write!(
|
||||
f,
|
||||
"{} = dequantize(values={}, params={})",
|
||||
output, values, params
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(
|
||||
CubeType, CubeLaunch, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord,
|
||||
)]
|
||||
pub struct QuantSchemeFuse {
|
||||
#[cube(comptime)]
|
||||
pub(crate) scheme: QuantScheme,
|
||||
}
|
||||
|
||||
impl FuseOp {
|
||||
/// Element type used for the computation.
|
||||
pub(crate) fn cmp_elem(&self) -> ElemType {
|
||||
match self {
|
||||
FuseOp::Add(op) => op.lhs.precision().into_elem(),
|
||||
FuseOp::Sub(op) => op.lhs.precision().into_elem(),
|
||||
FuseOp::Mul(op) => op.lhs.precision().into_elem(),
|
||||
FuseOp::Div(op) => op.lhs.precision().into_elem(),
|
||||
FuseOp::Powf(op) => op.lhs.precision().into_elem(),
|
||||
FuseOp::Abs(op) => op.out.precision().into_elem(),
|
||||
FuseOp::Exp(op) => op.out.precision().into_elem(),
|
||||
FuseOp::Log(op) => op.out.precision().into_elem(),
|
||||
FuseOp::Log1p(op) => op.out.precision().into_elem(),
|
||||
FuseOp::Cos(op) => op.out.precision().into_elem(),
|
||||
FuseOp::Sin(op) => op.out.precision().into_elem(),
|
||||
FuseOp::Tanh(op) => op.out.precision().into_elem(),
|
||||
FuseOp::Erf(op) => op.out.precision().into_elem(),
|
||||
FuseOp::Recip(op) => op.out.precision().into_elem(),
|
||||
FuseOp::Sqrt(op) => op.out.precision().into_elem(),
|
||||
FuseOp::Assign(op) => op.out.precision().into_elem(),
|
||||
FuseOp::Equal(op) => op.lhs.precision().into_elem(),
|
||||
FuseOp::Lower(op) => op.lhs.precision().into_elem(),
|
||||
FuseOp::Greater(op) => op.lhs.precision().into_elem(),
|
||||
FuseOp::LowerEqual(op) => op.lhs.precision().into_elem(),
|
||||
FuseOp::GreaterEqual(op) => op.lhs.precision().into_elem(),
|
||||
FuseOp::ConditionalAssign { out, .. } => out.precision().into_elem(),
|
||||
FuseOp::Gather { output, .. } => output.precision().into_elem(),
|
||||
FuseOp::Select { output, .. } => output.precision().into_elem(),
|
||||
FuseOp::Dequantize { output, .. } => output.precision().into_elem(),
|
||||
FuseOp::Rem(op) => op.out.precision().into_elem(),
|
||||
FuseOp::Clamp { out, .. } => out.precision().into_elem(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Element type used for the computation.
|
||||
pub(crate) fn cmp_type(&self) -> StorageType {
|
||||
self.cmp_elem().into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(CubeType, CubeLaunch, Default, Clone)]
|
||||
/// Global arguments that are used for fusing [element wise operations](ElemTypewiseOp).
|
||||
pub struct GlobalArgs {
|
||||
/// Tensors that are stored in global memory.
|
||||
pub tensors: Sequence<GlobalTensor>,
|
||||
/// Scalars that are stored in global memory.
|
||||
pub scalars: Sequence<InputScalar>,
|
||||
/// To be used to perform reshape inside a fused kernel.
|
||||
pub reshapes: Sequence<usize>,
|
||||
/// When there are no metadata as a reference layout, we provide runtime shape/strides in this
|
||||
/// sequence instead.
|
||||
pub runtime_layouts: Sequence<usize>,
|
||||
/// Variables shared between blocks.
|
||||
pub variables: MultiBlockVariables,
|
||||
}
|
||||
|
||||
impl<'a, R: Runtime> GlobalArgsLaunch<'a, R> {
|
||||
pub fn required_address_type(&self) -> AddressType {
|
||||
self.tensors
|
||||
.values
|
||||
.iter()
|
||||
.map(|it| it.address_type)
|
||||
.max()
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Variables shared between blocks.
|
||||
#[derive(CubeType, Default, Clone)]
|
||||
pub struct MultiBlockVariables {
|
||||
variables: Registry<usize, Registry<usize, RuntimeCell<Line<NumericExpand<DYN_ELEM_ID>>>>>,
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl MultiBlockVariables {
|
||||
/// Initializes the variable with the given key and line size.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The type of [`NumericExpand<DYN_ELEM_ID>`] must be set before calling this function.
|
||||
pub fn init(&mut self, #[comptime] key: MultiBlockPos, #[comptime] line_size: usize) {
|
||||
let mut registers = Registry::<
|
||||
usize,
|
||||
Registry<usize, RuntimeCell<Line<NumericExpand<DYN_ELEM_ID>>>>,
|
||||
>::find_or_default::<usize>(&mut self.variables, key.block_pos);
|
||||
let cell = RuntimeCell::new(Line::empty(line_size));
|
||||
registers.insert(key.block_local_pos, cell);
|
||||
}
|
||||
|
||||
/// Read the variable using the provided key.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The variable must be initialized.
|
||||
pub fn read(&self, #[comptime] key: MultiBlockPos) -> Line<NumericExpand<DYN_ELEM_ID>> {
|
||||
let registers = self.variables.find(key.block_pos);
|
||||
let cell = registers.find(key.block_local_pos);
|
||||
cell.read()
|
||||
}
|
||||
|
||||
/// Write to the variable using the provided key and value.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The variable must be initialized.
|
||||
pub fn write(
|
||||
&mut self,
|
||||
#[comptime] key: MultiBlockPos,
|
||||
value: Line<NumericExpand<DYN_ELEM_ID>>,
|
||||
) {
|
||||
let registers = self.variables.find(key.block_pos);
|
||||
// Try find for local(visibility) registers.
|
||||
let cell = registers.find(key.block_local_pos);
|
||||
cell.store(value);
|
||||
}
|
||||
}
|
||||
|
||||
// Because we only create it DURING compilation, not as a real launch arg.
|
||||
unsafe impl Send for MultiBlockVariables {}
|
||||
unsafe impl Sync for MultiBlockVariables {}
|
||||
|
||||
impl LaunchArg for MultiBlockVariables {
|
||||
type RuntimeArg<'a, R: Runtime> = ();
|
||||
type CompilationArg = ();
|
||||
|
||||
fn compilation_arg<R: Runtime>(_runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
|
||||
}
|
||||
|
||||
fn expand(
|
||||
_arg: &Self::CompilationArg,
|
||||
_builder: &mut KernelBuilder,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
MultiBlockVariablesExpand {
|
||||
variables: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> Default for GlobalArgsLaunch<'_, R> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
tensors: Default::default(),
|
||||
scalars: Default::default(),
|
||||
reshapes: Default::default(),
|
||||
variables: Default::default(),
|
||||
runtime_layouts: Default::default(),
|
||||
_phantom_runtime: std::marker::PhantomData,
|
||||
_phantom_a: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> core::fmt::Debug for GlobalArgsLaunch<'_, R> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "({:?})", self.tensors.values)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> GlobalArgsLaunch<'_, R> {
|
||||
/// Get the shape of the given [argument](Arg).
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the argument doesn't have an handle.
|
||||
pub fn shape(&self, arg: &FuseArg) -> Shape {
|
||||
match self.resolve_arg(arg) {
|
||||
TensorArg::Handle { handle, .. } => handle.shape.into(),
|
||||
TensorArg::Alias { .. } => panic!("Unsupported yet"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Shape used by the reference tensor.
|
||||
pub fn shape_ref(&self, ref_layout: &RefLayout, rank: usize) -> Shape {
|
||||
match ref_layout {
|
||||
RefLayout::Concrete(arg) => self.shape(arg),
|
||||
RefLayout::Virtual(layout) => match layout {
|
||||
VirtualLayout::SwapDims(original, dims) => {
|
||||
let mut shape = self.shape(original);
|
||||
shape.swap(dims.0, dims.1);
|
||||
shape
|
||||
}
|
||||
VirtualLayout::Reshaped { reshape_pos, .. } => {
|
||||
let start = *reshape_pos * rank;
|
||||
let end = start + rank;
|
||||
self.reshapes.values[start..end]
|
||||
.iter()
|
||||
.map(|s| s.elem)
|
||||
.collect()
|
||||
}
|
||||
VirtualLayout::Shape(original, _) => self.shape(original),
|
||||
VirtualLayout::Runtime { pos } => {
|
||||
let start = (*pos * 2) * rank;
|
||||
let end = start + rank;
|
||||
self.runtime_layouts.values[start..end]
|
||||
.iter()
|
||||
.map(|s| s.elem)
|
||||
.collect()
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the strides of the given [argument](Arg).
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the argument doesn't have an handle.
|
||||
pub fn strides(&self, arg: &FuseArg) -> Strides {
|
||||
match self.resolve_arg(arg) {
|
||||
TensorArg::Handle { handle, .. } => handle.strides.into(),
|
||||
TensorArg::Alias { .. } => panic!("Unsupported yet"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn strides_ref(&self, ref_layout: &RefLayout, rank: usize) -> Strides {
|
||||
match ref_layout {
|
||||
RefLayout::Concrete(arg) => self.strides(arg),
|
||||
// When not concrete, we operate on the contiguous layout.
|
||||
_ => {
|
||||
let shape = self.shape_ref(ref_layout, rank);
|
||||
let mut strides = strides![0; shape.len()];
|
||||
|
||||
let mut current = 1;
|
||||
shape.iter().enumerate().rev().for_each(|(index, val)| {
|
||||
strides[index] = current;
|
||||
current *= val;
|
||||
});
|
||||
|
||||
strides
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the line size of the given [argument](Arg).
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the argument doesn't have an handle.
|
||||
pub fn line_size(&self, arg: &FuseArg) -> LineSize {
|
||||
match self.resolve_arg(arg) {
|
||||
TensorArg::Handle { line_size, .. } => *line_size,
|
||||
TensorArg::Alias { .. } => panic!("Unsupported yet"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the [argument](Arg) to a [tensor argument](TensorArg).
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the argument isn't a global input or output tensor.
|
||||
pub fn resolve_arg(&self, arg: &FuseArg) -> &TensorArg<'_, R> {
|
||||
match arg {
|
||||
FuseArg::Input(pos, _, _) => &self.tensors.values[*pos].tensor,
|
||||
FuseArg::Output(pos, _, _) => &self.tensors.values[*pos].tensor,
|
||||
other => panic!("Arg not found: {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(CubeType, Clone)]
|
||||
/// Keep track of all local variables that are used as argument in fused
|
||||
/// [element wise operations](ElemwiseOp).
|
||||
pub struct LocalArgs {
|
||||
pub l_f64: Registry<usize, Line<f64>>,
|
||||
pub l_f32: Registry<usize, Line<f32>>,
|
||||
pub l_f16: Registry<usize, Line<f16>>,
|
||||
pub l_bf16: Registry<usize, Line<bf16>>,
|
||||
pub l_i64: Registry<usize, Line<i64>>,
|
||||
pub l_i32: Registry<usize, Line<i32>>,
|
||||
pub l_i16: Registry<usize, Line<i16>>,
|
||||
pub l_i8: Registry<usize, Line<i8>>,
|
||||
pub l_u64: Registry<usize, Line<u64>>,
|
||||
pub l_u32: Registry<usize, Line<u32>>,
|
||||
pub l_u16: Registry<usize, Line<u16>>,
|
||||
pub l_u8: Registry<usize, Line<u8>>,
|
||||
pub l_bool: Registry<usize, Line<bool>>,
|
||||
pub ref_shape: Slice<usize>,
|
||||
pub ref_strides: Slice<usize>,
|
||||
#[cube(comptime)]
|
||||
pub ref_line_size: LineSize,
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl LocalArgs {
|
||||
/// Creates a new [LocalArgs] container.
|
||||
pub fn new(
|
||||
ref_shape: Slice<usize>,
|
||||
ref_strides: Slice<usize>,
|
||||
#[comptime] ref_line_size: LineSize,
|
||||
) -> LocalArgs {
|
||||
LocalArgs {
|
||||
l_f64: Registry::<usize, Line<f64>>::new(),
|
||||
l_f32: Registry::<usize, Line<f32>>::new(),
|
||||
l_f16: Registry::<usize, Line<f16>>::new(),
|
||||
l_bf16: Registry::<usize, Line<bf16>>::new(),
|
||||
l_i64: Registry::<usize, Line<i64>>::new(),
|
||||
l_i32: Registry::<usize, Line<i32>>::new(),
|
||||
l_i16: Registry::<usize, Line<i16>>::new(),
|
||||
l_i8: Registry::<usize, Line<i8>>::new(),
|
||||
l_u64: Registry::<usize, Line<u64>>::new(),
|
||||
l_u32: Registry::<usize, Line<u32>>::new(),
|
||||
l_u16: Registry::<usize, Line<u16>>::new(),
|
||||
l_u8: Registry::<usize, Line<u8>>::new(),
|
||||
l_bool: Registry::<usize, Line<bool>>::new(),
|
||||
ref_shape,
|
||||
ref_strides,
|
||||
ref_line_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(CubeType, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
/// Unary [element wise operation](ElemwiseOp) arguments.
|
||||
pub struct UnaryFuseArgs {
|
||||
pub input: FuseArg,
|
||||
pub out: FuseArg,
|
||||
}
|
||||
|
||||
#[derive(CubeType, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
/// Binary [element wise operation](ElemwiseOp) arguments.
|
||||
pub struct BinaryFuseArgs {
|
||||
pub lhs: FuseArg,
|
||||
pub rhs: FuseArg,
|
||||
pub out: FuseArg,
|
||||
}
|
||||
|
||||
#[derive(
|
||||
CubeType, Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize,
|
||||
)]
|
||||
/// Precisions supported by [element wise operations](ElemwiseOp).
|
||||
///
|
||||
/// This is a custom type instead of [ElemType] so it can implement [CubeType]
|
||||
/// and restricts the supported types for fusion.
|
||||
pub enum FuseType {
|
||||
F64,
|
||||
F32,
|
||||
Flex32,
|
||||
F16,
|
||||
BF16,
|
||||
I64,
|
||||
I32,
|
||||
I16,
|
||||
I8,
|
||||
U64,
|
||||
U32,
|
||||
U16,
|
||||
U8,
|
||||
Bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
/// Configuration that encapsulates all comptime information necessary for element wise fusion.
|
||||
pub struct FuseBlockConfig {
|
||||
pub rank: usize,
|
||||
pub ref_layout: RefLayout,
|
||||
pub ops: Vec<FuseOp>,
|
||||
pub width: LineSize,
|
||||
}
|
||||
|
||||
impl FuseBlockConfig {
|
||||
pub fn multi_block_variables(&self, registers: &mut Vec<(MultiBlockPos, StorageType)>) {
|
||||
for op in self.ops.iter() {
|
||||
op.multi_block_variables(registers);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FuseArg {
|
||||
pub fn multi_block_variable(&self, registers: &mut Vec<(MultiBlockPos, StorageType)>) {
|
||||
match self {
|
||||
FuseArg::MultiBlockGlobal(arg, fuse_type)
|
||||
// TODO: we need to init the multi-block local, but at some point we could avoid
|
||||
// that for performance (easier for the underlying compiler).
|
||||
| FuseArg::MultiBlockLocal(arg, fuse_type) => {
|
||||
registers.push((arg.clone(), fuse_type.into_type()))
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
impl FuseOp {
|
||||
pub fn multi_block_variables(&self, registers: &mut Vec<(MultiBlockPos, StorageType)>) {
|
||||
match self {
|
||||
FuseOp::Add(binary_fuse_args)
|
||||
| FuseOp::Sub(binary_fuse_args)
|
||||
| FuseOp::Mul(binary_fuse_args)
|
||||
| FuseOp::Div(binary_fuse_args)
|
||||
| FuseOp::Powf(binary_fuse_args)
|
||||
| FuseOp::Equal(binary_fuse_args)
|
||||
| FuseOp::Lower(binary_fuse_args)
|
||||
| FuseOp::Greater(binary_fuse_args)
|
||||
| FuseOp::LowerEqual(binary_fuse_args)
|
||||
| FuseOp::Rem(binary_fuse_args)
|
||||
| FuseOp::GreaterEqual(binary_fuse_args) => {
|
||||
binary_fuse_args.lhs.multi_block_variable(registers);
|
||||
binary_fuse_args.rhs.multi_block_variable(registers);
|
||||
binary_fuse_args.out.multi_block_variable(registers);
|
||||
}
|
||||
FuseOp::Abs(unary_fuse_args)
|
||||
| FuseOp::Exp(unary_fuse_args)
|
||||
| FuseOp::Log(unary_fuse_args)
|
||||
| FuseOp::Log1p(unary_fuse_args)
|
||||
| FuseOp::Cos(unary_fuse_args)
|
||||
| FuseOp::Sin(unary_fuse_args)
|
||||
| FuseOp::Tanh(unary_fuse_args)
|
||||
| FuseOp::Erf(unary_fuse_args)
|
||||
| FuseOp::Sqrt(unary_fuse_args)
|
||||
| FuseOp::Recip(unary_fuse_args)
|
||||
| FuseOp::Assign(unary_fuse_args) => {
|
||||
unary_fuse_args.input.multi_block_variable(registers);
|
||||
unary_fuse_args.out.multi_block_variable(registers);
|
||||
}
|
||||
FuseOp::Clamp {
|
||||
input,
|
||||
min,
|
||||
max,
|
||||
out,
|
||||
} => {
|
||||
input.multi_block_variable(registers);
|
||||
min.multi_block_variable(registers);
|
||||
max.multi_block_variable(registers);
|
||||
out.multi_block_variable(registers);
|
||||
}
|
||||
FuseOp::ConditionalAssign {
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
} => {
|
||||
cond.multi_block_variable(registers);
|
||||
lhs.multi_block_variable(registers);
|
||||
rhs.multi_block_variable(registers);
|
||||
out.multi_block_variable(registers);
|
||||
}
|
||||
FuseOp::Gather {
|
||||
input,
|
||||
indices,
|
||||
output,
|
||||
dim: _,
|
||||
} => {
|
||||
input.multi_block_variable(registers);
|
||||
indices.multi_block_variable(registers);
|
||||
output.multi_block_variable(registers);
|
||||
}
|
||||
FuseOp::Select {
|
||||
input,
|
||||
indices,
|
||||
output,
|
||||
dim: _,
|
||||
} => {
|
||||
input.multi_block_variable(registers);
|
||||
indices.multi_block_variable(registers);
|
||||
output.multi_block_variable(registers);
|
||||
}
|
||||
FuseOp::Dequantize {
|
||||
values,
|
||||
params,
|
||||
output,
|
||||
scheme: _,
|
||||
} => {
|
||||
values.multi_block_variable(registers);
|
||||
params.multi_block_variable(registers);
|
||||
output.multi_block_variable(registers);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Initializes block variables, both globals and locals.
|
||||
pub fn multi_block_variables_init(
|
||||
#[comptime] block: &FuseBlockConfig,
|
||||
variables: &mut MultiBlockVariables,
|
||||
) {
|
||||
let output = comptime! {
|
||||
let mut output = Vec::<(MultiBlockPos, StorageType)>::new();
|
||||
block.multi_block_variables(&mut output);
|
||||
output
|
||||
};
|
||||
|
||||
#[unroll]
|
||||
for i in 0..comptime!(output.len()) {
|
||||
let (key, dtype) = comptime!(output.get(i).unwrap().clone());
|
||||
set_polyfill::<NumericExpand<DYN_ELEM_ID>>(dtype);
|
||||
variables.init(key, block.width);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
/// A reference layout determines how a fuse execution will access elements in tensors.
|
||||
///
|
||||
/// It can either follow the same layout as a concrete tensor, or follow a virtual layout.
|
||||
pub enum RefLayout {
|
||||
Concrete(FuseArg),
|
||||
Virtual(VirtualLayout),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
/// A virtual layout is always contiguous and retrieves its shape from either a reshaped tensor or a
|
||||
/// tensor with swap dimensions.
|
||||
pub enum VirtualLayout {
|
||||
/// Virtual tensor with the provided shape id and contiguous strides.
|
||||
Reshaped {
|
||||
reshape_pos: usize,
|
||||
line_size: LineSize,
|
||||
},
|
||||
/// Virtual tensor with the same shape as the given input, but with swap dims and contiguous
|
||||
/// strides.
|
||||
SwapDims(FuseArg, (usize, usize)),
|
||||
/// Virtual tensor with the same shape as the given input, but with contiguous strides.
|
||||
Shape(FuseArg, usize),
|
||||
/// We don't have access to global metadata, they are passed as runtime values.
|
||||
Runtime { pos: usize },
|
||||
}
|
||||
|
||||
impl FuseArg {
|
||||
/// Adds layout information.
|
||||
///
|
||||
/// It's going to impact how the input or output is read and written to.
|
||||
pub fn add_layout_info(&mut self, layout: LayoutInfo) {
|
||||
match self {
|
||||
FuseArg::Input(_, _, old) => {
|
||||
*old = layout;
|
||||
}
|
||||
FuseArg::Output(_, _, old) => {
|
||||
*old = layout;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RegistryQuery<Self> for FuseArg {}
|
||||
|
||||
impl From<ElemType> for FuseType {
|
||||
fn from(value: ElemType) -> Self {
|
||||
match value {
|
||||
ElemType::Float(kind) => match kind {
|
||||
FloatKind::F16 => Self::F16,
|
||||
FloatKind::BF16 => Self::BF16,
|
||||
FloatKind::F32 => Self::F32,
|
||||
FloatKind::Flex32 => Self::Flex32,
|
||||
_ => panic!("Unsupported precision for fusion: {value}"),
|
||||
},
|
||||
ElemType::Int(kind) => match kind {
|
||||
IntKind::I64 => Self::I64,
|
||||
IntKind::I32 => Self::I32,
|
||||
IntKind::I16 => Self::I16,
|
||||
IntKind::I8 => Self::I8,
|
||||
},
|
||||
ElemType::UInt(kind) => match kind {
|
||||
UIntKind::U64 => Self::U64,
|
||||
UIntKind::U32 => Self::U32,
|
||||
UIntKind::U16 => Self::U16,
|
||||
UIntKind::U8 => Self::U8,
|
||||
},
|
||||
ElemType::Bool => Self::Bool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StorageType> for FuseType {
|
||||
fn from(value: StorageType) -> Self {
|
||||
value.elem_type().into()
|
||||
}
|
||||
}
|
||||
|
||||
impl FuseType {
|
||||
/// Converts the [fused element type](FuseType) into the [cubecl element type](ElemType).
|
||||
pub fn into_elem(self) -> ElemType {
|
||||
match self {
|
||||
FuseType::F32 => ElemType::Float(FloatKind::F32),
|
||||
FuseType::Flex32 => ElemType::Float(FloatKind::Flex32),
|
||||
FuseType::F16 => ElemType::Float(FloatKind::F16),
|
||||
FuseType::BF16 => ElemType::Float(FloatKind::BF16),
|
||||
FuseType::I64 => ElemType::Int(IntKind::I64),
|
||||
FuseType::I32 => ElemType::Int(IntKind::I32),
|
||||
FuseType::I16 => ElemType::Int(IntKind::I16),
|
||||
FuseType::I8 => ElemType::Int(IntKind::I8),
|
||||
FuseType::U64 => ElemType::UInt(UIntKind::U64),
|
||||
FuseType::U32 => ElemType::UInt(UIntKind::U32),
|
||||
FuseType::U16 => ElemType::UInt(UIntKind::U16),
|
||||
FuseType::U8 => ElemType::UInt(UIntKind::U8),
|
||||
FuseType::Bool => ElemType::Bool,
|
||||
FuseType::F64 => ElemType::Float(FloatKind::F64),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert the [fused element type](FuseType) into the [cubecl storage type](StorageType).
|
||||
pub fn into_type(self) -> StorageType {
|
||||
self.into_elem().into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DType> for FuseType {
|
||||
fn from(value: DType) -> Self {
|
||||
match value {
|
||||
DType::F32 => Self::F32,
|
||||
DType::Flex32 => Self::Flex32,
|
||||
DType::F16 => Self::F16,
|
||||
DType::BF16 => Self::BF16,
|
||||
DType::I64 => Self::I64,
|
||||
DType::I32 => Self::I32,
|
||||
DType::I16 => Self::I16,
|
||||
DType::I8 => Self::I8,
|
||||
DType::U64 => Self::U64,
|
||||
DType::U32 => Self::U32,
|
||||
DType::U16 => Self::U16,
|
||||
DType::U8 => Self::U8,
|
||||
DType::Bool => Self::Bool,
|
||||
DType::F64 => Self::F64,
|
||||
DType::QFloat(scheme) => match scheme.store {
|
||||
QuantStore::Native => match scheme.value {
|
||||
QuantValue::Q8F | QuantValue::Q8S => Self::I8,
|
||||
QuantValue::E4M3 | QuantValue::E5M2 => {
|
||||
unimplemented!("Unsupported precision for fusion")
|
||||
}
|
||||
QuantValue::Q4F
|
||||
| QuantValue::Q4S
|
||||
| QuantValue::Q2F
|
||||
| QuantValue::Q2S
|
||||
| QuantValue::E2M1 => {
|
||||
panic!("Can't store native sub-byte values")
|
||||
}
|
||||
},
|
||||
QuantStore::PackedU32(_) => Self::U32,
|
||||
QuantStore::PackedNative(_) => match scheme.value {
|
||||
QuantValue::E2M1 => unimplemented!("Unsupported precision for fusion"),
|
||||
other => panic!("{other:?} doesn't support native packing"),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for FuseArg {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
FuseArg::Input(pos, ..) => write!(f, "input({pos})"),
|
||||
FuseArg::Output(pos, ..) => write!(f, "output({pos})"),
|
||||
FuseArg::BlockLocal { pos, ty } => write!(f, "local({pos}, {ty:?})"),
|
||||
FuseArg::MultiBlockLocal(mbp, ..) => write!(f, "{mbp}"),
|
||||
FuseArg::MultiBlockGlobal(mbp, ..) => write!(f, "global_{mbp}"),
|
||||
FuseArg::Scalar(pos, ..) => write!(f, "scalar({pos})"),
|
||||
FuseArg::ScalarShape(pos) => write!(f, "scalar_shape({pos})"),
|
||||
FuseArg::Literal(val, ..) => write!(f, "literal_{val}"),
|
||||
FuseArg::InputReshaped { original, .. } => write!(f, "input_reshaped_{original}"),
|
||||
FuseArg::InputSwapDims { original, .. } => write!(f, "input_swap_dims_{original}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for MultiBlockPos {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"block_local({}-{})",
|
||||
self.block_pos, self.block_local_pos
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,927 @@
|
||||
use super::{DYN_ELEM_ID, Q_PARAM_DYN_ELEM_ID, Q_STORE_DYN_ELEM_ID, io::*, ir::*};
|
||||
use burn_std::quantization::{QuantScheme, QuantStore, QuantValue};
|
||||
use cubecl::{
|
||||
ir::{ElemType, FloatKind, StorageType, UIntKind},
|
||||
prelude::*,
|
||||
};
|
||||
use cubek::quantization::{dequantize::dequantize_symmetric_packed_value_at, scheme::QuantMode};
|
||||
|
||||
#[cube]
|
||||
/// Fuse element-wise operations at the given write position.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `inputs`: Contains all readonly global kernel arguments.
|
||||
/// - `outputs`: Contains all readwrite global kernel arguments.
|
||||
/// - `locals`: Contains all local variables defined during kernel expansion.
|
||||
/// - `write_pos`: The logical position the values are written to.
|
||||
/// - `write_values`: The explicit values to write at the given position.
|
||||
/// - `write_args`: The arguments associated to the `writes_values`.
|
||||
/// - `config`: The current [fuse block configuration](FuseBlockConfig).
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The function will start by writing `write_values`.
|
||||
pub fn fuse_on_write<E: CubePrimitive>(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
locals: &mut LocalArgs,
|
||||
write_pos: usize,
|
||||
write_values: Registry<FuseArg, Line<E>>,
|
||||
#[comptime] write_args: Vec<FuseArg>,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) {
|
||||
comment!("Fuse on write begin");
|
||||
// Write the values given as arguments.
|
||||
#[unroll]
|
||||
for i in 0..write_args.len() {
|
||||
let arg = comptime![write_args.get(i).unwrap().clone()];
|
||||
let val = write_values.find(arg.clone());
|
||||
|
||||
write::<E>(inputs, outputs, locals, write_pos, val, arg, config);
|
||||
}
|
||||
|
||||
fuse(inputs, outputs, locals, write_pos, config);
|
||||
comment!("Fuse on write end");
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Fuse element-wise operations at the given read position.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `inputs`: Contains all readonly global kernel arguments.
|
||||
/// - `outputs`: Contains all readwrite global kernel arguments.
|
||||
/// - `locals`: Contains all local variables defined during kernel expansion.
|
||||
/// - `read_pos`: The logical position the values are read from.
|
||||
/// - `read_args`: The arguments associated to the `read_pos`.
|
||||
/// - `config`: The current [fuse block configuration](FuseBlockConfig).
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// - A sequence of values associated to the given `read_args`.
|
||||
pub fn fuse_on_read<E: CubePrimitive>(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
locals: &mut LocalArgs,
|
||||
read_pos: usize,
|
||||
#[comptime] read_args: Sequence<FuseArg>,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) -> Sequence<Line<E>> {
|
||||
comment!("Fuse on read begin");
|
||||
fuse(inputs, outputs, locals, read_pos, config);
|
||||
|
||||
let mut output = Sequence::new();
|
||||
|
||||
#[unroll]
|
||||
for i in 0..read_args.len() {
|
||||
let arg = comptime![read_args.index(i).clone()];
|
||||
let value = read::<E>(inputs, outputs, locals, read_pos, arg, config);
|
||||
|
||||
let value_line_size = value.line_size();
|
||||
let output_line_size = config.width;
|
||||
|
||||
// We currently don't support broadcasting __across__ blocks.
|
||||
if comptime!(value_line_size != output_line_size) {
|
||||
let mut tmp = Line::<E>::empty(config.width);
|
||||
comptime!(
|
||||
assert_eq!(value_line_size, 1, "The input line_size must be 1 or the same as the config width.");
|
||||
);
|
||||
|
||||
let val = value[0];
|
||||
|
||||
#[unroll]
|
||||
for i in 0..config.width {
|
||||
tmp[i] = val;
|
||||
}
|
||||
|
||||
output.push(tmp);
|
||||
} else {
|
||||
output.push(value);
|
||||
}
|
||||
}
|
||||
|
||||
comment!("Fuse on read end");
|
||||
output
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Initializes [LocalArgs] given the input and output [arguments](GlobalArgs) with the [FuseBlockConfig].
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The goal is to resolve and cache the reference shape and strides, as it is used in many
|
||||
/// different function during kernel expansion.
|
||||
pub fn init_locals(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) -> LocalArgs {
|
||||
comment!("Init locals begin");
|
||||
let mut ref_shape = Array::new(config.rank);
|
||||
let mut ref_strides = Array::new(config.rank);
|
||||
|
||||
let locals = match config.ref_layout.clone() {
|
||||
RefLayout::Concrete(arg) => match comptime![arg] {
|
||||
FuseArg::Input(index, ..) => {
|
||||
let layout = inputs.tensors.index(index);
|
||||
|
||||
#[unroll]
|
||||
for i in 0..config.rank {
|
||||
ref_shape[i] = layout.tensor.shape(i);
|
||||
ref_strides[i] = layout.tensor.stride(i);
|
||||
}
|
||||
|
||||
LocalArgs::new(
|
||||
ref_shape.to_slice(),
|
||||
ref_strides.to_slice(),
|
||||
layout.tensor.line_size(),
|
||||
)
|
||||
}
|
||||
FuseArg::Output(index, ..) => {
|
||||
let layout = outputs.tensors.index(index);
|
||||
|
||||
#[unroll]
|
||||
for i in 0..config.rank {
|
||||
ref_shape[i] = layout.tensor.shape(i);
|
||||
ref_strides[i] = layout.tensor.stride(i);
|
||||
}
|
||||
|
||||
LocalArgs::new(
|
||||
ref_shape.to_slice(),
|
||||
ref_strides.to_slice(),
|
||||
layout.tensor.line_size(),
|
||||
)
|
||||
}
|
||||
_ => comptime![panic!("Invalid concrete ref layout.")],
|
||||
},
|
||||
RefLayout::Virtual(layout) => match layout {
|
||||
VirtualLayout::SwapDims(original, dims) => {
|
||||
let layout = match original.clone() {
|
||||
FuseArg::Input(pos, ..) => inputs.tensors.index(pos),
|
||||
FuseArg::Output(pos, ..) => outputs.tensors.index(pos),
|
||||
_ => comptime![panic!("Unsupported")],
|
||||
};
|
||||
|
||||
let mut stride_curr = 1;
|
||||
|
||||
#[unroll]
|
||||
#[allow(clippy::clone_on_copy)]
|
||||
for i in 0..config.rank {
|
||||
let reverse = reverse_index(config.rank, i);
|
||||
let swap = comptime![swap_dims_transform(reverse, dims)];
|
||||
let shape = layout.tensor.shape(swap.clone());
|
||||
|
||||
ref_shape[reverse] = shape;
|
||||
ref_strides[reverse] = stride_curr;
|
||||
|
||||
stride_curr *= ref_shape[comptime![reverse]];
|
||||
}
|
||||
|
||||
LocalArgs::new(
|
||||
ref_shape.to_slice(),
|
||||
ref_strides.to_slice(),
|
||||
layout.tensor.line_size(),
|
||||
)
|
||||
}
|
||||
VirtualLayout::Reshaped {
|
||||
reshape_pos,
|
||||
line_size,
|
||||
} => {
|
||||
let mut stride_curr = 1;
|
||||
let start = reshape_pos * config.rank;
|
||||
|
||||
#[unroll]
|
||||
#[allow(clippy::clone_on_copy)]
|
||||
for i in 0..config.rank {
|
||||
let reverse = reverse_index(config.rank, i);
|
||||
let arg = comptime![FuseArg::ScalarShape(start + reverse)];
|
||||
let shape = read_scalar_shape(inputs, arg.clone());
|
||||
|
||||
ref_shape[comptime![reverse]] = shape;
|
||||
ref_strides[comptime![reverse]] = stride_curr;
|
||||
|
||||
stride_curr *= ref_shape[comptime![reverse]];
|
||||
}
|
||||
|
||||
LocalArgs::new(ref_shape.to_slice(), ref_strides.to_slice(), line_size)
|
||||
}
|
||||
VirtualLayout::Runtime { pos } => {
|
||||
let start_shape = (pos * 2) * config.rank;
|
||||
let start_strides = start_shape + config.rank;
|
||||
|
||||
#[unroll]
|
||||
for i in 0..config.rank {
|
||||
let shape_index = start_shape + i;
|
||||
let strides_index = start_strides + i;
|
||||
|
||||
ref_shape[i] = *inputs.runtime_layouts.index(shape_index);
|
||||
ref_strides[i] = *inputs.runtime_layouts.index(strides_index);
|
||||
}
|
||||
|
||||
LocalArgs::new(ref_shape.to_slice(), ref_strides.to_slice(), config.width)
|
||||
}
|
||||
VirtualLayout::Shape(original, line_size) => {
|
||||
let layout = match original.clone() {
|
||||
FuseArg::Input(pos, ..) => inputs.tensors.index(pos),
|
||||
FuseArg::Output(pos, ..) => outputs.tensors.index(pos),
|
||||
_ => comptime![panic!("Unsupported")],
|
||||
};
|
||||
let mut stride_curr = 1;
|
||||
|
||||
#[unroll]
|
||||
#[allow(clippy::clone_on_copy)]
|
||||
for i in 0..config.rank {
|
||||
let reverse = reverse_index(config.rank, i);
|
||||
let shape = layout.tensor.shape(reverse);
|
||||
|
||||
ref_shape[comptime![reverse]] = shape;
|
||||
ref_strides[comptime![reverse]] = stride_curr;
|
||||
|
||||
stride_curr *= ref_shape[comptime![reverse]];
|
||||
}
|
||||
|
||||
LocalArgs::new(ref_shape.to_slice(), ref_strides.to_slice(), line_size)
|
||||
}
|
||||
},
|
||||
};
|
||||
comment!("Init locals end");
|
||||
locals
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Expands all [operations](FuseOp) registered in the [block config](FuseBlockConfig].
|
||||
fn fuse(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
locals: &mut LocalArgs,
|
||||
pos: usize,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) {
|
||||
#[unroll]
|
||||
for index in 0..config.ops.len() {
|
||||
let op = config.ops[index].clone();
|
||||
set_polyfill::<NumericExpand<DYN_ELEM_ID>>(op.cmp_type());
|
||||
|
||||
match op {
|
||||
FuseOp::Add(op) => {
|
||||
add::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::Div(op) => {
|
||||
div::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::Sub(op) => {
|
||||
sub::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::Mul(op) => {
|
||||
mul::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::Powf(op) => {
|
||||
powf::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::Erf(op) => {
|
||||
erf::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::Sqrt(op) => {
|
||||
sqrt::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::Abs(op) => {
|
||||
abs::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::Log(op) => {
|
||||
log::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::Log1p(op) => {
|
||||
log1p::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::Recip(op) => {
|
||||
recip::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::Assign(op) => {
|
||||
assign::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::Exp(op) => {
|
||||
exp::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::Cos(op) => {
|
||||
cos::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::Sin(op) => {
|
||||
sin::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::Tanh(op) => {
|
||||
tanh::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::Equal(op) => {
|
||||
equal::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::Greater(op) => {
|
||||
greater::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::GreaterEqual(op) => greater_equal::<NumericExpand<DYN_ELEM_ID>>(
|
||||
inputs, outputs, locals, pos, op, config,
|
||||
),
|
||||
FuseOp::Lower(op) => {
|
||||
lower::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::LowerEqual(op) => {
|
||||
lower_equal::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::ConditionalAssign {
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
} => conditional_assign::<NumericExpand<DYN_ELEM_ID>>(
|
||||
inputs, outputs, locals, pos, cond, lhs, rhs, out, config,
|
||||
),
|
||||
FuseOp::Gather {
|
||||
input,
|
||||
indices,
|
||||
output,
|
||||
dim,
|
||||
} => gather::<NumericExpand<DYN_ELEM_ID>>(
|
||||
inputs, outputs, locals, pos, dim, input, indices, output, config,
|
||||
),
|
||||
FuseOp::Select {
|
||||
input,
|
||||
indices,
|
||||
output,
|
||||
dim,
|
||||
} => select_indices::<NumericExpand<DYN_ELEM_ID>>(
|
||||
inputs, outputs, locals, pos, dim, input, indices, output, config,
|
||||
),
|
||||
FuseOp::Dequantize {
|
||||
values,
|
||||
params,
|
||||
output,
|
||||
scheme,
|
||||
} => dequantize::<NumericExpand<DYN_ELEM_ID>>(
|
||||
inputs,
|
||||
outputs,
|
||||
locals,
|
||||
pos,
|
||||
values,
|
||||
params,
|
||||
output,
|
||||
scheme.scheme,
|
||||
config,
|
||||
),
|
||||
FuseOp::Rem(op) => {
|
||||
rem::<NumericExpand<DYN_ELEM_ID>>(inputs, outputs, locals, pos, op, config)
|
||||
}
|
||||
FuseOp::Clamp {
|
||||
input,
|
||||
min,
|
||||
max,
|
||||
out,
|
||||
} => clamp::<NumericExpand<DYN_ELEM_ID>>(
|
||||
inputs, outputs, locals, pos, input, min, max, out, config,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($ident:ident, $op:tt) => {
|
||||
#[cube]
|
||||
fn $ident<C: Numeric>(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
locals: &mut LocalArgs,
|
||||
write_pos: usize,
|
||||
#[comptime] op: BinaryFuseArgs,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) {
|
||||
let lhs = read::<C>(inputs, outputs, &locals, write_pos, op.lhs, config);
|
||||
let rhs = read::<C>(inputs, outputs, &locals, write_pos, op.rhs, config);
|
||||
let result = lhs $op rhs;
|
||||
|
||||
write::<C>(inputs, outputs, locals, write_pos, result, op.out, config);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! binary_func {
|
||||
($ident:ident, $func:expr, $c:tt) => {
|
||||
#[cube]
|
||||
fn $ident<C: $c>(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
locals: &mut LocalArgs,
|
||||
write_pos: usize,
|
||||
#[comptime] op: BinaryFuseArgs,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) {
|
||||
let lhs = read::<C>(inputs, outputs, &locals, write_pos, op.lhs, config);
|
||||
let rhs = read::<C>(inputs, outputs, &locals, write_pos, op.rhs, config);
|
||||
let result = $func(lhs, rhs);
|
||||
|
||||
write::<C>(inputs, outputs, locals, write_pos, result, op.out, config);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! comparison_op {
|
||||
($ident:ident, $op:tt) => {
|
||||
#[cube]
|
||||
fn $ident<C: CubePrimitive + core::cmp::PartialOrd>(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
locals: &mut LocalArgs,
|
||||
write_pos: usize,
|
||||
#[comptime] op: BinaryFuseArgs,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) {
|
||||
let lhs = read::<C>(inputs, outputs, &locals, write_pos, op.lhs, config);
|
||||
let rhs = read::<C>(inputs, outputs, &locals, write_pos, op.rhs, config);
|
||||
let result = Line::new(lhs $op rhs);
|
||||
|
||||
write::<bool>(inputs, outputs, locals, write_pos, result, op.out, config);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! unary_func {
|
||||
($ident:ident, $func:expr, $c:tt) => {
|
||||
#[cube]
|
||||
fn $ident<C: $c>(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
locals: &mut LocalArgs,
|
||||
write_pos: usize,
|
||||
#[comptime] op: UnaryFuseArgs,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) {
|
||||
let input = read::<C>(inputs, outputs, &locals, write_pos, op.input, config);
|
||||
let result = $func(input);
|
||||
|
||||
write::<C>(inputs, outputs, locals, write_pos, result, op.out, config);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn assign<C: CubePrimitive>(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
locals: &mut LocalArgs,
|
||||
write_pos: usize,
|
||||
#[comptime] op: UnaryFuseArgs,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) {
|
||||
let input = read::<C>(inputs, outputs, locals, write_pos, op.input, config);
|
||||
|
||||
write::<C>(inputs, outputs, locals, write_pos, input, op.out, config);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn gather<C: Numeric>(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
locals: &mut LocalArgs,
|
||||
write_pos: usize,
|
||||
#[comptime] dim: usize,
|
||||
#[comptime] input: FuseArg,
|
||||
#[comptime] indices: FuseArg,
|
||||
#[comptime] output: FuseArg,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) {
|
||||
let line_size = locals.ref_line_size;
|
||||
|
||||
let pos_input = comptime! {
|
||||
match input {
|
||||
FuseArg::Input(pos, ..) => pos,
|
||||
_ => panic!("Input tensor isn't an input"),
|
||||
}
|
||||
};
|
||||
let pos_indices = comptime! {
|
||||
match indices {
|
||||
FuseArg::Input(pos, ..) => pos,
|
||||
_ => panic!("Indices tensor isn't an input"),
|
||||
}
|
||||
};
|
||||
|
||||
let stride_input_dim = global_stride(inputs, dim, pos_input);
|
||||
|
||||
let mut index = 0;
|
||||
let mut result = Line::empty(line_size);
|
||||
|
||||
if comptime![dim > 0] {
|
||||
let index_before = global_offset(
|
||||
inputs,
|
||||
outputs,
|
||||
locals,
|
||||
write_pos,
|
||||
input.clone(),
|
||||
comptime![Some((0, dim))],
|
||||
config,
|
||||
);
|
||||
index += index_before;
|
||||
}
|
||||
|
||||
if comptime![dim + 1 < config.rank] {
|
||||
let index_after = global_offset(
|
||||
inputs,
|
||||
outputs,
|
||||
locals,
|
||||
write_pos,
|
||||
input,
|
||||
comptime![Some((dim + 1, config.rank))],
|
||||
config,
|
||||
);
|
||||
index += index_after;
|
||||
}
|
||||
|
||||
let index_offset = global_offset(
|
||||
inputs,
|
||||
outputs,
|
||||
locals,
|
||||
write_pos,
|
||||
indices,
|
||||
comptime![Some((0, config.rank))],
|
||||
config,
|
||||
);
|
||||
|
||||
if comptime![dim == config.rank - 1] {
|
||||
// Per-element indexing (along the dimension)
|
||||
#[unroll]
|
||||
for i in 0..line_size {
|
||||
let offset = read_input::<u32>(
|
||||
inputs,
|
||||
locals,
|
||||
pos_indices,
|
||||
index_offset + i,
|
||||
LayoutInfo::IsRef,
|
||||
config,
|
||||
None,
|
||||
);
|
||||
|
||||
let input = read_input::<C>(
|
||||
inputs,
|
||||
locals,
|
||||
pos_input,
|
||||
index + (offset[0] as usize * stride_input_dim),
|
||||
LayoutInfo::IsRef,
|
||||
config,
|
||||
None,
|
||||
);
|
||||
|
||||
result[i] = input[0];
|
||||
}
|
||||
} else {
|
||||
// Shared index for whole line
|
||||
let stride_input_line = global_stride(inputs, config.rank - 1, pos_input);
|
||||
|
||||
let offset = read_input::<u32>(
|
||||
inputs,
|
||||
locals,
|
||||
pos_indices,
|
||||
index_offset,
|
||||
LayoutInfo::IsRef,
|
||||
config,
|
||||
None,
|
||||
);
|
||||
|
||||
index += offset[0] as usize * stride_input_dim;
|
||||
|
||||
#[unroll]
|
||||
for i in 0..line_size {
|
||||
let input = read_input::<C>(
|
||||
inputs,
|
||||
locals,
|
||||
pos_input,
|
||||
index + i * stride_input_line,
|
||||
LayoutInfo::IsRef,
|
||||
config,
|
||||
None,
|
||||
);
|
||||
|
||||
result[i] = input[0];
|
||||
}
|
||||
}
|
||||
|
||||
write::<C>(inputs, outputs, locals, write_pos, result, output, config);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn select_indices<C: Numeric>(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
locals: &mut LocalArgs,
|
||||
write_pos: usize,
|
||||
#[comptime] dim: usize,
|
||||
#[comptime] input: FuseArg,
|
||||
#[comptime] indices: FuseArg,
|
||||
#[comptime] output: FuseArg,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) {
|
||||
let (line_size_ref, stride_dim_ref, shape_dim_ref) = (
|
||||
locals.ref_line_size,
|
||||
locals.ref_strides[dim],
|
||||
locals.ref_shape[dim],
|
||||
);
|
||||
|
||||
let pos_input = comptime! {
|
||||
match input {
|
||||
FuseArg::Input(pos, ..) => pos,
|
||||
_ => panic!("Input tensor isn't an input"),
|
||||
}
|
||||
};
|
||||
let pos_indices = match indices {
|
||||
FuseArg::Input(pos, ..) => pos,
|
||||
_ => panic!("Indices tensor isn't an input"),
|
||||
};
|
||||
|
||||
let stride_input_dim = global_stride(inputs, dim, pos_input);
|
||||
|
||||
let mut index = 0;
|
||||
let mut result = Line::empty(line_size_ref);
|
||||
|
||||
if comptime![dim != config.rank - 1] {
|
||||
// In this scenario the select is actually broadcasted along the axis we're working on.
|
||||
//
|
||||
// Therefore the same indices are used to fetch multiple entries in the input tensor.
|
||||
|
||||
if comptime![dim > 0] {
|
||||
let index_before = global_offset(
|
||||
inputs,
|
||||
outputs,
|
||||
locals,
|
||||
write_pos,
|
||||
input.clone(),
|
||||
comptime![Some((0, dim))],
|
||||
config,
|
||||
);
|
||||
index += index_before;
|
||||
}
|
||||
|
||||
if comptime![dim + 1 < config.rank] {
|
||||
let index_after = global_offset(
|
||||
inputs,
|
||||
outputs,
|
||||
locals,
|
||||
write_pos,
|
||||
input.clone(),
|
||||
comptime![Some((dim + 1, config.rank))],
|
||||
config,
|
||||
);
|
||||
index += index_after;
|
||||
}
|
||||
|
||||
let stride_input_line = global_stride(inputs, comptime![config.rank - 1], pos_input);
|
||||
let write_pos_input = write_pos * line_size_ref;
|
||||
let coordinate_dim = write_pos_input / stride_dim_ref % shape_dim_ref;
|
||||
let offset_dim = read_input::<u32>(
|
||||
inputs,
|
||||
locals,
|
||||
pos_indices,
|
||||
coordinate_dim,
|
||||
LayoutInfo::IsRef,
|
||||
config,
|
||||
None,
|
||||
);
|
||||
|
||||
index += offset_dim[0] as usize * stride_input_dim;
|
||||
|
||||
#[unroll]
|
||||
for i in 0..line_size_ref {
|
||||
let input = read_input::<C>(
|
||||
inputs,
|
||||
locals,
|
||||
pos_input,
|
||||
index + i * stride_input_line,
|
||||
LayoutInfo::IsRef,
|
||||
config,
|
||||
None,
|
||||
);
|
||||
result[i] = input[0];
|
||||
}
|
||||
} else {
|
||||
// In this scenario the select is actually performed on the last dimension we're working on.
|
||||
//
|
||||
// Therefore we need to fetch multiple indices that correspond to different entries in the
|
||||
// input tensor.
|
||||
|
||||
if comptime![dim > 0] {
|
||||
let index_before = global_offset(
|
||||
inputs,
|
||||
outputs,
|
||||
locals,
|
||||
write_pos,
|
||||
input.clone(),
|
||||
comptime![Some((0, dim))],
|
||||
config,
|
||||
);
|
||||
index += index_before;
|
||||
}
|
||||
|
||||
if comptime![dim + 1 < config.rank] {
|
||||
let index_after = global_offset(
|
||||
inputs,
|
||||
outputs,
|
||||
locals,
|
||||
write_pos,
|
||||
input,
|
||||
comptime![Some((dim + 1, config.rank))],
|
||||
config,
|
||||
);
|
||||
index += index_after;
|
||||
}
|
||||
|
||||
let write_pos_indices = write_pos * line_size_ref;
|
||||
|
||||
#[unroll]
|
||||
for i in 0..line_size_ref {
|
||||
let coordinate_dim = (write_pos_indices + i) / stride_dim_ref % shape_dim_ref;
|
||||
let offset_dim = read_input::<u32>(
|
||||
inputs,
|
||||
locals,
|
||||
pos_indices,
|
||||
coordinate_dim,
|
||||
LayoutInfo::IsRef,
|
||||
config,
|
||||
None,
|
||||
);
|
||||
|
||||
let input = read_input::<C>(
|
||||
inputs,
|
||||
locals,
|
||||
pos_input,
|
||||
index + (offset_dim[0] as usize * stride_input_dim),
|
||||
LayoutInfo::IsRef,
|
||||
config,
|
||||
None,
|
||||
);
|
||||
result[i] = input[0];
|
||||
}
|
||||
}
|
||||
|
||||
write::<C>(inputs, outputs, locals, write_pos, result, output, config);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn conditional_assign<C: CubePrimitive>(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
locals: &mut LocalArgs,
|
||||
write_pos: usize,
|
||||
#[comptime] cond: FuseArg,
|
||||
#[comptime] lhs: FuseArg,
|
||||
#[comptime] rhs: FuseArg,
|
||||
#[comptime] out: FuseArg,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) {
|
||||
let cond = read::<bool>(inputs, outputs, locals, write_pos, cond, config);
|
||||
let lhs = read::<C>(inputs, outputs, locals, write_pos, lhs, config);
|
||||
let rhs = read::<C>(inputs, outputs, locals, write_pos, rhs, config);
|
||||
let result = select_many(cond, lhs, rhs);
|
||||
|
||||
write::<C>(inputs, outputs, locals, write_pos, result, out, config);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn clamp<C: Numeric>(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
locals: &mut LocalArgs,
|
||||
write_pos: usize,
|
||||
#[comptime] input: FuseArg,
|
||||
#[comptime] min: FuseArg,
|
||||
#[comptime] max: FuseArg,
|
||||
#[comptime] out: FuseArg,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) {
|
||||
let input = read::<C>(inputs, outputs, locals, write_pos, input, config);
|
||||
let min = read::<C>(inputs, outputs, locals, write_pos, min, config);
|
||||
let max = read::<C>(inputs, outputs, locals, write_pos, max, config);
|
||||
let result = cubecl::prelude::clamp(input, min, max);
|
||||
|
||||
write::<C>(inputs, outputs, locals, write_pos, result, out, config);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
#[allow(clippy::explicit_counter_loop)]
|
||||
fn dequantize<C: Float>(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
locals: &mut LocalArgs,
|
||||
write_pos: usize,
|
||||
#[comptime] input: FuseArg,
|
||||
#[comptime] scales: FuseArg,
|
||||
#[comptime] output: FuseArg,
|
||||
#[comptime] scheme: QuantScheme,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) {
|
||||
comptime!(assert_eq!(
|
||||
scheme.mode,
|
||||
QuantMode::Symmetric,
|
||||
"Only symmetric quantization mode is supported."
|
||||
));
|
||||
|
||||
set_polyfill::<NumericExpand<Q_STORE_DYN_ELEM_ID>>(comptime![match scheme.store {
|
||||
QuantStore::Native => match scheme.value {
|
||||
QuantValue::Q8F | QuantValue::Q8S => StorageType::Scalar(ElemType::UInt(UIntKind::U8)),
|
||||
QuantValue::E4M3 => StorageType::Scalar(ElemType::Float(FloatKind::E4M3)),
|
||||
QuantValue::E5M2 => StorageType::Scalar(ElemType::Float(FloatKind::E5M2)),
|
||||
QuantValue::Q4F
|
||||
| QuantValue::Q4S
|
||||
| QuantValue::Q2F
|
||||
| QuantValue::Q2S
|
||||
| QuantValue::E2M1 => unreachable!("Can't store native sub-byte values"),
|
||||
},
|
||||
QuantStore::PackedU32(_) => ElemType::UInt(UIntKind::U32).into(),
|
||||
QuantStore::PackedNative(_) => match scheme.value {
|
||||
QuantValue::E2M1 => StorageType::Packed(ElemType::Float(FloatKind::E4M3), 2),
|
||||
other => panic!("{other:?} doesn't support native packing"),
|
||||
},
|
||||
}]);
|
||||
set_polyfill::<NumericExpand<Q_PARAM_DYN_ELEM_ID>>(comptime![match scheme.param {
|
||||
cubecl::quant::scheme::QuantParam::F32 =>
|
||||
StorageType::Scalar(ElemType::Float(FloatKind::F32)),
|
||||
cubecl::quant::scheme::QuantParam::F16 =>
|
||||
StorageType::Scalar(ElemType::Float(FloatKind::F16)),
|
||||
cubecl::quant::scheme::QuantParam::BF16 =>
|
||||
StorageType::Scalar(ElemType::Float(FloatKind::BF16)),
|
||||
cubecl::quant::scheme::QuantParam::UE8M0 =>
|
||||
StorageType::Scalar(ElemType::Float(FloatKind::UE8M0)),
|
||||
cubecl::quant::scheme::QuantParam::UE4M3 =>
|
||||
StorageType::Scalar(ElemType::Float(FloatKind::E4M3)),
|
||||
}]);
|
||||
|
||||
let tensor_pos = comptime!(match input {
|
||||
FuseArg::Input(pos, _, _) => pos,
|
||||
_ => panic!("Not supported"),
|
||||
});
|
||||
let pos = comptime!(match scales {
|
||||
FuseArg::Input(pos, ..) => pos,
|
||||
_ => unreachable!(""),
|
||||
});
|
||||
let input = read_quantized::<NumericExpand<Q_STORE_DYN_ELEM_ID>>(
|
||||
inputs, locals, write_pos, input, config, scheme,
|
||||
);
|
||||
|
||||
let line_size = input.line_size();
|
||||
let num_quants = scheme.num_quants();
|
||||
|
||||
let scales = input_as_scales_view::<NumericExpand<Q_PARAM_DYN_ELEM_ID>>(
|
||||
inputs,
|
||||
pos,
|
||||
tensor_pos,
|
||||
scheme.level,
|
||||
config,
|
||||
);
|
||||
let result = dequantize_symmetric_packed_value_at::<
|
||||
C,
|
||||
ElemExpand<Q_PARAM_DYN_ELEM_ID>,
|
||||
ElemExpand<Q_STORE_DYN_ELEM_ID>,
|
||||
>(write_pos * num_quants, input, &scales, scheme);
|
||||
|
||||
let line_size_result = comptime!(num_quants * line_size);
|
||||
|
||||
let line = if comptime!(line_size == 1) {
|
||||
result[0]
|
||||
} else {
|
||||
let mut line = Line::empty(line_size_result);
|
||||
|
||||
#[unroll]
|
||||
for i in 0..line_size {
|
||||
let value = result[i];
|
||||
|
||||
#[unroll]
|
||||
for j in 0..num_quants {
|
||||
let index = i * num_quants + j;
|
||||
line[index] = value[j];
|
||||
}
|
||||
}
|
||||
|
||||
line
|
||||
};
|
||||
|
||||
write::<C>(inputs, outputs, locals, write_pos, line, output, config);
|
||||
}
|
||||
|
||||
binary_op!(add, +);
|
||||
binary_op!(mul, *);
|
||||
binary_op!(div, /);
|
||||
binary_op!(sub, -);
|
||||
|
||||
comparison_op!(equal, ==);
|
||||
comparison_op!(greater, >);
|
||||
comparison_op!(greater_equal, >=);
|
||||
comparison_op!(lower, <);
|
||||
comparison_op!(lower_equal, <=);
|
||||
|
||||
binary_func!(powf, Line::<C>::powf, Float);
|
||||
binary_func!(rem, Line::<C>::rem, Float);
|
||||
|
||||
unary_func!(exp, Line::<C>::exp, Float);
|
||||
unary_func!(log, Line::<C>::ln, Float);
|
||||
unary_func!(log1p, Line::<C>::log1p, Float);
|
||||
unary_func!(sqrt, Line::<C>::sqrt, Float);
|
||||
unary_func!(cos, Line::<C>::cos, Float);
|
||||
unary_func!(sin, Line::<C>::sin, Float);
|
||||
unary_func!(tanh, Line::<C>::tanh, Float);
|
||||
unary_func!(erf, Line::<C>::erf, Float);
|
||||
unary_func!(recip, Line::<C>::recip, Float);
|
||||
unary_func!(abs, Line::<C>::abs, Numeric);
|
||||
@@ -0,0 +1,8 @@
|
||||
pub(crate) mod io;
|
||||
pub(crate) mod ir;
|
||||
pub(crate) mod kernel;
|
||||
pub(crate) mod tensor;
|
||||
pub(crate) mod view;
|
||||
|
||||
mod base;
|
||||
pub(crate) use base::*;
|
||||
@@ -0,0 +1,90 @@
|
||||
use super::DYN_ELEM_ID;
|
||||
use cubecl::{
|
||||
ir::{ElemType, Type},
|
||||
prelude::*,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::hash::Hash;
|
||||
|
||||
/// Represents a global tensor with the given [element type](ElemType).
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// The `tensor` field type [Line<NumericExpand<DYN_ELEM_ID>>] must be set using polyfill before
|
||||
/// use.
|
||||
#[derive(CubeType, Clone)]
|
||||
pub struct GlobalTensor {
|
||||
/// The global tensor type.
|
||||
pub tensor: Tensor<Line<NumericExpand<DYN_ELEM_ID>>>,
|
||||
/// The element type of the tensor.
|
||||
#[cube(comptime)]
|
||||
pub elem: ElemType,
|
||||
/// Whether the current tensor is logically broadcasted.
|
||||
#[cube(comptime)]
|
||||
pub broadcasted: bool,
|
||||
}
|
||||
|
||||
// Everything below is to implement [LaunchArg].
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash, Debug)]
|
||||
pub struct GlobalTensorCompilationArg {
|
||||
tensor: TensorCompilationArg,
|
||||
elem: ElemType,
|
||||
broadcasted: bool,
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
pub struct GlobalTensorArg<'a, R: Runtime> {
|
||||
pub tensor: <Tensor<Line<NumericExpand<DYN_ELEM_ID>>> as LaunchArg>::RuntimeArg<'a, R>,
|
||||
pub elem: ElemType,
|
||||
pub broadcasted: bool,
|
||||
pub address_type: AddressType,
|
||||
}
|
||||
|
||||
impl CompilationArg for GlobalTensorCompilationArg {}
|
||||
|
||||
impl LaunchArg for GlobalTensor {
|
||||
type RuntimeArg<'a, R: Runtime> = GlobalTensorArg<'a, R>;
|
||||
type CompilationArg = GlobalTensorCompilationArg;
|
||||
|
||||
fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
|
||||
let tensor = <Tensor<Line<NumericExpand<DYN_ELEM_ID>>> as LaunchArg>::compilation_arg(
|
||||
&runtime_arg.tensor,
|
||||
);
|
||||
GlobalTensorCompilationArg {
|
||||
tensor,
|
||||
elem: runtime_arg.elem,
|
||||
broadcasted: runtime_arg.broadcasted,
|
||||
}
|
||||
}
|
||||
|
||||
fn expand(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> GlobalTensorExpand {
|
||||
let tensor = builder.input_tensor(Type::scalar(arg.elem).line(arg.tensor.line_size));
|
||||
|
||||
GlobalTensorExpand {
|
||||
tensor: tensor.into(),
|
||||
elem: arg.elem,
|
||||
broadcasted: arg.broadcasted,
|
||||
}
|
||||
}
|
||||
fn expand_output(
|
||||
arg: &Self::CompilationArg,
|
||||
builder: &mut KernelBuilder,
|
||||
) -> GlobalTensorExpand {
|
||||
let tensor = match arg.tensor.inplace {
|
||||
Some(id) => builder.inplace_output(id),
|
||||
None => builder.output_tensor(Type::scalar(arg.elem).line(arg.tensor.line_size)),
|
||||
};
|
||||
GlobalTensorExpand {
|
||||
tensor: tensor.into(),
|
||||
elem: arg.elem,
|
||||
broadcasted: arg.broadcasted,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> ArgSettings<R> for GlobalTensorArg<'_, R> {
|
||||
fn register(&self, launcher: &mut KernelLauncher<R>) {
|
||||
launcher.register_tensor(&self.tensor)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,358 @@
|
||||
use super::{
|
||||
DYN_ELEM_ID,
|
||||
io::{
|
||||
Transform, global_buffer_len, global_line_size, input_as_slice, read_input,
|
||||
read_input_window, ref_buffer_len, ref_len,
|
||||
},
|
||||
ir::{FuseArg, FuseBlockConfig, GlobalArgs, LayoutInfo, LocalArgs},
|
||||
kernel::fuse_on_write,
|
||||
};
|
||||
use cubecl::{
|
||||
CubeType,
|
||||
io::read_masked,
|
||||
ir::StorageType,
|
||||
prelude::{barrier::BarrierExpand, *},
|
||||
std::tensor::{
|
||||
ViewOperations, ViewOperationsExpand, ViewOperationsMut, ViewOperationsMutExpand,
|
||||
layout::Coords1d,
|
||||
},
|
||||
};
|
||||
|
||||
#[allow(dead_code, reason = "only used in expand")]
|
||||
#[derive(CubeType)]
|
||||
pub struct GlobalInput {
|
||||
inputs: GlobalArgs,
|
||||
locals: LocalArgs,
|
||||
#[cube(comptime)]
|
||||
pos: usize,
|
||||
#[cube(comptime)]
|
||||
ty: StorageType,
|
||||
#[cube(comptime)]
|
||||
layout: LayoutInfo,
|
||||
#[cube(comptime)]
|
||||
config: FuseBlockConfig,
|
||||
#[cube(comptime)]
|
||||
transform: Option<Transform>,
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl GlobalInput {
|
||||
pub fn new(
|
||||
inputs: &GlobalArgs,
|
||||
locals: &LocalArgs,
|
||||
#[comptime] arg: FuseArg,
|
||||
#[comptime] config: FuseBlockConfig,
|
||||
#[comptime] transform: Option<Transform>,
|
||||
) -> GlobalInput {
|
||||
let (pos, ty, layout) = comptime![match arg {
|
||||
FuseArg::Input(pos, prec, layout) => (pos, prec.into_type(), layout),
|
||||
_ => unreachable!("Must be concrete input"),
|
||||
}];
|
||||
|
||||
GlobalInput {
|
||||
inputs: inputs.clone(),
|
||||
locals: locals.clone(),
|
||||
pos,
|
||||
ty,
|
||||
layout,
|
||||
config,
|
||||
transform,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: CubePrimitive> ViewOperations<E, Coords1d> for GlobalInput {}
|
||||
impl<E: CubePrimitive> ViewOperationsExpand<E, Coords1d> for GlobalInputExpand {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_read_method(
|
||||
&self,
|
||||
scope: &mut Scope,
|
||||
pos: ExpandElementTyped<usize>,
|
||||
) -> <E as CubeType>::ExpandType {
|
||||
ViewOperationsExpand::<E, Coords1d>::__expand_read_unchecked_method(self, scope, pos)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_read_checked_method(
|
||||
&self,
|
||||
scope: &mut Scope,
|
||||
pos: ExpandElementTyped<usize>,
|
||||
) -> <E as CubeType>::ExpandType {
|
||||
let zero = E::__expand_cast_from(scope, 0.into());
|
||||
ViewOperationsExpand::<E, Coords1d>::__expand_read_masked_method(self, scope, pos, zero)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_read_masked_method(
|
||||
&self,
|
||||
scope: &mut Scope,
|
||||
pos: ExpandElementTyped<usize>,
|
||||
value: <E as CubeType>::ExpandType,
|
||||
) -> <E as CubeType>::ExpandType {
|
||||
let in_bounds = ViewOperationsExpand::<E, Coords1d>::__expand_is_in_bounds_method(
|
||||
self,
|
||||
scope,
|
||||
pos.clone(),
|
||||
);
|
||||
scope.register_type::<NumericExpand<DYN_ELEM_ID>>(self.ty);
|
||||
let slice = input_as_slice::expand(scope, self.inputs.clone(), self.pos);
|
||||
read_masked::expand::<E>(scope, in_bounds, slice, pos, value)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_read_unchecked_method(
|
||||
&self,
|
||||
scope: &mut Scope,
|
||||
pos: ExpandElementTyped<usize>,
|
||||
) -> <E as CubeType>::ExpandType {
|
||||
let value = read_input::expand::<E>(
|
||||
scope,
|
||||
self.inputs.clone(),
|
||||
self.locals.clone(),
|
||||
self.pos,
|
||||
pos,
|
||||
self.layout,
|
||||
self.config.clone(),
|
||||
self.transform.clone(),
|
||||
);
|
||||
E::__expand_cast_from(scope, value)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_to_linear_slice_method(
|
||||
&self,
|
||||
scope: &mut Scope,
|
||||
pos: ExpandElementTyped<usize>,
|
||||
end: ExpandElementTyped<usize>,
|
||||
) -> SliceExpand<E, ReadOnly> {
|
||||
scope.register_type::<NumericExpand<DYN_ELEM_ID>>(self.ty);
|
||||
let end = add::expand(scope, end.clone(), 1.into());
|
||||
read_input_window::expand(scope, self.inputs.clone(), self.pos, pos, end)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_tensor_map_load_method(
|
||||
&self,
|
||||
_scope: &mut Scope,
|
||||
_barrier: BarrierExpand,
|
||||
_shared_memory: SliceExpand<E, ReadWrite>,
|
||||
_pos: ExpandElementTyped<usize>,
|
||||
) {
|
||||
panic!("Not a tensor map")
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
|
||||
global_buffer_len::expand(scope, self.inputs.clone(), self.pos)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_is_in_bounds_method(
|
||||
&self,
|
||||
scope: &mut Scope,
|
||||
pos: ExpandElementTyped<usize>,
|
||||
) -> ExpandElementTyped<bool> {
|
||||
let buffer_len = global_buffer_len::expand(scope, self.inputs.clone(), self.pos);
|
||||
lt::expand(scope, pos, buffer_len)
|
||||
}
|
||||
}
|
||||
|
||||
impl Lined for GlobalInput {}
|
||||
impl LinedExpand for GlobalInputExpand {
|
||||
fn line_size(&self) -> LineSize {
|
||||
let mut temp_scope = Scope::root(false);
|
||||
global_line_size::expand(&mut temp_scope, self.inputs.clone(), self.pos)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code, reason = "only used in expand")]
|
||||
#[derive(CubeType)]
|
||||
pub struct FusedOutput {
|
||||
inputs: GlobalArgs,
|
||||
outputs: GlobalArgs,
|
||||
locals: LocalArgs,
|
||||
arg: FuseArg,
|
||||
#[cube(comptime)]
|
||||
config: FuseBlockConfig,
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl FusedOutput {
|
||||
pub fn new(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
locals: &mut LocalArgs,
|
||||
arg: FuseArg,
|
||||
#[comptime] config: FuseBlockConfig,
|
||||
) -> Self {
|
||||
FusedOutput {
|
||||
inputs: inputs.clone(),
|
||||
outputs: outputs.clone(),
|
||||
locals: locals.clone(),
|
||||
arg,
|
||||
config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: CubePrimitive> ViewOperations<Line<E>, Coords1d> for FusedOutput {}
|
||||
impl<E: CubePrimitive> ViewOperationsExpand<Line<E>, Coords1d> for FusedOutputExpand {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_read_method(
|
||||
&self,
|
||||
_scope: &mut Scope,
|
||||
_pos: ExpandElementTyped<usize>,
|
||||
) -> <Line<E> as CubeType>::ExpandType {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_read_checked_method(
|
||||
&self,
|
||||
_scope: &mut Scope,
|
||||
_pos: ExpandElementTyped<usize>,
|
||||
) -> <Line<E> as CubeType>::ExpandType {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_read_masked_method(
|
||||
&self,
|
||||
_scope: &mut Scope,
|
||||
_pos: ExpandElementTyped<usize>,
|
||||
_value: <Line<E> as CubeType>::ExpandType,
|
||||
) -> <Line<E> as CubeType>::ExpandType {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_read_unchecked_method(
|
||||
&self,
|
||||
_scope: &mut Scope,
|
||||
_pos: ExpandElementTyped<usize>,
|
||||
) -> <Line<E> as CubeType>::ExpandType {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_to_linear_slice_method(
|
||||
&self,
|
||||
_scope: &mut Scope,
|
||||
_pos: ExpandElementTyped<usize>,
|
||||
_size: ExpandElementTyped<usize>,
|
||||
) -> SliceExpand<Line<E>, ReadOnly> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_tensor_map_load_method(
|
||||
&self,
|
||||
_scope: &mut Scope,
|
||||
_barrier: BarrierExpand,
|
||||
_shared_memory: SliceExpand<Line<E>, ReadWrite>,
|
||||
_pos: ExpandElementTyped<usize>,
|
||||
) {
|
||||
panic!("Not a tensor map")
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
|
||||
ref_len::expand(
|
||||
scope,
|
||||
self.inputs.clone(),
|
||||
self.outputs.clone(),
|
||||
self.locals.clone(),
|
||||
self.config.clone(),
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_is_in_bounds_method(
|
||||
&self,
|
||||
scope: &mut Scope,
|
||||
pos: ExpandElementTyped<usize>,
|
||||
) -> ExpandElementTyped<bool> {
|
||||
let buffer_len = ref_buffer_len::expand(
|
||||
scope,
|
||||
self.inputs.clone(),
|
||||
self.outputs.clone(),
|
||||
self.locals.clone(),
|
||||
self.config.clone(),
|
||||
);
|
||||
lt::expand(scope, pos, buffer_len)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: CubePrimitive> ViewOperationsMut<Line<E>, Coords1d> for FusedOutput {}
|
||||
impl<E: CubePrimitive> ViewOperationsMutExpand<Line<E>, Coords1d> for FusedOutputExpand {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_write_method(
|
||||
&self,
|
||||
scope: &mut Scope,
|
||||
pos: ExpandElementTyped<usize>,
|
||||
value: <Line<E> as CubeType>::ExpandType,
|
||||
) {
|
||||
let values = Registry::<FuseArg, Line<E>>::__expand_new(scope);
|
||||
let mut args = comptime![Vec::<FuseArg>::new()];
|
||||
|
||||
values
|
||||
.clone()
|
||||
.__expand_insert_method(scope, comptime![self.arg.clone()], value);
|
||||
comptime![args.push(self.arg.clone())];
|
||||
|
||||
fuse_on_write::expand(
|
||||
scope,
|
||||
self.inputs.clone(),
|
||||
self.outputs.clone(),
|
||||
self.locals.clone(),
|
||||
pos,
|
||||
values,
|
||||
args,
|
||||
self.config.clone(),
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_write_checked_method(
|
||||
&self,
|
||||
scope: &mut Scope,
|
||||
pos: ExpandElementTyped<usize>,
|
||||
value: <Line<E> as CubeType>::ExpandType,
|
||||
) {
|
||||
let in_bounds = ViewOperationsExpand::<Line<E>, Coords1d>::__expand_is_in_bounds_method(
|
||||
self,
|
||||
scope,
|
||||
pos.clone(),
|
||||
);
|
||||
if_expand(scope, in_bounds.into(), |scope| {
|
||||
self.__expand_write_method(scope, pos, value);
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_to_linear_slice_mut_method(
|
||||
&self,
|
||||
_scope: &mut Scope,
|
||||
_pos: ExpandElementTyped<usize>,
|
||||
_size: ExpandElementTyped<usize>,
|
||||
) -> SliceExpand<Line<E>, ReadWrite> {
|
||||
todo!("Not yet supported")
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn __expand_tensor_map_store_method(
|
||||
&self,
|
||||
_scope: &mut Scope,
|
||||
_shared_memory: SliceExpand<Line<E>, ReadOnly>,
|
||||
_pos: ExpandElementTyped<usize>,
|
||||
) {
|
||||
panic!("Not a tensor map")
|
||||
}
|
||||
}
|
||||
|
||||
impl Lined for FusedOutput {}
|
||||
impl LinedExpand for FusedOutputExpand {
|
||||
fn line_size(&self) -> LineSize {
|
||||
self.locals.ref_line_size
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,769 @@
|
||||
use super::{
|
||||
codegen::ir::{BinaryFuseArgs, FuseArg, FuseOp, FuseType, UnaryFuseArgs},
|
||||
settings::FuseSettings,
|
||||
trace::{FuseTrace, TraceFuser, block::QuantInput},
|
||||
};
|
||||
use crate::engine::codegen::ir::QuantSchemeFuse;
|
||||
use burn_fusion::{FuserProperties, FuserStatus, OperationFuser};
|
||||
use burn_ir::{
|
||||
BaseOperationIr, BinaryOpIr, FloatOperationIr, NumericOperationIr, OperationIr, ScalarOpIr,
|
||||
TensorIr, UnaryOpIr,
|
||||
};
|
||||
use burn_std::{DType, Shape};
|
||||
use cubecl::ir::ElemType;
|
||||
|
||||
/// The base operation fuser that can be used to fuse [all supported fuse operations](FuseOp).
|
||||
///
|
||||
///
|
||||
/// This fuser doesn't create a ready-to-execute kernel, but rather generates a
|
||||
/// [trace](FuseTrace) that be used with a [runner](super::trace::TraceRunner).
|
||||
///
|
||||
/// Since this fuser supports fusing multiple blocks, you can fuse any compute-bound operations
|
||||
/// with the combination of fuse-on-read and fuse-on-write strategy.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// It is responsible to translate [OperationIr] into [FuseOp] and it uses the [TraceFuser]
|
||||
/// to actually fuse the [FuseOp] when possible.
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct TraceOperationFuser {
|
||||
fuser: TryTraceFuser,
|
||||
pub(crate) settings: FuseSettings,
|
||||
pub(crate) current_output_shape: Shape,
|
||||
status: FuserStatus,
|
||||
pub(crate) num_ops: usize,
|
||||
pub(crate) num_views: usize,
|
||||
pub(crate) max_bindings: u32,
|
||||
}
|
||||
|
||||
impl TraceOperationFuser {
|
||||
/// Checks if the [operation](OperationIr) can be fused with the current fuser.
|
||||
pub(crate) fn can_fuse(&self, op: &OperationIr) -> bool {
|
||||
let len_previous = self.len();
|
||||
let mut fuser_cloned = self.clone();
|
||||
|
||||
fuser_cloned.fuse(op);
|
||||
let len_after = fuser_cloned.len();
|
||||
|
||||
len_after > len_previous
|
||||
}
|
||||
}
|
||||
|
||||
impl OperationFuser<FuseTrace> for TraceOperationFuser {
|
||||
fn fuse(&mut self, op: &OperationIr) {
|
||||
if let FuserStatus::Closed = self.status {
|
||||
return;
|
||||
}
|
||||
|
||||
match op {
|
||||
OperationIr::Drop(tensor) => {
|
||||
if self.num_ops == 0 {
|
||||
self.status = FuserStatus::Closed;
|
||||
return;
|
||||
}
|
||||
|
||||
self.fuser.fuser.fuse_dropped(tensor);
|
||||
}
|
||||
OperationIr::BaseFloat(ops) => {
|
||||
if !self.fuse_base(ops) {
|
||||
self.status = FuserStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
OperationIr::BaseInt(ops) => {
|
||||
if !self.fuse_base(ops) {
|
||||
self.status = FuserStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
OperationIr::Float(_dtype, ops) => {
|
||||
if !self.fuse_float(ops) {
|
||||
self.status = FuserStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
OperationIr::NumericFloat(_dtype, ops) => {
|
||||
if !self.fuse_numeric(ops) {
|
||||
self.status = FuserStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
OperationIr::NumericInt(_dtype, ops) => {
|
||||
if !self.fuse_numeric(ops) {
|
||||
self.status = FuserStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
OperationIr::BaseBool(ops) => {
|
||||
if !self.fuse_base(ops) {
|
||||
self.status = FuserStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
self.status = FuserStatus::Closed;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
self.status = FuserStatus::Open;
|
||||
self.num_ops += 1;
|
||||
}
|
||||
|
||||
fn finish(&mut self) -> FuseTrace {
|
||||
self.fuser.finish(self.current_output_shape.clone())
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.num_ops
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.num_ops = 0;
|
||||
self.num_views = 0;
|
||||
self.status = FuserStatus::Open;
|
||||
self.fuser = TryTraceFuser::new(
|
||||
self.max_bindings,
|
||||
self.fuser.fuser.bool_precision,
|
||||
self.settings,
|
||||
);
|
||||
self.current_output_shape = Shape::new([]);
|
||||
}
|
||||
|
||||
fn status(&self) -> FuserStatus {
|
||||
self.status
|
||||
}
|
||||
|
||||
fn properties(&self) -> FuserProperties {
|
||||
let ready = self.num_ops > 0;
|
||||
|
||||
FuserProperties {
|
||||
ready,
|
||||
score: self.num_ops as u64,
|
||||
}
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn OperationFuser<FuseTrace>> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl TraceOperationFuser {
|
||||
/// Creates a new fuser.
|
||||
pub fn new(max_bindings: u32, bool_precision: FuseType, settings: FuseSettings) -> Self {
|
||||
Self {
|
||||
fuser: TryTraceFuser::new(max_bindings, bool_precision, settings),
|
||||
settings,
|
||||
num_ops: 0,
|
||||
num_views: 0,
|
||||
max_bindings,
|
||||
current_output_shape: Shape::new([]),
|
||||
status: FuserStatus::Open,
|
||||
}
|
||||
}
|
||||
|
||||
/// Closes the fuser.
|
||||
pub fn close(&mut self) {
|
||||
self.status = FuserStatus::Closed;
|
||||
}
|
||||
|
||||
/// Declares an input tensor argument where the kernel is responsible to load.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// - The argument that maps to the tensor to be used during kernel expansion.
|
||||
pub fn input_unhandled(&mut self, tensor: &TensorIr) -> FuseArg {
|
||||
self.fuser.fuser.input_unhandled(tensor)
|
||||
}
|
||||
|
||||
/// Declares an input quantized tensor argument where the kernel is responsible to load.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// None if it's not possible to fuse a quantized tensor. Otherwise:
|
||||
///
|
||||
/// - The argument that maps to the tensor values to be used during kernel expansion.
|
||||
/// - The argument that maps to the tensor params to be used during kernel expansion.
|
||||
pub fn input_quantized_unhandled(&mut self, tensor: &TensorIr) -> Option<(FuseArg, FuseArg)> {
|
||||
self.fuser.fuser.input_quantized_unhandled(tensor)
|
||||
}
|
||||
|
||||
/// Declares an output tensor argument where the kernel is responsible to write values.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Normally you don't have to declare outputs explicitly before they are going to be
|
||||
/// fused based on the operations [fused](Self::fuse).
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// - The argument that maps to the tensor to be used during kernel expansion.
|
||||
pub fn output_unhandled(&mut self, tensor: &TensorIr) -> FuseArg {
|
||||
if self.current_output_shape.is_empty() {
|
||||
self.current_output_shape = tensor.shape.clone();
|
||||
} else if self.current_output_shape.iter().sum::<usize>() < tensor.shape.iter().sum() {
|
||||
// The larguest shape win.
|
||||
self.current_output_shape = tensor.shape.clone();
|
||||
}
|
||||
|
||||
self.fuser.fuser.output_unhandled(tensor)
|
||||
}
|
||||
|
||||
/// Closes the previous block and declares a new one.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - arguments: Tensors that are logical outputs of the current block and inputs of the following blocks.
|
||||
/// - settings: [FuseSettings] to be used by the next block.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// None if it's impossible to create a next block with the given arguments. Otherwise, the
|
||||
/// corresponding [arguments](Arg) to the given tensors are returned.
|
||||
pub fn next_block<const N: usize>(
|
||||
&mut self,
|
||||
arguments: [&TensorIr; N],
|
||||
settings: FuseSettings,
|
||||
global: bool,
|
||||
) -> [FuseArg; N] {
|
||||
let block_pos = self.fuser.fuser.num_previous_blocks();
|
||||
let current_output_shape =
|
||||
core::mem::replace(&mut self.current_output_shape, Shape::new([]));
|
||||
|
||||
self.fuser.fuser.next_block(current_output_shape, settings);
|
||||
|
||||
self.settings = settings;
|
||||
self.status = FuserStatus::Open;
|
||||
|
||||
arguments.map(|arg| self.fuser.fuser.block_local_input(arg, block_pos, global))
|
||||
}
|
||||
|
||||
/// Tag the [tensor](TensorIr) as received from a previous block.
|
||||
///
|
||||
/// This will avoid reading the input again and instead use le local version when possible.
|
||||
pub fn block_local_input(&mut self, tensor: &TensorIr, block_pos: usize, global: bool) {
|
||||
self.fuser
|
||||
.fuser
|
||||
.block_local_input(tensor, block_pos, global);
|
||||
}
|
||||
|
||||
fn fuse_base(&mut self, ops: &BaseOperationIr) -> bool {
|
||||
match ops {
|
||||
BaseOperationIr::Equal(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::Equal(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
BaseOperationIr::EqualElem(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::Equal(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
BaseOperationIr::Cast(desc) => {
|
||||
self.fuse_unary_op(&desc.input, &desc.out, |input, out| {
|
||||
FuseOp::Assign(UnaryFuseArgs { input, out })
|
||||
})
|
||||
}
|
||||
BaseOperationIr::SwapDims(desc) => {
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if self.fuser.fuse(|fuser| {
|
||||
fuser.input_swap_dims(&desc.input, &desc.out, (desc.dim1, desc.dim2))?;
|
||||
|
||||
Some(())
|
||||
}) {
|
||||
self.num_views += 1;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
BaseOperationIr::Reshape(desc) => {
|
||||
if desc.input.shape == desc.out.shape {
|
||||
return self.fuse_unary_op(&desc.input, &desc.out, |input, out| {
|
||||
FuseOp::Assign(UnaryFuseArgs { input, out })
|
||||
});
|
||||
}
|
||||
|
||||
if desc.input.shape.rank() > desc.out.shape.rank() {
|
||||
// Not yet supported.
|
||||
return false;
|
||||
}
|
||||
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if self.fuser.fuse(|fuser| {
|
||||
fuser.input_reshaped(&desc.input, &desc.out)?;
|
||||
Some(())
|
||||
}) {
|
||||
self.num_views += 1;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
BaseOperationIr::Ones(desc) => {
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let elem: ElemType = desc.out.dtype.into();
|
||||
let precision = elem.into();
|
||||
let input = FuseArg::Literal(1, precision);
|
||||
|
||||
self.fuser.fuse(|fuser| {
|
||||
let out = fuser.output(&desc.out)?;
|
||||
|
||||
fuser.fuse_operation(FuseOp::Assign(UnaryFuseArgs { input, out }));
|
||||
|
||||
Some(())
|
||||
})
|
||||
}
|
||||
BaseOperationIr::Zeros(desc) => {
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let elem: ElemType = desc.out.dtype.into();
|
||||
let precision = elem.into();
|
||||
let input = FuseArg::Literal(0, precision);
|
||||
|
||||
self.fuser.fuse(|fuser| {
|
||||
let out = fuser.output(&desc.out)?;
|
||||
|
||||
fuser.fuse_operation(FuseOp::Assign(UnaryFuseArgs { input, out }));
|
||||
|
||||
Some(())
|
||||
})
|
||||
}
|
||||
BaseOperationIr::Gather(desc) => {
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.fuser.fuse(|build| {
|
||||
let input = build.input_indexed(&desc.tensor)?;
|
||||
let indices = build.input_indexed(&desc.indices)?;
|
||||
let output = build.output(&desc.out)?;
|
||||
|
||||
build.fuse_operation(FuseOp::Gather {
|
||||
input,
|
||||
indices,
|
||||
output,
|
||||
dim: desc.dim,
|
||||
});
|
||||
|
||||
Some(())
|
||||
})
|
||||
}
|
||||
BaseOperationIr::Select(desc) => {
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.fuser.fuse(|build| {
|
||||
let input = build.input_indexed(&desc.tensor)?;
|
||||
let indices = build.input_indexed(&desc.indices)?;
|
||||
let output = build.output(&desc.out)?;
|
||||
|
||||
build.fuse_operation(FuseOp::Select {
|
||||
input,
|
||||
indices,
|
||||
output,
|
||||
dim: desc.dim,
|
||||
});
|
||||
|
||||
Some(())
|
||||
})
|
||||
}
|
||||
BaseOperationIr::MaskWhere(desc) => {
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.fuser.fuse(|build| {
|
||||
let cond = build.input(&desc.mask)?;
|
||||
let rhs = build.input(&desc.tensor)?;
|
||||
let lhs = build.input(&desc.value)?;
|
||||
let out = build.output(&desc.out)?;
|
||||
|
||||
build.fuse_operation(FuseOp::ConditionalAssign {
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
});
|
||||
|
||||
Some(())
|
||||
})
|
||||
}
|
||||
BaseOperationIr::MaskFill(desc) => {
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.fuser.fuse(|build| {
|
||||
let cond = build.input(&desc.mask)?;
|
||||
let lhs = build.scalar(&desc.value, desc.out.dtype);
|
||||
let rhs = build.input(&desc.tensor)?;
|
||||
let out = build.output(&desc.out)?;
|
||||
|
||||
build.fuse_operation(FuseOp::ConditionalAssign {
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
});
|
||||
|
||||
Some(())
|
||||
})
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn fuse_float(&mut self, ops: &FloatOperationIr) -> bool {
|
||||
match ops {
|
||||
FloatOperationIr::Exp(desc) => {
|
||||
self.fuse_unary_ops(desc, |input, out| FuseOp::Exp(UnaryFuseArgs { input, out }))
|
||||
}
|
||||
FloatOperationIr::Log(desc) => {
|
||||
self.fuse_unary_ops(desc, |input, out| FuseOp::Log(UnaryFuseArgs { input, out }))
|
||||
}
|
||||
FloatOperationIr::Log1p(desc) => self.fuse_unary_ops(desc, |input, out| {
|
||||
FuseOp::Log1p(UnaryFuseArgs { input, out })
|
||||
}),
|
||||
FloatOperationIr::Cos(desc) => {
|
||||
self.fuse_unary_ops(desc, |input, out| FuseOp::Cos(UnaryFuseArgs { input, out }))
|
||||
}
|
||||
FloatOperationIr::Sin(desc) => {
|
||||
self.fuse_unary_ops(desc, |input, out| FuseOp::Sin(UnaryFuseArgs { input, out }))
|
||||
}
|
||||
FloatOperationIr::PowfScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::Powf(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
FloatOperationIr::Tanh(desc) => self.fuse_unary_ops(desc, |input, out| {
|
||||
FuseOp::Tanh(UnaryFuseArgs { input, out })
|
||||
}),
|
||||
FloatOperationIr::Erf(desc) => {
|
||||
self.fuse_unary_ops(desc, |input, out| FuseOp::Erf(UnaryFuseArgs { input, out }))
|
||||
}
|
||||
FloatOperationIr::Sqrt(desc) => self.fuse_unary_ops(desc, |input, out| {
|
||||
FuseOp::Sqrt(UnaryFuseArgs { input, out })
|
||||
}),
|
||||
FloatOperationIr::Recip(desc) => self.fuse_unary_ops(desc, |input, out| {
|
||||
FuseOp::Recip(UnaryFuseArgs { input, out })
|
||||
}),
|
||||
FloatOperationIr::Dequantize(desc) => {
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.fuser.fuse(|build| {
|
||||
let qinput = build.input_quantized(&desc.input)?;
|
||||
let out = build.output(&desc.out)?;
|
||||
|
||||
match qinput {
|
||||
QuantInput::AlreadyDequantized { local } => {
|
||||
build.fuse_operation(FuseOp::Assign(UnaryFuseArgs {
|
||||
input: local,
|
||||
out,
|
||||
}));
|
||||
}
|
||||
QuantInput::Quantized { values, params } => {
|
||||
build.fuse_operation(FuseOp::Dequantize {
|
||||
values,
|
||||
params,
|
||||
output: out,
|
||||
scheme: match desc.input.dtype {
|
||||
DType::QFloat(scheme) => QuantSchemeFuse { scheme },
|
||||
_ => unreachable!("Should be a quant tensor."),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Some(())
|
||||
})
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn fuse_numeric(&mut self, op: &NumericOperationIr) -> bool {
|
||||
match op {
|
||||
NumericOperationIr::Add(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::Add(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::AddScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::Add(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::Sub(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::Sub(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::SubScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::Sub(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::Mul(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::Mul(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::MulScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::Mul(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::Div(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::Div(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::DivScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::Div(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::Abs(desc) => {
|
||||
self.fuse_unary_ops(desc, |input, out| FuseOp::Abs(UnaryFuseArgs { input, out }))
|
||||
}
|
||||
NumericOperationIr::Lower(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::Lower(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::LowerElem(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::Lower(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::Greater(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::Greater(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::GreaterElem(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::Greater(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::LowerEqual(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::LowerEqual(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::LowerEqualElem(desc) => self
|
||||
.fuse_scalar_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::LowerEqual(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::GreaterEqual(desc) => self
|
||||
.fuse_binary_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::GreaterEqual(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::GreaterEqualElem(desc) => self
|
||||
.fuse_scalar_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::GreaterEqual(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::Full(desc) => {
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.fuser.fuse(|build| {
|
||||
let input = build.scalar(&desc.value, desc.out.dtype);
|
||||
let out = build.output(&desc.out)?;
|
||||
|
||||
build.fuse_operation(FuseOp::Assign(UnaryFuseArgs { input, out }));
|
||||
|
||||
Some(())
|
||||
})
|
||||
}
|
||||
NumericOperationIr::Rem(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::Rem(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::RemScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::Rem(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::Powf(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {
|
||||
FuseOp::Powf(BinaryFuseArgs { lhs, rhs, out })
|
||||
}),
|
||||
NumericOperationIr::Clamp(desc) => {
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.fuser.fuse(|build| {
|
||||
let input = build.input(&desc.tensor)?;
|
||||
let min = build.scalar(&desc.min, desc.out.dtype);
|
||||
let max = build.scalar(&desc.max, desc.out.dtype);
|
||||
let out = build.output(&desc.out)?;
|
||||
|
||||
build.fuse_operation(FuseOp::Clamp {
|
||||
input,
|
||||
min,
|
||||
max,
|
||||
out,
|
||||
});
|
||||
|
||||
Some(())
|
||||
})
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn fuse_binary_ops<Func>(&mut self, desc: &BinaryOpIr, func: Func) -> bool
|
||||
where
|
||||
Func: Fn(FuseArg, FuseArg, FuseArg) -> FuseOp,
|
||||
{
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.fuser.fuse(|build| {
|
||||
let lhs = build.input(&desc.lhs)?;
|
||||
let rhs = build.input(&desc.rhs)?;
|
||||
let out = build.output(&desc.out)?;
|
||||
|
||||
build.fuse_operation(func(lhs, rhs, out));
|
||||
|
||||
Some(())
|
||||
})
|
||||
}
|
||||
|
||||
fn fuse_unary_ops<Func>(&mut self, desc: &UnaryOpIr, func: Func) -> bool
|
||||
where
|
||||
Func: Fn(FuseArg, FuseArg) -> FuseOp,
|
||||
{
|
||||
self.fuse_unary_op(&desc.input, &desc.out, func)
|
||||
}
|
||||
|
||||
fn fuse_unary_op<Func>(&mut self, input: &TensorIr, out: &TensorIr, func: Func) -> bool
|
||||
where
|
||||
Func: Fn(FuseArg, FuseArg) -> FuseOp,
|
||||
{
|
||||
if !self.output_is_compatible(out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.fuser.fuse(|build| {
|
||||
let input = build.input(input)?;
|
||||
let out = build.output(out)?;
|
||||
build.fuse_operation(func(input, out));
|
||||
Some(())
|
||||
})
|
||||
}
|
||||
|
||||
fn fuse_scalar_ops<Func>(&mut self, desc: &ScalarOpIr, func: Func) -> bool
|
||||
where
|
||||
Func: Fn(FuseArg, FuseArg, FuseArg) -> FuseOp,
|
||||
{
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.fuser.fuse(|build| {
|
||||
let elem = desc.lhs.dtype;
|
||||
let lhs = build.input(&desc.lhs)?;
|
||||
let rhs = build.scalar(&desc.rhs, elem);
|
||||
let out = build.output(&desc.out)?;
|
||||
|
||||
build.fuse_operation(func(lhs, rhs, out));
|
||||
|
||||
Some(())
|
||||
})
|
||||
}
|
||||
|
||||
fn output_is_compatible(&mut self, out: &TensorIr) -> bool {
|
||||
if self.current_output_shape.is_empty() {
|
||||
self.current_output_shape.clone_from(&out.shape);
|
||||
return true;
|
||||
}
|
||||
|
||||
let rank = self.current_output_shape.len();
|
||||
|
||||
// Rank should be equal.
|
||||
if rank != out.shape.num_dims() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let mut updated = self.current_output_shape.clone();
|
||||
let mut should_update = false;
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for i in 0..rank {
|
||||
let curr = self.current_output_shape[i];
|
||||
let new = out.shape[i];
|
||||
|
||||
if curr == new {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcast not enabled.
|
||||
if !self.settings.broadcast {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Broadcasted on new dim.
|
||||
if new == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcasted on curr dim - update reference output shape.
|
||||
if curr == 0 && self.settings.output_shape_updates {
|
||||
should_update = true;
|
||||
updated[i] = new;
|
||||
continue;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
if should_update {
|
||||
// For now forced to have exact shape.
|
||||
if updated != out.shape {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.current_output_shape.clone_from_slice(&out.shape);
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Builder wrapper to limit the number of bindings in generated kernels.
|
||||
struct TryTraceFuser {
|
||||
fuser: TraceFuser,
|
||||
max_bindings: u32,
|
||||
max_ops: u32,
|
||||
added_ops: bool,
|
||||
}
|
||||
|
||||
impl TryTraceFuser {
|
||||
fn new(max_bindings: u32, bool_precision: FuseType, settings: FuseSettings) -> Self {
|
||||
Self {
|
||||
fuser: TraceFuser::new(bool_precision, settings),
|
||||
max_bindings,
|
||||
// A good default, avoid errors with for loops over only memory
|
||||
// bound operations.
|
||||
max_ops: 64,
|
||||
added_ops: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn fuse(&mut self, add_ops: impl FnOnce(&mut TraceFuser) -> Option<()>) -> bool {
|
||||
if self.fuser.num_ops_fused() > self.max_ops {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Always allow the first operation to be added.
|
||||
if !self.added_ops {
|
||||
self.added_ops = true;
|
||||
|
||||
if add_ops(&mut self.fuser).is_none() {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
let mut cloned = self.fuser.clone();
|
||||
if add_ops(&mut cloned).is_none() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if cloned.estimate_bindings() > self.max_bindings {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.fuser = cloned;
|
||||
true
|
||||
}
|
||||
|
||||
fn finish(&mut self, shape: Shape) -> FuseTrace {
|
||||
self.fuser.finish(shape)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
use crate::{
|
||||
CubeFusionHandle,
|
||||
engine::{
|
||||
launch::{
|
||||
HandleInput, HandleOutput, LaunchPlan, executor::LaunchPlanExecutor,
|
||||
input::InputPlanner, output::OutputPlanner, runner::TraceRunner,
|
||||
vectorization::VectorizationPlanner,
|
||||
},
|
||||
trace::{FuseTrace, TraceError, TuneOutput},
|
||||
},
|
||||
};
|
||||
use burn_fusion::stream::Context;
|
||||
use cubecl::{CubeElement, Runtime, client::ComputeClient};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// The launcher is responsible to launch a fused kernel using the [TraceRunner] and a [FuseTrace].
|
||||
pub struct FuseTraceLauncher<'a, R: Runtime, Runner: TraceRunner<R>> {
|
||||
trace: &'a FuseTrace,
|
||||
runner: &'a Runner,
|
||||
_runtime: PhantomData<R>,
|
||||
}
|
||||
|
||||
impl<'a, R: Runtime, Runner: TraceRunner<R>> FuseTraceLauncher<'a, R, Runner> {
|
||||
/// Creates a new launcher.
|
||||
pub fn new(trace: &'a FuseTrace, runner: &'a Runner) -> Self {
|
||||
Self {
|
||||
trace,
|
||||
runner,
|
||||
_runtime: PhantomData,
|
||||
}
|
||||
}
|
||||
/// Launches the fuse kernel on the given device modifying the context.
|
||||
pub fn launch<BT: CubeElement>(
|
||||
&self,
|
||||
client: &ComputeClient<R>,
|
||||
device: &R::Device,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
) -> Result<TuneOutput<R>, TraceError<Runner::Error>> {
|
||||
let mut plan = LaunchPlan::new(&self.trace.blocks);
|
||||
|
||||
InputPlanner::new(&self.trace.resources, &self.trace.blocks).run(context, &mut plan);
|
||||
|
||||
OutputPlanner::new(&self.trace.resources, &self.trace.blocks)
|
||||
.run::<BT>(client, device, context, &mut plan);
|
||||
|
||||
VectorizationPlanner::new(&self.trace.resources, &self.trace.blocks).run(
|
||||
client,
|
||||
self.runner,
|
||||
context,
|
||||
&mut plan,
|
||||
);
|
||||
|
||||
match LaunchPlanExecutor::new(&self.trace.resources, &self.trace.blocks).execute::<_, BT>(
|
||||
client,
|
||||
self.runner,
|
||||
context,
|
||||
plan,
|
||||
) {
|
||||
Err(err) => {
|
||||
self.rollback(context, err.handles_input, err.handles_output);
|
||||
Err(err.error)
|
||||
}
|
||||
Ok(val) => Ok(val),
|
||||
}
|
||||
}
|
||||
|
||||
fn rollback(
|
||||
&self,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
handle_inputs: Vec<HandleInput<R>>,
|
||||
handle_outputs: Vec<HandleOutput<R>>,
|
||||
) {
|
||||
for input in handle_inputs {
|
||||
match input {
|
||||
HandleInput::Normal(input) => {
|
||||
context
|
||||
.handles
|
||||
.register_handle(input.global_ir.id, input.handle_rollback());
|
||||
}
|
||||
HandleInput::QuantValues(input) => {
|
||||
context
|
||||
.handles
|
||||
.register_handle(input.global_ir.id, input.handle);
|
||||
}
|
||||
HandleInput::QuantParams(_) => {
|
||||
// The scales are part of the quant data handle.
|
||||
}
|
||||
};
|
||||
}
|
||||
for output in handle_outputs {
|
||||
if let HandleOutput::Owned {
|
||||
global_id, handle, ..
|
||||
} = output
|
||||
{
|
||||
context.handles.register_handle(global_id, handle);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,292 @@
|
||||
use super::{HandleInput, HandleOutput, LaunchPlan, ReferenceSelection};
|
||||
use crate::engine::launch::runner::TraceRunner;
|
||||
use crate::engine::trace::{FuseResources, TensorView, TraceError, TuneOutput, block::FuseBlock};
|
||||
use crate::{
|
||||
CubeFusionHandle, elem_dtype,
|
||||
engine::{
|
||||
codegen::ir::{
|
||||
FuseBlockConfig, FuseOp, FuseType, GlobalArgsLaunch, RefLayout, VirtualLayout,
|
||||
},
|
||||
codegen::tensor::GlobalTensorArg,
|
||||
},
|
||||
};
|
||||
use burn_fusion::stream::{Context, ScalarId};
|
||||
use burn_ir::ScalarIr;
|
||||
use burn_std::DType;
|
||||
use cubecl::{
|
||||
CubeElement, Runtime,
|
||||
client::ComputeClient,
|
||||
ir::AddressType,
|
||||
prelude::{InputScalar, ScalarArg, TensorArg},
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Execute a [plan](LaunchPlan) using a [runner](TraceRunner) modifying the [context](Context).
|
||||
pub struct LaunchPlanExecutor<'a, R: Runtime> {
|
||||
resources: &'a FuseResources,
|
||||
blocks: &'a Vec<FuseBlock>,
|
||||
_r: PhantomData<R>,
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
pub struct ExecutionError<R: Runtime, Runner: TraceRunner<R>> {
|
||||
pub error: TraceError<Runner::Error>,
|
||||
pub handles_input: Vec<HandleInput<R>>,
|
||||
pub handles_output: Vec<HandleOutput<R>>,
|
||||
}
|
||||
|
||||
impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> {
|
||||
pub fn new(resources: &'a FuseResources, blocks: &'a Vec<FuseBlock>) -> Self {
|
||||
Self {
|
||||
resources,
|
||||
blocks,
|
||||
_r: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute<Runner: TraceRunner<R>, BT: CubeElement>(
|
||||
self,
|
||||
client: &ComputeClient<R>,
|
||||
runner: &Runner,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
plan: LaunchPlan<'a, R>,
|
||||
) -> Result<TuneOutput<R>, ExecutionError<R, Runner>> {
|
||||
let mut num_writes = 0;
|
||||
for b in plan.blocks.iter() {
|
||||
for writes in b.writes.values() {
|
||||
num_writes += writes.len();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
let mut tune_output = TuneOutput::Checked {
|
||||
handles: std::collections::HashMap::new(),
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "autotune-checks"))]
|
||||
let mut tune_output = TuneOutput::UnChecked(PhantomData);
|
||||
|
||||
if num_writes == 0 {
|
||||
// Nothing to write, can skip execution.
|
||||
return Ok(tune_output);
|
||||
}
|
||||
|
||||
let mut inputs = GlobalArgsLaunch::default();
|
||||
let mut outputs = GlobalArgsLaunch::default();
|
||||
|
||||
register_inputs(&plan.handle_inputs, &mut inputs);
|
||||
register_scalars(
|
||||
self.resources.scalars.iter(),
|
||||
self.resources.views.iter(),
|
||||
context,
|
||||
&mut inputs,
|
||||
);
|
||||
register_outputs::<BT, R>(&plan.handle_outputs, &mut outputs, &mut tune_output);
|
||||
|
||||
for layout in plan.runtime_layouts {
|
||||
for s in layout.shape.iter() {
|
||||
inputs.runtime_layouts.push(ScalarArg::new(*s));
|
||||
}
|
||||
for s in layout.strides.iter() {
|
||||
inputs.runtime_layouts.push(ScalarArg::new(*s));
|
||||
}
|
||||
}
|
||||
|
||||
let mut configs = Vec::with_capacity(plan.blocks.len());
|
||||
|
||||
for (block_plan, block) in plan.blocks.into_iter().zip(self.blocks) {
|
||||
let reference = match block_plan.reference {
|
||||
ReferenceSelection::Concrete { layout, .. } => RefLayout::Concrete(layout),
|
||||
ReferenceSelection::VirtualShape { original, .. } => {
|
||||
RefLayout::Virtual(VirtualLayout::Shape(original, block_plan.width))
|
||||
}
|
||||
ReferenceSelection::SwapDims { original, dims } => {
|
||||
RefLayout::Virtual(VirtualLayout::SwapDims(original, dims))
|
||||
}
|
||||
ReferenceSelection::Reshaped { reshape_pos } => {
|
||||
RefLayout::Virtual(VirtualLayout::Reshaped {
|
||||
reshape_pos,
|
||||
line_size: block_plan.width,
|
||||
})
|
||||
}
|
||||
ReferenceSelection::Runtime { pos } => {
|
||||
RefLayout::Virtual(VirtualLayout::Runtime { pos })
|
||||
}
|
||||
ReferenceSelection::Searching => {
|
||||
return Err(ExecutionError::new(
|
||||
TraceError::ReferenceNotFound,
|
||||
plan.handle_inputs,
|
||||
plan.handle_outputs,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut ops = Vec::<FuseOp>::new();
|
||||
|
||||
for read_ops in block_plan.reads.into_values() {
|
||||
for op in read_ops {
|
||||
ops.push(op);
|
||||
}
|
||||
}
|
||||
|
||||
for op in block.ops.iter() {
|
||||
ops.push(op.clone());
|
||||
}
|
||||
|
||||
for opsw in block_plan.writes.into_values() {
|
||||
for op in opsw {
|
||||
ops.push(op);
|
||||
}
|
||||
}
|
||||
|
||||
let config = FuseBlockConfig {
|
||||
rank: plan.rank,
|
||||
ref_layout: reference,
|
||||
ops,
|
||||
width: block_plan.width,
|
||||
};
|
||||
configs.push(config);
|
||||
}
|
||||
|
||||
Runner::run(runner, client, inputs, outputs, &configs).map_err(|err| {
|
||||
ExecutionError::new(
|
||||
TraceError::RunnerError(err),
|
||||
plan.handle_inputs,
|
||||
plan.handle_outputs,
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(tune_output)
|
||||
}
|
||||
}
|
||||
|
||||
fn register_inputs<'h, R: Runtime>(
|
||||
handle_inputs: &'h [HandleInput<R>],
|
||||
inputs: &mut GlobalArgsLaunch<'h, R>,
|
||||
) {
|
||||
for hi in handle_inputs.iter() {
|
||||
match hi {
|
||||
HandleInput::Normal(hi) => {
|
||||
let arg = hi.handle.as_tensor_arg(&hi.global_ir.shape, hi.line_size);
|
||||
inputs.tensors.push(GlobalTensorArg::new(
|
||||
arg,
|
||||
hi.precision.into_elem(),
|
||||
hi.broadcated,
|
||||
hi.handle.required_address_type(),
|
||||
));
|
||||
}
|
||||
HandleInput::QuantValues(hi) => {
|
||||
let arg = hi.handle.as_tensor_arg(&hi.global_ir.shape, hi.line_size);
|
||||
inputs.tensors.push(GlobalTensorArg::new(
|
||||
arg,
|
||||
hi.precision.into_elem(),
|
||||
false,
|
||||
hi.handle.required_address_type(),
|
||||
));
|
||||
}
|
||||
HandleInput::QuantParams(hi) => {
|
||||
let arg = hi.handle.as_tensor_arg(&hi.shape, 1);
|
||||
inputs.tensors.push(GlobalTensorArg::new(
|
||||
arg,
|
||||
hi.precision.into_elem(),
|
||||
false,
|
||||
hi.handle.required_address_type(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_outputs<'s, BT: CubeElement, R: Runtime>(
|
||||
handle_outputs: &'s [HandleOutput<R>],
|
||||
outputs: &mut GlobalArgsLaunch<'s, R>,
|
||||
#[allow(unused_variables)] tune_output: &mut TuneOutput<R>,
|
||||
) {
|
||||
for item in handle_outputs.iter() {
|
||||
match item {
|
||||
HandleOutput::Alias {
|
||||
input_pos,
|
||||
precision,
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
debug_info,
|
||||
} => {
|
||||
outputs.tensors.push(GlobalTensorArg::new(
|
||||
TensorArg::alias(*input_pos),
|
||||
precision.into_elem(),
|
||||
false,
|
||||
AddressType::default(),
|
||||
));
|
||||
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
if let TuneOutput::Checked { handles, .. } = tune_output {
|
||||
handles.insert(
|
||||
debug_info.relative_id,
|
||||
(debug_info.global_shape.clone(), debug_info.handle.clone()),
|
||||
);
|
||||
}
|
||||
}
|
||||
HandleOutput::Owned {
|
||||
precision,
|
||||
handle,
|
||||
global_shape,
|
||||
vectorization: line_size,
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
relative_id,
|
||||
..
|
||||
} => {
|
||||
let arg = handle.as_tensor_arg(global_shape, *line_size);
|
||||
|
||||
let elem = match precision {
|
||||
FuseType::Bool => match elem_dtype::<BT>() {
|
||||
DType::U32 => FuseType::U32.into_elem(),
|
||||
DType::U8 => FuseType::U8.into_elem(),
|
||||
_ => todo!(),
|
||||
},
|
||||
_ => precision.into_elem(),
|
||||
};
|
||||
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
if let TuneOutput::Checked { handles, .. } = tune_output {
|
||||
handles.insert(*relative_id, (global_shape.clone(), handle.clone()));
|
||||
}
|
||||
|
||||
outputs.tensors.push(GlobalTensorArg::new(
|
||||
arg,
|
||||
elem,
|
||||
false,
|
||||
handle.required_address_type(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_scalars<'h, R: Runtime>(
|
||||
scalars: impl Iterator<Item = &'h (FuseType, u64)>,
|
||||
views: impl DoubleEndedIterator<Item = &'h TensorView>,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
inputs: &mut GlobalArgsLaunch<'h, R>,
|
||||
) {
|
||||
for (precision, id) in scalars {
|
||||
let dtype = precision.into_type();
|
||||
match context.scalars.get(&ScalarId { value: *id }) {
|
||||
Some(scalar) => match scalar {
|
||||
ScalarIr::Float(val) => inputs.scalars.push(InputScalar::new(*val, dtype)),
|
||||
ScalarIr::Int(val) => inputs.scalars.push(InputScalar::new(*val, dtype)),
|
||||
ScalarIr::UInt(val) => inputs.scalars.push(InputScalar::new(*val, dtype)),
|
||||
ScalarIr::Bool(val) => inputs.scalars.push(InputScalar::new(*val as u8, dtype)),
|
||||
},
|
||||
None => panic!("Scalar ID not found"),
|
||||
}
|
||||
}
|
||||
|
||||
for relative in views {
|
||||
if let TensorView::Reshape { reshaped, .. } = relative {
|
||||
let global = context.tensors.get(reshaped).unwrap();
|
||||
|
||||
for shape in global.shape.iter() {
|
||||
inputs.reshapes.push(ScalarArg::new(*shape));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
use super::{BlockPlan, HandleInput, InputReference};
|
||||
use super::{LaunchPlan, NormalHandleInput, PotentialInplace};
|
||||
use crate::CubeFusionHandle;
|
||||
use crate::engine::launch::{QuantParamsHandleInput, QuantValuesHandleInput};
|
||||
use crate::engine::trace::block::FuseBlock;
|
||||
use crate::engine::trace::{FuseResources, RegisterTensor, TensorView};
|
||||
use burn_fusion::stream::Context;
|
||||
use burn_ir::{TensorIr, TensorStatus};
|
||||
use burn_std::quantization::params_shape;
|
||||
use cubecl::Runtime;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Fetch and register [input handles](HandleInput). Also identifies potential inputs that
|
||||
/// can be used inplace and/or as the [reference layout](super::super::ir::RefLayout).
|
||||
pub struct InputPlanner<'a, R: Runtime> {
|
||||
resources: &'a FuseResources,
|
||||
blocks: &'a Vec<FuseBlock>,
|
||||
_r: PhantomData<R>,
|
||||
}
|
||||
|
||||
impl<'a, R: Runtime> InputPlanner<'a, R> {
|
||||
pub fn new(resources: &'a FuseResources, blocks: &'a Vec<FuseBlock>) -> Self {
|
||||
Self {
|
||||
resources,
|
||||
blocks,
|
||||
_r: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(self, context: &mut Context<'_, CubeFusionHandle<R>>, plan: &mut LaunchPlan<'a, R>) {
|
||||
for (pos, input) in self.resources.inputs.iter().enumerate() {
|
||||
match input {
|
||||
RegisterTensor::Normal(tensor_relative, precision) => {
|
||||
let mut tensor_global =
|
||||
context.tensors.get(&tensor_relative.id).unwrap().clone();
|
||||
let handle = context
|
||||
.handles
|
||||
.get_handle(&tensor_global.id, &TensorStatus::ReadOnly);
|
||||
|
||||
if let TensorStatus::ReadWrite = tensor_relative.status {
|
||||
plan.cleared.push(tensor_global.id);
|
||||
}
|
||||
|
||||
let mut new_strides = handle.strides.clone();
|
||||
|
||||
self.analyze(plan, pos, tensor_relative, &handle);
|
||||
|
||||
if tensor_global.shape.rank() < plan.rank {
|
||||
let num_elem: usize = tensor_global.shape.iter().product();
|
||||
for _ in 0..(plan.rank - tensor_global.shape.rank()) {
|
||||
tensor_global.shape.insert(0, 1);
|
||||
new_strides.insert(0, num_elem);
|
||||
}
|
||||
}
|
||||
|
||||
plan.handle_inputs
|
||||
.push(HandleInput::Normal(NormalHandleInput::new(
|
||||
tensor_global,
|
||||
tensor_relative,
|
||||
*precision,
|
||||
handle,
|
||||
new_strides,
|
||||
)));
|
||||
}
|
||||
RegisterTensor::QuantValues(tensor_relative) => {
|
||||
let tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone();
|
||||
let handle = context
|
||||
.handles
|
||||
.get_handle(&tensor_global.id, &TensorStatus::ReadOnly);
|
||||
|
||||
let scheme = match tensor_relative.dtype {
|
||||
burn_std::DType::QFloat(scheme) => scheme,
|
||||
_ => unreachable!("Can't have quant data without QFloat"),
|
||||
};
|
||||
let params = handle.params(scheme).unwrap();
|
||||
let precision = tensor_relative.dtype.into();
|
||||
let precision_scales = params.dtype.into();
|
||||
|
||||
let global_shape = tensor_global.shape.clone();
|
||||
let shape_params = params_shape(&global_shape, scheme.level);
|
||||
plan.handle_inputs
|
||||
.push(HandleInput::QuantValues(QuantValuesHandleInput {
|
||||
relative_id: tensor_relative.id,
|
||||
global_ir: tensor_global,
|
||||
precision,
|
||||
handle,
|
||||
line_size: 1,
|
||||
}));
|
||||
|
||||
plan.handle_inputs
|
||||
.push(HandleInput::QuantParams(QuantParamsHandleInput {
|
||||
precision: precision_scales,
|
||||
handle: params,
|
||||
shape: shape_params,
|
||||
}));
|
||||
}
|
||||
RegisterTensor::QuantParams(_) => {
|
||||
// It is registered at the same time as quant data.
|
||||
// The order is important and the index in the vector as well, so that's why we
|
||||
// have QuantParams.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn analyze(
|
||||
&self,
|
||||
plan: &mut LaunchPlan<'a, R>,
|
||||
pos: usize,
|
||||
tensor_relative: &'a TensorIr,
|
||||
handle: &CubeFusionHandle<R>,
|
||||
) {
|
||||
if !self
|
||||
.resources
|
||||
.inputs_unhandled
|
||||
.contains(&tensor_relative.id)
|
||||
{
|
||||
let mut is_a_view = false;
|
||||
// For each view we try to see if it's not possible to set it as a reference input.
|
||||
for view in self.resources.views.iter() {
|
||||
for (block_plan, block) in plan.blocks.iter_mut().zip(self.blocks) {
|
||||
is_a_view = is_a_view
|
||||
|| Self::analyze_view(pos, tensor_relative, block, block_plan, view);
|
||||
}
|
||||
}
|
||||
|
||||
if !is_a_view {
|
||||
self.analyze_normal(plan, pos, tensor_relative, handle);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Analyzes if the given tensor can be used inplace in one of the block.
|
||||
fn analyze_normal(
|
||||
&self,
|
||||
plan: &mut LaunchPlan<'a, R>,
|
||||
pos: usize,
|
||||
tensor_relative: &'a TensorIr,
|
||||
handle: &CubeFusionHandle<R>,
|
||||
) {
|
||||
enum BlockInplaceSelection {
|
||||
Notinit,
|
||||
/// The block reads the input, and therefore can use it for inplace.
|
||||
Selected(usize),
|
||||
/// The same input is used in multiple blocks.
|
||||
Unavailable,
|
||||
}
|
||||
|
||||
let mut block_inplace_selection = BlockInplaceSelection::Notinit;
|
||||
|
||||
for (idx, block) in plan.blocks.iter().enumerate() {
|
||||
if block.reads.contains_key(&tensor_relative.id) {
|
||||
match block_inplace_selection {
|
||||
BlockInplaceSelection::Notinit => {
|
||||
block_inplace_selection = BlockInplaceSelection::Selected(idx);
|
||||
}
|
||||
BlockInplaceSelection::Selected(_) => {
|
||||
block_inplace_selection = BlockInplaceSelection::Unavailable;
|
||||
}
|
||||
BlockInplaceSelection::Unavailable => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let BlockInplaceSelection::Selected(idx) = block_inplace_selection {
|
||||
if self.blocks[idx].shape_ref != tensor_relative.shape {
|
||||
return;
|
||||
}
|
||||
|
||||
let block_plan = &mut plan.blocks[idx];
|
||||
if tensor_relative.status == TensorStatus::ReadWrite {
|
||||
if self.blocks[idx].settings.inplace && handle.handle.can_mut() {
|
||||
block_plan.potential_inplaces.push(PotentialInplace {
|
||||
input_pos: pos,
|
||||
tensor_relative,
|
||||
strides: handle.strides.clone(),
|
||||
});
|
||||
}
|
||||
// Inplace tensors are normally really good as the reference layout, since
|
||||
// it's normally better to be based on writes rather than on reads.
|
||||
block_plan.potential_reference_input =
|
||||
Some(InputReference::Normal { input_pos: pos });
|
||||
} else {
|
||||
block_plan.potential_reference_input =
|
||||
Some(InputReference::Normal { input_pos: pos });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Analyzes if the given tensor is also the view provided, and check if it can be used as the reference layout
|
||||
/// for the given block.
|
||||
fn analyze_view(
|
||||
pos: usize,
|
||||
tensor_relative: &'a TensorIr,
|
||||
block: &FuseBlock,
|
||||
block_plan: &mut BlockPlan<'a>,
|
||||
view: &TensorView,
|
||||
) -> bool {
|
||||
match view {
|
||||
TensorView::Reshape {
|
||||
reshaped,
|
||||
original,
|
||||
reshape_pos,
|
||||
shape_relative,
|
||||
} => {
|
||||
if original == &tensor_relative.id || reshaped == &tensor_relative.id {
|
||||
if block_plan.potential_reference_input.is_none()
|
||||
&& shape_relative == &block.shape_ref
|
||||
{
|
||||
block_plan.potential_reference_input = Some(InputReference::Reshaped {
|
||||
reshape_pos: *reshape_pos,
|
||||
});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
TensorView::SwapDims {
|
||||
swapped,
|
||||
original,
|
||||
dims,
|
||||
..
|
||||
} => {
|
||||
if swapped == &tensor_relative.id {
|
||||
return true;
|
||||
}
|
||||
|
||||
if original == &tensor_relative.id {
|
||||
let shape = tensor_relative
|
||||
.shape
|
||||
.clone()
|
||||
.swapped(dims.0, dims.1)
|
||||
.unwrap();
|
||||
|
||||
if block_plan.potential_reference_input.is_none() && shape == block.shape_ref {
|
||||
block_plan.potential_reference_input = Some(InputReference::SwapDims {
|
||||
original_pos: pos,
|
||||
dims: *dims,
|
||||
});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
pub(crate) mod executor;
|
||||
pub(crate) mod input;
|
||||
pub(crate) mod output;
|
||||
pub(crate) mod runner;
|
||||
pub(crate) mod vectorization;
|
||||
|
||||
pub(crate) mod plan;
|
||||
pub use plan::*;
|
||||
|
||||
mod base;
|
||||
pub use base::*;
|
||||
@@ -0,0 +1,696 @@
|
||||
use super::{
|
||||
super::codegen::ir::FuseType, BlockPlan, HandleOutput, InputReference, LaunchPlan,
|
||||
NormalHandleInput, ReferenceSelection,
|
||||
};
|
||||
use crate::{
|
||||
CubeFusionHandle, elem_dtype,
|
||||
engine::{
|
||||
codegen::ir::{FuseArg, FuseOp, LayoutInfo},
|
||||
launch::HandleInput,
|
||||
settings::RefLayoutSetting,
|
||||
trace::{FuseResources, RegisterTensor, RuntimeLayout, TensorView, block::FuseBlock},
|
||||
},
|
||||
strides_dyn_rank,
|
||||
};
|
||||
use burn_fusion::stream::Context;
|
||||
use burn_ir::{TensorId, TensorIr};
|
||||
use burn_std::{DType, Shape};
|
||||
use burn_std::{
|
||||
Strides,
|
||||
tensor::{ReshapeAction, contiguous_strides, is_contiguous, reshape_action},
|
||||
};
|
||||
use cubecl::{CubeElement, Runtime, client::ComputeClient, ir::StorageType};
|
||||
|
||||
/// Create or reuse handles for the outputs.
|
||||
///
|
||||
/// It is also responsible to select the reference tensor.
|
||||
pub struct OutputPlanner<'a, R: Runtime> {
|
||||
resources: &'a FuseResources,
|
||||
outputs_sorted: Vec<OutputSorted<'a>>,
|
||||
handles: Vec<Option<HandleOutput<R>>>,
|
||||
globals: Vec<Option<TensorIr>>,
|
||||
blocks: &'a Vec<FuseBlock>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct OutputSorted<'a> {
|
||||
pos_original: usize,
|
||||
precision: FuseType,
|
||||
tensor_relative: &'a TensorIr,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum OutputKind {
|
||||
Normal,
|
||||
Inplace {
|
||||
/// The position in the potential inplace vector
|
||||
input_pos: usize,
|
||||
},
|
||||
Transform(TensorView),
|
||||
}
|
||||
|
||||
impl<'a, R: Runtime> OutputPlanner<'a, R> {
|
||||
pub fn new(resources: &'a FuseResources, blocks: &'a Vec<FuseBlock>) -> Self {
|
||||
let mut outputs_sorted: Vec<_> = resources
|
||||
.outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(pos, entry)| match entry {
|
||||
RegisterTensor::Normal(ir, p) => Some((pos, ir, p)),
|
||||
RegisterTensor::QuantValues(_) => None,
|
||||
RegisterTensor::QuantParams(_) => None,
|
||||
})
|
||||
.map(|(pos, tensor, precision)| OutputSorted {
|
||||
pos_original: pos,
|
||||
precision: *precision,
|
||||
tensor_relative: tensor,
|
||||
})
|
||||
.collect();
|
||||
|
||||
outputs_sorted.sort_by(|a, b| {
|
||||
let a_val: usize = a.tensor_relative.shape.iter().sum();
|
||||
let b_val: usize = b.tensor_relative.shape.iter().sum();
|
||||
|
||||
b_val.cmp(&a_val)
|
||||
});
|
||||
|
||||
let mut handles = Vec::with_capacity(resources.outputs.len());
|
||||
let mut globals = Vec::with_capacity(resources.outputs.len());
|
||||
|
||||
for _ in 0..resources.outputs.len() {
|
||||
handles.push(None);
|
||||
globals.push(None);
|
||||
}
|
||||
|
||||
Self {
|
||||
resources,
|
||||
outputs_sorted,
|
||||
handles,
|
||||
globals,
|
||||
blocks,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run<BT: CubeElement>(
|
||||
mut self,
|
||||
client: &ComputeClient<R>,
|
||||
device: &R::Device,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
plan: &mut LaunchPlan<'a, R>,
|
||||
) {
|
||||
// So that we can borrow self during the iteration.
|
||||
let mut outputs = Vec::new();
|
||||
core::mem::swap(&mut outputs, &mut self.outputs_sorted);
|
||||
|
||||
for output in outputs.into_iter() {
|
||||
let tensor_global = context
|
||||
.tensors
|
||||
.get(&output.tensor_relative.id)
|
||||
.unwrap()
|
||||
.clone();
|
||||
let strides = strides_dyn_rank(&tensor_global.shape);
|
||||
let (kind, block_idx) = self.output_kind(plan, &tensor_global, &output, &strides);
|
||||
|
||||
match kind {
|
||||
OutputKind::Inplace { input_pos } => {
|
||||
self.inplace_output(context, plan, output, tensor_global, input_pos, block_idx);
|
||||
}
|
||||
OutputKind::Normal => {
|
||||
self.normal_output::<BT>(
|
||||
client,
|
||||
device,
|
||||
context,
|
||||
plan,
|
||||
output,
|
||||
tensor_global,
|
||||
strides,
|
||||
block_idx,
|
||||
);
|
||||
}
|
||||
OutputKind::Transform(TensorView::Reshape { original, .. }) => {
|
||||
self.reshaped_output::<BT>(
|
||||
client,
|
||||
device,
|
||||
context,
|
||||
plan,
|
||||
output,
|
||||
tensor_global,
|
||||
strides,
|
||||
original,
|
||||
block_idx,
|
||||
);
|
||||
}
|
||||
OutputKind::Transform(TensorView::SwapDims { original, dims, .. }) => {
|
||||
self.swapped_dims_output::<BT>(
|
||||
client,
|
||||
device,
|
||||
context,
|
||||
plan,
|
||||
output,
|
||||
tensor_global,
|
||||
original,
|
||||
dims,
|
||||
block_idx,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (handle, global) in self.handles.into_iter().zip(self.globals.into_iter()) {
|
||||
plan.handle_outputs.push(handle.unwrap());
|
||||
plan.global_outputs.push(global.unwrap());
|
||||
}
|
||||
|
||||
for i in 0..plan.blocks.len() {
|
||||
if !plan.blocks[i].reference.is_found() {
|
||||
match self.blocks[i].settings.ref_layout {
|
||||
RefLayoutSetting::SameAsBlock { block_pos } => {
|
||||
plan.blocks[i].reference =
|
||||
plan.blocks[block_pos as usize].reference.clone();
|
||||
}
|
||||
_ => {
|
||||
let new_runtime = Self::select_reference_from_inputs(
|
||||
&self.blocks[i],
|
||||
&mut plan.blocks[i],
|
||||
&plan.handle_inputs,
|
||||
);
|
||||
|
||||
if let Some(shape) = new_runtime {
|
||||
let pos = plan.runtime_layouts.len();
|
||||
let mut shape_global = shape.clone();
|
||||
for (i, s) in shape.iter().enumerate() {
|
||||
shape_global[i] = *context.shapes_relative2global.get(s).unwrap();
|
||||
}
|
||||
|
||||
let strides = strides_dyn_rank(&shape_global);
|
||||
|
||||
plan.blocks[i].reference = ReferenceSelection::Runtime { pos };
|
||||
plan.runtime_layouts.push(RuntimeLayout {
|
||||
shape: shape_global,
|
||||
strides,
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
} else {
|
||||
Self::add_layout_info_inputs(&mut plan.blocks[i], &plan.handle_inputs);
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure dropped are correctly executed.
|
||||
for id in self.resources.dropped.iter() {
|
||||
if let Some(tensor_global) = context.tensors.get(id) {
|
||||
context.handles.remove_handle(tensor_global.id);
|
||||
}
|
||||
}
|
||||
for id in plan.cleared.drain(..) {
|
||||
context.handles.remove_handle(id);
|
||||
}
|
||||
}
|
||||
|
||||
fn select_reference_from_inputs(
|
||||
block: &FuseBlock,
|
||||
block_plan: &mut BlockPlan<'_>,
|
||||
handle_inputs: &[HandleInput<R>],
|
||||
) -> Option<Shape> {
|
||||
if let Some(input_ref) = block_plan.potential_reference_input.take() {
|
||||
match input_ref {
|
||||
InputReference::Normal { input_pos } => {
|
||||
let reference = handle_inputs
|
||||
.get(input_pos)
|
||||
.unwrap()
|
||||
.as_normal()
|
||||
.expect("Quant can't be used as inplace");
|
||||
|
||||
let set_ref_as_concrete = |block: &mut BlockPlan<'_>| {
|
||||
block.reference = ReferenceSelection::Concrete {
|
||||
layout: FuseArg::Input(
|
||||
input_pos,
|
||||
reference.precision,
|
||||
LayoutInfo::IsRef,
|
||||
),
|
||||
shape: reference.global_ir.shape.clone(),
|
||||
strides: reference.handle.strides.clone(),
|
||||
};
|
||||
};
|
||||
|
||||
let set_ref_as_virtual = |block: &mut BlockPlan<'_>| {
|
||||
block.reference = ReferenceSelection::VirtualShape {
|
||||
original: FuseArg::Input(
|
||||
input_pos,
|
||||
reference.precision,
|
||||
LayoutInfo::Unknown,
|
||||
),
|
||||
shape: reference.global_ir.shape.clone(),
|
||||
strides: contiguous_strides(&reference.global_ir.shape),
|
||||
};
|
||||
};
|
||||
|
||||
match block.settings.ref_layout {
|
||||
RefLayoutSetting::Any => set_ref_as_concrete(block_plan),
|
||||
RefLayoutSetting::SameAsBlock { .. } => {
|
||||
// Skip set ref.
|
||||
}
|
||||
RefLayoutSetting::OnlyContiguous => {
|
||||
if is_contiguous(&reference.global_ir.shape, &reference.handle.strides)
|
||||
{
|
||||
set_ref_as_concrete(block_plan)
|
||||
} else {
|
||||
set_ref_as_virtual(block_plan)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self::add_layout_info_inputs(block_plan, handle_inputs);
|
||||
}
|
||||
InputReference::SwapDims { original_pos, dims } => {
|
||||
let reference = handle_inputs
|
||||
.get(original_pos)
|
||||
.unwrap()
|
||||
.as_normal()
|
||||
.expect("Quant can't be used in swap dims operation");
|
||||
block_plan.reference = ReferenceSelection::SwapDims {
|
||||
original: FuseArg::Input(
|
||||
original_pos,
|
||||
reference.precision,
|
||||
LayoutInfo::Unknown,
|
||||
),
|
||||
dims,
|
||||
};
|
||||
}
|
||||
InputReference::Reshaped { reshape_pos } => {
|
||||
block_plan.reference = ReferenceSelection::Reshaped { reshape_pos };
|
||||
}
|
||||
};
|
||||
None
|
||||
} else {
|
||||
Some(block.shape_ref.clone())
|
||||
}
|
||||
}
|
||||
|
||||
fn add_layout_info_inputs(block: &mut BlockPlan<'_>, handle_inputs: &[HandleInput<R>]) {
|
||||
for hi in handle_inputs.iter().filter_map(|h| match h {
|
||||
HandleInput::Normal(input) => Some(input),
|
||||
_ => None,
|
||||
}) {
|
||||
let (strides, shape) = match &block.reference {
|
||||
ReferenceSelection::Concrete { strides, shape, .. }
|
||||
| ReferenceSelection::VirtualShape { strides, shape, .. } => (strides, shape),
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
if strides == &hi.handle.strides
|
||||
&& shape == &hi.global_ir.shape
|
||||
&& let Some(ops) = block.reads.get_mut(&hi.relative_id)
|
||||
{
|
||||
for op in ops.iter_mut() {
|
||||
if let FuseOp::Assign(op) = op {
|
||||
op.input.add_layout_info(LayoutInfo::SameAsRef);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn output_kind(
|
||||
&self,
|
||||
plan: &mut LaunchPlan<'a, R>,
|
||||
tensor_global: &TensorIr,
|
||||
output: &OutputSorted,
|
||||
strides: &[usize],
|
||||
) -> (OutputKind, usize) {
|
||||
let mut block_idx = None;
|
||||
for (i, block) in plan.blocks.iter().enumerate() {
|
||||
if block.writes.contains_key(&output.tensor_relative.id) {
|
||||
block_idx = Some(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
let block_idx = block_idx.unwrap();
|
||||
|
||||
if let Some(transform) = self.resources.views.iter().find(|v| match v {
|
||||
TensorView::Reshape { reshaped, .. } => reshaped == &output.tensor_relative.id,
|
||||
TensorView::SwapDims { swapped, .. } => swapped == &output.tensor_relative.id,
|
||||
}) {
|
||||
return (OutputKind::Transform(transform.clone()), block_idx);
|
||||
}
|
||||
|
||||
let block = &plan.blocks[block_idx];
|
||||
let kind = block
|
||||
.potential_inplaces
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_pos, pi)| {
|
||||
pi.tensor_relative.dtype == tensor_global.dtype
|
||||
&& pi.tensor_relative.shape == output.tensor_relative.shape
|
||||
&& &*pi.strides == strides
|
||||
&& block.reference.compatible_strides_for_inplace(strides)
|
||||
})
|
||||
.map(|(pos, _)| OutputKind::Inplace { input_pos: pos })
|
||||
.unwrap_or(OutputKind::Normal);
|
||||
|
||||
(kind, block_idx)
|
||||
}
|
||||
|
||||
fn inplace_output(
|
||||
&mut self,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
plan: &mut LaunchPlan<'a, R>,
|
||||
output: OutputSorted,
|
||||
tensor_global: TensorIr,
|
||||
input_index: usize,
|
||||
block_idx: usize,
|
||||
) {
|
||||
let block = &mut plan.blocks[block_idx];
|
||||
let potential_inplace = block.potential_inplaces.remove(input_index);
|
||||
let handle_input = match plan.handle_inputs.get(potential_inplace.input_pos).unwrap() {
|
||||
HandleInput::Normal(handle) => handle,
|
||||
_ => {
|
||||
unreachable!("Quant tensor handle can't be used inplace yet.")
|
||||
}
|
||||
};
|
||||
|
||||
if !block.reference.is_found()
|
||||
&& !matches!(
|
||||
self.blocks[block_idx].settings.ref_layout,
|
||||
RefLayoutSetting::SameAsBlock { .. }
|
||||
)
|
||||
{
|
||||
let index_input = self
|
||||
.resources
|
||||
.inputs
|
||||
.get_index(potential_inplace.tensor_relative.id)
|
||||
.unwrap();
|
||||
|
||||
block.reference = ReferenceSelection::Concrete {
|
||||
layout: FuseArg::Input(index_input, output.precision, LayoutInfo::IsRef),
|
||||
shape: tensor_global.shape.clone(),
|
||||
strides: handle_input.handle.strides.clone(),
|
||||
};
|
||||
|
||||
if let Some(ops) = block.reads.get_mut(&handle_input.relative_id) {
|
||||
for op in ops.iter_mut() {
|
||||
if let FuseOp::Assign(op) = op {
|
||||
op.input.add_layout_info(LayoutInfo::IsRef);
|
||||
break;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ops) = block.writes.get_mut(&output.tensor_relative.id) {
|
||||
for op in ops {
|
||||
if let FuseOp::Assign(op) = op {
|
||||
op.out.add_layout_info(LayoutInfo::IsRef);
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
} else {
|
||||
// Already validated, necessary for correctness.
|
||||
if let Some(ops) = block.writes.get_mut(&output.tensor_relative.id) {
|
||||
for op in ops {
|
||||
if let FuseOp::Assign(op) = op {
|
||||
op.out.add_layout_info(LayoutInfo::SameAsRef);
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
context
|
||||
.handles
|
||||
.register_handle(tensor_global.id, handle_input.handle.clone());
|
||||
|
||||
self.handles[output.pos_original] = Some(HandleOutput::Alias {
|
||||
input_pos: potential_inplace.input_pos,
|
||||
precision: output.precision,
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
debug_info: super::HandleOutputAliasDebugInfo {
|
||||
relative_id: output.tensor_relative.id,
|
||||
handle: handle_input.handle.clone(),
|
||||
global_shape: tensor_global.shape.dims.clone(),
|
||||
},
|
||||
});
|
||||
self.globals[output.pos_original] = Some(tensor_global);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn normal_output<BT: CubeElement>(
|
||||
&mut self,
|
||||
client: &ComputeClient<R>,
|
||||
device: &R::Device,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
plan: &mut LaunchPlan<'a, R>,
|
||||
output: OutputSorted,
|
||||
tensor_global: TensorIr,
|
||||
strides: Strides,
|
||||
block_idx: usize,
|
||||
) {
|
||||
let block = &mut plan.blocks[block_idx];
|
||||
|
||||
if !block.reference.is_found()
|
||||
&& self.blocks[block_idx].shape_ref == output.tensor_relative.shape
|
||||
&& !matches!(
|
||||
self.blocks[block_idx].settings.ref_layout,
|
||||
RefLayoutSetting::SameAsBlock { .. }
|
||||
)
|
||||
{
|
||||
block.reference = ReferenceSelection::Concrete {
|
||||
layout: FuseArg::Output(output.pos_original, output.precision, LayoutInfo::IsRef),
|
||||
shape: tensor_global.shape.clone(),
|
||||
strides: strides.clone(),
|
||||
};
|
||||
|
||||
// Sometimes outputs that are manually handled don't have any write registered.
|
||||
if let Some(ops) = block.writes.get_mut(&output.tensor_relative.id) {
|
||||
for op in ops {
|
||||
if let FuseOp::Assign(op) = op {
|
||||
op.out.add_layout_info(LayoutInfo::IsRef);
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
} else if let ReferenceSelection::Concrete {
|
||||
shape: ref_shape,
|
||||
strides: ref_strides,
|
||||
..
|
||||
} = &block.reference
|
||||
&& ref_strides == &strides
|
||||
&& ref_shape == &tensor_global.shape
|
||||
&& let Some(ops) = block.writes.get_mut(&output.tensor_relative.id)
|
||||
{
|
||||
for op in ops {
|
||||
if let FuseOp::Assign(op) = op {
|
||||
op.out.add_layout_info(LayoutInfo::SameAsRef);
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// We encode bool tensors as `B`.
|
||||
let dtype = match tensor_global.dtype {
|
||||
DType::Bool => elem_dtype::<BT>(),
|
||||
_ => tensor_global.dtype,
|
||||
};
|
||||
let size = tensor_global.shape.iter().product::<usize>() * StorageType::from(dtype).size();
|
||||
|
||||
let handle = CubeFusionHandle {
|
||||
client: client.clone(),
|
||||
handle: client.empty(size),
|
||||
device: device.clone(),
|
||||
strides,
|
||||
dtype,
|
||||
qparams: None,
|
||||
};
|
||||
|
||||
plan.rank = usize::max(tensor_global.shape.rank(), plan.rank);
|
||||
context
|
||||
.handles
|
||||
.register_handle(tensor_global.id, handle.clone());
|
||||
|
||||
self.handles[output.pos_original] = Some(HandleOutput::Owned {
|
||||
precision: output.precision,
|
||||
handle,
|
||||
global_shape: tensor_global.shape.clone(),
|
||||
global_id: tensor_global.id,
|
||||
relative_id: output.tensor_relative.id,
|
||||
vectorization: 1,
|
||||
});
|
||||
self.globals[output.pos_original] = Some(tensor_global);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn reshaped_output<BT: CubeElement>(
|
||||
&mut self,
|
||||
client: &ComputeClient<R>,
|
||||
device: &R::Device,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
plan: &mut LaunchPlan<'a, R>,
|
||||
output: OutputSorted,
|
||||
tensor_global: TensorIr,
|
||||
strides: Strides,
|
||||
original: TensorId,
|
||||
block_idx: usize,
|
||||
) {
|
||||
let block = &mut plan.blocks[block_idx];
|
||||
|
||||
let (pos_input, original_handle) = Self::find_child_input(&plan.handle_inputs, original);
|
||||
|
||||
// We encode bool tensors as `B`.
|
||||
let dtype = match tensor_global.dtype {
|
||||
DType::Bool => elem_dtype::<BT>(),
|
||||
_ => tensor_global.dtype,
|
||||
};
|
||||
|
||||
let action = reshape_action(
|
||||
&original_handle.global_ir.shape,
|
||||
&original_handle.handle.strides,
|
||||
&tensor_global.shape,
|
||||
);
|
||||
|
||||
let update = match action {
|
||||
ReshapeAction::UpdateStrides { strides } => Some(strides),
|
||||
ReshapeAction::NoChange => Some(original_handle.handle.strides.clone()),
|
||||
ReshapeAction::Recompute => None,
|
||||
};
|
||||
|
||||
match update {
|
||||
Some(strides) => {
|
||||
// We modify the metadata instead.
|
||||
remove_concrete_write(block, output.tensor_relative.id, output.pos_original);
|
||||
|
||||
let handle = CubeFusionHandle {
|
||||
client: client.clone(),
|
||||
handle: original_handle.handle.handle.clone(),
|
||||
device: device.clone(),
|
||||
strides,
|
||||
dtype,
|
||||
qparams: original_handle.handle.qparams.clone(),
|
||||
};
|
||||
context
|
||||
.handles
|
||||
.register_handle(tensor_global.id, handle.clone());
|
||||
|
||||
// IT will never be access, just a way to keep the original position working.
|
||||
self.handles[output.pos_original] = Some(HandleOutput::Alias {
|
||||
input_pos: pos_input,
|
||||
precision: output.precision,
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
debug_info: super::HandleOutputAliasDebugInfo {
|
||||
relative_id: output.tensor_relative.id,
|
||||
handle: handle.clone(),
|
||||
global_shape: tensor_global.shape.dims.clone(),
|
||||
},
|
||||
});
|
||||
self.globals[output.pos_original] = Some(tensor_global);
|
||||
}
|
||||
None => {
|
||||
self.normal_output::<BT>(
|
||||
client,
|
||||
device,
|
||||
context,
|
||||
plan,
|
||||
output,
|
||||
tensor_global,
|
||||
strides,
|
||||
block_idx,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn swapped_dims_output<BT: CubeElement>(
|
||||
&mut self,
|
||||
client: &ComputeClient<R>,
|
||||
device: &R::Device,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
plan: &mut LaunchPlan<'a, R>,
|
||||
output: OutputSorted,
|
||||
tensor_global: TensorIr,
|
||||
original: TensorId,
|
||||
dims: (usize, usize),
|
||||
block_idx: usize,
|
||||
) {
|
||||
let block = &mut plan.blocks[block_idx];
|
||||
let (pos_input, original_handle) = Self::find_child_input(&plan.handle_inputs, original);
|
||||
|
||||
// We encode bool tensors as `B`.
|
||||
let dtype = match tensor_global.dtype {
|
||||
DType::Bool => elem_dtype::<BT>(),
|
||||
_ => tensor_global.dtype,
|
||||
};
|
||||
|
||||
// TODO: Check if we can also remove the read, if we have a dead partial graph.
|
||||
//
|
||||
// We modify the metadata instead.
|
||||
remove_concrete_write(block, output.tensor_relative.id, output.pos_original);
|
||||
|
||||
let strides = original_handle.handle.strides.clone();
|
||||
|
||||
let mut handle = CubeFusionHandle {
|
||||
client: client.clone(),
|
||||
handle: original_handle.handle.handle.clone(),
|
||||
device: device.clone(),
|
||||
strides,
|
||||
dtype,
|
||||
qparams: original_handle.handle.qparams.clone(),
|
||||
};
|
||||
handle.strides.swap(dims.0, dims.1);
|
||||
|
||||
context
|
||||
.handles
|
||||
.register_handle(tensor_global.id, handle.clone());
|
||||
|
||||
// IT will never be access, just a way to keep the original position working.
|
||||
self.handles[output.pos_original] = Some(HandleOutput::Alias {
|
||||
input_pos: pos_input,
|
||||
precision: output.precision,
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
debug_info: super::HandleOutputAliasDebugInfo {
|
||||
relative_id: output.tensor_relative.id,
|
||||
handle: handle.clone(),
|
||||
global_shape: tensor_global.shape.dims.clone(),
|
||||
},
|
||||
});
|
||||
self.globals[output.pos_original] = Some(tensor_global);
|
||||
}
|
||||
|
||||
fn find_child_input(
|
||||
handle_inputs: &[HandleInput<R>],
|
||||
original: TensorId,
|
||||
) -> (usize, &NormalHandleInput<R>) {
|
||||
handle_inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find_map(|(pi, handle)| match handle {
|
||||
HandleInput::Normal(handle) => match handle.relative_id == original {
|
||||
true => Some((pi, handle)),
|
||||
false => None,
|
||||
},
|
||||
_ => None, // Quant tensor can't be reshaped.
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_concrete_write(block: &mut BlockPlan, id: TensorId, output_pos: usize) {
|
||||
let ops = block.writes.remove(&id);
|
||||
|
||||
if let Some(ops) = ops {
|
||||
let mut keep = Vec::with_capacity(ops.len());
|
||||
|
||||
for op in ops {
|
||||
if let FuseOp::Assign(args) = &op {
|
||||
if let FuseArg::Output(pos, ..) = args.out {
|
||||
if pos != output_pos {
|
||||
keep.push(op);
|
||||
}
|
||||
} else {
|
||||
keep.push(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
block.writes.insert(id, keep);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,273 @@
|
||||
use crate::{
|
||||
CubeFusionHandle,
|
||||
engine::{
|
||||
codegen::ir::{FuseArg, FuseOp, FuseType},
|
||||
launch::vectorization::Vect,
|
||||
trace::{RuntimeLayout, block::FuseBlock},
|
||||
},
|
||||
};
|
||||
use burn_ir::{TensorId, TensorIr};
|
||||
use burn_std::{Shape, Strides};
|
||||
use cubecl::{Runtime, ir::LineSize};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
/// The `LaunchPlan` is responsible for aggregating all runtime information required
|
||||
/// to dispatch a fused kernel.
|
||||
///
|
||||
/// It maps abstract IR tensors to memory handles, manages vectorization
|
||||
/// strategies, and tracks layout transformations.
|
||||
#[derive(Debug)]
|
||||
pub struct LaunchPlan<'a, R: Runtime> {
|
||||
/// The IR representation of tensors that are results of the fusion.
|
||||
pub global_outputs: Vec<TensorIr>,
|
||||
/// Memory handles and metadata for all input tensors.
|
||||
pub handle_inputs: Vec<HandleInput<R>>,
|
||||
/// Memory handles and metadata for all output tensors, including aliased inputs.
|
||||
pub handle_outputs: Vec<HandleOutput<R>>,
|
||||
/// The rank across all tensors in the plan.
|
||||
///
|
||||
/// Smaller tensors are unsqueezed during launch.
|
||||
pub rank: usize,
|
||||
/// Detailed planning for each individual computation block within the fusion.
|
||||
pub blocks: Vec<BlockPlan<'a>>,
|
||||
/// Mapping of tensor IDs to their specific vectorization factors.
|
||||
pub vectorizations: BTreeMap<TensorId, Vect>,
|
||||
/// Tensors that can be cleared or deallocated after this plan executes.
|
||||
pub cleared: Vec<TensorId>,
|
||||
/// Metadata for shapes and strides passed from the host when they cannot be
|
||||
/// inferred from input tensors (e.g., complex deep fusions).
|
||||
pub runtime_layouts: Vec<RuntimeLayout>,
|
||||
}
|
||||
|
||||
/// Information regarding the execution of a specific block of operations within a fusion.
|
||||
#[derive(Debug)]
|
||||
pub struct BlockPlan<'a> {
|
||||
/// List of inputs that are candidates for in-place memory reuse within this block.
|
||||
pub potential_inplaces: Vec<PotentialInplace<'a>>,
|
||||
/// The input tensor chosen to define the iteration space, if any.
|
||||
pub potential_reference_input: Option<InputReference>,
|
||||
/// How the master layout is determined for this block.
|
||||
pub reference: ReferenceSelection,
|
||||
/// Mapping of tensor IDs to the read operations performed on them.
|
||||
pub reads: BTreeMap<TensorId, Vec<FuseOp>>,
|
||||
/// Mapping of tensor IDs to the write operations performed on them.
|
||||
pub writes: BTreeMap<TensorId, Vec<FuseOp>>,
|
||||
/// The width for the operations in this block.
|
||||
pub width: LineSize,
|
||||
}
|
||||
|
||||
/// Metadata for an input tensor being used as a reference for a block's layout.
|
||||
#[derive(Debug)]
|
||||
pub enum InputReference {
|
||||
/// Standard input at the specified position.
|
||||
Normal { input_pos: usize },
|
||||
/// Input that has an axis swapped.
|
||||
SwapDims {
|
||||
original_pos: usize,
|
||||
dims: (usize, usize),
|
||||
},
|
||||
/// Input that has been reshaped.
|
||||
Reshaped { reshape_pos: usize },
|
||||
}
|
||||
|
||||
/// Strategies for selecting the reference layout of a fused block.
|
||||
///
|
||||
/// The reference layout determines how global indices are mapped to tensor coordinates.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ReferenceSelection {
|
||||
/// The engine is still calculating the optimal reference.
|
||||
Searching,
|
||||
/// Layout from a normal tensor.
|
||||
Concrete {
|
||||
layout: FuseArg,
|
||||
shape: Shape,
|
||||
strides: Strides,
|
||||
},
|
||||
/// Layout from a swapped dim tensor.
|
||||
SwapDims {
|
||||
original: FuseArg,
|
||||
dims: (usize, usize),
|
||||
},
|
||||
/// Layout from a reshaped tensor.
|
||||
Reshaped { reshape_pos: usize },
|
||||
/// Layout that has the shape of an input, but not its strides.
|
||||
VirtualShape {
|
||||
original: FuseArg,
|
||||
shape: Shape,
|
||||
strides: Strides,
|
||||
},
|
||||
/// The layout is provided dynamically by the host at runtime.
|
||||
Runtime { pos: usize },
|
||||
}
|
||||
|
||||
impl<R: Runtime> LaunchPlan<'_, R> {
|
||||
/// Creates a new `LaunchPlan` from a slice of fusion blocks.
|
||||
///
|
||||
/// Initializes blocks with default "Searching" references and calculates
|
||||
/// the initial max rank.
|
||||
pub fn new(fuse_blocks: &[FuseBlock]) -> Self {
|
||||
let mut rank = 0;
|
||||
let mut blocks = Vec::with_capacity(fuse_blocks.len());
|
||||
|
||||
for b in fuse_blocks.iter() {
|
||||
rank = usize::max(b.shape_ref.len(), rank);
|
||||
let block = BlockPlan {
|
||||
reference: ReferenceSelection::Searching,
|
||||
reads: b.reads.clone(),
|
||||
writes: b.writes.clone(),
|
||||
width: 0,
|
||||
potential_inplaces: Vec::new(),
|
||||
potential_reference_input: None,
|
||||
};
|
||||
blocks.push(block);
|
||||
}
|
||||
|
||||
LaunchPlan {
|
||||
global_outputs: Vec::new(),
|
||||
handle_inputs: Vec::new(),
|
||||
handle_outputs: Vec::new(),
|
||||
rank,
|
||||
blocks,
|
||||
vectorizations: Default::default(),
|
||||
cleared: Default::default(),
|
||||
runtime_layouts: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Debugging information for aliased handles when `autotune-checks` is enabled.
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
#[derive(Debug)]
|
||||
pub struct HandleOutputAliasDebugInfo<R: Runtime> {
|
||||
pub handle: CubeFusionHandle<R>,
|
||||
pub relative_id: TensorId,
|
||||
pub global_shape: Shape,
|
||||
}
|
||||
|
||||
/// Represents the output of a fused kernel execution.
|
||||
#[derive(Debug)]
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum HandleOutput<R: Runtime> {
|
||||
/// An output that reuses the memory of an input tensor (In-place).
|
||||
Alias {
|
||||
/// Index of the input handle being aliased.
|
||||
input_pos: usize,
|
||||
/// Data type precision.
|
||||
precision: FuseType,
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
debug_info: HandleOutputAliasDebugInfo<R>,
|
||||
},
|
||||
/// An output that requires a newly allocated memory buffer.
|
||||
Owned {
|
||||
global_id: TensorId,
|
||||
relative_id: TensorId,
|
||||
precision: FuseType,
|
||||
handle: CubeFusionHandle<R>,
|
||||
global_shape: Shape,
|
||||
vectorization: LineSize,
|
||||
},
|
||||
}
|
||||
|
||||
/// A standard input handle with associated layout and vectorization metadata.
|
||||
#[derive(Debug)]
|
||||
pub struct NormalHandleInput<R: Runtime> {
|
||||
pub relative_id: TensorId,
|
||||
pub global_ir: TensorIr,
|
||||
pub precision: FuseType,
|
||||
pub handle: CubeFusionHandle<R>,
|
||||
pub line_size: LineSize,
|
||||
pub broadcated: bool,
|
||||
/// Stores the original strides of the handle for restoration during plan rollback.
|
||||
pub orig_strides: Strides,
|
||||
}
|
||||
|
||||
/// An input handle containing values for a quantized tensor.
|
||||
#[derive(Debug)]
|
||||
pub struct QuantValuesHandleInput<R: Runtime> {
|
||||
pub relative_id: TensorId,
|
||||
pub global_ir: TensorIr,
|
||||
pub precision: FuseType,
|
||||
pub handle: CubeFusionHandle<R>,
|
||||
pub line_size: LineSize,
|
||||
}
|
||||
|
||||
/// An input handle containing parameters (scales/offsets) for quantization.
|
||||
#[derive(Debug)]
|
||||
pub struct QuantParamsHandleInput<R: Runtime> {
|
||||
pub precision: FuseType,
|
||||
pub handle: CubeFusionHandle<R>,
|
||||
pub shape: Shape,
|
||||
}
|
||||
|
||||
/// Different types of inputs that can be passed to a fused kernel.
|
||||
#[derive(Debug)]
|
||||
pub enum HandleInput<R: Runtime> {
|
||||
Normal(NormalHandleInput<R>),
|
||||
QuantValues(QuantValuesHandleInput<R>),
|
||||
QuantParams(QuantParamsHandleInput<R>),
|
||||
}
|
||||
|
||||
impl<R: Runtime> HandleInput<R> {
|
||||
/// Returns a reference to the inner `NormalHandleInput` if the variant matches.
|
||||
pub fn as_normal(&self) -> Option<&NormalHandleInput<R>> {
|
||||
match self {
|
||||
HandleInput::Normal(normal) => Some(normal),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> NormalHandleInput<R> {
|
||||
/// Creates a new `NormalHandleInput` tracking original strides.
|
||||
pub fn new(
|
||||
tensor_global: TensorIr,
|
||||
tensor_relative: &TensorIr,
|
||||
precision: FuseType,
|
||||
mut handle: CubeFusionHandle<R>,
|
||||
mut strides: Strides,
|
||||
) -> Self {
|
||||
// Swap current handle strides with provided strides to track the original state for rollback.
|
||||
core::mem::swap(&mut handle.strides, &mut strides);
|
||||
Self {
|
||||
precision,
|
||||
handle,
|
||||
relative_id: tensor_relative.id,
|
||||
global_ir: tensor_global,
|
||||
line_size: 1,
|
||||
broadcated: false,
|
||||
orig_strides: strides,
|
||||
}
|
||||
}
|
||||
|
||||
/// Restores the handle's original strides and returns the handle.
|
||||
///
|
||||
/// Used when a plan is invalidated or needs to be rolled back.
|
||||
pub fn handle_rollback(mut self) -> CubeFusionHandle<R> {
|
||||
core::mem::swap(&mut self.handle.strides, &mut self.orig_strides);
|
||||
self.handle
|
||||
}
|
||||
}
|
||||
|
||||
/// A candidate for in-place optimization.
|
||||
#[derive(Debug)]
|
||||
pub struct PotentialInplace<'a> {
|
||||
/// Position of the input handle in the `handle_inputs` vector.
|
||||
pub input_pos: usize,
|
||||
/// Reference to the IR of the relative tensor.
|
||||
pub tensor_relative: &'a TensorIr,
|
||||
/// Current strides of the potential in-place candidate.
|
||||
pub strides: Strides,
|
||||
}
|
||||
|
||||
impl ReferenceSelection {
|
||||
pub fn is_found(&self) -> bool {
|
||||
!matches!(self, Self::Searching)
|
||||
}
|
||||
|
||||
pub fn compatible_strides_for_inplace(&self, strides_inplace: &[usize]) -> bool {
|
||||
match self {
|
||||
ReferenceSelection::Concrete { strides, .. } => &**strides == strides_inplace,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
use super::super::codegen::ir::{FuseBlockConfig, GlobalArgsLaunch};
|
||||
use crate::{
|
||||
CubeFusionHandle,
|
||||
engine::launch::{
|
||||
LaunchPlan,
|
||||
vectorization::{Vect, vectorization_default},
|
||||
},
|
||||
};
|
||||
use burn_fusion::stream::Context;
|
||||
use burn_ir::{TensorId, TensorIr};
|
||||
use cubecl::prelude::*;
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
|
||||
/// A trace runner is responsible for determining the vectorization factor as well as launching
|
||||
/// a kernel based on global [inputs](GlobalArgsLaunch) and [outputs](GlobalArgsLaunch)
|
||||
/// with provided [fuse block configs](FuseBlockConfig).
|
||||
pub trait TraceRunner<R: Runtime>: Vectorization<R> {
|
||||
/// The error that might happen while running the trace.
|
||||
type Error;
|
||||
|
||||
/// Run the trace with the given inputs and outputs.
|
||||
///
|
||||
/// There is one [fuse config](FuseBlockConfig) for each [block](super::block::FuseBlock) registered
|
||||
/// in the [optimization builder](burn_fusion::OptimizationBuilder).
|
||||
fn run<'a>(
|
||||
&'a self,
|
||||
client: &'a ComputeClient<R>,
|
||||
inputs: GlobalArgsLaunch<'a, R>,
|
||||
outputs: GlobalArgsLaunch<'a, R>,
|
||||
configs: &'a [FuseBlockConfig],
|
||||
) -> Result<(), Self::Error>;
|
||||
}
|
||||
|
||||
pub enum VectorizationHandle<'a, R: Runtime> {
|
||||
NormalInput(&'a CubeFusionHandle<R>, &'a TensorIr),
|
||||
QuantValues(&'a CubeFusionHandle<R>, &'a TensorIr),
|
||||
QuantParams,
|
||||
}
|
||||
|
||||
impl<'a, R: Runtime> VectorizationHandle<'a, R> {
|
||||
/// Returns if the current vectorization handle is from the given tensor id.
|
||||
pub fn is_from_tensor(&self, id: TensorId) -> bool {
|
||||
match self {
|
||||
VectorizationHandle::NormalInput(_, tensor_ir) => tensor_ir.id == id,
|
||||
VectorizationHandle::QuantValues(_, tensor_ir) => tensor_ir.id == id,
|
||||
VectorizationHandle::QuantParams => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct VectorizationAxis {
|
||||
axis: HashMap<TensorId, usize>,
|
||||
}
|
||||
|
||||
impl VectorizationAxis {
|
||||
pub fn get<F: FnOnce() -> usize>(&self, id: TensorId, default: F) -> usize {
|
||||
self.axis.get(&id).copied().unwrap_or_else(default)
|
||||
}
|
||||
pub fn insert(&mut self, id: TensorId, axis: usize) {
|
||||
self.axis.insert(id, axis);
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Vectorization<R: Runtime> {
|
||||
/// Returns the vectorization options.
|
||||
fn axis(&self, _plan: &LaunchPlan<'_, R>) -> VectorizationAxis {
|
||||
VectorizationAxis::default()
|
||||
}
|
||||
/// The vectorization factor for all inputs and outputs.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn vectorization<'a>(
|
||||
&self,
|
||||
_context: &Context<'_, CubeFusionHandle<R>>,
|
||||
vectorizations: &mut BTreeMap<TensorId, Vect>,
|
||||
inputs: impl Iterator<Item = VectorizationHandle<'a, R>>,
|
||||
outputs: impl Iterator<Item = &'a TensorIr>,
|
||||
reshaped: impl Iterator<Item = (&'a TensorIr, &'a TensorIr, bool)>,
|
||||
swapped: impl Iterator<Item = (&'a TensorIr, &'a TensorIr, bool, &'a (usize, usize))>,
|
||||
line_sizes: &[LineSize],
|
||||
max: LineSize,
|
||||
axis: VectorizationAxis,
|
||||
) {
|
||||
vectorization_default(
|
||||
vectorizations,
|
||||
inputs,
|
||||
outputs,
|
||||
reshaped,
|
||||
swapped,
|
||||
line_sizes,
|
||||
&Default::default(),
|
||||
max,
|
||||
&axis,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,439 @@
|
||||
use crate::{
|
||||
CubeFusionHandle,
|
||||
engine::launch::runner::{VectorizationAxis, VectorizationHandle},
|
||||
};
|
||||
use burn_fusion::stream::Context;
|
||||
use burn_ir::{TensorId, TensorIr};
|
||||
use cubecl::{Runtime, ir::LineSize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Vect {
|
||||
Broadcasted,
|
||||
Aligned(LineSize),
|
||||
}
|
||||
|
||||
impl Vect {
|
||||
pub fn line_size(&self) -> LineSize {
|
||||
match self {
|
||||
Vect::Broadcasted => 1,
|
||||
Vect::Aligned(val) => *val,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_broadcast(&self) -> bool {
|
||||
matches!(self, Vect::Broadcasted)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Serialize, Deserialize, Debug)]
|
||||
pub struct LineSizeOverrides {
|
||||
state: Option<BTreeMap<TensorId, Vec<LineSize>>>,
|
||||
default: Option<Vec<LineSize>>,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
impl LineSizeOverrides {
|
||||
pub fn overrides(&mut self, tensor_id: &TensorId, line_sizes: Vec<LineSize>) {
|
||||
let map = match &mut self.state {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
self.state = Some(BTreeMap::new());
|
||||
self.state.as_mut().unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
map.insert(*tensor_id, line_sizes);
|
||||
}
|
||||
pub fn overrides_default(&mut self, line_sizes: Vec<LineSize>) {
|
||||
self.default = Some(line_sizes);
|
||||
}
|
||||
|
||||
pub fn mapping<R: Runtime>(&self, context: &Context<'_, CubeFusionHandle<R>>) -> Self {
|
||||
match &self.state {
|
||||
Some(state) => {
|
||||
let mut state_new = BTreeMap::new();
|
||||
|
||||
for (k, v) in state.iter() {
|
||||
let global = context.tensors.get(k).unwrap();
|
||||
state_new.insert(global.id, v.clone());
|
||||
}
|
||||
|
||||
Self {
|
||||
state: Some(state_new),
|
||||
default: self.default.clone(),
|
||||
}
|
||||
}
|
||||
None => Self {
|
||||
state: None,
|
||||
default: self.default.clone(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tensor(&self, tensor_id: &TensorId) -> Option<&Vec<LineSize>> {
|
||||
let map = match &self.state {
|
||||
Some(val) => val,
|
||||
None => match &self.default {
|
||||
Some(val) => return Some(val),
|
||||
None => return None,
|
||||
},
|
||||
};
|
||||
|
||||
match map.get(tensor_id) {
|
||||
Some(val) => Some(val),
|
||||
None => match &self.default {
|
||||
Some(val) => Some(val),
|
||||
None => None,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn vectorization_default<'a, R: Runtime>(
|
||||
vectorizations: &mut BTreeMap<TensorId, Vect>,
|
||||
inputs: impl Iterator<Item = VectorizationHandle<'a, R>>,
|
||||
outputs: impl Iterator<Item = &'a TensorIr>,
|
||||
reshaped: impl Iterator<Item = (&'a TensorIr, &'a TensorIr, bool)>,
|
||||
swapped: impl Iterator<Item = (&'a TensorIr, &'a TensorIr, bool, &'a (usize, usize))>,
|
||||
line_sizes: &[LineSize],
|
||||
overrides: &LineSizeOverrides,
|
||||
max: LineSize,
|
||||
axis: &VectorizationAxis,
|
||||
) {
|
||||
let swapped: Vec<_> = swapped.collect();
|
||||
|
||||
for input in inputs {
|
||||
if let Some((s, o, mr, dims)) = swapped
|
||||
.iter()
|
||||
.find(|(_s, o, _mr, _dims)| input.is_from_tensor(o.id))
|
||||
{
|
||||
let (handle, id) = match input {
|
||||
VectorizationHandle::NormalInput(handle, tensor_ir) => (handle, &tensor_ir.id),
|
||||
VectorizationHandle::QuantValues(..) => panic!("Can't be swapped"),
|
||||
VectorizationHandle::QuantParams => panic!("Can't be swapped"),
|
||||
};
|
||||
let val = vectorization_swapped(
|
||||
handle,
|
||||
s,
|
||||
o,
|
||||
*mr,
|
||||
dims,
|
||||
max,
|
||||
axis,
|
||||
line_sizes,
|
||||
overrides.tensor(id),
|
||||
);
|
||||
multi_reads_vectorization_update(vectorizations, o.id, val);
|
||||
} else {
|
||||
match input {
|
||||
VectorizationHandle::NormalInput(handle, tensor_ir) => {
|
||||
let val = vectorization_input(
|
||||
handle,
|
||||
tensor_ir,
|
||||
axis,
|
||||
line_sizes,
|
||||
overrides.tensor(&tensor_ir.id),
|
||||
);
|
||||
vectorizations.insert(tensor_ir.id, val);
|
||||
}
|
||||
VectorizationHandle::QuantValues(handle, tensor_ir) => {
|
||||
let val = vectorization_input(
|
||||
handle,
|
||||
tensor_ir,
|
||||
axis,
|
||||
line_sizes,
|
||||
overrides.tensor(&tensor_ir.id),
|
||||
);
|
||||
let num_quants = match tensor_ir.dtype {
|
||||
burn_std::DType::QFloat(quant_scheme) => quant_scheme.num_quants(),
|
||||
_ => panic!(""),
|
||||
};
|
||||
let val = match val {
|
||||
Vect::Broadcasted => Vect::Aligned(1),
|
||||
Vect::Aligned(val) => Vect::Aligned(val.div_ceil(num_quants)),
|
||||
};
|
||||
vectorizations.insert(tensor_ir.id, val);
|
||||
}
|
||||
VectorizationHandle::QuantParams => {
|
||||
// Doesn't have vectorization for now.
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
for (reshaped, original, multi_reads) in reshaped {
|
||||
let val = vectorization_reshape(
|
||||
reshaped,
|
||||
original,
|
||||
multi_reads,
|
||||
axis,
|
||||
line_sizes,
|
||||
max,
|
||||
overrides.tensor(&original.id),
|
||||
);
|
||||
multi_reads_vectorization_update(vectorizations, original.id, val);
|
||||
}
|
||||
|
||||
for tensor in outputs {
|
||||
let val = vectorization_output(tensor, axis, line_sizes, max, overrides.tensor(&tensor.id));
|
||||
vectorizations.insert(tensor.id, val);
|
||||
}
|
||||
}
|
||||
|
||||
fn multi_reads_vectorization_update(
|
||||
vectorizations: &mut BTreeMap<TensorId, Vect>,
|
||||
original: TensorId,
|
||||
vect: Vect,
|
||||
) {
|
||||
if let Some(ori_vect) = vectorizations.get(&original).cloned() {
|
||||
match ori_vect {
|
||||
Vect::Broadcasted => {
|
||||
// keep the original as is.
|
||||
}
|
||||
Vect::Aligned(ori) => match vect {
|
||||
Vect::Broadcasted => {
|
||||
vectorizations.insert(original, Vect::Aligned(1));
|
||||
}
|
||||
Vect::Aligned(new) => {
|
||||
let val = if new != ori { 1 } else { new };
|
||||
vectorizations.insert(original, Vect::Aligned(val));
|
||||
}
|
||||
},
|
||||
};
|
||||
} else {
|
||||
vectorizations.insert(original, vect);
|
||||
}
|
||||
}
|
||||
|
||||
// The default version uses the last dimension as vectorization axis and assumes a
|
||||
// perpendicular contiguous line.
|
||||
fn vectorization_input<R: Runtime>(
|
||||
handle: &CubeFusionHandle<R>,
|
||||
desc: &TensorIr,
|
||||
axis: &VectorizationAxis,
|
||||
line_sizes: &[LineSize],
|
||||
overrides: Option<&Vec<LineSize>>,
|
||||
) -> Vect {
|
||||
let axis = axis.get(desc.id, || handle.strides.len() - 1);
|
||||
let shape_axis = desc.shape[axis];
|
||||
|
||||
if shape_axis == 1 {
|
||||
return Vect::Broadcasted;
|
||||
}
|
||||
|
||||
// Last dimension strides should be 1, otherwise vecX won't be contiguous.
|
||||
if handle.strides[axis] != 1 {
|
||||
return Vect::Aligned(1);
|
||||
}
|
||||
|
||||
let inner = |s: LineSize| {
|
||||
// The last dimension should be a multiple of the vector size or broadcated.
|
||||
if shape_axis.is_multiple_of(s) {
|
||||
return Some(Vect::Aligned(s));
|
||||
}
|
||||
None
|
||||
};
|
||||
|
||||
match overrides {
|
||||
Some(vals) => {
|
||||
for s in vals {
|
||||
if let Some(val) = inner(*s) {
|
||||
return val;
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
for s in line_sizes {
|
||||
if let Some(val) = inner(*s) {
|
||||
return val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Vect::Aligned(1)
|
||||
}
|
||||
|
||||
fn vectorization_output(
|
||||
desc: &TensorIr,
|
||||
axis: &VectorizationAxis,
|
||||
line_sizes: &[LineSize],
|
||||
max: LineSize,
|
||||
overrides: Option<&Vec<LineSize>>,
|
||||
) -> Vect {
|
||||
let axis = axis.get(desc.id, || desc.shape.rank() - 1);
|
||||
|
||||
let inner = |s: LineSize| {
|
||||
// The dimension should be a multiple of the vector size.
|
||||
if desc.shape[axis].is_multiple_of(s) && s <= max {
|
||||
return Some(Vect::Aligned(s));
|
||||
}
|
||||
|
||||
None
|
||||
};
|
||||
match overrides {
|
||||
Some(val) => {
|
||||
for s in val {
|
||||
if let Some(val) = inner(*s) {
|
||||
return val;
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
for s in line_sizes {
|
||||
if let Some(val) = inner(*s) {
|
||||
return val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Vect::Aligned(1)
|
||||
}
|
||||
|
||||
fn vectorization_reshape(
|
||||
reshaped: &TensorIr,
|
||||
original: &TensorIr,
|
||||
multi_reads: bool,
|
||||
axis: &VectorizationAxis,
|
||||
line_sizes: &[LineSize],
|
||||
max: LineSize,
|
||||
overrides: Option<&Vec<LineSize>>,
|
||||
) -> Vect {
|
||||
let axis = axis.get(reshaped.id, || reshaped.shape.rank() - 1);
|
||||
let reshape_shape_axis = reshaped.shape[axis];
|
||||
|
||||
if !multi_reads && reshape_shape_axis == 1 {
|
||||
return Vect::Broadcasted;
|
||||
}
|
||||
|
||||
// If the axis is not the last dim, didn't think of it, return Aligned(1) to be sure.
|
||||
if axis != reshaped.shape.rank() - 1 {
|
||||
return Vect::Aligned(1);
|
||||
}
|
||||
|
||||
let original_shape_axis = original.shape[original.shape.rank() - 1];
|
||||
|
||||
if original_shape_axis != reshape_shape_axis {
|
||||
return Vect::Aligned(1);
|
||||
}
|
||||
|
||||
let inner = |s: LineSize| {
|
||||
if !multi_reads {
|
||||
// The last dimension should be a multiple of the vector size or broadcated.
|
||||
if reshape_shape_axis.is_multiple_of(s) && s <= max {
|
||||
Some(Vect::Aligned(s))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
// Since the original tensor must share the same vectorization factor as the
|
||||
// reshaped tensor, they must have compatible shapes when both are access
|
||||
// independently.
|
||||
if reshape_shape_axis.is_multiple_of(s)
|
||||
&& original_shape_axis.is_multiple_of(s)
|
||||
&& s <= max
|
||||
{
|
||||
Some(Vect::Aligned(s))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match overrides {
|
||||
Some(val) => {
|
||||
for i in val {
|
||||
if let Some(vect) = inner(*i) {
|
||||
return vect;
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
for s in line_sizes {
|
||||
if let Some(vect) = inner(*s) {
|
||||
return vect;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Vect::Aligned(1)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn vectorization_swapped<R: Runtime>(
|
||||
handle: &CubeFusionHandle<R>,
|
||||
swapped: &TensorIr,
|
||||
original: &TensorIr,
|
||||
multi_reads: bool,
|
||||
dims: &(usize, usize),
|
||||
max: LineSize,
|
||||
axis: &VectorizationAxis,
|
||||
line_sizes: &[LineSize],
|
||||
overrides: Option<&Vec<LineSize>>,
|
||||
) -> Vect {
|
||||
let axis = axis.get(swapped.id, || swapped.shape.rank() - 1);
|
||||
|
||||
let swapped_axis = swapped.shape[axis];
|
||||
let shape_axis = original.shape[axis];
|
||||
|
||||
let axis_index = axis;
|
||||
let dim_index = if dims.0 == axis_index {
|
||||
dims.1
|
||||
} else if dims.1 == axis_index {
|
||||
dims.0
|
||||
} else {
|
||||
axis_index
|
||||
};
|
||||
|
||||
// Last dimension strides should be 1, otherwise vecX won't be contiguous.
|
||||
if multi_reads {
|
||||
if handle.strides[axis_index] != 1 {
|
||||
return Vect::Aligned(1);
|
||||
}
|
||||
if handle.strides[dim_index] != 1 {
|
||||
return Vect::Aligned(1);
|
||||
}
|
||||
} else if handle.strides[dim_index] != 1 {
|
||||
return Vect::Aligned(1);
|
||||
}
|
||||
|
||||
if !multi_reads && swapped_axis == 1 {
|
||||
return Vect::Broadcasted;
|
||||
}
|
||||
|
||||
let inner = |s: LineSize| {
|
||||
// The last dimension should be a multiple of the vector size or broadcated.
|
||||
if multi_reads {
|
||||
if swapped_axis.is_multiple_of(s) && s <= max {
|
||||
return Some(Vect::Aligned(s));
|
||||
}
|
||||
} else if swapped_axis.is_multiple_of(s) && shape_axis.is_multiple_of(s) && s <= max {
|
||||
return Some(Vect::Aligned(s));
|
||||
}
|
||||
None
|
||||
};
|
||||
|
||||
match overrides {
|
||||
Some(val) => {
|
||||
for s in val {
|
||||
if let Some(val) = inner(*s) {
|
||||
return val;
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
for s in line_sizes {
|
||||
if let Some(val) = inner(*s) {
|
||||
return val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Vect::Aligned(1)
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
mod base;
|
||||
mod planner;
|
||||
|
||||
pub use base::*;
|
||||
pub use planner::*;
|
||||
@@ -0,0 +1,438 @@
|
||||
use super::{
|
||||
super::{BlockPlan, HandleOutput, LaunchPlan},
|
||||
Vect,
|
||||
};
|
||||
use crate::{
|
||||
CubeFusionHandle,
|
||||
engine::{
|
||||
launch::{
|
||||
HandleInput,
|
||||
runner::{Vectorization, VectorizationHandle},
|
||||
},
|
||||
settings::VectorizationSetting,
|
||||
trace::{FuseResources, TensorView, block::FuseBlock},
|
||||
},
|
||||
};
|
||||
use burn_fusion::stream::Context;
|
||||
use burn_ir::TensorId;
|
||||
use cubecl::{
|
||||
Runtime,
|
||||
client::ComputeClient,
|
||||
ir::{ElemType, StorageType, UIntKind},
|
||||
};
|
||||
use cubecl::{
|
||||
ir::LineSize,
|
||||
quant::scheme::{QuantScheme, QuantStore, QuantValue},
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Select the best vectorization factor for each tensor handle.
|
||||
pub struct VectorizationPlanner<'a, R: Runtime> {
|
||||
resources: &'a FuseResources,
|
||||
blocks: &'a Vec<FuseBlock>,
|
||||
_r: PhantomData<R>,
|
||||
}
|
||||
|
||||
impl<'a, R: Runtime> VectorizationPlanner<'a, R> {
|
||||
pub fn new(resources: &'a FuseResources, blocks: &'a Vec<FuseBlock>) -> Self {
|
||||
Self {
|
||||
resources,
|
||||
blocks,
|
||||
_r: PhantomData,
|
||||
}
|
||||
}
|
||||
pub fn run<Runner: Vectorization<R>>(
|
||||
self,
|
||||
client: &ComputeClient<R>,
|
||||
runner: &Runner,
|
||||
context: &Context<'_, CubeFusionHandle<R>>,
|
||||
plan: &mut LaunchPlan<'a, R>,
|
||||
) {
|
||||
let has_multiple_read = |tensor: &TensorId| {
|
||||
let mut read_count = 0;
|
||||
for block in plan.blocks.iter() {
|
||||
read_count += block.reads.get(tensor).map(|a| a.len()).unwrap_or(0);
|
||||
}
|
||||
read_count > 1
|
||||
};
|
||||
let tensors_reshaped = self.resources.views.iter().filter_map(|view| match view {
|
||||
TensorView::Reshape {
|
||||
reshaped, original, ..
|
||||
} => Some((
|
||||
context.tensors.get(reshaped).unwrap(),
|
||||
context.tensors.get(original).unwrap(),
|
||||
has_multiple_read(original),
|
||||
)),
|
||||
TensorView::SwapDims { .. } => None,
|
||||
});
|
||||
let tensors_swapped = self.resources.views.iter().filter_map(|view| match view {
|
||||
TensorView::SwapDims {
|
||||
swapped,
|
||||
original,
|
||||
dims,
|
||||
..
|
||||
} => Some((
|
||||
context.tensors.get(swapped).unwrap(),
|
||||
context.tensors.get(original).unwrap(),
|
||||
has_multiple_read(original),
|
||||
dims,
|
||||
)),
|
||||
TensorView::Reshape { .. } => None,
|
||||
});
|
||||
|
||||
let mut ref_elem = (ElemType::UInt(UIntKind::U64).into(), 8);
|
||||
let mut quants_line_sizes: Option<Vec<LineSize>> = None;
|
||||
|
||||
for input in plan.handle_inputs.iter() {
|
||||
let elem: StorageType = match input {
|
||||
HandleInput::Normal(h) => h.global_ir.dtype.into(),
|
||||
HandleInput::QuantValues(handle) => match handle.global_ir.dtype {
|
||||
burn_std::DType::QFloat(scheme) => {
|
||||
line_sizes_quants(client, &mut quants_line_sizes, scheme);
|
||||
continue;
|
||||
}
|
||||
_ => panic!("Unable to retrieve the scheme for quantized values."),
|
||||
},
|
||||
HandleInput::QuantParams(..) => continue,
|
||||
};
|
||||
let elem_size = elem.size();
|
||||
|
||||
if ref_elem.1 >= elem_size {
|
||||
ref_elem = (elem, elem_size);
|
||||
}
|
||||
}
|
||||
for r in plan.global_outputs.iter() {
|
||||
let elem: StorageType = r.dtype.into();
|
||||
let elem_size = elem.size();
|
||||
|
||||
if ref_elem.1 >= elem_size {
|
||||
ref_elem = (elem, elem_size);
|
||||
}
|
||||
}
|
||||
|
||||
let filtered = plan
|
||||
.handle_inputs
|
||||
.iter()
|
||||
.map(|item| {
|
||||
item.as_normal()
|
||||
// Filter out indexed resources.
|
||||
.map(|item| !self.resources.indexed.contains_key(&item.relative_id))
|
||||
.unwrap_or(true)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let line_sizes = match quants_line_sizes {
|
||||
// Quantization normally triggers higher vectorization than anything else, no need to
|
||||
// compare to ref elem.
|
||||
Some(line_sizes) => line_sizes,
|
||||
None => client
|
||||
.io_optimized_line_sizes(ref_elem.0.size())
|
||||
.collect::<Vec<_>>(),
|
||||
};
|
||||
|
||||
let vectorization_axis = runner.axis(plan);
|
||||
|
||||
runner.vectorization(
|
||||
context,
|
||||
&mut plan.vectorizations,
|
||||
plan.handle_inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, item)| {
|
||||
if filtered[i] {
|
||||
Some(match item {
|
||||
HandleInput::Normal(h) => {
|
||||
VectorizationHandle::NormalInput(&h.handle, &h.global_ir)
|
||||
}
|
||||
HandleInput::QuantValues(h) => {
|
||||
VectorizationHandle::QuantValues(&h.handle, &h.global_ir)
|
||||
}
|
||||
HandleInput::QuantParams(_) => VectorizationHandle::QuantParams,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}),
|
||||
plan.global_outputs.iter(),
|
||||
tensors_reshaped,
|
||||
tensors_swapped,
|
||||
&line_sizes,
|
||||
u8::MAX as usize,
|
||||
vectorization_axis,
|
||||
);
|
||||
|
||||
for tensor in self.resources.indexed.keys() {
|
||||
let global = context.tensors.get(tensor).unwrap();
|
||||
plan.vectorizations.insert(global.id, Vect::Aligned(1));
|
||||
}
|
||||
|
||||
let mut block_vectorization = Vec::with_capacity(self.blocks.len());
|
||||
for _ in 0..self.blocks.len() {
|
||||
block_vectorization.push(Vec::new());
|
||||
}
|
||||
|
||||
for (input_pos, handle) in plan.handle_inputs.iter_mut().enumerate() {
|
||||
let (global_ir, relative_id) = match handle {
|
||||
HandleInput::Normal(h) => (&h.global_ir, &h.relative_id),
|
||||
HandleInput::QuantValues(h) => (&h.global_ir, &h.relative_id),
|
||||
HandleInput::QuantParams(_) => continue,
|
||||
};
|
||||
let (vect, br) = match plan.vectorizations.get(&global_ir.id) {
|
||||
Some(v) => (v.line_size(), v.is_broadcast()),
|
||||
None => panic!("No vectorization factor found for {:?}", global_ir.id),
|
||||
};
|
||||
|
||||
for (block_pos, block_plan) in plan.blocks.iter().enumerate() {
|
||||
if block_plan.reads.contains_key(relative_id) {
|
||||
block_vectorization[block_pos].push(BlockVectorization {
|
||||
action: VectorizationAction::Input(input_pos),
|
||||
potential: vect,
|
||||
broadcasted: br,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (output_pos, handle) in plan.handle_outputs.iter().enumerate() {
|
||||
if let HandleOutput::Owned {
|
||||
global_id,
|
||||
relative_id,
|
||||
..
|
||||
} = handle
|
||||
{
|
||||
for (block_pos, block_plan) in plan.blocks.iter().enumerate() {
|
||||
if block_plan.writes.contains_key(relative_id) {
|
||||
let vectorization = plan.vectorizations.get(global_id).unwrap().line_size();
|
||||
block_vectorization[block_pos].push(BlockVectorization {
|
||||
action: VectorizationAction::Output(output_pos),
|
||||
potential: vectorization,
|
||||
broadcasted: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut previous_widths = Vec::with_capacity(block_vectorization.len());
|
||||
|
||||
// Unhandled inputs might not get included in any fused blocks for now.
|
||||
//
|
||||
// So we ensure they are vectorized by setting their vectorization before we set the
|
||||
// vectorizations in blocks.
|
||||
//
|
||||
// Unhandled Outputs are correctly vectorized, so this is only necessary for inputs.
|
||||
for input in self.resources.inputs_unhandled.iter() {
|
||||
let pos = self
|
||||
.resources
|
||||
.inputs
|
||||
.get_index(*input)
|
||||
.unwrap_or_else(|| self.resources.inputs.get_index_quant(*input).unwrap());
|
||||
let input_global = context.tensors.get(input).unwrap();
|
||||
|
||||
match plan.vectorizations.get(&input_global.id).unwrap() {
|
||||
Vect::Aligned(vect) => {
|
||||
let handle = &mut plan.handle_inputs[pos];
|
||||
match handle {
|
||||
HandleInput::Normal(handle) => {
|
||||
handle.line_size = *vect;
|
||||
}
|
||||
HandleInput::QuantValues(handle) => {
|
||||
handle.line_size = *vect;
|
||||
}
|
||||
HandleInput::QuantParams(_) => {}
|
||||
}
|
||||
}
|
||||
Vect::Broadcasted => {}
|
||||
}
|
||||
}
|
||||
|
||||
for ((tmp, block_plan), block) in block_vectorization
|
||||
.into_iter()
|
||||
.zip(plan.blocks.iter_mut())
|
||||
.zip(self.blocks)
|
||||
{
|
||||
match block.settings.vectorization {
|
||||
VectorizationSetting::Activated => {
|
||||
apply_vectorization_block(
|
||||
tmp,
|
||||
&mut plan.handle_inputs,
|
||||
&mut plan.handle_outputs,
|
||||
block_plan,
|
||||
u8::MAX as usize,
|
||||
);
|
||||
}
|
||||
VectorizationSetting::SmallerOrEqualThanPreviousBlock { block_pos } => {
|
||||
apply_vectorization_block(
|
||||
tmp,
|
||||
&mut plan.handle_inputs,
|
||||
&mut plan.handle_outputs,
|
||||
block_plan,
|
||||
previous_widths[block_pos],
|
||||
);
|
||||
if block_plan.width == 0 {
|
||||
block_plan.width = previous_widths[block_pos];
|
||||
}
|
||||
}
|
||||
VectorizationSetting::EqualThanPreviousBlock { block_pos } => {
|
||||
apply_vectorization_block(
|
||||
tmp,
|
||||
&mut plan.handle_inputs,
|
||||
&mut plan.handle_outputs,
|
||||
block_plan,
|
||||
previous_widths[block_pos],
|
||||
);
|
||||
// Enforces the width.
|
||||
block_plan.width = previous_widths[block_pos];
|
||||
}
|
||||
VectorizationSetting::Deactivated => {
|
||||
apply_vectorization_block(
|
||||
tmp,
|
||||
&mut plan.handle_inputs,
|
||||
&mut plan.handle_outputs,
|
||||
block_plan,
|
||||
1,
|
||||
);
|
||||
block_plan.width = 1;
|
||||
}
|
||||
}
|
||||
|
||||
// When only virtual inputs/outputs are present for a block, we need to set a width.
|
||||
if block_plan.width == 0 {
|
||||
if let Some(w) = previous_widths.last() {
|
||||
block_plan.width = *w;
|
||||
} else {
|
||||
block_plan.width = 1;
|
||||
}
|
||||
}
|
||||
|
||||
previous_widths.push(block_plan.width);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum VectorizationAction {
|
||||
Input(usize),
|
||||
Output(usize),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct BlockVectorization {
|
||||
action: VectorizationAction,
|
||||
potential: LineSize,
|
||||
broadcasted: bool,
|
||||
}
|
||||
|
||||
fn apply_vectorization_block<R: Runtime>(
|
||||
block_vectorization: Vec<BlockVectorization>,
|
||||
inputs: &mut [HandleInput<R>],
|
||||
outputs: &mut [HandleOutput<R>],
|
||||
block_plan: &mut BlockPlan,
|
||||
max: LineSize,
|
||||
) {
|
||||
for item in block_vectorization {
|
||||
match item.action {
|
||||
VectorizationAction::Input(pos) => {
|
||||
let (vect, br) = if item.potential <= max {
|
||||
(item.potential, item.broadcasted)
|
||||
} else {
|
||||
(1, false)
|
||||
};
|
||||
|
||||
match &mut inputs[pos] {
|
||||
HandleInput::Normal(input) => {
|
||||
input.line_size = vect;
|
||||
input.broadcated = br;
|
||||
}
|
||||
HandleInput::QuantValues(input) => {
|
||||
input.line_size = vect;
|
||||
}
|
||||
HandleInput::QuantParams(_) => {
|
||||
// Not vectorized
|
||||
}
|
||||
}
|
||||
|
||||
if block_plan.width < vect {
|
||||
block_plan.width = vect;
|
||||
}
|
||||
}
|
||||
VectorizationAction::Output(pos) => {
|
||||
if let HandleOutput::Owned { vectorization, .. } = &mut outputs[pos] {
|
||||
let vect = if item.potential <= max {
|
||||
item.potential
|
||||
} else {
|
||||
1
|
||||
};
|
||||
*vectorization = vect;
|
||||
|
||||
if block_plan.width < vect {
|
||||
block_plan.width = vect;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn line_sizes_quants<R: Runtime>(
|
||||
client: &ComputeClient<R>,
|
||||
quants_line_sizes: &mut Option<Vec<LineSize>>,
|
||||
scheme: QuantScheme,
|
||||
) {
|
||||
match scheme.store {
|
||||
QuantStore::Native => match scheme.value {
|
||||
// Type sizes are the same so just treat fp8/fp4x2 as i8
|
||||
QuantValue::Q8F
|
||||
| QuantValue::Q8S
|
||||
| QuantValue::E4M3
|
||||
| QuantValue::E5M2
|
||||
| QuantValue::E2M1 => {
|
||||
let line_sizes = client
|
||||
.io_optimized_line_sizes(size_of::<i8>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
match &quants_line_sizes {
|
||||
Some(sizes) => {
|
||||
if sizes[0] < line_sizes[0] {
|
||||
*quants_line_sizes = Some(line_sizes);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
*quants_line_sizes = Some(line_sizes);
|
||||
}
|
||||
}
|
||||
}
|
||||
QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
|
||||
unreachable!("Can't store native sub-byte values")
|
||||
}
|
||||
},
|
||||
QuantStore::PackedU32(_) => {
|
||||
let mut line_sizes = client
|
||||
.io_optimized_line_sizes(size_of::<u32>())
|
||||
.collect::<Vec<_>>();
|
||||
for val in line_sizes.iter_mut() {
|
||||
*val *= scheme.num_quants();
|
||||
}
|
||||
|
||||
match &quants_line_sizes {
|
||||
Some(sizes) => {
|
||||
if sizes[0] < line_sizes[0] {
|
||||
let mut min = *line_sizes.last().unwrap();
|
||||
|
||||
while min > 1 {
|
||||
min /= 2;
|
||||
line_sizes.push(min);
|
||||
}
|
||||
*quants_line_sizes = Some(line_sizes);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
*quants_line_sizes = Some(line_sizes);
|
||||
}
|
||||
}
|
||||
}
|
||||
QuantStore::PackedNative(_) => {
|
||||
panic!("Not yet supported")
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
pub(crate) mod codegen;
|
||||
pub(crate) mod fuser;
|
||||
pub(crate) mod launch;
|
||||
pub(crate) mod settings;
|
||||
|
||||
pub mod trace;
|
||||
@@ -0,0 +1,59 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Controls which operations can be fused.
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
|
||||
pub struct FuseSettings {
|
||||
/// Enables broadcasting of shapes.
|
||||
pub broadcast: bool,
|
||||
/// Enables output shape updates.
|
||||
///
|
||||
/// When broadcast is enabled, the output shape can become bigger after a fusion,
|
||||
/// therefore an update is needed.
|
||||
pub output_shape_updates: bool,
|
||||
/// Enables the reuse of input buffers.
|
||||
pub inplace: bool,
|
||||
/// Whether vectorization is enabled.
|
||||
pub vectorization: VectorizationSetting,
|
||||
/// How [reference layout](super::ir::RefLayout) selection is done.
|
||||
pub ref_layout: RefLayoutSetting,
|
||||
}
|
||||
|
||||
impl Default for FuseSettings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
broadcast: true,
|
||||
output_shape_updates: true,
|
||||
inplace: true,
|
||||
vectorization: VectorizationSetting::Activated,
|
||||
ref_layout: RefLayoutSetting::Any,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
|
||||
/// How vectorization is handled during fusion.
|
||||
pub enum VectorizationSetting {
|
||||
/// The biggest line_size possible will be used.
|
||||
Activated,
|
||||
/// Equivalent to using line_size of one.
|
||||
Deactivated,
|
||||
/// This is a good setting when a block processes values calculated from a previous block.
|
||||
SmallerOrEqualThanPreviousBlock { block_pos: usize },
|
||||
/// This is a good setting when a block processes values calculated from a previous block.
|
||||
EqualThanPreviousBlock { block_pos: usize },
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
|
||||
/// Influence how the [reference layout](super::ir::RefLayout) selection is done.
|
||||
pub enum RefLayoutSetting {
|
||||
/// Any reference layout is allowed.
|
||||
Any,
|
||||
/// Only contiguous reference layout is allowed.
|
||||
///
|
||||
/// Note that forcing a contiguous reference layout might reduce the opportunity of inplace
|
||||
/// fusion.
|
||||
OnlyContiguous,
|
||||
SameAsBlock {
|
||||
block_pos: u32,
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,377 @@
|
||||
use crate::engine::{
|
||||
codegen::ir::{FuseArg, FuseType},
|
||||
trace::block::FuseBlock,
|
||||
};
|
||||
use burn_ir::{TensorId, TensorIr};
|
||||
use burn_std::{Shape, Strides};
|
||||
use cubecl::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::{BTreeMap, HashSet},
|
||||
marker::PhantomData,
|
||||
};
|
||||
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
use crate::CubeFusionHandle;
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
use burn_backend::TensorData;
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
/// A trace contains all [blocks](FuseBlock) and the [resources](FuseResources) used by the
|
||||
/// kernel.
|
||||
pub struct FuseTrace {
|
||||
pub blocks: Vec<FuseBlock>,
|
||||
pub resources: FuseResources,
|
||||
}
|
||||
|
||||
impl core::fmt::Display for FuseTrace {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
writeln!(f, "FuseTrace")?;
|
||||
for b in self.blocks.iter() {
|
||||
writeln!(f, " - Block shape={:?}", b.shape_ref)?;
|
||||
for (tensor, ops) in b.reads.iter() {
|
||||
for op in ops.iter() {
|
||||
writeln!(f, " - {op} <== {tensor}")?;
|
||||
}
|
||||
}
|
||||
for op in b.ops.iter() {
|
||||
writeln!(f, " - {op}")?;
|
||||
}
|
||||
for (tensor, ops) in b.writes.iter() {
|
||||
for op in ops.iter() {
|
||||
writeln!(f, " - {op} <== {tensor}")?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub enum TuneOutput<R: Runtime> {
|
||||
UnChecked(PhantomData<R>),
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
Checked {
|
||||
handles: HashMap<TensorId, (Vec<usize>, CubeFusionHandle<R>)>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<R: Runtime> TuneOutput<R> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn merge(self, other: Self) -> Self {
|
||||
let mut result = self;
|
||||
|
||||
match &mut result {
|
||||
TuneOutput::UnChecked(..) => {}
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
TuneOutput::Checked { handles } => match other {
|
||||
TuneOutput::UnChecked(..) => {}
|
||||
TuneOutput::Checked { handles: o } => {
|
||||
for (k, v) in o.into_iter() {
|
||||
handles.insert(k, v);
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> cubecl::tune::AutotuneOutput for TuneOutput<R> {
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
fn check_equivalence(&self, other: Self) {
|
||||
use burn_backend::Tolerance;
|
||||
use burn_std::DType;
|
||||
|
||||
if let (
|
||||
TuneOutput::Checked {
|
||||
handles: handles_ref,
|
||||
},
|
||||
TuneOutput::Checked { handles },
|
||||
) = (self, &other)
|
||||
{
|
||||
let mut num_checked = 0;
|
||||
let mut num_handles = 0;
|
||||
for (id, (shape, handle)) in handles_ref.iter() {
|
||||
num_handles += 1;
|
||||
if let Some((shape_other, other)) = handles.get(id) {
|
||||
use burn_std::is_contiguous;
|
||||
use cubecl::std::tensor::into_contiguous_ref;
|
||||
|
||||
let current_handle = if !is_contiguous(&shape, &handle.strides) {
|
||||
into_contiguous_ref::<R>(
|
||||
&handle.client,
|
||||
&handle.as_handle_ref(&shape),
|
||||
handle.dtype.into(),
|
||||
)
|
||||
.unwrap()
|
||||
.handle
|
||||
} else {
|
||||
handle.handle.clone()
|
||||
};
|
||||
let other_handle = if !is_contiguous(&shape, &other.strides) {
|
||||
into_contiguous_ref::<R>(
|
||||
&other.client,
|
||||
&other.as_handle_ref(&shape),
|
||||
other.dtype.into(),
|
||||
)
|
||||
.unwrap()
|
||||
.handle
|
||||
} else {
|
||||
other.handle.clone()
|
||||
};
|
||||
|
||||
let data_ref = handle.client.read_one(current_handle);
|
||||
let data_other = other.client.read_one(other_handle);
|
||||
let data_ref = TensorData::from_bytes(data_ref, shape.clone(), handle.dtype);
|
||||
let data_other =
|
||||
TensorData::from_bytes(data_other, shape_other.clone(), handle.dtype);
|
||||
|
||||
match handle.dtype {
|
||||
DType::F64 => {
|
||||
data_ref.assert_approx_eq::<f64>(&data_other, Tolerance::permissive())
|
||||
}
|
||||
DType::F32 => {
|
||||
data_ref.assert_approx_eq::<f32>(&data_other, Tolerance::permissive())
|
||||
}
|
||||
DType::F16 => data_ref
|
||||
.assert_approx_eq::<half::f16>(&data_other, Tolerance::permissive()),
|
||||
DType::BF16 => data_ref
|
||||
.assert_approx_eq::<half::bf16>(&data_other, Tolerance::permissive()),
|
||||
_ => data_ref.assert_eq(&data_other, true),
|
||||
}
|
||||
num_checked += 1;
|
||||
} else {
|
||||
// Debug info for the tests.
|
||||
println!("No tensor found for {id:?}=>{shape:?}");
|
||||
}
|
||||
}
|
||||
|
||||
// At least one check is needed per output when there is an output.
|
||||
//
|
||||
// Some optimizations might write more outputs than needed, so it might be fined if
|
||||
// the number of handles is different, but at least one is required.
|
||||
//
|
||||
// An optimization might not create outputs if its dead code detection is triggered,
|
||||
// therefore avoiding useless computation.
|
||||
if num_handles > 0 {
|
||||
assert!(num_checked >= 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, Default)]
|
||||
/// Declare all resources used by the kernel, and potentially multiple [blocks](FuseBlock).
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Each block can't contain their own resources, since they are shared between blocks. The
|
||||
/// vectorization factor of one input tensor must be the same for all blocks.
|
||||
pub struct FuseResources {
|
||||
pub outputs: RegisteredTensors,
|
||||
pub inputs: RegisteredTensors,
|
||||
pub scalars: Vec<(FuseType, u64)>,
|
||||
// TODO: Making put a map of global registers.
|
||||
pub views: Vec<TensorView>,
|
||||
pub indexed: BTreeMap<TensorId, FuseArg>,
|
||||
pub inputs_unhandled: Vec<TensorId>,
|
||||
pub outputs_unhandled: Vec<FuseArg>,
|
||||
pub num_reshaped: usize,
|
||||
/// Necessary to remove some entries from the context.
|
||||
pub dropped: HashSet<TensorId>,
|
||||
/// We know during fusion that we have to have those buffers has global.
|
||||
/// The pos here can be interpreted as GLOBAL pos where the output pos are locals.
|
||||
pub buffers: RegisteredTensors,
|
||||
/// Global registers available everywhere.
|
||||
///
|
||||
/// TODO: Not all registers should be globals.
|
||||
pub registers: BTreeMap<TensorId, FuseArg>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub struct RuntimeLayout {
|
||||
pub shape: Shape,
|
||||
pub strides: Strides,
|
||||
}
|
||||
|
||||
impl Default for RuntimeLayout {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
shape: Shape::new([]),
|
||||
strides: Strides::new(&[]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TraceError<Err> {
|
||||
ReferenceNotFound,
|
||||
RunnerError(Err),
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub enum TensorView {
|
||||
Reshape {
|
||||
reshaped: TensorId,
|
||||
original: TensorId,
|
||||
reshape_pos: usize,
|
||||
shape_relative: Shape,
|
||||
},
|
||||
SwapDims {
|
||||
swapped: TensorId,
|
||||
original: TensorId,
|
||||
dims: (usize, usize),
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Serialize, Deserialize, Debug)]
|
||||
pub struct RegisteredTensors {
|
||||
tensors: Vec<RegisterTensor>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub enum RegisterTensor {
|
||||
Normal(TensorIr, FuseType),
|
||||
QuantValues(TensorIr),
|
||||
QuantParams(TensorId),
|
||||
}
|
||||
|
||||
impl RegisterTensor {
|
||||
pub fn as_normal_tensor(&self) -> Option<(&TensorIr, &FuseType)> {
|
||||
match self {
|
||||
RegisterTensor::Normal(tensor_ir, precision) => Some((tensor_ir, precision)),
|
||||
RegisterTensor::QuantValues(_) => None,
|
||||
RegisterTensor::QuantParams(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RegisteredTensors {
|
||||
/// Iterate over all the registered tensors.
|
||||
pub fn iter(&self) -> impl Iterator<Item = &RegisterTensor> {
|
||||
self.tensors.iter()
|
||||
}
|
||||
|
||||
/// Consumes and iterate over all the registered tensors.
|
||||
pub fn into_iter(self) -> impl Iterator<Item = RegisterTensor> {
|
||||
self.tensors.into_iter()
|
||||
}
|
||||
|
||||
/// Returns the number of tensors registered.
|
||||
pub fn len(&self) -> usize {
|
||||
self.tensors.len()
|
||||
}
|
||||
|
||||
/// Retrieve the [tensor id](TensorId) at the given index.
|
||||
pub fn get_id(&self, index: usize) -> Option<TensorId> {
|
||||
self.tensors.get(index).map(|entry| match entry {
|
||||
RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id,
|
||||
RegisterTensor::QuantValues(tensor_ir) => tensor_ir.id,
|
||||
RegisterTensor::QuantParams(tensor_id) => *tensor_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// Doesn't return quantized tensor.
|
||||
pub fn get_index(&self, tensor_id: TensorId) -> Option<usize> {
|
||||
self.tensors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_pos, entry)| match entry {
|
||||
RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id == tensor_id,
|
||||
RegisterTensor::QuantValues(_) => false,
|
||||
RegisterTensor::QuantParams(_) => false,
|
||||
})
|
||||
.map(|(pos, _)| pos)
|
||||
}
|
||||
|
||||
/// Get the index of a quantized tensor.
|
||||
pub fn get_index_quant(&self, tensor_id: TensorId) -> Option<usize> {
|
||||
self.tensors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_pos, entry)| match entry {
|
||||
RegisterTensor::Normal(..) => false,
|
||||
RegisterTensor::QuantValues(tensor_ir) => tensor_ir.id == tensor_id,
|
||||
RegisterTensor::QuantParams(_) => false,
|
||||
})
|
||||
.map(|(pos, _)| pos)
|
||||
}
|
||||
|
||||
/// Doesn't return quantized tensor.
|
||||
pub fn get(&self, tensor_id: TensorId) -> Option<(&TensorIr, &FuseType)> {
|
||||
self.tensors
|
||||
.iter()
|
||||
.find(|entry| match entry {
|
||||
RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id == tensor_id,
|
||||
RegisterTensor::QuantValues(_) => false,
|
||||
RegisterTensor::QuantParams(_) => false,
|
||||
})
|
||||
.and_then(|entry| match entry {
|
||||
RegisterTensor::Normal(tensor_ir, fuse_precision) => {
|
||||
Some((tensor_ir, fuse_precision))
|
||||
}
|
||||
RegisterTensor::QuantValues(_) => None,
|
||||
RegisterTensor::QuantParams(_) => None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Insert a quantized tensor.
|
||||
///
|
||||
/// It will return the positions for both the value tensor and param tensor.
|
||||
pub fn insert_quant(&mut self, tensor: TensorIr) -> (usize, usize) {
|
||||
if let Some(old) = self.tensors.iter().enumerate().find(|(_, val)| match &val {
|
||||
RegisterTensor::QuantValues(tensor_ir) => tensor_ir == &tensor,
|
||||
_ => false,
|
||||
}) {
|
||||
let values = old.0;
|
||||
let params = values + 1;
|
||||
return (values, params);
|
||||
}
|
||||
|
||||
let params = RegisterTensor::QuantParams(tensor.id);
|
||||
let values = RegisterTensor::QuantValues(tensor);
|
||||
let pos_values = self.len();
|
||||
self.tensors.push(values);
|
||||
|
||||
let pos_params = self.len();
|
||||
self.tensors.push(params);
|
||||
|
||||
(pos_values, pos_params)
|
||||
}
|
||||
|
||||
/// Insert a normal tensor with the given [precision](FusePrecision) in the current block.
|
||||
pub fn insert(&mut self, precision: FuseType, tensor: TensorIr) -> usize {
|
||||
if let Some(old) = self.tensors.iter().enumerate().find(|(_, val)| match &val {
|
||||
RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id == tensor.id,
|
||||
_ => false,
|
||||
}) {
|
||||
return old.0;
|
||||
}
|
||||
|
||||
let value = RegisterTensor::Normal(tensor, precision);
|
||||
let pos = self.len();
|
||||
|
||||
self.tensors.push(value);
|
||||
|
||||
pos
|
||||
}
|
||||
|
||||
/// Update the already registered tensor with the given [tensor ir](TensorIr).
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This function only works with normal tensors, not quantized tensors.
|
||||
pub fn update(&mut self, tensor: &TensorIr) {
|
||||
if let Some(entry) = self.tensors.iter_mut().find(|entry| match entry {
|
||||
RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id == tensor.id,
|
||||
_ => false,
|
||||
}) && let RegisterTensor::Normal(tensor_ir, _) = entry
|
||||
{
|
||||
tensor_ir.status = tensor.status
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,555 @@
|
||||
use super::{FuseResources, RegisteredTensors, TensorView};
|
||||
use crate::engine::{
|
||||
codegen::ir::{FuseArg, FuseOp, FuseType, LayoutInfo, MultiBlockPos, UnaryFuseArgs},
|
||||
settings::FuseSettings,
|
||||
};
|
||||
use burn_ir::{TensorId, TensorIr, TensorStatus};
|
||||
use burn_std::{DType, Shape, quantization::QuantParam};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{BTreeMap, btree_map::Entry};
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
/// A block containing all [operations](FuseOp) as well as reads and writes for each tensor along
|
||||
/// with the [fusion settings](FuseSettings).
|
||||
pub struct FuseBlock {
|
||||
/// Contains the [fusion settings](FuseSettings) associated to the current block.
|
||||
pub settings: FuseSettings,
|
||||
/// Contains all the [operations](FuseOp) registered in the current block.
|
||||
pub ops: Vec<FuseOp>,
|
||||
/// The reference shape of the current block.
|
||||
pub shape_ref: Shape,
|
||||
/// Contains all tensor inputs of the current block except for manually handled tensors.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Some reads might not have read operations registered, such as dequantization, but it's
|
||||
/// important to be registered here for vectorization. Input tensors that are not
|
||||
/// registered here must be vectorized manually.
|
||||
pub reads: BTreeMap<TensorId, Vec<FuseOp>>,
|
||||
/// Contains all tensor outputs of the current block except for manually handled tensors.
|
||||
/// We can have multiple writes when the same variable is reused after in another block.
|
||||
pub writes: BTreeMap<TensorId, Vec<FuseOp>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
/// It is responsible to build a [trace](FuseBlock).
|
||||
pub struct FuseBlockBuilder {
|
||||
pub settings: FuseSettings,
|
||||
locals: LocalVariablePool,
|
||||
pub ops: Vec<FuseOp>,
|
||||
reads: BTreeMap<TensorId, Vec<FuseOp>>,
|
||||
// Only for global registers.
|
||||
writes: BTreeMap<TensorId, Vec<FuseOp>>,
|
||||
bool_precision: FuseType,
|
||||
// Output declared in this block alone.
|
||||
outputs: RegisteredTensors,
|
||||
pub outputs_unhandled: Vec<FuseArg>,
|
||||
pub local_inputs: BTreeMap<TensorId, FuseArg>,
|
||||
/// The reference shape used by this block.
|
||||
pub shape_ref: Shape,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// How a quantized input can be read.
|
||||
pub enum QuantInput {
|
||||
/// If already dequantized, we cache the dequantization and returns the local variable
|
||||
/// corresponding to the float value.
|
||||
AlreadyDequantized { local: FuseArg },
|
||||
/// Otherwise we return the information necessary to dequantize the tensor.
|
||||
Quantized { values: FuseArg, params: FuseArg },
|
||||
}
|
||||
|
||||
impl FuseBlockBuilder {
|
||||
pub fn new(bool_precision: FuseType, settings: FuseSettings) -> Self {
|
||||
Self {
|
||||
bool_precision,
|
||||
settings,
|
||||
locals: Default::default(),
|
||||
ops: Default::default(),
|
||||
reads: Default::default(),
|
||||
writes: Default::default(),
|
||||
outputs: Default::default(),
|
||||
outputs_unhandled: Default::default(),
|
||||
local_inputs: Default::default(),
|
||||
shape_ref: Shape::new([]),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register an output tensor.
|
||||
pub fn output(&mut self, tensor: &TensorIr, resources: &mut FuseResources) -> Option<FuseArg> {
|
||||
if resources.indexed.contains_key(&tensor.id) {
|
||||
return None;
|
||||
}
|
||||
if matches!(tensor.dtype, DType::QFloat(..)) {
|
||||
return None;
|
||||
}
|
||||
let precision = tensor.dtype.into();
|
||||
|
||||
// Bool tensors are encoded as bool_precision.
|
||||
let precision_output = match precision {
|
||||
FuseType::Bool => self.bool_precision,
|
||||
_ => precision,
|
||||
};
|
||||
|
||||
let out = match self.locals.get(precision, tensor.id) {
|
||||
Some(local) => local,
|
||||
None => {
|
||||
let out = self.locals.create(precision, tensor.id);
|
||||
|
||||
self.outputs.insert(precision_output, tensor.clone());
|
||||
resources.outputs.insert(precision_output, tensor.clone());
|
||||
|
||||
out
|
||||
}
|
||||
};
|
||||
|
||||
Some(out)
|
||||
}
|
||||
|
||||
/// Register an input tensor.
|
||||
pub fn multi_block_variable(
|
||||
&mut self,
|
||||
block_pos: usize,
|
||||
tensor: &TensorIr,
|
||||
global: bool,
|
||||
) -> Option<FuseArg> {
|
||||
let precision = tensor.dtype.into();
|
||||
|
||||
if let Some(val) = self.local_inputs.get(&tensor.id) {
|
||||
return Some(val.clone());
|
||||
}
|
||||
|
||||
let val = match self.locals.get(precision, tensor.id) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let arg = if global {
|
||||
FuseArg::MultiBlockGlobal(
|
||||
MultiBlockPos {
|
||||
block_pos,
|
||||
block_local_pos: self.writes.len(),
|
||||
},
|
||||
val.precision(),
|
||||
)
|
||||
} else {
|
||||
FuseArg::MultiBlockLocal(
|
||||
MultiBlockPos {
|
||||
block_pos,
|
||||
block_local_pos: self.writes.len(),
|
||||
},
|
||||
val.precision(),
|
||||
)
|
||||
};
|
||||
|
||||
let ops = match self.writes.get_mut(&tensor.id) {
|
||||
Some(ops) => ops,
|
||||
None => {
|
||||
self.writes.insert(tensor.id, Vec::new());
|
||||
self.writes.get_mut(&tensor.id).unwrap()
|
||||
}
|
||||
};
|
||||
ops.push(FuseOp::Assign(UnaryFuseArgs {
|
||||
input: val,
|
||||
out: arg.clone(),
|
||||
}));
|
||||
|
||||
Some(arg)
|
||||
}
|
||||
|
||||
/// Register an input tensor.
|
||||
pub fn input(&mut self, tensor: &TensorIr, resources: &mut FuseResources) -> Option<FuseArg> {
|
||||
if resources.indexed.contains_key(&tensor.id) {
|
||||
return None;
|
||||
}
|
||||
|
||||
if matches!(tensor.dtype, DType::QFloat(..)) {
|
||||
return None;
|
||||
}
|
||||
let precision = tensor.dtype.into();
|
||||
|
||||
// Bool tensors are encoded as bool_precision.
|
||||
let precision_input = match precision {
|
||||
FuseType::Bool => self.bool_precision,
|
||||
_ => precision,
|
||||
};
|
||||
|
||||
if let Some(val) = self.local_inputs.get(&tensor.id) {
|
||||
return Some(val.clone());
|
||||
}
|
||||
|
||||
let arg = match self.locals.get(precision, tensor.id) {
|
||||
Some(local) => {
|
||||
resources.inputs.update(tensor);
|
||||
|
||||
local
|
||||
}
|
||||
None => {
|
||||
let input = if resources.outputs.get_index(tensor.id).is_some() {
|
||||
if let Some(val) = resources.registers.get(&tensor.id) {
|
||||
return Some(val.clone());
|
||||
};
|
||||
|
||||
let pos = resources.buffers.insert(precision, tensor.clone());
|
||||
FuseArg::Output(pos, precision_input, LayoutInfo::Unknown)
|
||||
} else {
|
||||
let pos = resources.inputs.insert(precision_input, tensor.clone());
|
||||
FuseArg::Input(pos, precision_input, LayoutInfo::Unknown)
|
||||
};
|
||||
|
||||
let out = self.locals.create(precision, tensor.id);
|
||||
|
||||
let reads = if let Entry::Vacant(e) = self.reads.entry(tensor.id) {
|
||||
e.insert(Vec::with_capacity(1));
|
||||
self.reads.get_mut(&tensor.id).unwrap()
|
||||
} else {
|
||||
self.reads.get_mut(&tensor.id).unwrap()
|
||||
};
|
||||
|
||||
reads.push(FuseOp::Assign(UnaryFuseArgs {
|
||||
input,
|
||||
out: out.clone(),
|
||||
}));
|
||||
|
||||
out
|
||||
}
|
||||
};
|
||||
|
||||
Some(arg)
|
||||
}
|
||||
|
||||
/// Register an input quantized tensor.
|
||||
pub fn input_quant(
|
||||
&mut self,
|
||||
tensor: &TensorIr,
|
||||
resources: &mut FuseResources,
|
||||
) -> Option<QuantInput> {
|
||||
if resources.indexed.contains_key(&tensor.id) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let precision = tensor.dtype.into();
|
||||
let precision_scales = match tensor.dtype {
|
||||
DType::QFloat(scheme) => match scheme.param {
|
||||
QuantParam::F32 => FuseType::F32,
|
||||
QuantParam::F16 => FuseType::F16,
|
||||
QuantParam::BF16 => FuseType::BF16,
|
||||
QuantParam::UE8M0 | QuantParam::UE4M3 => {
|
||||
unimplemented!("Unsupported fuse precision");
|
||||
}
|
||||
},
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let arg = match self.locals.get(precision, tensor.id) {
|
||||
Some(local) => {
|
||||
resources.inputs.update(tensor);
|
||||
QuantInput::AlreadyDequantized { local }
|
||||
}
|
||||
None => {
|
||||
let (new_input, q_index) = resources.inputs.insert_quant(tensor.clone());
|
||||
let input = FuseArg::Input(new_input, precision, LayoutInfo::Unknown);
|
||||
let scales = FuseArg::Input(q_index, precision_scales, LayoutInfo::Unknown);
|
||||
|
||||
// Important to flag that there is a read, even if no operation is registered.
|
||||
if let Entry::Vacant(e) = self.reads.entry(tensor.id) {
|
||||
e.insert(Vec::new());
|
||||
};
|
||||
|
||||
QuantInput::Quantized {
|
||||
values: input,
|
||||
params: scales,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Some(arg)
|
||||
}
|
||||
|
||||
/// Register an input with swapped dims.
|
||||
pub fn input_swap_dims(
|
||||
&mut self,
|
||||
tensor: &TensorIr,
|
||||
output: &TensorIr,
|
||||
dims: (usize, usize),
|
||||
resources: &mut FuseResources,
|
||||
) -> Option<FuseArg> {
|
||||
if matches!(tensor.dtype, DType::QFloat(..)) {
|
||||
return None;
|
||||
}
|
||||
let precision = tensor.dtype.into();
|
||||
|
||||
// Bool tensors are encoded as bool_precision.
|
||||
let precision_input = match precision {
|
||||
FuseType::Bool => self.bool_precision,
|
||||
_ => precision,
|
||||
};
|
||||
|
||||
let input_index = match self.locals.get(precision, tensor.id) {
|
||||
Some(_) => {
|
||||
// Can't fused an already fused input.
|
||||
if resources.outputs.get(tensor.id).is_some() {
|
||||
return None;
|
||||
}
|
||||
|
||||
match resources.inputs.get_index(tensor.id) {
|
||||
Some(index) => {
|
||||
resources.inputs.update(tensor);
|
||||
index
|
||||
}
|
||||
None => {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
None => resources.inputs.insert(precision_input, tensor.clone()),
|
||||
};
|
||||
|
||||
let out = self.output(output, resources)?;
|
||||
let original = FuseArg::Input(input_index, precision_input, LayoutInfo::Unknown);
|
||||
|
||||
let broadcasted = output.shape[output.shape.rank() - 1] == 0;
|
||||
|
||||
resources.views.push(TensorView::SwapDims {
|
||||
swapped: output.id,
|
||||
original: tensor.id,
|
||||
dims,
|
||||
});
|
||||
|
||||
let input = FuseArg::InputSwapDims {
|
||||
original: Box::new(original),
|
||||
dims,
|
||||
broadcasted,
|
||||
};
|
||||
|
||||
let reads = if let Entry::Vacant(e) = self.reads.entry(tensor.id) {
|
||||
e.insert(Vec::with_capacity(1));
|
||||
self.reads.get_mut(&tensor.id).unwrap()
|
||||
} else {
|
||||
self.reads.get_mut(&tensor.id).unwrap()
|
||||
};
|
||||
|
||||
reads.push(FuseOp::Assign(UnaryFuseArgs {
|
||||
input,
|
||||
out: out.clone(),
|
||||
}));
|
||||
|
||||
Some(out)
|
||||
}
|
||||
|
||||
/// Register an input that is reshaped.
|
||||
pub fn input_reshaped(
|
||||
&mut self,
|
||||
tensor: &TensorIr,
|
||||
output: &TensorIr,
|
||||
resources: &mut FuseResources,
|
||||
) -> Option<FuseArg> {
|
||||
if matches!(tensor.dtype, DType::QFloat(..)) {
|
||||
return None;
|
||||
}
|
||||
let precision = tensor.dtype.into();
|
||||
|
||||
// Bool tensors are encoded as bool_precision.
|
||||
let precision_input = match precision {
|
||||
FuseType::Bool => self.bool_precision,
|
||||
_ => precision,
|
||||
};
|
||||
|
||||
let input_index = match self.locals.get(precision, tensor.id) {
|
||||
Some(_) => {
|
||||
// Can't fused an already fused input.
|
||||
if resources.outputs.get(tensor.id).is_some() {
|
||||
return None;
|
||||
}
|
||||
|
||||
match resources.inputs.get_index(tensor.id) {
|
||||
Some(index) => {
|
||||
resources.inputs.update(tensor);
|
||||
index
|
||||
}
|
||||
None => {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
None => resources.inputs.insert(precision_input, tensor.clone()),
|
||||
};
|
||||
|
||||
let out = self.output(output, resources)?;
|
||||
let original = FuseArg::Input(input_index, precision_input, LayoutInfo::Unknown);
|
||||
|
||||
let mut shape = Vec::new();
|
||||
|
||||
let index = resources.num_reshaped;
|
||||
resources.num_reshaped += 1;
|
||||
|
||||
let rank = output.shape.rank();
|
||||
|
||||
for i in 0..output.shape.rank() {
|
||||
let id = index * rank + i;
|
||||
shape.push(FuseArg::ScalarShape(id));
|
||||
}
|
||||
|
||||
resources.views.push(TensorView::Reshape {
|
||||
reshaped: output.id,
|
||||
original: tensor.id,
|
||||
reshape_pos: index,
|
||||
shape_relative: output.shape.clone(),
|
||||
});
|
||||
|
||||
let input = FuseArg::InputReshaped {
|
||||
original: Box::new(original),
|
||||
shape,
|
||||
broadcasted: output.shape[rank - 1] == 0,
|
||||
};
|
||||
|
||||
let reads = if let Entry::Vacant(e) = self.reads.entry(tensor.id) {
|
||||
e.insert(Vec::with_capacity(1));
|
||||
self.reads.get_mut(&tensor.id).unwrap()
|
||||
} else {
|
||||
self.reads.get_mut(&tensor.id).unwrap()
|
||||
};
|
||||
|
||||
reads.push(FuseOp::Assign(UnaryFuseArgs {
|
||||
input,
|
||||
out: out.clone(),
|
||||
}));
|
||||
|
||||
Some(out)
|
||||
}
|
||||
|
||||
/// Build into a fuse block.
|
||||
pub fn build(
|
||||
&self,
|
||||
resources: &FuseResources,
|
||||
outputs: &mut RegisteredTensors,
|
||||
buffers: &mut Vec<TensorId>,
|
||||
) -> FuseBlock {
|
||||
let ops = self.ops.clone();
|
||||
let reads = self.reads.clone();
|
||||
let tensor_writes = self.tensor_writes(resources, buffers);
|
||||
|
||||
let mut writes = self.writes.clone();
|
||||
|
||||
for (tensor, precision) in tensor_writes
|
||||
.iter()
|
||||
.filter_map(|entry| entry.as_normal_tensor())
|
||||
{
|
||||
if let Some(local) = self.locals.get_any_precision(tensor.id) {
|
||||
let out_index = outputs.insert(*precision, tensor.clone());
|
||||
|
||||
let ops = match writes.get_mut(&tensor.id) {
|
||||
Some(ops) => ops,
|
||||
None => {
|
||||
writes.insert(tensor.id, Vec::new());
|
||||
writes.get_mut(&tensor.id).unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
ops.push(FuseOp::Assign(UnaryFuseArgs {
|
||||
input: local,
|
||||
out: FuseArg::Output(out_index, *precision, LayoutInfo::Unknown),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
FuseBlock {
|
||||
settings: self.settings,
|
||||
ops,
|
||||
shape_ref: self.shape_ref.clone(),
|
||||
reads,
|
||||
writes,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the tensor that needs to be written to.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The buffers vector passed as input is only to track the intermediary buffer writes needed
|
||||
/// during execution.
|
||||
pub fn tensor_writes(
|
||||
&self,
|
||||
resources: &FuseResources,
|
||||
buffers: &mut Vec<TensorId>,
|
||||
) -> RegisteredTensors {
|
||||
let mut result = RegisteredTensors::default();
|
||||
|
||||
// All tensors where their latest representation is not read write should be written to since they
|
||||
// are going to be used after the fused kernel by other operations.
|
||||
for output in self.outputs.iter() {
|
||||
if let Some((tensor, _precision)) = output.as_normal_tensor() {
|
||||
// We get the latest representation from the resources, not just this block.
|
||||
if let Some((tensor, precision)) = resources.outputs.get(tensor.id) {
|
||||
if !matches!(tensor.status, TensorStatus::ReadWrite) {
|
||||
result.insert(*precision, tensor.clone());
|
||||
} else if resources.buffers.get(tensor.id).is_some()
|
||||
&& !buffers.contains(&tensor.id)
|
||||
{
|
||||
result.insert(*precision, tensor.clone());
|
||||
// We make sure we don't write multiple time in the same buffer, only the
|
||||
// earliest possible.
|
||||
buffers.push(tensor.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug)]
|
||||
pub struct LocalVariablePool {
|
||||
values: BTreeMap<FuseType, BTreeMap<TensorId, usize>>,
|
||||
}
|
||||
|
||||
impl LocalVariablePool {
|
||||
fn get(&self, precision: FuseType, tensor_id: TensorId) -> Option<FuseArg> {
|
||||
if let Some(indexes) = self.values.get(&precision)
|
||||
&& let Some(index) = indexes.get(&tensor_id)
|
||||
{
|
||||
return Some(FuseArg::BlockLocal {
|
||||
pos: *index,
|
||||
ty: precision,
|
||||
});
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn get_any_precision(&self, tensor_id: TensorId) -> Option<FuseArg> {
|
||||
for (precision, indexes) in self.values.iter() {
|
||||
if let Some(index) = indexes.get(&tensor_id) {
|
||||
return Some(FuseArg::BlockLocal {
|
||||
pos: *index,
|
||||
ty: *precision,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn create(&mut self, precision: FuseType, tensor_id: TensorId) -> FuseArg {
|
||||
if let Some(indexes) = self.values.get_mut(&precision) {
|
||||
let new_index = indexes.len();
|
||||
indexes.insert(tensor_id, new_index);
|
||||
return FuseArg::BlockLocal {
|
||||
pos: new_index,
|
||||
ty: precision,
|
||||
};
|
||||
}
|
||||
|
||||
let new_index = 0;
|
||||
self.values
|
||||
.insert(precision, BTreeMap::from_iter([(tensor_id, new_index)]));
|
||||
|
||||
FuseArg::BlockLocal {
|
||||
pos: new_index,
|
||||
ty: precision,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,330 @@
|
||||
use super::{
|
||||
super::{
|
||||
codegen::ir::{FuseArg, FuseOp, FuseType, LayoutInfo},
|
||||
settings::FuseSettings,
|
||||
},
|
||||
FuseResources,
|
||||
block::FuseBlockBuilder,
|
||||
};
|
||||
use super::{FuseTrace, RegisteredTensors};
|
||||
use crate::engine::trace::block::QuantInput;
|
||||
use burn_fusion::stream::ScalarId;
|
||||
use burn_ir::{ScalarIr, TensorIr};
|
||||
use burn_std::{DType, Shape};
|
||||
use cubecl::quant::scheme::QuantParam;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
/// It is responsible to create a [trace](FuseTrace) composed of multiple [blocks](super::block::FuseBlock).
|
||||
///
|
||||
/// It mostly handles the [resources](KernelResources) needed by the generated fused kernel, and
|
||||
/// delegates most of the work to the [block builder](FuseBlockBuilder).
|
||||
pub struct TraceFuser {
|
||||
settings: FuseSettings,
|
||||
pub bool_precision: FuseType,
|
||||
// The tensors returned by the block that don't need to be written to global memory.
|
||||
block_current: FuseBlockBuilder,
|
||||
blocks_previous: Vec<FuseBlockBuilder>,
|
||||
resources: FuseResources,
|
||||
}
|
||||
|
||||
impl TraceFuser {
|
||||
/// Create a new trace builder with the given bool precision and [fuse settings](FuseSettings).
|
||||
pub fn new(bool_precision: FuseType, settings: FuseSettings) -> Self {
|
||||
Self {
|
||||
settings,
|
||||
bool_precision,
|
||||
block_current: FuseBlockBuilder::new(bool_precision, settings),
|
||||
blocks_previous: Default::default(),
|
||||
resources: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of blocks that are closed.
|
||||
pub fn num_previous_blocks(&self) -> usize {
|
||||
self.blocks_previous.len()
|
||||
}
|
||||
|
||||
/// Tag a tensor as dropped.
|
||||
pub fn fuse_dropped(&mut self, tensor: &TensorIr) {
|
||||
self.resources.outputs.update(tensor);
|
||||
self.resources.inputs.update(tensor);
|
||||
self.resources.dropped.insert(tensor.id);
|
||||
}
|
||||
|
||||
/// Register an operation.
|
||||
pub fn fuse_operation(&mut self, op: FuseOp) {
|
||||
self.block_current.ops.push(op);
|
||||
}
|
||||
|
||||
/// The number of operations fused.
|
||||
pub fn num_ops_fused(&self) -> u32 {
|
||||
let mut num_ops_fused = 0;
|
||||
|
||||
for block in self.blocks_previous.iter() {
|
||||
num_ops_fused += block.ops.len();
|
||||
}
|
||||
|
||||
num_ops_fused += self.block_current.ops.len();
|
||||
num_ops_fused as u32
|
||||
}
|
||||
|
||||
/// Close the current block with the given reference shape and creates a new one with new [fusion settings](FuseSettings).
|
||||
pub fn next_block(&mut self, shape_ref: Shape, settings: FuseSettings) {
|
||||
let mut block_new = FuseBlockBuilder::new(self.bool_precision, settings);
|
||||
core::mem::swap(&mut self.block_current, &mut block_new);
|
||||
block_new.shape_ref = shape_ref;
|
||||
self.blocks_previous.push(block_new);
|
||||
self.settings = settings;
|
||||
}
|
||||
|
||||
// Estimate how many bindings are in use right now. This can return more than the actual number
|
||||
// but should never return less.
|
||||
pub fn estimate_bindings(&self) -> u32 {
|
||||
let mut buffers = Vec::new();
|
||||
let mut estimation = 1; // Metadata takes one.
|
||||
|
||||
// We assume we are not going to write multiple times in the same output buffer.
|
||||
for b in self.blocks_previous.iter() {
|
||||
estimation += b.tensor_writes(&self.resources, &mut buffers).len() as u32;
|
||||
}
|
||||
|
||||
estimation += self
|
||||
.block_current
|
||||
.tensor_writes(&self.resources, &mut buffers)
|
||||
.len() as u32;
|
||||
estimation += self.resources.inputs.len() as u32;
|
||||
// One buffer per scalar type for now.
|
||||
estimation += self.resources.scalars.len() as u32;
|
||||
|
||||
estimation
|
||||
}
|
||||
|
||||
/// Tag the [tensor](TensorIr) as received from a previous block.
|
||||
///
|
||||
/// This will avoid reading the input again and instead use le local version when possible.
|
||||
pub fn block_local_input(
|
||||
&mut self,
|
||||
tensor: &TensorIr,
|
||||
block_pos: usize,
|
||||
global: bool,
|
||||
) -> FuseArg {
|
||||
let block = &mut self.blocks_previous[block_pos];
|
||||
|
||||
let src_arg = match block.multi_block_variable(block_pos, tensor, global) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
// We try to read the input if not present.
|
||||
block.input(tensor, &mut self.resources);
|
||||
block
|
||||
.multi_block_variable(block_pos, tensor, global)
|
||||
.unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
self.resources.outputs.update(tensor);
|
||||
|
||||
if global {
|
||||
self.resources.registers.insert(tensor.id, src_arg.clone());
|
||||
}
|
||||
|
||||
self.block_current
|
||||
.local_inputs
|
||||
.insert(tensor.id, src_arg.clone());
|
||||
src_arg
|
||||
}
|
||||
|
||||
/// Register an output tensor that won't be automatically synced into global memory.
|
||||
///
|
||||
/// It is therefore the responsibility of the operation to write the result to given tensor.
|
||||
pub fn output_unhandled(&mut self, tensor: &TensorIr) -> FuseArg {
|
||||
let arg = self
|
||||
.output(tensor)
|
||||
.expect("Can't add a new output that is already used in an index operation");
|
||||
|
||||
self.resources.outputs_unhandled.push(arg.clone());
|
||||
self.block_current.outputs_unhandled.push(arg.clone());
|
||||
arg
|
||||
}
|
||||
|
||||
/// Register an input tensor that won't be automatically read into a local variable.
|
||||
///
|
||||
/// It is therefore the responsibility of the operation to read the given tensor.
|
||||
pub fn input_unhandled(&mut self, tensor: &TensorIr) -> FuseArg {
|
||||
if self.resources.indexed.contains_key(&tensor.id) {
|
||||
panic!("Can't add a new input that is already used in an index operation");
|
||||
}
|
||||
|
||||
self.resources.outputs.update(tensor);
|
||||
|
||||
let precision = tensor.dtype.into();
|
||||
// Bool tensors are encoded as bool_precision.
|
||||
let precision_input = match precision {
|
||||
FuseType::Bool => self.bool_precision,
|
||||
_ => precision,
|
||||
};
|
||||
let new_input = self
|
||||
.resources
|
||||
.inputs
|
||||
.insert(precision_input, tensor.clone());
|
||||
let arg = FuseArg::Input(new_input, precision_input, LayoutInfo::Unknown);
|
||||
|
||||
self.resources.inputs_unhandled.push(tensor.id);
|
||||
arg
|
||||
}
|
||||
|
||||
/// Register an input tensor.
|
||||
pub fn input_quantized_unhandled(&mut self, tensor: &TensorIr) -> Option<(FuseArg, FuseArg)> {
|
||||
if self.resources.indexed.contains_key(&tensor.id) {
|
||||
panic!("Can't add a new input that is already used in an index operation");
|
||||
}
|
||||
self.resources.outputs.update(tensor);
|
||||
|
||||
let precision = tensor.dtype.into();
|
||||
let precision_scales = match tensor.dtype {
|
||||
DType::QFloat(scheme) => match scheme.param {
|
||||
QuantParam::F32 => FuseType::F32,
|
||||
QuantParam::F16 => FuseType::F16,
|
||||
QuantParam::BF16 => FuseType::BF16,
|
||||
QuantParam::UE8M0 | QuantParam::UE4M3 => {
|
||||
unimplemented!("Unsupported fuse precision");
|
||||
}
|
||||
},
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let (new_input, q_index) = self.resources.inputs.insert_quant(tensor.clone());
|
||||
let input = FuseArg::Input(new_input, precision, LayoutInfo::Unknown);
|
||||
let scales = FuseArg::Input(q_index, precision_scales, LayoutInfo::Unknown);
|
||||
|
||||
self.resources.inputs_unhandled.push(tensor.id);
|
||||
Some((input, scales))
|
||||
}
|
||||
|
||||
/// Register an input tensor.
|
||||
pub fn input(&mut self, tensor: &TensorIr) -> Option<FuseArg> {
|
||||
if matches!(tensor.dtype, DType::QFloat(_)) {
|
||||
return None;
|
||||
}
|
||||
|
||||
self.resources.outputs.update(tensor);
|
||||
|
||||
self.block_current.input(tensor, &mut self.resources)
|
||||
}
|
||||
|
||||
/// Register an input tensor.
|
||||
pub fn input_quantized(&mut self, tensor: &TensorIr) -> Option<QuantInput> {
|
||||
self.resources.outputs.update(tensor);
|
||||
self.block_current.input_quant(tensor, &mut self.resources)
|
||||
}
|
||||
|
||||
/// Register an output tensor.
|
||||
pub fn output(&mut self, tensor: &TensorIr) -> Option<FuseArg> {
|
||||
if matches!(tensor.dtype, DType::QFloat(_)) {
|
||||
return None;
|
||||
}
|
||||
self.block_current.output(tensor, &mut self.resources)
|
||||
}
|
||||
|
||||
/// Register an input that will be accessed using custom indexing with no vectorization.
|
||||
pub fn input_indexed(&mut self, tensor: &TensorIr) -> Option<FuseArg> {
|
||||
if matches!(tensor.dtype, DType::QFloat(_)) {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(val) = self.resources.indexed.get(&tensor.id) {
|
||||
self.resources.outputs.update(tensor);
|
||||
return Some(val.clone());
|
||||
};
|
||||
|
||||
if self.resources.inputs.get(tensor.id).is_some() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if self.resources.outputs.get(tensor.id).is_some() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let input = self.input_unhandled(tensor);
|
||||
self.resources.indexed.insert(tensor.id, input.clone());
|
||||
|
||||
Some(input)
|
||||
}
|
||||
|
||||
/// Register an input with swapped dims.
|
||||
pub fn input_swap_dims(
|
||||
&mut self,
|
||||
tensor: &TensorIr,
|
||||
output: &TensorIr,
|
||||
dims: (usize, usize),
|
||||
) -> Option<FuseArg> {
|
||||
if matches!(tensor.dtype, DType::QFloat(_)) {
|
||||
return None;
|
||||
}
|
||||
|
||||
self.resources.outputs.update(tensor);
|
||||
self.block_current
|
||||
.input_swap_dims(tensor, output, dims, &mut self.resources)
|
||||
}
|
||||
|
||||
/// Register an input that is reshaped.
|
||||
pub fn input_reshaped(&mut self, tensor: &TensorIr, output: &TensorIr) -> Option<FuseArg> {
|
||||
if matches!(tensor.dtype, DType::QFloat(_)) {
|
||||
return None;
|
||||
}
|
||||
|
||||
self.resources.outputs.update(tensor);
|
||||
self.block_current
|
||||
.input_reshaped(tensor, output, &mut self.resources)
|
||||
}
|
||||
|
||||
/// Register a scalar value.
|
||||
pub fn scalar(&mut self, elem: &ScalarIr, dtype: DType) -> FuseArg {
|
||||
let precision = dtype.into();
|
||||
let id = if let ScalarIr::UInt(value) = elem {
|
||||
ScalarId { value: *value }
|
||||
} else {
|
||||
unreachable!() // should always be u64
|
||||
};
|
||||
|
||||
// Bool scalars are encoded as bool_precision.
|
||||
let precision = match precision {
|
||||
FuseType::Bool => self.bool_precision,
|
||||
_ => precision,
|
||||
};
|
||||
let new_index = self.resources.scalars.len();
|
||||
|
||||
self.resources.scalars.push((precision, id.value));
|
||||
FuseArg::Scalar(new_index, precision)
|
||||
}
|
||||
|
||||
/// Finish fusing and returns the created trace.
|
||||
pub fn finish(&mut self, shape_ref: Shape) -> FuseTrace {
|
||||
let mut resources = self.resources.clone();
|
||||
let mut outputs = RegisteredTensors::default();
|
||||
let mut buffers = Vec::new();
|
||||
|
||||
for tensor in resources.buffers.iter() {
|
||||
let (tensor, ty) = tensor.as_normal_tensor().unwrap();
|
||||
outputs.insert(*ty, tensor.clone());
|
||||
}
|
||||
|
||||
let mut blocks = Vec::new();
|
||||
|
||||
let mut register_block = |block: &FuseBlockBuilder| {
|
||||
let block = block.build(&self.resources, &mut outputs, &mut buffers);
|
||||
blocks.push(block);
|
||||
};
|
||||
|
||||
for block in self.blocks_previous.iter() {
|
||||
register_block(block);
|
||||
}
|
||||
self.block_current.shape_ref = shape_ref;
|
||||
register_block(&self.block_current);
|
||||
|
||||
// We update the output tensors registered to be the ones that are written to in global
|
||||
// memory.
|
||||
resources.outputs = outputs;
|
||||
|
||||
FuseTrace { blocks, resources }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
pub(crate) mod block;
|
||||
|
||||
mod base;
|
||||
mod fuser;
|
||||
|
||||
pub use base::*;
|
||||
pub use fuser::*;
|
||||
@@ -0,0 +1,11 @@
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
pub mod optim;
|
||||
|
||||
mod base;
|
||||
|
||||
pub(crate) mod engine;
|
||||
pub(crate) mod tune;
|
||||
|
||||
pub use base::*;
|
||||
@@ -0,0 +1,63 @@
|
||||
use crate::optim::{
|
||||
elemwise::{ElemwiseOptimization, ElemwiseOptimizationState},
|
||||
matmul::{MatmulOptimization, MatmulOptimizationState},
|
||||
reduce::{ReduceOptimization, ReduceOptimizationState},
|
||||
reduce_broadcasted::{ReduceBroadcastedOptimization, ReduceBroadcastedOptimizationState},
|
||||
};
|
||||
use cubecl::Runtime;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Fusion optimization type for cubecl.
|
||||
///
|
||||
/// More optimization variants should be added here.
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum CubeOptimization<R: Runtime> {
|
||||
ElementWise(ElemwiseOptimization<R>),
|
||||
Matmul(MatmulOptimization<R>),
|
||||
Reduce(ReduceOptimization<R>),
|
||||
ReduceBroadcasted(ReduceBroadcastedOptimization<R>),
|
||||
}
|
||||
|
||||
impl<R: Runtime> core::fmt::Debug for CubeOptimization<R> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let value = self.to_opt_state();
|
||||
f.write_fmt(format_args!("{value:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> CubeOptimization<R> {
|
||||
/// Serializes the current optimization to its state.
|
||||
pub fn to_opt_state(&self) -> CubeOptimizationState {
|
||||
match self {
|
||||
Self::ElementWise(value) => CubeOptimizationState::ElementWise(value.to_state()),
|
||||
Self::Matmul(value) => CubeOptimizationState::Matmul(value.to_state()),
|
||||
Self::Reduce(value) => CubeOptimizationState::Reduce(value.to_state()),
|
||||
Self::ReduceBroadcasted(value) => {
|
||||
CubeOptimizationState::ReduceBroadcasted(value.to_state())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> burn_fusion::NumOperations for CubeOptimization<R> {
|
||||
fn len(&self) -> usize {
|
||||
match self {
|
||||
Self::ElementWise(op) => op.num_ops_fused(),
|
||||
Self::Matmul(op) => op.num_ops_fused(),
|
||||
Self::Reduce(op) => op.num_ops_fused(),
|
||||
Self::ReduceBroadcasted(op) => op.num_ops_fused(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fusion optimization state type for cubecl.
|
||||
///
|
||||
/// More optimization variants should be added here.
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub enum CubeOptimizationState {
|
||||
ElementWise(ElemwiseOptimizationState),
|
||||
Matmul(MatmulOptimizationState),
|
||||
Reduce(ReduceOptimizationState),
|
||||
ReduceBroadcasted(ReduceBroadcastedOptimizationState),
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
use super::optimization::ElemwiseOptimization;
|
||||
use crate::{
|
||||
engine::{
|
||||
codegen::ir::FuseType,
|
||||
fuser::TraceOperationFuser,
|
||||
settings::{FuseSettings, RefLayoutSetting, VectorizationSetting},
|
||||
},
|
||||
optim::CubeOptimization,
|
||||
};
|
||||
use burn_fusion::OperationFuser;
|
||||
use burn_std::Shape;
|
||||
use cubecl::Runtime;
|
||||
|
||||
/// Fuses element wise operations.
|
||||
pub struct ElementWiseFuser<R: Runtime> {
|
||||
fuser: TraceOperationFuser,
|
||||
device: R::Device,
|
||||
}
|
||||
|
||||
impl<R: Runtime> Clone for ElementWiseFuser<R> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
fuser: self.fuser.clone(),
|
||||
device: self.device.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> ElementWiseFuser<R> {
|
||||
pub fn shape_id(&self) -> Shape {
|
||||
self.fuser.current_output_shape.clone()
|
||||
}
|
||||
pub fn new(device: R::Device, bool_precision: FuseType) -> Self {
|
||||
let client = R::client(&device);
|
||||
let props = client.properties();
|
||||
let max_bindings = props.hardware.max_bindings;
|
||||
|
||||
Self {
|
||||
fuser: TraceOperationFuser::new(
|
||||
max_bindings,
|
||||
bool_precision,
|
||||
FuseSettings {
|
||||
broadcast: true,
|
||||
output_shape_updates: true,
|
||||
inplace: true,
|
||||
vectorization: VectorizationSetting::Activated,
|
||||
ref_layout: RefLayoutSetting::Any,
|
||||
},
|
||||
),
|
||||
device,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> OperationFuser<CubeOptimization<R>> for ElementWiseFuser<R> {
|
||||
fn fuse(&mut self, operation: &burn_ir::OperationIr) {
|
||||
self.fuser.fuse(operation);
|
||||
}
|
||||
|
||||
fn finish(&mut self) -> CubeOptimization<R> {
|
||||
let client = R::client(&self.device);
|
||||
let trace = self.fuser.finish();
|
||||
let elementwise = ElemwiseOptimization::new(trace, client, self.device.clone(), self.len());
|
||||
|
||||
CubeOptimization::ElementWise(elementwise)
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.fuser.reset()
|
||||
}
|
||||
|
||||
fn status(&self) -> burn_fusion::FuserStatus {
|
||||
self.fuser.status()
|
||||
}
|
||||
|
||||
fn properties(&self) -> burn_fusion::FuserProperties {
|
||||
self.fuser.properties()
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.fuser.len()
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn OperationFuser<CubeOptimization<R>>> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
mod fuser;
|
||||
mod optimization;
|
||||
|
||||
pub use fuser::*;
|
||||
pub use optimization::*;
|
||||
@@ -0,0 +1,140 @@
|
||||
use crate::{
|
||||
CubeFusionHandle,
|
||||
engine::{
|
||||
codegen::{
|
||||
io::ref_len,
|
||||
ir::{
|
||||
FuseArg, FuseBlockConfig, GlobalArgs, GlobalArgsLaunch, RefLayout,
|
||||
multi_block_variables_init,
|
||||
},
|
||||
kernel::{fuse_on_write, init_locals},
|
||||
},
|
||||
launch::{
|
||||
FuseTraceLauncher,
|
||||
runner::{TraceRunner, Vectorization},
|
||||
},
|
||||
trace::FuseTrace,
|
||||
},
|
||||
};
|
||||
use burn_fusion::stream::Context;
|
||||
use cubecl::{CubeDim, calculate_cube_count_elemwise, client::ComputeClient, prelude::*};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(new)]
|
||||
/// Fuse element wise operations into a single kernel.
|
||||
pub struct ElemwiseOptimization<R: Runtime> {
|
||||
pub(crate) trace: FuseTrace,
|
||||
client: ComputeClient<R>,
|
||||
device: R::Device,
|
||||
len: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
/// State for the [elemwise optimization](ElemwiseOptimization).
|
||||
pub struct ElemwiseOptimizationState {
|
||||
trace: FuseTrace,
|
||||
len: usize,
|
||||
}
|
||||
|
||||
impl<R: Runtime> ElemwiseOptimization<R> {
|
||||
/// Execute the optimization.
|
||||
pub fn execute<BT: CubeElement>(&self, context: &mut Context<'_, CubeFusionHandle<R>>) {
|
||||
let launcher = FuseTraceLauncher::new(&self.trace, &ElemwiseRunner);
|
||||
|
||||
match launcher.launch::<BT>(&self.client, &self.device, context) {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
panic!("{err:?} - {:?}", self.trace);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of element wise operations fused.
|
||||
pub fn num_ops_fused(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
/// Create an optimization from its [state](ElemwiseOptimizationState).
|
||||
pub fn from_state(device: &R::Device, state: ElemwiseOptimizationState) -> Self {
|
||||
Self {
|
||||
trace: state.trace,
|
||||
len: state.len,
|
||||
client: R::client(device),
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert the optimization to its [state](ElemwiseOptimizationState).
|
||||
pub fn to_state(&self) -> ElemwiseOptimizationState {
|
||||
ElemwiseOptimizationState {
|
||||
trace: self.trace.clone(),
|
||||
len: self.len,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ElemwiseRunner;
|
||||
|
||||
impl<R: Runtime> Vectorization<R> for ElemwiseRunner {}
|
||||
impl<R: Runtime> TraceRunner<R> for ElemwiseRunner {
|
||||
type Error = LaunchError; // No error possible
|
||||
|
||||
fn run<'a>(
|
||||
&'a self,
|
||||
client: &'a ComputeClient<R>,
|
||||
inputs: GlobalArgsLaunch<'a, R>,
|
||||
outputs: GlobalArgsLaunch<'a, R>,
|
||||
configs: &[FuseBlockConfig],
|
||||
) -> Result<(), Self::Error> {
|
||||
let config = &configs[0];
|
||||
let shape = match &config.ref_layout {
|
||||
RefLayout::Concrete(arg) => match arg {
|
||||
FuseArg::Input(..) => inputs.shape_ref(&config.ref_layout, config.rank),
|
||||
FuseArg::Output(..) => outputs.shape_ref(&config.ref_layout, config.rank),
|
||||
_ => panic!("Invalid concreate ref layout"),
|
||||
},
|
||||
RefLayout::Virtual(_) => inputs.shape_ref(&config.ref_layout, config.rank),
|
||||
};
|
||||
let working_units = shape.iter().product::<usize>() / config.width;
|
||||
let cube_dim = CubeDim::new(client, working_units);
|
||||
let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim);
|
||||
let address_type = inputs
|
||||
.required_address_type()
|
||||
.max(outputs.required_address_type());
|
||||
|
||||
unsafe {
|
||||
elemwise_fuse::launch_unchecked(
|
||||
client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
address_type,
|
||||
inputs,
|
||||
outputs,
|
||||
config.clone(),
|
||||
)?;
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
fn elemwise_fuse(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
) {
|
||||
// We write no values for this fusion.
|
||||
let values = Registry::<FuseArg, Line<f32>>::new();
|
||||
let args = comptime![Vec::<FuseArg>::new()];
|
||||
let pos = ABSOLUTE_POS;
|
||||
|
||||
multi_block_variables_init(config, &mut outputs.variables);
|
||||
|
||||
let mut locals = init_locals(inputs, outputs, config);
|
||||
let length = ref_len(inputs, outputs, &locals, config);
|
||||
|
||||
if pos < length {
|
||||
fuse_on_write::<f32>(inputs, outputs, &mut locals, pos, values, args, config)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,548 @@
|
||||
use crate::engine::codegen::{
|
||||
io::ref_line_size,
|
||||
ir::{FuseArg, FuseBlockConfig, FuseType, GlobalArgs, LocalArgs, multi_block_variables_init},
|
||||
kernel::init_locals,
|
||||
view::{FusedOutput, GlobalInput, GlobalInputExpand},
|
||||
};
|
||||
use cubecl::{
|
||||
intrinsic,
|
||||
prelude::*,
|
||||
quant::scheme::{QuantLevel, QuantScheme},
|
||||
std::{
|
||||
FastDivmod,
|
||||
quant::{
|
||||
RunWithQuantType,
|
||||
view::{QuantizedView, run_with_quant_type},
|
||||
},
|
||||
tensor::{
|
||||
View, ViewExpand,
|
||||
layout::{Coords1d, Coords2d, VirtualLayout},
|
||||
},
|
||||
},
|
||||
};
|
||||
use cubek::matmul::{
|
||||
components::global::memory::{
|
||||
BatchLayout, BlockScaledLayout, GlobalLayout, GlobalLayoutConfig, GlobalLayoutExpand,
|
||||
GlobalScaleLayout, GlobalScaleLayoutExpand, NoopLayout,
|
||||
},
|
||||
definition::MatrixLayout,
|
||||
launch::{BatchedCoords, MatmulArgs},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FusedMatmulArgs;
|
||||
|
||||
#[derive(CubeLaunch, CubeType)]
|
||||
pub struct FusedMatmulInput {
|
||||
global: GlobalArgs,
|
||||
#[cube(comptime)]
|
||||
config: FuseBlockConfig,
|
||||
#[cube(comptime)]
|
||||
a: MatmulArg,
|
||||
#[cube(comptime)]
|
||||
b: MatmulArg,
|
||||
#[cube(comptime)]
|
||||
c: Option<MatmulArg>,
|
||||
#[cube(comptime)]
|
||||
out: FuseArg,
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl MatmulArgs for FusedMatmulArgs {
|
||||
type Output<EO: Numeric> = GlobalArgs;
|
||||
type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = FusedMatmulInput;
|
||||
type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = FusedMatmulState;
|
||||
type Config = ();
|
||||
|
||||
fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
|
||||
inputs: &Self::Input<Lhs, Rhs, EO>,
|
||||
outputs: &mut Self::Output<EO>,
|
||||
_config: (),
|
||||
#[comptime] lhs_layout_config: GlobalLayoutConfig,
|
||||
#[comptime] rhs_layout_config: GlobalLayoutConfig,
|
||||
#[comptime] out_layout_config: GlobalLayoutConfig,
|
||||
) -> Self::State<Lhs, Rhs, EO> {
|
||||
multi_block_variables_init(&inputs.config, &mut outputs.variables);
|
||||
|
||||
let mut locals = init_locals(&inputs.global, outputs, &inputs.config);
|
||||
let rank = comptime![inputs.config.rank];
|
||||
|
||||
let mut batch_shape = Sequence::new();
|
||||
let mut batch_strides_out = Sequence::new();
|
||||
|
||||
#[unroll]
|
||||
for i in 0..rank - 2 {
|
||||
batch_shape.push(FastDivmod::new_Fallback(locals.ref_shape[i] as u32));
|
||||
batch_strides_out.push(locals.ref_strides[i]);
|
||||
}
|
||||
|
||||
let batch_lhs = input_batch_layout(
|
||||
&inputs.global,
|
||||
&batch_shape,
|
||||
comptime![inputs.a.clone()],
|
||||
comptime![inputs.config.clone()],
|
||||
);
|
||||
let batch_rhs = input_batch_layout(
|
||||
&inputs.global,
|
||||
&batch_shape,
|
||||
comptime![inputs.b.clone()],
|
||||
comptime![inputs.config.clone()],
|
||||
);
|
||||
let batch_acc = match comptime![inputs.c.clone()] {
|
||||
Some(c) => Option::Some(input_batch_layout(
|
||||
&inputs.global,
|
||||
&batch_shape,
|
||||
comptime![c],
|
||||
comptime![inputs.config.clone()],
|
||||
)),
|
||||
None => Option::new_None(),
|
||||
};
|
||||
let batch_out = BatchLayout::new(batch_strides_out, batch_shape.clone());
|
||||
|
||||
FusedMatmulState::new(
|
||||
inputs,
|
||||
outputs,
|
||||
&mut locals,
|
||||
batch_lhs,
|
||||
batch_rhs,
|
||||
batch_acc,
|
||||
VirtualLayout::new::<BatchLayout>(batch_out),
|
||||
batch_shape,
|
||||
&inputs.config,
|
||||
lhs_layout_config,
|
||||
rhs_layout_config,
|
||||
out_layout_config,
|
||||
)
|
||||
}
|
||||
|
||||
fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
|
||||
state: &Self::State<Lhs, Rhs, EO>,
|
||||
) -> View<Line<Lhs>, BatchedCoords> {
|
||||
global_view(
|
||||
&state.inputs,
|
||||
&state.locals,
|
||||
&state.batch_shape,
|
||||
comptime![state.a.clone()],
|
||||
comptime![state.config.clone()],
|
||||
state.lhs_layout_config,
|
||||
)
|
||||
}
|
||||
|
||||
fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
|
||||
state: &Self::State<Lhs, Rhs, EO>,
|
||||
batch: usize,
|
||||
) -> usize {
|
||||
state.a_batch.to_source_pos(batch)
|
||||
}
|
||||
|
||||
fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
|
||||
state: &Self::State<Lhs, Rhs, EO>,
|
||||
) -> View<Line<Rhs>, BatchedCoords> {
|
||||
global_view(
|
||||
&state.inputs,
|
||||
&state.locals,
|
||||
&state.batch_shape,
|
||||
comptime![state.b.clone()],
|
||||
comptime![state.config.clone()],
|
||||
comptime![state.rhs_layout_config],
|
||||
)
|
||||
}
|
||||
|
||||
fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
|
||||
state: &Self::State<Lhs, Rhs, EO>,
|
||||
batch: usize,
|
||||
) -> usize {
|
||||
state.b_batch.to_source_pos(batch)
|
||||
}
|
||||
|
||||
fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
|
||||
state: &Self::State<Lhs, Rhs, EO>,
|
||||
) -> Option<View<Line<EO>, BatchedCoords>> {
|
||||
match comptime![state.c.clone()] {
|
||||
Some(c) => {
|
||||
let view = global_view(
|
||||
&state.inputs,
|
||||
&state.locals,
|
||||
&state.batch_shape,
|
||||
c,
|
||||
comptime![state.config.clone()],
|
||||
comptime![state.out_layout_config],
|
||||
);
|
||||
Option::Some(view)
|
||||
}
|
||||
None => Option::new_None(),
|
||||
}
|
||||
}
|
||||
|
||||
fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
|
||||
state: &Self::State<Lhs, Rhs, EO>,
|
||||
batch: usize,
|
||||
) -> usize {
|
||||
match state.c_batch {
|
||||
Some(c_batch) => c_batch.to_source_pos(batch),
|
||||
None => batch,
|
||||
}
|
||||
}
|
||||
|
||||
fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
|
||||
state: &mut Self::State<Lhs, Rhs, EO>,
|
||||
) -> View<Line<EO>, BatchedCoords, ReadWrite> {
|
||||
let rank = comptime![state.config.rank];
|
||||
|
||||
let shape_row = state.locals.ref_shape[rank - 2] as u32;
|
||||
let shape_col = state.locals.ref_shape[rank - 1] as u32;
|
||||
|
||||
let stride_row = state.locals.ref_strides[rank - 2];
|
||||
let stride_col = state.locals.ref_strides[rank - 1];
|
||||
|
||||
let layout = GlobalLayout::new(
|
||||
VirtualLayout::new::<NoopLayout>(NoopLayout::new()),
|
||||
shape_row,
|
||||
shape_col,
|
||||
stride_row,
|
||||
stride_col,
|
||||
ref_line_size(&state.locals),
|
||||
1u32,
|
||||
state.out_layout_config,
|
||||
);
|
||||
let mut buffer = FusedOutput::new(
|
||||
&state.inputs,
|
||||
&mut state.outputs,
|
||||
&mut state.locals,
|
||||
comptime![state.out.clone()],
|
||||
comptime![state.config.clone()],
|
||||
);
|
||||
View::new_mut::<FusedOutput, Coords1d>(&mut buffer, layout)
|
||||
}
|
||||
|
||||
fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
|
||||
state: &Self::State<Lhs, Rhs, EO>,
|
||||
batch: usize,
|
||||
) -> usize {
|
||||
state.out_batch.to_source_pos(batch)
|
||||
}
|
||||
|
||||
fn runtime_config<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(_state: &Self::State<Lhs, Rhs, EO>) {
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn global_view<E: Numeric>(
|
||||
inputs: &GlobalArgs,
|
||||
locals: &LocalArgs,
|
||||
batch_shape: &Sequence<FastDivmod<u32>>,
|
||||
#[comptime] arg: MatmulArg,
|
||||
#[comptime] config: FuseBlockConfig,
|
||||
#[comptime] layout_config: GlobalLayoutConfig,
|
||||
) -> View<Line<E>, BatchedCoords> {
|
||||
let rank = comptime![config.rank];
|
||||
let data = comptime![arg.data().clone()];
|
||||
let data_tensor = match comptime![data.clone()] {
|
||||
FuseArg::Input(pos, ..) => inputs.tensors.index(pos),
|
||||
_ => panic!("Input must be concrete"),
|
||||
};
|
||||
|
||||
let mut shape_row = data_tensor.tensor.shape(rank - 2) as u32;
|
||||
let mut shape_col = data_tensor.tensor.shape(rank - 1) as u32;
|
||||
let mut packing = comptime![1];
|
||||
|
||||
if arg.scheme().is_some() {
|
||||
let scheme = arg.scheme().unwrap();
|
||||
let num_quants = scheme.num_quants() as u32;
|
||||
comptime![packing = num_quants];
|
||||
match comptime![layout_config.matrix_layout] {
|
||||
MatrixLayout::RowMajor => shape_col *= num_quants,
|
||||
MatrixLayout::ColMajor => shape_row *= num_quants,
|
||||
};
|
||||
}
|
||||
|
||||
let shape = (shape_row, shape_col);
|
||||
|
||||
// Noop for normal inputs because batch offset is cached, quantized uses logical batches
|
||||
let batch_layout = match comptime![arg.clone()] {
|
||||
MatmulArg::Normal(_) => VirtualLayout::new::<NoopLayout>(NoopLayout::new()),
|
||||
MatmulArg::Quantized { data, .. } => {
|
||||
let data_arg = comptime![MatmulArg::Normal(data)];
|
||||
input_batch_layout(inputs, batch_shape, data_arg, comptime![config.clone()])
|
||||
}
|
||||
};
|
||||
|
||||
let data_layout = global_layout(
|
||||
inputs,
|
||||
shape,
|
||||
batch_layout,
|
||||
arg.data().clone(),
|
||||
config.clone(),
|
||||
data_tensor.tensor.line_size(),
|
||||
layout_config,
|
||||
packing,
|
||||
);
|
||||
let data_buf = GlobalInput::new(inputs, locals, data, comptime![config.clone()], None);
|
||||
|
||||
match comptime![arg.clone()] {
|
||||
MatmulArg::Normal(_) => View::new::<GlobalInput, Coords1d>(&data_buf, data_layout),
|
||||
MatmulArg::Quantized { scales, scheme, .. } => {
|
||||
let scales_layout = match comptime![scheme.level] {
|
||||
QuantLevel::Tensor => GlobalScaleLayout::new_PerTensor(shape),
|
||||
QuantLevel::Block(block_size) => {
|
||||
let block_size = comptime![block_size.as_dim::<2>()];
|
||||
|
||||
let scales_arg = comptime![MatmulArg::Normal(scales.clone())];
|
||||
let batch_layout = input_batch_layout(
|
||||
inputs,
|
||||
batch_shape,
|
||||
scales_arg,
|
||||
comptime![config.clone()],
|
||||
);
|
||||
|
||||
let scales_layout = global_layout(
|
||||
inputs,
|
||||
shape,
|
||||
batch_layout,
|
||||
comptime![scales.clone()],
|
||||
comptime![config.clone()],
|
||||
1usize,
|
||||
layout_config,
|
||||
1u32,
|
||||
);
|
||||
GlobalScaleLayout::new_BlockScaled(BlockScaledLayout::new(
|
||||
shape,
|
||||
scales_layout,
|
||||
comptime![(block_size[0] as u32, block_size[1] as u32)],
|
||||
))
|
||||
}
|
||||
};
|
||||
let scales_buf = GlobalInput::new(inputs, locals, scales, config, None);
|
||||
create_quant_view_dynamic(data_buf, data_layout, scales_buf, scales_layout, scheme)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn input_batch_layout(
|
||||
inputs: &GlobalArgs,
|
||||
batch_shape: &Sequence<FastDivmod<u32>>,
|
||||
#[comptime] arg: MatmulArg,
|
||||
#[comptime] config: FuseBlockConfig,
|
||||
) -> VirtualLayout<usize, usize> {
|
||||
let rank = comptime![config.rank];
|
||||
match comptime![arg.clone()] {
|
||||
MatmulArg::Normal(arg) => {
|
||||
let data_tensor = match comptime![arg.clone()] {
|
||||
FuseArg::Input(pos, ..) => inputs.tensors.index(pos),
|
||||
_ => panic!("Input must be concrete"),
|
||||
};
|
||||
|
||||
let mut batch_strides = Sequence::new();
|
||||
#[unroll]
|
||||
for i in 0..rank - 2 {
|
||||
let shape = data_tensor.tensor.shape(i);
|
||||
let stride = select(shape == 1, 0, data_tensor.tensor.stride(i));
|
||||
batch_strides.push(stride);
|
||||
}
|
||||
|
||||
VirtualLayout::new::<BatchLayout>(BatchLayout::new(batch_strides, batch_shape.clone()))
|
||||
}
|
||||
MatmulArg::Quantized { .. } => VirtualLayout::new::<NoopLayout>(NoopLayout::new()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn global_layout(
|
||||
inputs: &GlobalArgs,
|
||||
shape: Coords2d,
|
||||
batch_layout: VirtualLayout<usize, usize>,
|
||||
#[comptime] arg: FuseArg,
|
||||
#[comptime] config: FuseBlockConfig,
|
||||
#[comptime] line_size: LineSize,
|
||||
#[comptime] layout_config: GlobalLayoutConfig,
|
||||
#[comptime] packing: u32,
|
||||
) -> GlobalLayout {
|
||||
let rank = comptime![config.rank];
|
||||
let data_tensor = match comptime![arg.clone()] {
|
||||
FuseArg::Input(pos, ..) => inputs.tensors.index(pos),
|
||||
_ => panic!("Input must be concrete"),
|
||||
};
|
||||
|
||||
let (shape_row, shape_col) = shape;
|
||||
|
||||
let stride_row = data_tensor.tensor.stride(rank - 2);
|
||||
let stride_col = data_tensor.tensor.stride(rank - 1);
|
||||
|
||||
GlobalLayout::new(
|
||||
batch_layout,
|
||||
shape_row,
|
||||
shape_col,
|
||||
stride_row,
|
||||
stride_col,
|
||||
line_size,
|
||||
packing,
|
||||
layout_config,
|
||||
)
|
||||
}
|
||||
|
||||
struct CreateQuantView<'a, E: Numeric> {
|
||||
scope: &'a mut Scope,
|
||||
data_buf: GlobalInputExpand,
|
||||
data_layout: GlobalLayoutExpand,
|
||||
scales_buf: GlobalInputExpand,
|
||||
scales_layout: GlobalScaleLayoutExpand,
|
||||
scheme: QuantScheme,
|
||||
_ty: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<'a, E: Numeric> RunWithQuantType for CreateQuantView<'a, E> {
|
||||
type Output = ViewExpand<Line<E>, BatchedCoords>;
|
||||
|
||||
fn execute<Q: CubePrimitive, S: CubePrimitive>(self) -> Self::Output {
|
||||
create_quant_view::expand::<E, Q, S>(
|
||||
self.scope,
|
||||
self.data_buf,
|
||||
self.data_layout,
|
||||
self.scales_buf,
|
||||
self.scales_layout,
|
||||
self.scheme,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
#[allow(unused)]
|
||||
fn create_quant_view_dynamic<E: Numeric>(
|
||||
data_buf: GlobalInput,
|
||||
data_layout: GlobalLayout,
|
||||
scales_buf: GlobalInput,
|
||||
scales_layout: GlobalScaleLayout,
|
||||
#[comptime] scheme: QuantScheme,
|
||||
) -> View<Line<E>, BatchedCoords> {
|
||||
intrinsic!(|scope| {
|
||||
let func = CreateQuantView {
|
||||
scope,
|
||||
data_buf,
|
||||
data_layout,
|
||||
scales_buf,
|
||||
scales_layout,
|
||||
scheme,
|
||||
_ty: PhantomData,
|
||||
};
|
||||
run_with_quant_type(func, scheme)
|
||||
})
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn create_quant_view<E: Numeric, Q: CubePrimitive, S: CubePrimitive>(
|
||||
data_buf: GlobalInput,
|
||||
data_layout: GlobalLayout,
|
||||
scales_buf: GlobalInput,
|
||||
scales_layout: GlobalScaleLayout,
|
||||
#[comptime] scheme: QuantScheme,
|
||||
) -> View<Line<E>, BatchedCoords> {
|
||||
let data_view: View<Line<Q>, BatchedCoords> =
|
||||
View::new::<GlobalInput, Coords1d>(&data_buf, data_layout);
|
||||
let scales_view: View<S, BatchedCoords> =
|
||||
View::new::<GlobalInput, Coords1d>(&scales_buf, scales_layout);
|
||||
QuantizedView::new(data_view, scales_view, scheme).view()
|
||||
}
|
||||
|
||||
#[derive(CubeType)]
|
||||
pub struct FusedMatmulState {
|
||||
inputs: GlobalArgs,
|
||||
outputs: GlobalArgs,
|
||||
locals: LocalArgs,
|
||||
a_batch: VirtualLayout<Coords1d, Coords1d>,
|
||||
b_batch: VirtualLayout<Coords1d, Coords1d>,
|
||||
c_batch: Option<VirtualLayout<Coords1d, Coords1d>>,
|
||||
out_batch: VirtualLayout<Coords1d, Coords1d>,
|
||||
#[cube(comptime)]
|
||||
config: FuseBlockConfig,
|
||||
#[cube(comptime)]
|
||||
a: MatmulArg,
|
||||
#[cube(comptime)]
|
||||
b: MatmulArg,
|
||||
#[cube(comptime)]
|
||||
c: Option<MatmulArg>,
|
||||
#[cube(comptime)]
|
||||
out: FuseArg,
|
||||
#[cube(comptime)]
|
||||
lhs_layout_config: GlobalLayoutConfig,
|
||||
#[cube(comptime)]
|
||||
rhs_layout_config: GlobalLayoutConfig,
|
||||
#[cube(comptime)]
|
||||
out_layout_config: GlobalLayoutConfig,
|
||||
batch_shape: Sequence<FastDivmod<u32>>,
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl FusedMatmulState {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
inputs: &FusedMatmulInput,
|
||||
outputs: &mut GlobalArgs,
|
||||
locals: &mut LocalArgs,
|
||||
a_batch: VirtualLayout<usize, usize>,
|
||||
b_batch: VirtualLayout<usize, usize>,
|
||||
c_batch: Option<VirtualLayout<usize, usize>>,
|
||||
out_batch: VirtualLayout<usize, usize>,
|
||||
batch_shape: Sequence<FastDivmod<u32>>,
|
||||
#[comptime] config: &FuseBlockConfig,
|
||||
#[comptime] lhs_layout_config: GlobalLayoutConfig,
|
||||
#[comptime] rhs_layout_config: GlobalLayoutConfig,
|
||||
#[comptime] out_layout_config: GlobalLayoutConfig,
|
||||
) -> FusedMatmulState {
|
||||
FusedMatmulState {
|
||||
inputs: inputs.global.clone(),
|
||||
outputs: outputs.clone(),
|
||||
config: comptime![config.clone()],
|
||||
locals: locals.clone(),
|
||||
a_batch,
|
||||
b_batch,
|
||||
c_batch,
|
||||
out_batch,
|
||||
a: comptime![inputs.a.clone()],
|
||||
b: comptime![inputs.b.clone()],
|
||||
c: comptime![inputs.c.clone()],
|
||||
out: comptime![inputs.out.clone()],
|
||||
lhs_layout_config,
|
||||
rhs_layout_config,
|
||||
out_layout_config,
|
||||
batch_shape,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
/// Argument to a matmul operation.
|
||||
pub enum MatmulArg {
|
||||
Normal(FuseArg),
|
||||
Quantized {
|
||||
data: FuseArg,
|
||||
scales: FuseArg,
|
||||
precision: FuseType,
|
||||
scheme: QuantScheme,
|
||||
},
|
||||
}
|
||||
|
||||
impl MatmulArg {
|
||||
pub fn data(&self) -> &FuseArg {
|
||||
match self {
|
||||
MatmulArg::Normal(arg) => arg,
|
||||
MatmulArg::Quantized { data, .. } => data,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn scheme(&self) -> Option<&QuantScheme> {
|
||||
match self {
|
||||
MatmulArg::Normal(_) => None,
|
||||
MatmulArg::Quantized { scheme, .. } => Some(scheme),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn precision(&self) -> FuseType {
|
||||
match self {
|
||||
MatmulArg::Normal(arg) => arg.precision(),
|
||||
MatmulArg::Quantized { precision, .. } => *precision,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,160 @@
|
||||
use super::optimization::{FusedMatmul, MatmulOptimization};
|
||||
use crate::{
|
||||
engine::{codegen::ir::FuseType, fuser::TraceOperationFuser, settings::FuseSettings},
|
||||
optim::CubeOptimization,
|
||||
optim::matmul::args::MatmulArg,
|
||||
};
|
||||
use burn_fusion::{FuserStatus, OperationFuser};
|
||||
use burn_ir::{FloatOperationIr, OperationIr};
|
||||
use burn_std::DType;
|
||||
use cubecl::Runtime;
|
||||
|
||||
/// Fused element wise operations that are normally memory bound.
|
||||
pub struct MatmulFuser<R: Runtime> {
|
||||
fuser: TraceOperationFuser,
|
||||
fuser_fallback: TraceOperationFuser,
|
||||
device: R::Device,
|
||||
matmul: Option<FusedMatmul>,
|
||||
}
|
||||
|
||||
impl<R: Runtime> Clone for MatmulFuser<R> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
fuser: self.fuser.clone(),
|
||||
fuser_fallback: self.fuser_fallback.clone(),
|
||||
device: self.device.clone(),
|
||||
matmul: self.matmul.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> MatmulFuser<R> {
|
||||
pub fn new(device: R::Device, bool_precision: FuseType) -> Self {
|
||||
let client = R::client(&device);
|
||||
let props = client.properties();
|
||||
let max_bindings = props.hardware.max_bindings;
|
||||
let settings_matmul = FuseSettings {
|
||||
output_shape_updates: false,
|
||||
..Default::default()
|
||||
};
|
||||
let settings_fallback = FuseSettings::default();
|
||||
|
||||
Self {
|
||||
fuser: TraceOperationFuser::new(max_bindings, bool_precision, settings_matmul),
|
||||
fuser_fallback: TraceOperationFuser::new(
|
||||
max_bindings,
|
||||
bool_precision,
|
||||
settings_fallback,
|
||||
),
|
||||
device,
|
||||
matmul: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> OperationFuser<CubeOptimization<R>> for MatmulFuser<R> {
|
||||
fn fuse(&mut self, operation: &OperationIr) {
|
||||
if let FuserStatus::Closed = self.fuser.status() {
|
||||
return;
|
||||
}
|
||||
|
||||
if self.matmul.is_none() {
|
||||
if let OperationIr::Float(_, FloatOperationIr::Matmul(op)) = operation {
|
||||
// Precision shouldn't be hardcoded but I don't know how to get float precision of the backend
|
||||
let lhs = match op.lhs.dtype {
|
||||
DType::QFloat(scheme) => {
|
||||
let (data, scales) = self.fuser.input_quantized_unhandled(&op.lhs).unwrap();
|
||||
MatmulArg::Quantized {
|
||||
data,
|
||||
scales,
|
||||
precision: op.out.dtype.into(),
|
||||
scheme,
|
||||
}
|
||||
}
|
||||
_ => MatmulArg::Normal(self.fuser.input_unhandled(&op.lhs)),
|
||||
};
|
||||
let rhs = match op.rhs.dtype {
|
||||
DType::QFloat(scheme) => {
|
||||
let (data, scales) = self.fuser.input_quantized_unhandled(&op.rhs).unwrap();
|
||||
MatmulArg::Quantized {
|
||||
data,
|
||||
scales,
|
||||
precision: op.out.dtype.into(),
|
||||
scheme,
|
||||
}
|
||||
}
|
||||
_ => MatmulArg::Normal(self.fuser.input_unhandled(&op.rhs)),
|
||||
};
|
||||
|
||||
let out = self.fuser.output_unhandled(&op.out);
|
||||
|
||||
self.matmul = Some(FusedMatmul::new(
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
op.clone().into(),
|
||||
Default::default(),
|
||||
));
|
||||
} else {
|
||||
self.fuser.close();
|
||||
self.fuser_fallback.close();
|
||||
}
|
||||
} else {
|
||||
let can_register =
|
||||
self.fuser.can_fuse(operation) && self.fuser_fallback.can_fuse(operation);
|
||||
|
||||
match can_register {
|
||||
true => {
|
||||
self.fuser.fuse(operation);
|
||||
self.fuser_fallback.fuse(operation);
|
||||
}
|
||||
false => {
|
||||
self.fuser.close();
|
||||
self.fuser_fallback.close();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn finish(&mut self) -> CubeOptimization<R> {
|
||||
let client = R::client(&self.device);
|
||||
let trace = self.fuser.finish();
|
||||
let trace_fallback = self.fuser_fallback.finish();
|
||||
|
||||
let matmul = MatmulOptimization::new(
|
||||
trace,
|
||||
trace_fallback,
|
||||
client,
|
||||
self.device.clone(),
|
||||
self.len(),
|
||||
self.matmul.as_ref().unwrap().clone(),
|
||||
);
|
||||
|
||||
CubeOptimization::Matmul(matmul)
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.fuser.reset();
|
||||
self.fuser_fallback.reset();
|
||||
self.matmul = None;
|
||||
}
|
||||
|
||||
fn status(&self) -> burn_fusion::FuserStatus {
|
||||
self.fuser.status()
|
||||
}
|
||||
|
||||
fn properties(&self) -> burn_fusion::FuserProperties {
|
||||
let mut properties = self.fuser.properties();
|
||||
properties.score += 1;
|
||||
properties
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
// Matmul operation isn't registered in the fuser
|
||||
self.fuser.len() + 1
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn OperationFuser<CubeOptimization<R>>> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
mod fuser;
|
||||
mod optimization;
|
||||
|
||||
pub(crate) mod args;
|
||||
pub(crate) mod tune;
|
||||
|
||||
pub use fuser::*;
|
||||
pub use optimization::*;
|
||||
@@ -0,0 +1,649 @@
|
||||
use super::args::FusedMatmulInputLaunch;
|
||||
#[cfg(feature = "autotune")]
|
||||
use super::tune::fused_matmul_autotune;
|
||||
use crate::{
|
||||
CubeFusionHandle, FallbackOperation,
|
||||
engine::{
|
||||
codegen::ir::{FuseArg, FuseBlockConfig, FuseType, GlobalArgsLaunch, RefLayout},
|
||||
launch::{
|
||||
FuseTraceLauncher, HandleInput, LaunchPlan,
|
||||
runner::{TraceRunner, Vectorization, VectorizationAxis},
|
||||
},
|
||||
trace::{FuseTrace, TraceError, TuneOutput},
|
||||
},
|
||||
optim::{
|
||||
elemwise::ElemwiseRunner,
|
||||
matmul::args::{FusedMatmulArgs, MatmulArg},
|
||||
},
|
||||
};
|
||||
use burn_fusion::stream::Context;
|
||||
use burn_ir::BinaryOpIr;
|
||||
use cubecl::{
|
||||
client::ComputeClient,
|
||||
prelude::*,
|
||||
std::tensor::{MatrixBatchLayout, matrix_batch_layout},
|
||||
};
|
||||
use cubek::matmul::{
|
||||
components::tile::{cmma::CmmaMatmul, mma::MmaMatmul},
|
||||
definition::{
|
||||
MatmulElems, MatmulGlobalElems, MatmulLineSizes, MatmulProblem, MatmulSetupError,
|
||||
MatrixLayout,
|
||||
},
|
||||
launch::launch_kernel_virtual,
|
||||
routines::{
|
||||
BlueprintStrategy, Routine,
|
||||
double_buffering::{CyclicDoubleBufferingAlgorithm, DoubleBufferingArgs},
|
||||
double_unit::DoubleUnitAlgorithm,
|
||||
ordered_double_buffering::{OrderedDoubleBufferingAlgorithm, OrderedSelectionArgs},
|
||||
simple::{SimpleAlgorithm, SimpleArgs},
|
||||
simple_unit::SimpleUnitAlgorithm,
|
||||
vecmat::{DoubleVecMatAlgorithm, SimpleVecMatAlgorithm},
|
||||
},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Fuse matmul operation followed by elemwise operations into a single kernel.
|
||||
pub struct MatmulOptimization<R: Runtime> {
|
||||
pub(crate) info: Arc<MatmulOptimizationInfo<R>>,
|
||||
}
|
||||
|
||||
pub struct MatmulOptimizationTuneArg<R: Runtime> {
|
||||
pub(crate) info: Arc<MatmulOptimizationInfo<R>>,
|
||||
pub(crate) fallback: Box<dyn FallbackOperation<R>>,
|
||||
}
|
||||
|
||||
pub(crate) struct MatmulOptimizationInfo<R: Runtime> {
|
||||
trace: FuseTrace,
|
||||
trace_fallback: FuseTrace,
|
||||
pub(crate) client: ComputeClient<R>,
|
||||
pub(crate) device: R::Device,
|
||||
pub(crate) len: usize,
|
||||
pub(crate) matmul: FusedMatmul,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
/// State for the [matrix optimization](MatmulOptimizationState).
|
||||
pub struct MatmulOptimizationState {
|
||||
trace: FuseTrace,
|
||||
trace_fallback: FuseTrace,
|
||||
matmul: FusedMatmul,
|
||||
len: usize,
|
||||
}
|
||||
|
||||
impl<R: Runtime> MatmulOptimizationInfo<R> {
|
||||
/// Returns the number of output buffers added by fusion.
|
||||
pub fn num_output_buffers(&self) -> usize {
|
||||
self.trace_fallback.resources.outputs.len()
|
||||
}
|
||||
|
||||
/// Number of operations fused.
|
||||
pub fn num_ops_fused(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> MatmulOptimizationTuneArg<R> {
|
||||
pub(crate) fn execute_fused<BT: CubeElement>(
|
||||
&self,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
selector: FusedMatmulSelector,
|
||||
) -> Result<TuneOutput<R>, TraceError<FusedMatmulError>> {
|
||||
let launch = FusedMatmulLaunch::new(&self.info.matmul, selector);
|
||||
let launcher = FuseTraceLauncher::new(&self.info.trace, &launch);
|
||||
|
||||
launcher.launch::<BT>(&self.info.client, &self.info.device, context)
|
||||
}
|
||||
|
||||
pub fn execute_fallback<BT: CubeElement>(
|
||||
&self,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
) -> TuneOutput<R> {
|
||||
self.fallback.run(context);
|
||||
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
let mut output = TuneOutput::Checked {
|
||||
handles: Default::default(),
|
||||
};
|
||||
#[cfg(not(feature = "autotune-checks"))]
|
||||
let output = TuneOutput::UnChecked(core::marker::PhantomData);
|
||||
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
if let TuneOutput::Checked { handles } = &mut output {
|
||||
let out_desc = context.tensors.get(&self.info.matmul.op.out.id).unwrap();
|
||||
let handle_out = context
|
||||
.handles
|
||||
.get_handle(&out_desc.id, &burn_ir::TensorStatus::ReadOnly);
|
||||
|
||||
handles.insert(
|
||||
self.info.matmul.op.out.id,
|
||||
(out_desc.shape.dims.clone(), handle_out.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
let launcher = FuseTraceLauncher::new(&self.info.trace_fallback, &ElemwiseRunner);
|
||||
let output_write = launcher
|
||||
.launch::<BT>(&self.info.client, &self.info.device, context)
|
||||
.unwrap();
|
||||
|
||||
output.merge(output_write)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> MatmulOptimization<R> {
|
||||
pub fn new(
|
||||
trace: FuseTrace,
|
||||
trace_fallback: FuseTrace,
|
||||
client: ComputeClient<R>,
|
||||
device: R::Device,
|
||||
len: usize,
|
||||
matmul: FusedMatmul,
|
||||
) -> Self {
|
||||
let info = MatmulOptimizationInfo {
|
||||
trace,
|
||||
trace_fallback,
|
||||
client,
|
||||
device,
|
||||
len,
|
||||
matmul,
|
||||
};
|
||||
|
||||
Self {
|
||||
info: Arc::new(info),
|
||||
}
|
||||
}
|
||||
/// Execute the optimization.
|
||||
pub fn execute<BT: CubeElement>(
|
||||
&mut self,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
fallback: impl FnOnce(usize) -> Box<dyn FallbackOperation<R>>,
|
||||
) {
|
||||
// The index of the fallback matmul is always 0.
|
||||
let fallback = fallback(0);
|
||||
let arg = MatmulOptimizationTuneArg {
|
||||
info: self.info.clone(),
|
||||
fallback,
|
||||
};
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
fused_matmul_autotune::<R, BT>(arg, context);
|
||||
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
if arg
|
||||
.execute_fused::<BT>(context, FusedMatmulSelector::default())
|
||||
.is_err()
|
||||
{
|
||||
arg.execute_fallback::<BT>(context);
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of operations fused.
|
||||
pub fn num_ops_fused(&self) -> usize {
|
||||
self.info.num_ops_fused()
|
||||
}
|
||||
|
||||
/// Create an optimization from its [state](MatmulOptimizationState).
|
||||
pub fn from_state(device: &R::Device, state: MatmulOptimizationState) -> Self {
|
||||
let info = MatmulOptimizationInfo {
|
||||
trace: state.trace,
|
||||
trace_fallback: state.trace_fallback,
|
||||
len: state.len,
|
||||
client: R::client(device),
|
||||
device: device.clone(),
|
||||
matmul: state.matmul.clone(),
|
||||
};
|
||||
|
||||
Self {
|
||||
info: Arc::new(info),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert the optimization to its [state](MatmulOptimizationState).
|
||||
pub fn to_state(&self) -> MatmulOptimizationState {
|
||||
MatmulOptimizationState {
|
||||
trace: self.info.trace.clone(),
|
||||
trace_fallback: self.info.trace_fallback.clone(),
|
||||
matmul: self.info.matmul.clone(),
|
||||
len: self.info.len,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
|
||||
pub enum FusedMatmulSelector {
|
||||
Simple {
|
||||
multi_rows: bool,
|
||||
tile_matmul: AcceleratedTileKind,
|
||||
},
|
||||
DoubleBuffering {
|
||||
specialized: bool,
|
||||
tile_matmul: AcceleratedTileKind,
|
||||
},
|
||||
OrderedDoubleBuffering {
|
||||
tile_matmul: AcceleratedTileKind,
|
||||
},
|
||||
SimpleVecMat,
|
||||
DoubleVecMat,
|
||||
SimpleUnit,
|
||||
DoubleUnit,
|
||||
}
|
||||
|
||||
impl FusedMatmulSelector {
|
||||
/// Not efficient, but only called once when initializing the tunables.
|
||||
pub fn name(&self) -> String {
|
||||
let name = match self {
|
||||
FusedMatmulSelector::Simple {
|
||||
multi_rows,
|
||||
tile_matmul,
|
||||
} => match multi_rows {
|
||||
false => format!("simple_{tile_matmul:?}"),
|
||||
true => format!("simple_multirows_{tile_matmul:?}"),
|
||||
},
|
||||
FusedMatmulSelector::DoubleBuffering {
|
||||
specialized,
|
||||
tile_matmul,
|
||||
} => match specialized {
|
||||
false => format!("double_buffering_{tile_matmul:?}"),
|
||||
true => format!("double_buffering_specialized_{tile_matmul:?}"),
|
||||
},
|
||||
FusedMatmulSelector::OrderedDoubleBuffering { tile_matmul } => {
|
||||
format!("double_buffering_ordered_{tile_matmul:?}").to_lowercase()
|
||||
}
|
||||
FusedMatmulSelector::SimpleVecMat => "simple_vec_mat".into(),
|
||||
FusedMatmulSelector::DoubleVecMat => "double_buffering_vec_mat".into(),
|
||||
FusedMatmulSelector::SimpleUnit => "simple_unit".into(),
|
||||
FusedMatmulSelector::DoubleUnit => "double_buffering_unit".into(),
|
||||
};
|
||||
|
||||
format!("fused_{name}")
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FusedMatmulSelector {
|
||||
fn default() -> Self {
|
||||
FusedMatmulSelector::Simple {
|
||||
multi_rows: false,
|
||||
tile_matmul: AcceleratedTileKind::Cmma,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new, Clone, Serialize, Deserialize, Debug)]
|
||||
pub struct FusedMatmul {
|
||||
pub(crate) lhs: MatmulArg,
|
||||
pub(crate) rhs: MatmulArg,
|
||||
out: FuseArg,
|
||||
pub(crate) op: BinaryOpIr,
|
||||
pub(crate) selector: FusedMatmulSelector,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub struct FusedMatmulLaunch<'a> {
|
||||
pub(crate) matmul: &'a FusedMatmul,
|
||||
pub(crate) selector: FusedMatmulSelector,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum FusedMatmulError {
|
||||
LaunchError(MatmulSetupError),
|
||||
InvalidInput(&'static str),
|
||||
}
|
||||
|
||||
impl From<MatmulSetupError> for FusedMatmulError {
|
||||
fn from(value: MatmulSetupError) -> Self {
|
||||
Self::LaunchError(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, R: Runtime> Vectorization<R> for FusedMatmulLaunch<'a> {
|
||||
fn axis(&self, plan: &LaunchPlan<'_, R>) -> VectorizationAxis {
|
||||
let lhs_id = self.matmul.op.lhs.id;
|
||||
let rhs_id = self.matmul.op.rhs.id;
|
||||
|
||||
let mut tensor_lhs = None;
|
||||
let mut tensor_rhs = None;
|
||||
|
||||
for input in plan.handle_inputs.iter() {
|
||||
match input {
|
||||
HandleInput::Normal(input) => {
|
||||
if input.relative_id == lhs_id {
|
||||
tensor_lhs = Some((input.global_ir.id, &input.handle.strides));
|
||||
}
|
||||
if input.relative_id == rhs_id {
|
||||
tensor_rhs = Some((input.global_ir.id, &input.handle.strides));
|
||||
}
|
||||
}
|
||||
HandleInput::QuantValues(input) => {
|
||||
if input.relative_id == lhs_id {
|
||||
tensor_lhs = Some((input.global_ir.id, &input.handle.strides));
|
||||
}
|
||||
if input.relative_id == rhs_id {
|
||||
tensor_rhs = Some((input.global_ir.id, &input.handle.strides));
|
||||
}
|
||||
}
|
||||
HandleInput::QuantParams(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
let (lhs_id_global, lhs_strides) = tensor_lhs.unwrap();
|
||||
let (rhs_id_global, rhs_strides) = tensor_rhs.unwrap();
|
||||
|
||||
let mut axis = VectorizationAxis::default();
|
||||
|
||||
if let MatrixBatchLayout::MildlyPermuted { transposed, .. } =
|
||||
matrix_batch_layout(lhs_strides, self.matmul.lhs.scheme())
|
||||
&& transposed
|
||||
{
|
||||
axis.insert(lhs_id_global, lhs_strides.len() - 2);
|
||||
}
|
||||
|
||||
if let MatrixBatchLayout::MildlyPermuted { transposed, .. } =
|
||||
matrix_batch_layout(rhs_strides, self.matmul.rhs.scheme())
|
||||
&& transposed
|
||||
{
|
||||
axis.insert(rhs_id_global, rhs_strides.len() - 2);
|
||||
}
|
||||
|
||||
axis
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> TraceRunner<R> for FusedMatmulLaunch<'_> {
|
||||
type Error = FusedMatmulError;
|
||||
|
||||
fn run<'a>(
|
||||
&'a self,
|
||||
client: &'a ComputeClient<R>,
|
||||
inputs: GlobalArgsLaunch<'a, R>,
|
||||
outputs: GlobalArgsLaunch<'a, R>,
|
||||
configs: &'a [FuseBlockConfig],
|
||||
) -> Result<(), FusedMatmulError> {
|
||||
let global_elems = MatmulGlobalElems {
|
||||
lhs: self.matmul.lhs.precision().into_type(),
|
||||
rhs: self.matmul.rhs.precision().into_type(),
|
||||
out: self.matmul.out.precision().into_type(),
|
||||
};
|
||||
let dtypes = MatmulElems::from_globals(&global_elems);
|
||||
self.matmul_fused(client, inputs, outputs, &configs[0], dtypes)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
/// Which tile matmul to use for accelerated algorithms
|
||||
pub enum AcceleratedTileKind {
|
||||
#[default]
|
||||
Cmma,
|
||||
Mma,
|
||||
}
|
||||
|
||||
macro_rules! with_tile_kind {
|
||||
($kind: expr, $T: ident, $launch: expr) => {
|
||||
match $kind {
|
||||
AcceleratedTileKind::Cmma => {
|
||||
type $T = CmmaMatmul;
|
||||
($launch)()
|
||||
}
|
||||
AcceleratedTileKind::Mma => {
|
||||
type $T = MmaMatmul;
|
||||
($launch)()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl FusedMatmulLaunch<'_> {
|
||||
fn matmul_fused<'a, R: Runtime>(
|
||||
&'a self,
|
||||
client: &'a ComputeClient<R>,
|
||||
inputs: GlobalArgsLaunch<'a, R>,
|
||||
outputs: GlobalArgsLaunch<'a, R>,
|
||||
config: &'a FuseBlockConfig,
|
||||
dtypes: MatmulElems,
|
||||
) -> Result<(), FusedMatmulError> {
|
||||
let lhs_shape = inputs.shape(self.matmul.lhs.data());
|
||||
let rhs_shape = inputs.shape(self.matmul.rhs.data());
|
||||
let out_shape = outputs.shape_ref(&config.ref_layout, config.rank);
|
||||
|
||||
let lhs_strides = inputs.strides(self.matmul.lhs.data());
|
||||
let lhs_scheme = self.matmul.lhs.scheme();
|
||||
let rhs_strides = inputs.strides(self.matmul.rhs.data());
|
||||
let rhs_scheme = self.matmul.rhs.scheme();
|
||||
|
||||
if matrix_batch_layout(&lhs_strides, lhs_scheme) == MatrixBatchLayout::HighlyPermuted {
|
||||
return Err(FusedMatmulError::InvalidInput(
|
||||
"Lhs needs to be contiguous, but can't when fusing.",
|
||||
));
|
||||
}
|
||||
if matrix_batch_layout(&rhs_strides, rhs_scheme) == MatrixBatchLayout::HighlyPermuted {
|
||||
return Err(FusedMatmulError::InvalidInput(
|
||||
"Rhs needs to be contiguous, but can't when fusing.",
|
||||
));
|
||||
}
|
||||
|
||||
let mut line_sizes = MatmulLineSizes {
|
||||
lhs: inputs.line_size(self.matmul.lhs.data()),
|
||||
rhs: inputs.line_size(self.matmul.rhs.data()),
|
||||
out: match &config.ref_layout {
|
||||
RefLayout::Concrete(arg) => match arg {
|
||||
FuseArg::Input(..) => inputs.line_size(arg),
|
||||
FuseArg::Output(..) => outputs.line_size(arg),
|
||||
_ => panic!("Invalid ref layout"),
|
||||
},
|
||||
RefLayout::Virtual(_) => 1,
|
||||
},
|
||||
};
|
||||
|
||||
let address_type = inputs
|
||||
.required_address_type()
|
||||
.max(outputs.required_address_type());
|
||||
|
||||
if line_sizes.out == 1 && (line_sizes.lhs > 1 || line_sizes.rhs > 1) {
|
||||
return Err(FusedMatmulError::InvalidInput(
|
||||
"Output line size of 1 removes the gain from fusion",
|
||||
));
|
||||
}
|
||||
|
||||
if let MatmulArg::Quantized { scheme, .. } = self.matmul.lhs {
|
||||
line_sizes.lhs *= scheme.num_quants();
|
||||
}
|
||||
if let MatmulArg::Quantized { scheme, .. } = self.matmul.rhs {
|
||||
line_sizes.rhs *= scheme.num_quants();
|
||||
}
|
||||
|
||||
let out_strides = MatrixLayout::RowMajor.to_strides(&out_shape);
|
||||
let problem = MatmulProblem::from_shapes_and_strides(
|
||||
lhs_shape,
|
||||
rhs_shape,
|
||||
out_shape,
|
||||
lhs_strides,
|
||||
rhs_strides,
|
||||
out_strides,
|
||||
dtypes.as_global_elems(),
|
||||
address_type,
|
||||
self.matmul.lhs.scheme(),
|
||||
self.matmul.rhs.scheme(),
|
||||
);
|
||||
|
||||
match self.selector {
|
||||
FusedMatmulSelector::Simple {
|
||||
multi_rows,
|
||||
tile_matmul,
|
||||
} => with_tile_kind!(tile_matmul, Accelerated, || match launch_inner_fix_dtype::<
|
||||
R,
|
||||
SimpleAlgorithm<Accelerated>,
|
||||
>(
|
||||
client,
|
||||
FusedMatmulInputLaunch::new(
|
||||
inputs,
|
||||
config.clone(),
|
||||
self.matmul.lhs.clone(),
|
||||
self.matmul.rhs.clone(),
|
||||
None,
|
||||
self.matmul.out.clone(),
|
||||
),
|
||||
outputs,
|
||||
problem,
|
||||
line_sizes,
|
||||
&BlueprintStrategy::Inferred(SimpleArgs { multi_rows }),
|
||||
) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(FusedMatmulError::LaunchError(err)),
|
||||
}),
|
||||
FusedMatmulSelector::DoubleBuffering {
|
||||
specialized,
|
||||
tile_matmul,
|
||||
} => with_tile_kind!(tile_matmul, Accelerated, || match launch_inner_fix_dtype::<
|
||||
R,
|
||||
CyclicDoubleBufferingAlgorithm<Accelerated>,
|
||||
>(
|
||||
client,
|
||||
FusedMatmulInputLaunch::new(
|
||||
inputs,
|
||||
config.clone(),
|
||||
self.matmul.lhs.clone(),
|
||||
self.matmul.rhs.clone(),
|
||||
None,
|
||||
self.matmul.out.clone(),
|
||||
),
|
||||
outputs,
|
||||
problem,
|
||||
line_sizes,
|
||||
&BlueprintStrategy::Inferred(DoubleBufferingArgs { specialized }),
|
||||
) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(FusedMatmulError::LaunchError(err)),
|
||||
}),
|
||||
FusedMatmulSelector::OrderedDoubleBuffering { tile_matmul } => {
|
||||
let row_count = match self.matmul.lhs.precision() {
|
||||
FuseType::F16 | FuseType::BF16 => 8,
|
||||
_ => 4,
|
||||
};
|
||||
|
||||
with_tile_kind!(tile_matmul, Accelerated, || match launch_inner_fix_dtype::<
|
||||
R,
|
||||
OrderedDoubleBufferingAlgorithm<Accelerated>,
|
||||
>(
|
||||
client,
|
||||
FusedMatmulInputLaunch::new(
|
||||
inputs,
|
||||
config.clone(),
|
||||
self.matmul.lhs.clone(),
|
||||
self.matmul.rhs.clone(),
|
||||
None,
|
||||
self.matmul.out.clone(),
|
||||
),
|
||||
outputs,
|
||||
problem,
|
||||
line_sizes,
|
||||
&BlueprintStrategy::Inferred(OrderedSelectionArgs {
|
||||
row_count: Some(row_count),
|
||||
rows_per_plane: Some(2),
|
||||
partition_k: Some(2),
|
||||
}),
|
||||
) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(FusedMatmulError::LaunchError(err)),
|
||||
})
|
||||
}
|
||||
FusedMatmulSelector::SimpleUnit => {
|
||||
match launch_inner_fix_dtype::<R, SimpleUnitAlgorithm>(
|
||||
client,
|
||||
FusedMatmulInputLaunch::new(
|
||||
inputs,
|
||||
config.clone(),
|
||||
self.matmul.lhs.clone(),
|
||||
self.matmul.rhs.clone(),
|
||||
None,
|
||||
self.matmul.out.clone(),
|
||||
),
|
||||
outputs,
|
||||
problem,
|
||||
line_sizes,
|
||||
&Default::default(),
|
||||
) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(FusedMatmulError::LaunchError(err)),
|
||||
}
|
||||
}
|
||||
FusedMatmulSelector::DoubleUnit => {
|
||||
match launch_inner_fix_dtype::<R, DoubleUnitAlgorithm>(
|
||||
client,
|
||||
FusedMatmulInputLaunch::new(
|
||||
inputs,
|
||||
config.clone(),
|
||||
self.matmul.lhs.clone(),
|
||||
self.matmul.rhs.clone(),
|
||||
None,
|
||||
self.matmul.out.clone(),
|
||||
),
|
||||
outputs,
|
||||
problem,
|
||||
line_sizes,
|
||||
&Default::default(),
|
||||
) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(FusedMatmulError::LaunchError(err)),
|
||||
}
|
||||
}
|
||||
FusedMatmulSelector::SimpleVecMat => {
|
||||
match launch_inner_fix_dtype::<R, SimpleVecMatAlgorithm>(
|
||||
client,
|
||||
FusedMatmulInputLaunch::new(
|
||||
inputs,
|
||||
config.clone(),
|
||||
self.matmul.lhs.clone(),
|
||||
self.matmul.rhs.clone(),
|
||||
None,
|
||||
self.matmul.out.clone(),
|
||||
),
|
||||
outputs,
|
||||
problem,
|
||||
line_sizes,
|
||||
&Default::default(),
|
||||
) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(FusedMatmulError::LaunchError(err)),
|
||||
}
|
||||
}
|
||||
FusedMatmulSelector::DoubleVecMat => {
|
||||
match launch_inner_fix_dtype::<R, DoubleVecMatAlgorithm>(
|
||||
client,
|
||||
FusedMatmulInputLaunch::new(
|
||||
inputs,
|
||||
config.clone(),
|
||||
self.matmul.lhs.clone(),
|
||||
self.matmul.rhs.clone(),
|
||||
None,
|
||||
self.matmul.out.clone(),
|
||||
),
|
||||
outputs,
|
||||
problem,
|
||||
line_sizes,
|
||||
&Default::default(),
|
||||
) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(FusedMatmulError::LaunchError(err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn launch_inner_fix_dtype<'a, R: Runtime, A: Routine<()>>(
|
||||
client: &ComputeClient<R>,
|
||||
input: FusedMatmulInputLaunch<'a, R>,
|
||||
output: GlobalArgsLaunch<'a, R>,
|
||||
problem: MatmulProblem,
|
||||
line_sizes: MatmulLineSizes,
|
||||
blueprint_strategy: &BlueprintStrategy<(), A>,
|
||||
) -> Result<(), MatmulSetupError> {
|
||||
launch_kernel_virtual::<FusedMatmulArgs, R, A>(
|
||||
client,
|
||||
input,
|
||||
output,
|
||||
(),
|
||||
problem,
|
||||
line_sizes,
|
||||
blueprint_strategy,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,269 @@
|
||||
use super::optimization::MatmulOptimizationTuneArg;
|
||||
use crate::{
|
||||
CubeFusionHandle,
|
||||
engine::trace::TuneOutput,
|
||||
optim::matmul::{AcceleratedTileKind, FusedMatmulSelector},
|
||||
tune::{TuneContext, TuneInput},
|
||||
};
|
||||
use burn_fusion::stream::Context;
|
||||
use cubecl::{
|
||||
AutotuneKey, CubeElement, CubeTuneId, Runtime,
|
||||
tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner},
|
||||
};
|
||||
use cubek::matmul::{
|
||||
definition::MatmulKind,
|
||||
launch::{MatmulAutotuneKey, MatmulGlobalScale, should_tune_double_buffering},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
|
||||
pub struct FusedMatmulAutotuneKey {
|
||||
matmul_key: MatmulAutotuneKey,
|
||||
#[autotune(anchor)]
|
||||
num_out_buffers: usize,
|
||||
#[autotune(anchor)]
|
||||
num_ops: usize,
|
||||
}
|
||||
|
||||
/// Executes autotune on matmul operations
|
||||
pub fn fused_matmul_autotune<R: Runtime, BT: CubeElement>(
|
||||
optimization: MatmulOptimizationTuneArg<R>,
|
||||
context: &mut Context<CubeFusionHandle<R>>,
|
||||
) {
|
||||
static TUNER: LocalTuner<FusedMatmulAutotuneKey, CubeTuneId> = local_tuner!();
|
||||
|
||||
let tunables = TUNER.init(|| {
|
||||
const PRIORITY_MAX: i8 = 3;
|
||||
const PRIORITY_HIGH: i8 = 2;
|
||||
const PRIORITY_MEDIUM: i8 = 1;
|
||||
const PRIORITY_MIN: i8 = 0;
|
||||
|
||||
let cmma = TuneGroup::<FusedMatmulAutotuneKey>::new("cmma", |key| {
|
||||
if matches!(
|
||||
key.matmul_key.analysis.kind,
|
||||
MatmulKind::General
|
||||
// Those variants are just because the unit alternatives aren't very good yet.
|
||||
| MatmulKind::VecMat | MatmulKind::MatVec
|
||||
) {
|
||||
PRIORITY_MAX
|
||||
} else {
|
||||
PRIORITY_MEDIUM
|
||||
}
|
||||
});
|
||||
|
||||
let mma = TuneGroup::<FusedMatmulAutotuneKey>::new("mma", |key| {
|
||||
if matches!(
|
||||
key.matmul_key.analysis.kind,
|
||||
// General is usually bad, but I think shapes like 16x8196 would be classed as
|
||||
// general and are very good with MMA
|
||||
// Should highly degenerated matrices that aren't VecMat have their own class?
|
||||
MatmulKind::General | MatmulKind::VecMat | MatmulKind::MatVec
|
||||
) {
|
||||
PRIORITY_MAX
|
||||
} else {
|
||||
PRIORITY_MEDIUM
|
||||
}
|
||||
});
|
||||
|
||||
let odd = TuneGroup::<FusedMatmulAutotuneKey>::new("odd", |key| {
|
||||
if key.matmul_key.definition.lhs_pow2_factor == 0
|
||||
|| key.matmul_key.definition.rhs_pow2_factor == 0
|
||||
{
|
||||
PRIORITY_MAX
|
||||
} else {
|
||||
PRIORITY_MIN
|
||||
}
|
||||
});
|
||||
|
||||
let unit = TuneGroup::<FusedMatmulAutotuneKey>::new("unit", |key| {
|
||||
if !matches!(key.matmul_key.analysis.kind, MatmulKind::General)
|
||||
|| matches!(
|
||||
key.matmul_key.analysis.scale_global,
|
||||
MatmulGlobalScale::Small
|
||||
)
|
||||
{
|
||||
PRIORITY_MAX
|
||||
} else {
|
||||
PRIORITY_MIN
|
||||
}
|
||||
});
|
||||
|
||||
fn double_buffering_priority(key: &FusedMatmulAutotuneKey, max: i8, min: i8) -> i8 {
|
||||
if should_tune_double_buffering(key.num_out_buffers > 1, &key.matmul_key) {
|
||||
max
|
||||
} else {
|
||||
min
|
||||
}
|
||||
}
|
||||
|
||||
let mut set = TunableSet::new(create_key::<R>, input_gen::<R>).with(Tunable::new(
|
||||
"fused_matmul_fallback",
|
||||
tune_fallback::<R, BT>,
|
||||
)); // First one should always work.
|
||||
|
||||
// Unit matmuls
|
||||
for (selector, double_buf) in [
|
||||
(FusedMatmulSelector::SimpleUnit, false),
|
||||
(FusedMatmulSelector::DoubleUnit, true),
|
||||
(FusedMatmulSelector::SimpleVecMat, false),
|
||||
(FusedMatmulSelector::DoubleVecMat, true),
|
||||
] {
|
||||
set = set.with(
|
||||
Tunable::new(selector.name(), move |input| {
|
||||
tune_fused::<R, BT>(input, selector)
|
||||
})
|
||||
.group(&unit, move |key| match double_buf {
|
||||
true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH),
|
||||
false => PRIORITY_MAX,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
// Accelerated matmuls
|
||||
for (tile_matmul, group) in [
|
||||
(AcceleratedTileKind::Cmma, &cmma),
|
||||
(AcceleratedTileKind::Mma, &mma),
|
||||
] {
|
||||
for (selector, double_buf, extra_group) in [
|
||||
(
|
||||
FusedMatmulSelector::Simple {
|
||||
multi_rows: false,
|
||||
tile_matmul,
|
||||
},
|
||||
false,
|
||||
None,
|
||||
),
|
||||
(
|
||||
FusedMatmulSelector::Simple {
|
||||
multi_rows: true,
|
||||
tile_matmul,
|
||||
},
|
||||
false,
|
||||
None,
|
||||
),
|
||||
(
|
||||
FusedMatmulSelector::OrderedDoubleBuffering { tile_matmul },
|
||||
true,
|
||||
None,
|
||||
),
|
||||
(
|
||||
FusedMatmulSelector::DoubleBuffering {
|
||||
specialized: false,
|
||||
tile_matmul,
|
||||
},
|
||||
true,
|
||||
None,
|
||||
),
|
||||
(
|
||||
FusedMatmulSelector::DoubleBuffering {
|
||||
specialized: true,
|
||||
tile_matmul,
|
||||
},
|
||||
true,
|
||||
Some(&odd),
|
||||
),
|
||||
] {
|
||||
let mut tunable = Tunable::new(selector.name(), move |input| {
|
||||
tune_fused::<R, BT>(input, selector)
|
||||
})
|
||||
.group(group, move |key| match double_buf {
|
||||
true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH),
|
||||
false => PRIORITY_MAX,
|
||||
});
|
||||
if let Some(group) = extra_group {
|
||||
tunable = tunable.group(group, |_| PRIORITY_MAX);
|
||||
}
|
||||
set = set.with(tunable);
|
||||
}
|
||||
}
|
||||
|
||||
set
|
||||
});
|
||||
|
||||
TUNER.execute(
|
||||
&CubeTuneId::new(&optimization.info.client, &optimization.info.device),
|
||||
&optimization.info.client.clone(),
|
||||
tunables,
|
||||
TuneInput::new(context, optimization),
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) fn create_key<R: Runtime>(
|
||||
input: &TuneInput<R, MatmulOptimizationTuneArg<R>>,
|
||||
) -> FusedMatmulAutotuneKey {
|
||||
let opt = input.optimization();
|
||||
let context = match input.context() {
|
||||
TuneContext::Original(context) => context,
|
||||
TuneContext::Fork(_) => panic!("Not supported when generating key"),
|
||||
};
|
||||
|
||||
let lhs = context.tensors.get(&opt.info.matmul.op.lhs.id).unwrap();
|
||||
let rhs = context.tensors.get(&opt.info.matmul.op.rhs.id).unwrap();
|
||||
let out = context.tensors.get(&opt.info.matmul.op.out.id).unwrap();
|
||||
|
||||
let lhs_strides = context
|
||||
.handles
|
||||
.get_handle(&lhs.id, &burn_ir::TensorStatus::ReadOnly)
|
||||
.strides;
|
||||
let rhs_strides = context
|
||||
.handles
|
||||
.get_handle(&rhs.id, &burn_ir::TensorStatus::ReadOnly)
|
||||
.strides;
|
||||
|
||||
let key = MatmulAutotuneKey::generate(
|
||||
&opt.info.client,
|
||||
&lhs.shape,
|
||||
&rhs.shape,
|
||||
&lhs_strides,
|
||||
&rhs_strides,
|
||||
lhs.dtype.into(),
|
||||
rhs.dtype.into(),
|
||||
out.dtype.into(),
|
||||
opt.info.matmul.lhs.scheme(),
|
||||
opt.info.matmul.rhs.scheme(),
|
||||
);
|
||||
FusedMatmulAutotuneKey::new(key, opt.info.num_output_buffers(), opt.info.num_ops_fused())
|
||||
}
|
||||
|
||||
fn input_gen<R: Runtime>(
|
||||
_key: &FusedMatmulAutotuneKey,
|
||||
input: &TuneInput<R, MatmulOptimizationTuneArg<R>>,
|
||||
) -> TuneInput<R, MatmulOptimizationTuneArg<R>> {
|
||||
input.clone()
|
||||
}
|
||||
|
||||
fn tune_fused<R: Runtime, BT: CubeElement>(
|
||||
input: TuneInput<R, MatmulOptimizationTuneArg<R>>,
|
||||
selector: FusedMatmulSelector,
|
||||
) -> Result<TuneOutput<R>, String> {
|
||||
let optimization = input.optimization();
|
||||
let context = input.context();
|
||||
|
||||
match context {
|
||||
TuneContext::Original(context) => match optimization.execute_fused::<BT>(context, selector)
|
||||
{
|
||||
Ok(out) => Ok(out),
|
||||
Err(_) => {
|
||||
return tune_fallback::<R, BT>(input);
|
||||
}
|
||||
},
|
||||
TuneContext::Fork(mut context_owned) => {
|
||||
optimization.execute_fused::<BT>(&mut context_owned.as_context(), selector)
|
||||
}
|
||||
}
|
||||
.map_err(|e| format!("{e:?}"))
|
||||
}
|
||||
|
||||
fn tune_fallback<R: Runtime, BT: CubeElement>(
|
||||
input: TuneInput<R, MatmulOptimizationTuneArg<R>>,
|
||||
) -> Result<TuneOutput<R>, String> {
|
||||
let optimization = input.optimization();
|
||||
let context = input.context();
|
||||
|
||||
Ok(match context {
|
||||
TuneContext::Original(context) => optimization.execute_fallback::<BT>(context),
|
||||
TuneContext::Fork(mut context_owned) => {
|
||||
optimization.execute_fallback::<BT>(&mut context_owned.as_context())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
pub mod elemwise;
|
||||
pub mod matmul;
|
||||
pub mod reduce;
|
||||
pub mod reduce_broadcasted;
|
||||
|
||||
mod base;
|
||||
|
||||
pub use base::*;
|
||||
@@ -0,0 +1,208 @@
|
||||
use crate::engine::codegen::{
|
||||
io::{ref_buffer_len, ref_len, ref_line_size, ref_shape, ref_stride},
|
||||
ir::{FuseArg, FuseBlockConfig, GlobalArgs, GlobalArgsExpand, LocalArgs, LocalArgsExpand},
|
||||
kernel::{fuse_on_read, fuse_on_write, init_locals},
|
||||
};
|
||||
use cubecl::prelude::*;
|
||||
use cubek::reduce::components::args::{ReduceArgs, ReduceDType};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FusedReduceArgs;
|
||||
|
||||
#[derive(CubeType, CubeLaunch)]
|
||||
pub struct FusedReduceInput {
|
||||
pub global: GlobalArgs,
|
||||
#[cube(comptime)]
|
||||
pub config: FuseBlockConfig,
|
||||
#[cube(comptime)]
|
||||
pub arg: FuseArg,
|
||||
}
|
||||
|
||||
#[derive(CubeType, CubeLaunch)]
|
||||
pub struct FusedReduceOutput {
|
||||
pub global: GlobalArgs,
|
||||
#[cube(comptime)]
|
||||
pub config: FuseBlockConfig,
|
||||
#[cube(comptime)]
|
||||
pub arg: FuseArg,
|
||||
}
|
||||
|
||||
pub struct FusedReduceState {
|
||||
inputs: *const GlobalArgs,
|
||||
outputs: *mut GlobalArgs,
|
||||
locals_on_read: *mut LocalArgs,
|
||||
locals_on_write: *mut LocalArgs,
|
||||
config_on_read: FuseBlockConfig,
|
||||
config_on_write: FuseBlockConfig,
|
||||
// TODO: Should be a list when multiple blocks are there.
|
||||
input: FuseArg,
|
||||
out: FuseArg,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FusedReduceStateExpand {
|
||||
inputs: GlobalArgsExpand,
|
||||
outputs: GlobalArgsExpand,
|
||||
locals_on_read: LocalArgsExpand,
|
||||
locals_on_write: LocalArgsExpand,
|
||||
config_on_read: FuseBlockConfig,
|
||||
config_on_write: FuseBlockConfig,
|
||||
input: FuseArg,
|
||||
out: FuseArg,
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl ReduceArgs for FusedReduceArgs {
|
||||
type Input<E: Numeric> = FusedReduceInput;
|
||||
type Output<E: Numeric> = FusedReduceOutput;
|
||||
type State<P: ReduceDType> = FusedReduceState;
|
||||
|
||||
fn init_state<P: ReduceDType>(
|
||||
input: &Self::Input<P::In>,
|
||||
output: &mut Self::Output<P::Out>,
|
||||
) -> Self::State<P> {
|
||||
let mut locals_read = init_locals(&input.global, &mut output.global, &input.config);
|
||||
let mut locals_write = init_locals(&input.global, &mut output.global, &output.config);
|
||||
// TODO Add stuff from previous blocks to the local of each block.
|
||||
FusedReduceState::new(input, output, &mut locals_read, &mut locals_write)
|
||||
}
|
||||
|
||||
fn read_input<P: ReduceDType>(state: &Self::State<P>, index: usize) -> Line<P::In> {
|
||||
let value = fuse_on_read::<P::In>(
|
||||
unsafe { &(*state.inputs) },
|
||||
unsafe { &mut (*state.outputs) },
|
||||
unsafe { &mut (*state.locals_on_read) },
|
||||
index,
|
||||
comptime! {
|
||||
let mut sequence = Sequence::new();
|
||||
// TODO: Register local arguments from previous blocks.
|
||||
sequence.push(state.input.clone());
|
||||
sequence
|
||||
},
|
||||
&state.config_on_read,
|
||||
)[0];
|
||||
value
|
||||
}
|
||||
|
||||
fn read_output<P: ReduceDType>(_state: &Self::State<P>, _index: usize) -> Line<P::Out> {
|
||||
Line::empty(1usize)
|
||||
}
|
||||
|
||||
fn write_output<P: ReduceDType>(state: &mut Self::State<P>, index: usize, value: Line<P::Out>) {
|
||||
let mut values = Registry::<FuseArg, Line<P::Out>>::new();
|
||||
let mut args = comptime![Vec::<FuseArg>::new()];
|
||||
|
||||
values.insert(comptime![state.out.clone()], value);
|
||||
comptime![args.push(state.out.clone())];
|
||||
fuse_on_write(
|
||||
unsafe { &(*state.inputs) },
|
||||
unsafe { &mut (*state.outputs) },
|
||||
unsafe { &mut (*state.locals_on_write) },
|
||||
index,
|
||||
values,
|
||||
args,
|
||||
&state.config_on_write,
|
||||
);
|
||||
}
|
||||
|
||||
fn len_input<P: ReduceDType>(state: &Self::State<P>) -> usize {
|
||||
ref_len(
|
||||
unsafe { &(*state.inputs) },
|
||||
unsafe { &(*state.outputs) },
|
||||
unsafe { &(*state.locals_on_read) },
|
||||
&state.config_on_read,
|
||||
)
|
||||
}
|
||||
|
||||
fn len_output<P: ReduceDType>(state: &Self::State<P>) -> usize {
|
||||
ref_len(
|
||||
unsafe { &(*state.inputs) },
|
||||
unsafe { &(*state.outputs) },
|
||||
unsafe { &(*state.locals_on_write) },
|
||||
&state.config_on_write,
|
||||
)
|
||||
}
|
||||
|
||||
fn buffer_len_input<P: ReduceDType>(state: &Self::State<P>) -> usize {
|
||||
ref_buffer_len(
|
||||
unsafe { &(*state.inputs) },
|
||||
unsafe { &(*state.outputs) },
|
||||
unsafe { &(*state.locals_on_read) },
|
||||
&state.config_on_read,
|
||||
)
|
||||
}
|
||||
|
||||
fn buffer_len_output<P: ReduceDType>(state: &Self::State<P>) -> usize {
|
||||
ref_buffer_len(
|
||||
unsafe { &(*state.inputs) },
|
||||
unsafe { &(*state.outputs) },
|
||||
unsafe { &(*state.locals_on_write) },
|
||||
&state.config_on_write,
|
||||
)
|
||||
}
|
||||
|
||||
fn rank_input<P: ReduceDType>(state: &Self::State<P>) -> usize {
|
||||
state.config_on_read.rank.runtime()
|
||||
}
|
||||
|
||||
fn rank_output<P: ReduceDType>(state: &Self::State<P>) -> usize {
|
||||
state.config_on_write.rank.runtime()
|
||||
}
|
||||
|
||||
fn shape_input<P: ReduceDType>(state: &Self::State<P>, dim: usize) -> usize {
|
||||
ref_shape(unsafe { &(*state.locals_on_read) }, dim)
|
||||
}
|
||||
|
||||
fn shape_output<P: ReduceDType>(state: &Self::State<P>, dim: usize) -> usize {
|
||||
ref_shape(unsafe { &(*state.locals_on_write) }, dim)
|
||||
}
|
||||
|
||||
fn stride_input<P: ReduceDType>(state: &Self::State<P>, dim: usize) -> usize {
|
||||
ref_stride(unsafe { &(*state.locals_on_read) }, dim)
|
||||
}
|
||||
|
||||
fn stride_output<P: ReduceDType>(state: &Self::State<P>, dim: usize) -> usize {
|
||||
ref_stride(unsafe { &(*state.locals_on_write) }, dim)
|
||||
}
|
||||
|
||||
fn line_size_input<P: ReduceDType>(state: &Self::State<P>) -> comptime_type!(LineSize) {
|
||||
ref_line_size(unsafe { &(*state.locals_on_read) })
|
||||
}
|
||||
|
||||
fn line_size_output<P: ReduceDType>(state: &Self::State<P>) -> comptime_type!(LineSize) {
|
||||
ref_line_size(unsafe { &(*state.locals_on_write) })
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl FusedReduceState {
|
||||
pub fn new(
|
||||
inputs: &FusedReduceInput,
|
||||
outputs: &mut FusedReduceOutput,
|
||||
locals_on_read: &mut LocalArgs,
|
||||
locals_on_write: &mut LocalArgs,
|
||||
) -> FusedReduceState {
|
||||
FusedReduceState {
|
||||
inputs: &inputs.global,
|
||||
outputs: &mut outputs.global,
|
||||
locals_on_read,
|
||||
locals_on_write,
|
||||
config_on_read: comptime![inputs.config.clone()],
|
||||
config_on_write: comptime![outputs.config.clone()],
|
||||
input: comptime![inputs.arg.clone()],
|
||||
out: comptime![outputs.arg.clone()],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CubeType for FusedReduceState {
|
||||
type ExpandType = FusedReduceStateExpand;
|
||||
}
|
||||
|
||||
impl IntoMut for FusedReduceStateExpand {
|
||||
fn into_mut(self, _context: &mut Scope) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl CubeDebug for FusedReduceStateExpand {}
|
||||
@@ -0,0 +1,328 @@
|
||||
use super::{
|
||||
ReduceSettings,
|
||||
optimization::{FusedReduce, ReduceInstruction, ReduceOptimization},
|
||||
};
|
||||
use crate::{
|
||||
engine::{
|
||||
codegen::ir::FuseType,
|
||||
fuser::TraceOperationFuser,
|
||||
settings::{FuseSettings, RefLayoutSetting, VectorizationSetting},
|
||||
},
|
||||
optim::CubeOptimization,
|
||||
};
|
||||
use burn_fusion::{FuserStatus, OperationFuser};
|
||||
use burn_ir::{NumericOperationIr, OperationIr, ReduceDimOpIr};
|
||||
use burn_std::Shape;
|
||||
use cubecl::Runtime;
|
||||
|
||||
/// Fuses element wise operations around a reduce operation.
|
||||
pub struct ReduceFuser<R: Runtime> {
|
||||
pub(crate) fuser: TraceOperationFuser,
|
||||
pub(crate) fuser_read_fallback: TraceOperationFuser,
|
||||
fuser_write_fallback: TraceOperationFuser,
|
||||
settings_write: FuseSettings,
|
||||
pub(crate) device: R::Device,
|
||||
pub(crate) reduce: Option<FusedReduce>,
|
||||
settings: ReduceSettings,
|
||||
}
|
||||
|
||||
impl<R: Runtime> Clone for ReduceFuser<R> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
fuser: self.fuser.clone(),
|
||||
fuser_read_fallback: self.fuser_read_fallback.clone(),
|
||||
fuser_write_fallback: self.fuser_write_fallback.clone(),
|
||||
settings_write: self.settings_write,
|
||||
device: self.device.clone(),
|
||||
reduce: self.reduce.clone(),
|
||||
settings: self.settings,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ReduceFuserInfo {
|
||||
FusedReduce { shape_input_id: Shape, axis: usize },
|
||||
FusedElemwise { shape_id: Shape },
|
||||
}
|
||||
|
||||
impl<R: Runtime> ReduceFuser<R> {
|
||||
pub fn new(device: R::Device, bool_precision: FuseType, settings: ReduceSettings) -> Self {
|
||||
let client = R::client(&device);
|
||||
let props = client.properties();
|
||||
let max_bindings = props.hardware.max_bindings;
|
||||
let settings_read = FuseSettings {
|
||||
// Inplace would work, but not when we have a concrete output to write too.
|
||||
inplace: true,
|
||||
ref_layout: RefLayoutSetting::OnlyContiguous,
|
||||
broadcast: false,
|
||||
output_shape_updates: true,
|
||||
vectorization: VectorizationSetting::Activated,
|
||||
};
|
||||
let settings_write = FuseSettings {
|
||||
inplace: false,
|
||||
output_shape_updates: false,
|
||||
vectorization: VectorizationSetting::SmallerOrEqualThanPreviousBlock { block_pos: 0 },
|
||||
broadcast: false,
|
||||
ref_layout: RefLayoutSetting::OnlyContiguous,
|
||||
};
|
||||
let settings_fallback = FuseSettings::default();
|
||||
|
||||
Self {
|
||||
fuser: TraceOperationFuser::new(max_bindings, bool_precision, settings_read),
|
||||
fuser_read_fallback: TraceOperationFuser::new(
|
||||
max_bindings,
|
||||
bool_precision,
|
||||
settings_fallback,
|
||||
),
|
||||
fuser_write_fallback: TraceOperationFuser::new(
|
||||
max_bindings,
|
||||
bool_precision,
|
||||
settings_fallback,
|
||||
),
|
||||
settings_write,
|
||||
device,
|
||||
reduce: None,
|
||||
settings,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reduce_info(&self) -> ReduceFuserInfo {
|
||||
match &self.reduce {
|
||||
Some(reduce) => {
|
||||
let shape_input_id = reduce.op.input.shape.clone();
|
||||
let axis = reduce.axis;
|
||||
|
||||
ReduceFuserInfo::FusedReduce {
|
||||
shape_input_id,
|
||||
axis,
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let shape_id = self.fuser_read_fallback.current_output_shape.clone();
|
||||
ReduceFuserInfo::FusedElemwise { shape_id }
|
||||
}
|
||||
}
|
||||
}
|
||||
fn on_reduce(&mut self, op: &ReduceDimOpIr, inst: ReduceInstruction) {
|
||||
// TODO: Fix: we need to have fuse-on-read with an identity block.
|
||||
//
|
||||
// if self.fuser.num_ops == 0 && false {
|
||||
// self.fuser.current_output_shape = op.input.shape.dims.clone();
|
||||
// } else if self.fuser.current_output_shape != op.input.shape.dims {
|
||||
|
||||
if self.fuser.current_output_shape != op.input.shape {
|
||||
self.fuser.close();
|
||||
self.fuser_read_fallback.close();
|
||||
return;
|
||||
}
|
||||
|
||||
let [input] = self
|
||||
.fuser
|
||||
.next_block([&op.input], self.settings_write, false);
|
||||
|
||||
let output = self.fuser.output_unhandled(&op.out);
|
||||
let axis = op.axis;
|
||||
|
||||
let fuse_on_write_activated = match self.settings {
|
||||
ReduceSettings::Always => true,
|
||||
// We only activate fuse-on-write when the reduction isn't on the last dimension, otherwise
|
||||
// vectorization is impossible. Only [LineMode::Perpendicular] supports vectorization.
|
||||
//
|
||||
// We could still fuse some output operations, but it would probably lead to worse performance.
|
||||
ReduceSettings::OnlyParallel => axis != op.input.shape.rank() - 1,
|
||||
ReduceSettings::Never => false,
|
||||
};
|
||||
|
||||
if !fuse_on_write_activated {
|
||||
self.fuser.close();
|
||||
}
|
||||
|
||||
let acc = match inst {
|
||||
ReduceInstruction::Mean | ReduceInstruction::Prod | ReduceInstruction::Sum => {
|
||||
match input.precision() {
|
||||
FuseType::F16 | FuseType::BF16 => FuseType::F32,
|
||||
FuseType::I16 | FuseType::I8 => FuseType::I32,
|
||||
FuseType::U16 | FuseType::U8 => FuseType::U32,
|
||||
_ => input.precision(),
|
||||
}
|
||||
}
|
||||
_ => input.precision(),
|
||||
};
|
||||
|
||||
self.reduce = Some(FusedReduce {
|
||||
input,
|
||||
output,
|
||||
acc,
|
||||
axis,
|
||||
op: op.clone(),
|
||||
use_planes: false,
|
||||
shared: false,
|
||||
inst,
|
||||
});
|
||||
|
||||
self.fuser_read_fallback.close();
|
||||
}
|
||||
|
||||
fn on_elemwise_read(&mut self, operation: &OperationIr) {
|
||||
let can_register =
|
||||
self.fuser.can_fuse(operation) && self.fuser_read_fallback.can_fuse(operation);
|
||||
|
||||
match can_register {
|
||||
true => {
|
||||
self.fuser.fuse(operation);
|
||||
self.fuser_read_fallback.fuse(operation);
|
||||
}
|
||||
false => {
|
||||
self.fuser.close();
|
||||
self.fuser_read_fallback.close();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn on_elemwise_write(&mut self, operation: &OperationIr) {
|
||||
let can_register =
|
||||
self.fuser.can_fuse(operation) && self.fuser_write_fallback.can_fuse(operation);
|
||||
|
||||
match can_register {
|
||||
true => {
|
||||
self.fuser.fuse(operation);
|
||||
self.fuser_write_fallback.fuse(operation);
|
||||
}
|
||||
false => {
|
||||
self.fuser.close();
|
||||
self.fuser_write_fallback.close();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> OperationFuser<CubeOptimization<R>> for ReduceFuser<R> {
|
||||
fn fuse(&mut self, operation: &OperationIr) {
|
||||
if let FuserStatus::Closed = self.fuser.status() {
|
||||
return;
|
||||
}
|
||||
|
||||
if self.reduce.is_none() {
|
||||
if let OperationIr::NumericFloat(_, op) = operation {
|
||||
match op {
|
||||
NumericOperationIr::SumDim(op) => {
|
||||
self.on_reduce(op, ReduceInstruction::Sum);
|
||||
}
|
||||
NumericOperationIr::MeanDim(op) => {
|
||||
self.on_reduce(op, ReduceInstruction::Mean);
|
||||
}
|
||||
NumericOperationIr::ProdDim(op) => {
|
||||
self.on_reduce(op, ReduceInstruction::Prod);
|
||||
}
|
||||
NumericOperationIr::ArgMax(op) => {
|
||||
self.on_reduce(op, ReduceInstruction::ArgMax);
|
||||
}
|
||||
NumericOperationIr::ArgMin(op) => {
|
||||
self.on_reduce(op, ReduceInstruction::ArgMin);
|
||||
}
|
||||
NumericOperationIr::MinDim(op) => {
|
||||
self.on_reduce(op, ReduceInstruction::Min);
|
||||
}
|
||||
NumericOperationIr::MaxDim(op) => {
|
||||
self.on_reduce(op, ReduceInstruction::Max);
|
||||
}
|
||||
NumericOperationIr::MaxAbsDim(op) => {
|
||||
self.on_reduce(op, ReduceInstruction::MaxAbs);
|
||||
}
|
||||
_ => {
|
||||
self.on_elemwise_read(operation);
|
||||
}
|
||||
};
|
||||
} else if let OperationIr::NumericInt(_, op) = operation {
|
||||
match op {
|
||||
NumericOperationIr::SumDim(op) => {
|
||||
self.on_reduce(op, ReduceInstruction::Sum);
|
||||
}
|
||||
NumericOperationIr::MeanDim(op) => {
|
||||
self.on_reduce(op, ReduceInstruction::Mean);
|
||||
}
|
||||
NumericOperationIr::ProdDim(op) => {
|
||||
self.on_reduce(op, ReduceInstruction::Prod);
|
||||
}
|
||||
NumericOperationIr::ArgMax(op) => {
|
||||
self.on_reduce(op, ReduceInstruction::ArgMax);
|
||||
}
|
||||
NumericOperationIr::ArgMin(op) => {
|
||||
self.on_reduce(op, ReduceInstruction::ArgMin);
|
||||
}
|
||||
NumericOperationIr::MinDim(op) => {
|
||||
self.on_reduce(op, ReduceInstruction::Min);
|
||||
}
|
||||
NumericOperationIr::MaxDim(op) => {
|
||||
self.on_reduce(op, ReduceInstruction::Max);
|
||||
}
|
||||
NumericOperationIr::MaxAbsDim(op) => {
|
||||
self.on_reduce(op, ReduceInstruction::MaxAbs);
|
||||
}
|
||||
_ => {
|
||||
self.on_elemwise_read(operation);
|
||||
}
|
||||
};
|
||||
} else {
|
||||
self.on_elemwise_read(operation);
|
||||
}
|
||||
} else {
|
||||
self.on_elemwise_write(operation);
|
||||
}
|
||||
}
|
||||
|
||||
fn finish(&mut self) -> CubeOptimization<R> {
|
||||
let client = R::client(&self.device);
|
||||
let trace = self.fuser.finish();
|
||||
let trace_read_fallback = self.fuser_read_fallback.finish();
|
||||
let trace_write_fallback = self.fuser_write_fallback.finish();
|
||||
let fuse_reduce = self.reduce.as_ref().unwrap();
|
||||
|
||||
let reduce = ReduceOptimization::new(
|
||||
trace,
|
||||
trace_read_fallback,
|
||||
trace_write_fallback,
|
||||
client,
|
||||
self.device.clone(),
|
||||
self.len(),
|
||||
self.fuser_read_fallback.len(),
|
||||
fuse_reduce.clone(),
|
||||
self.settings,
|
||||
);
|
||||
|
||||
CubeOptimization::Reduce(reduce)
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.fuser.reset();
|
||||
self.fuser_read_fallback.reset();
|
||||
self.fuser_write_fallback.reset();
|
||||
self.reduce = None;
|
||||
}
|
||||
|
||||
fn status(&self) -> burn_fusion::FuserStatus {
|
||||
self.fuser.status()
|
||||
}
|
||||
|
||||
fn properties(&self) -> burn_fusion::FuserProperties {
|
||||
let mut properties = self.fuser.properties();
|
||||
|
||||
if self.reduce.is_some() {
|
||||
properties.ready = true;
|
||||
properties.score += 1;
|
||||
} else {
|
||||
properties.ready = false;
|
||||
};
|
||||
|
||||
properties
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.fuser.len() + if self.reduce.is_some() { 1 } else { 0 }
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn OperationFuser<CubeOptimization<R>>> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
mod fuser;
|
||||
mod optimization;
|
||||
|
||||
pub(crate) mod args;
|
||||
pub(crate) mod tune;
|
||||
|
||||
pub use fuser::*;
|
||||
pub use optimization::*;
|
||||
@@ -0,0 +1,492 @@
|
||||
use super::args::{
|
||||
FusedReduceInput, FusedReduceInputLaunch, FusedReduceOutput, FusedReduceOutputLaunch,
|
||||
};
|
||||
#[cfg(feature = "autotune")]
|
||||
use super::tune::fused_reduce_autotune;
|
||||
use crate::{
|
||||
CubeFusionHandle, FallbackOperation,
|
||||
engine::{
|
||||
codegen::ir::{
|
||||
FuseArg, FuseBlockConfig, FuseType, GlobalArgsLaunch, RefLayout,
|
||||
multi_block_variables_init,
|
||||
},
|
||||
launch::{
|
||||
FuseTraceLauncher,
|
||||
runner::{TraceRunner, Vectorization},
|
||||
},
|
||||
trace::{FuseTrace, TraceError, TuneOutput},
|
||||
},
|
||||
optim::{elemwise::ElemwiseRunner, reduce::args::FusedReduceArgs},
|
||||
};
|
||||
use burn_fusion::stream::Context;
|
||||
use burn_ir::ReduceDimOpIr;
|
||||
use burn_std::DType;
|
||||
use cubecl::{Runtime, client::ComputeClient, ir::StorageType, prelude::*};
|
||||
use cubek::reduce::{
|
||||
LineMode, ReduceDtypes, ReduceError,
|
||||
components::instructions::ReduceOperationConfig,
|
||||
init_tensors,
|
||||
launch::{RoutineStrategy, reduce_kernel_virtual},
|
||||
routines::{
|
||||
ReduceBlueprint, ReduceLaunchSettings, ReduceLineSettings, ReduceProblem, Routine,
|
||||
cube::CubeRoutine, plane::PlaneRoutine, unit::UnitRoutine,
|
||||
},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
use cubek::reduce::routines::{BlueprintStrategy, unit::UnitStrategy};
|
||||
|
||||
pub struct ReduceOptimization<R: Runtime> {
|
||||
pub(crate) info: Arc<ReduceOptimizationInfo<R>>,
|
||||
}
|
||||
|
||||
pub(crate) struct ReduceOptimizationInfo<R: Runtime> {
|
||||
pub(crate) trace: FuseTrace,
|
||||
trace_read_fallback: FuseTrace,
|
||||
trace_write_fallback: FuseTrace,
|
||||
pub(crate) client: ComputeClient<R>,
|
||||
pub(crate) device: R::Device,
|
||||
pub(crate) len: usize,
|
||||
pub(crate) len_read: usize,
|
||||
pub(crate) reduce: FusedReduce,
|
||||
settings: ReduceSettings,
|
||||
}
|
||||
|
||||
impl<R: Runtime> ReduceOptimizationInfo<R> {
|
||||
pub fn from_state(device: &R::Device, state: ReduceOptimizationState) -> Self {
|
||||
let client = R::client(device);
|
||||
|
||||
Self {
|
||||
trace: state.trace,
|
||||
trace_read_fallback: state.trace_read_fallback,
|
||||
trace_write_fallback: state.trace_write_fallback,
|
||||
client,
|
||||
device: device.clone(),
|
||||
len: state.len,
|
||||
len_read: state.len_read,
|
||||
reduce: state.reduce,
|
||||
settings: state.settings,
|
||||
}
|
||||
}
|
||||
pub fn to_state(&self) -> ReduceOptimizationState {
|
||||
ReduceOptimizationState {
|
||||
trace: self.trace.clone(),
|
||||
trace_read_fallback: self.trace_read_fallback.clone(),
|
||||
trace_write_fallback: self.trace_write_fallback.clone(),
|
||||
len: self.len,
|
||||
len_read: self.len_read,
|
||||
reduce: self.reduce.clone(),
|
||||
settings: self.settings,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Copy, Clone)]
|
||||
pub enum ReduceSettings {
|
||||
Always,
|
||||
/// We only activate fuse-on-write when the reduction isn't on the last dimension, otherwise
|
||||
/// vectorization is impossible. Only [LineMode::Perpendicular] supports vectorization.
|
||||
///
|
||||
/// We could still fuse some output operations, but it would probably lead to worse performance.
|
||||
OnlyParallel,
|
||||
Never,
|
||||
}
|
||||
|
||||
pub(crate) struct ReduceOptimizationTuneArg<R: Runtime> {
|
||||
pub(crate) info: Arc<ReduceOptimizationInfo<R>>,
|
||||
pub(crate) fallback: Arc<Box<dyn FallbackOperation<R>>>,
|
||||
}
|
||||
|
||||
impl<R: Runtime> Clone for ReduceOptimizationTuneArg<R> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
info: self.info.clone(),
|
||||
fallback: self.fallback.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
|
||||
pub enum ReduceInstruction {
|
||||
ArgMax,
|
||||
ArgMin,
|
||||
Mean,
|
||||
Prod,
|
||||
Sum,
|
||||
Max,
|
||||
Min,
|
||||
MaxAbs,
|
||||
}
|
||||
|
||||
pub trait ReduceFallbackFn<R: Runtime>: Send + Sync {
|
||||
fn run(&self, context: &mut Context<'_, CubeFusionHandle<R>>);
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ReduceOptimizationState {
|
||||
pub(crate) trace: FuseTrace,
|
||||
pub(crate) trace_read_fallback: FuseTrace,
|
||||
pub(crate) trace_write_fallback: FuseTrace,
|
||||
pub(crate) reduce: FusedReduce,
|
||||
pub(crate) len: usize,
|
||||
pub(crate) len_read: usize,
|
||||
pub(crate) settings: ReduceSettings,
|
||||
}
|
||||
|
||||
impl core::fmt::Debug for ReduceOptimizationState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_fmt(format_args!(
|
||||
"{{ len_read: {}, len_total: {} }}",
|
||||
self.len_read, self.len
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct FusedReduce {
|
||||
pub(crate) input: FuseArg,
|
||||
pub(crate) output: FuseArg,
|
||||
pub(crate) acc: FuseType,
|
||||
pub(crate) axis: usize,
|
||||
pub(crate) op: ReduceDimOpIr,
|
||||
pub(crate) use_planes: bool,
|
||||
pub(crate) shared: bool,
|
||||
pub(crate) inst: ReduceInstruction,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub struct FusedReduceLaunch<'a> {
|
||||
reduce: &'a FusedReduce,
|
||||
strategy: RoutineStrategy,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum FusedReduceError {
|
||||
Reduce(ReduceError),
|
||||
InvalidSelection(Box<&'static str>),
|
||||
InvalidInput,
|
||||
}
|
||||
|
||||
impl From<ReduceError> for FusedReduceError {
|
||||
fn from(value: ReduceError) -> Self {
|
||||
Self::Reduce(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> ReduceOptimizationTuneArg<R> {
|
||||
pub fn execute_fused<BT: CubeElement>(
|
||||
&self,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
strategy: RoutineStrategy,
|
||||
) -> Result<TuneOutput<R>, TraceError<FusedReduceError>> {
|
||||
let launch = FusedReduceLaunch::new(&self.info.reduce, strategy);
|
||||
let launcher = FuseTraceLauncher::new(&self.info.trace, &launch);
|
||||
launcher.launch::<BT>(&self.info.client, &self.info.device, context)
|
||||
}
|
||||
|
||||
pub fn execute_fallback<BT: CubeElement>(
|
||||
&self,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
) -> TuneOutput<R> {
|
||||
let launcher = FuseTraceLauncher::new(&self.info.trace_read_fallback, &ElemwiseRunner);
|
||||
|
||||
#[allow(unused_mut)] // It is used when `autotune-checks` is activated.
|
||||
let mut output_read = launcher
|
||||
.launch::<BT>(&self.info.client, &self.info.device, context)
|
||||
.unwrap();
|
||||
|
||||
self.fallback.run(context);
|
||||
|
||||
#[cfg(feature = "autotune-checks")]
|
||||
if let TuneOutput::Checked { handles } = &mut output_read {
|
||||
let out_desc = context.tensors.get(&self.info.reduce.op.out.id).unwrap();
|
||||
let handle_out = context
|
||||
.handles
|
||||
.get_handle(&out_desc.id, &burn_ir::TensorStatus::ReadOnly);
|
||||
|
||||
handles.insert(
|
||||
self.info.reduce.op.out.id,
|
||||
(out_desc.shape.dims.clone(), handle_out.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
let launcher = FuseTraceLauncher::new(&self.info.trace_write_fallback, &ElemwiseRunner);
|
||||
|
||||
let output_write = launcher
|
||||
.launch::<BT>(&self.info.client, &self.info.device, context)
|
||||
.unwrap();
|
||||
|
||||
output_read.merge(output_write)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
impl<R: Runtime> ReduceOptimization<R> {
|
||||
pub fn new(
|
||||
trace: FuseTrace,
|
||||
trace_read_fallback: FuseTrace,
|
||||
trace_write_fallback: FuseTrace,
|
||||
client: ComputeClient<R>,
|
||||
device: R::Device,
|
||||
len: usize,
|
||||
len_read: usize,
|
||||
reduce: FusedReduce,
|
||||
settings: ReduceSettings,
|
||||
) -> Self {
|
||||
let info = ReduceOptimizationInfo {
|
||||
trace,
|
||||
trace_read_fallback,
|
||||
trace_write_fallback,
|
||||
client,
|
||||
device,
|
||||
len,
|
||||
len_read,
|
||||
reduce,
|
||||
settings,
|
||||
};
|
||||
|
||||
Self {
|
||||
info: Arc::new(info),
|
||||
}
|
||||
}
|
||||
/// Execute the optimization.
|
||||
pub fn execute<BT: CubeElement>(
|
||||
&mut self,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
fallback: impl FnOnce(usize) -> Box<dyn FallbackOperation<R>>,
|
||||
) {
|
||||
// The index of the fallback reduce is the number of ops fused as read.
|
||||
let fallback = fallback(self.info.len_read);
|
||||
let arg = ReduceOptimizationTuneArg {
|
||||
info: self.info.clone(),
|
||||
fallback: Arc::new(fallback),
|
||||
};
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
fused_reduce_autotune::<R, BT>(arg, context);
|
||||
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
if arg
|
||||
.execute_fused::<BT>(
|
||||
context,
|
||||
RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
|
||||
)
|
||||
.is_err()
|
||||
{
|
||||
arg.execute_fallback::<BT>(context);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_output_buffers(&self) -> usize {
|
||||
self.info.trace_read_fallback.resources.outputs.len()
|
||||
}
|
||||
|
||||
pub fn to_state(&self) -> ReduceOptimizationState {
|
||||
ReduceOptimizationState {
|
||||
trace: self.info.trace.clone(),
|
||||
trace_read_fallback: self.info.trace_read_fallback.clone(),
|
||||
trace_write_fallback: self.info.trace_write_fallback.clone(),
|
||||
reduce: self.info.reduce.clone(),
|
||||
len: self.info.len,
|
||||
len_read: self.info.len_read,
|
||||
settings: self.info.settings,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_state(device: &R::Device, state: ReduceOptimizationState) -> Self {
|
||||
let client = R::client(device);
|
||||
|
||||
let info = ReduceOptimizationInfo {
|
||||
trace: state.trace,
|
||||
trace_read_fallback: state.trace_read_fallback,
|
||||
trace_write_fallback: state.trace_write_fallback,
|
||||
reduce: state.reduce,
|
||||
len: state.len,
|
||||
len_read: state.len_read,
|
||||
client,
|
||||
device: device.clone(),
|
||||
settings: state.settings,
|
||||
};
|
||||
|
||||
Self {
|
||||
info: Arc::new(info),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of output buffers added by fusion.
|
||||
pub fn num_ops_fused(&self) -> usize {
|
||||
self.info.len
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Implement better vectorization here.
|
||||
impl<R: Runtime> Vectorization<R> for FusedReduceLaunch<'_> {}
|
||||
|
||||
impl<R: Runtime> TraceRunner<R> for FusedReduceLaunch<'_> {
|
||||
type Error = FusedReduceError;
|
||||
|
||||
fn run<'a>(
|
||||
&'a self,
|
||||
client: &'a ComputeClient<R>,
|
||||
inputs: GlobalArgsLaunch<'a, R>,
|
||||
outputs: GlobalArgsLaunch<'a, R>,
|
||||
configs: &'a [FuseBlockConfig],
|
||||
) -> Result<(), FusedReduceError> {
|
||||
let [config_read, config_write] = [&configs[0], &configs[1]];
|
||||
let shape = match &config_read.ref_layout {
|
||||
RefLayout::Concrete(FuseArg::Output(..)) => {
|
||||
outputs.shape_ref(&config_read.ref_layout, config_read.rank)
|
||||
}
|
||||
_ => inputs.shape_ref(&config_read.ref_layout, config_read.rank),
|
||||
};
|
||||
let reduce_count: usize = shape
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, s)| if i == self.reduce.axis { 1 } else { *s })
|
||||
.product();
|
||||
|
||||
let line_mode = match self.reduce.axis == config_read.rank - 1 {
|
||||
true => LineMode::Parallel,
|
||||
false => LineMode::Perpendicular,
|
||||
};
|
||||
let address_type = inputs
|
||||
.required_address_type()
|
||||
.max(outputs.required_address_type());
|
||||
|
||||
let settings = ReduceLineSettings {
|
||||
line_mode,
|
||||
line_size_input: config_read.width,
|
||||
line_size_output: config_write.width,
|
||||
};
|
||||
let problem = ReduceProblem {
|
||||
vector_size: shape[self.reduce.axis],
|
||||
vector_count: reduce_count,
|
||||
axis: self.reduce.axis,
|
||||
dtypes: ReduceDtypes {
|
||||
input: self.reduce.op.input.dtype.into(),
|
||||
output: self.reduce.op.out.dtype.into(),
|
||||
accumulation: self.reduce.acc.into_elem().into(),
|
||||
},
|
||||
address_type,
|
||||
};
|
||||
|
||||
let (blueprint, settings) = match self.strategy.clone() {
|
||||
RoutineStrategy::Unit(strategy) => {
|
||||
let routine = UnitRoutine;
|
||||
routine.prepare(client, problem, settings, strategy)?
|
||||
}
|
||||
RoutineStrategy::Plane(strategy) => {
|
||||
let routine = PlaneRoutine;
|
||||
routine.prepare(client, problem, settings, strategy)?
|
||||
}
|
||||
RoutineStrategy::Cube(strategy) => {
|
||||
let routine = CubeRoutine;
|
||||
routine.prepare(client, problem, settings, strategy)?
|
||||
}
|
||||
};
|
||||
|
||||
let kwargs = ReduceKwArgs {
|
||||
client,
|
||||
inputs,
|
||||
outputs,
|
||||
axis: self.reduce.axis,
|
||||
config_fuse_read: config_read.clone(),
|
||||
config_fuse_write: config_write.clone(),
|
||||
input: self.reduce.input.clone(),
|
||||
output: self.reduce.output.clone(),
|
||||
blueprint,
|
||||
settings,
|
||||
};
|
||||
let result = launch_reduce_mixed_precision(
|
||||
kwargs,
|
||||
self.reduce.inst,
|
||||
self.reduce.op.input.dtype,
|
||||
self.reduce.op.out.dtype,
|
||||
DType::from(self.reduce.acc.into_elem()),
|
||||
);
|
||||
|
||||
match result {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(FusedReduceError::Reduce(ReduceError::Launch(err))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ReduceKwArgs<'a, 'b, Run: Runtime> {
|
||||
client: &'b ComputeClient<Run>,
|
||||
inputs: GlobalArgsLaunch<'a, Run>,
|
||||
outputs: GlobalArgsLaunch<'a, Run>,
|
||||
axis: usize,
|
||||
blueprint: ReduceBlueprint,
|
||||
settings: ReduceLaunchSettings,
|
||||
config_fuse_read: FuseBlockConfig,
|
||||
config_fuse_write: FuseBlockConfig,
|
||||
input: FuseArg,
|
||||
output: FuseArg,
|
||||
}
|
||||
|
||||
fn launch_reduce_mixed_precision<Run: Runtime>(
|
||||
kwargs: ReduceKwArgs<'_, '_, Run>,
|
||||
instruction: ReduceInstruction,
|
||||
dtype_input: DType,
|
||||
dtype_output: DType,
|
||||
dtype_acc: DType,
|
||||
) -> Result<(), LaunchError> {
|
||||
let config = match instruction {
|
||||
ReduceInstruction::ArgMax => ReduceOperationConfig::ArgMax,
|
||||
ReduceInstruction::ArgMin => ReduceOperationConfig::ArgMin,
|
||||
ReduceInstruction::Prod => ReduceOperationConfig::Prod,
|
||||
ReduceInstruction::Mean => ReduceOperationConfig::Mean,
|
||||
ReduceInstruction::Sum => ReduceOperationConfig::Sum,
|
||||
ReduceInstruction::Max => ReduceOperationConfig::Max,
|
||||
ReduceInstruction::Min => ReduceOperationConfig::Min,
|
||||
ReduceInstruction::MaxAbs => ReduceOperationConfig::MaxAbs,
|
||||
};
|
||||
launch_reduce::<Run>(kwargs, config, dtype_input, dtype_output, dtype_acc)
|
||||
}
|
||||
|
||||
fn launch_reduce<Run: Runtime>(
|
||||
kwargs: ReduceKwArgs<'_, '_, Run>,
|
||||
inst: ReduceOperationConfig,
|
||||
dtype_input: DType,
|
||||
dtype_output: DType,
|
||||
dtype_acc: DType,
|
||||
) -> Result<(), LaunchError> {
|
||||
unsafe {
|
||||
reduce_kernel_fused::launch_unchecked::<Run>(
|
||||
kwargs.client,
|
||||
kwargs.settings.cube_count,
|
||||
kwargs.settings.cube_dim,
|
||||
kwargs.settings.address_type,
|
||||
FusedReduceInputLaunch::new(kwargs.inputs, kwargs.config_fuse_read, kwargs.input),
|
||||
FusedReduceOutputLaunch::new(kwargs.outputs, kwargs.config_fuse_write, kwargs.output),
|
||||
ScalarArg::new(kwargs.axis),
|
||||
kwargs.blueprint,
|
||||
inst,
|
||||
dtype_input.into(),
|
||||
dtype_output.into(),
|
||||
dtype_acc.into(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
pub fn reduce_kernel_fused<In: Numeric, Out: Numeric, Acc: Numeric>(
|
||||
input: &FusedReduceInput,
|
||||
output: &mut FusedReduceOutput,
|
||||
axis_reduce: usize,
|
||||
#[comptime] blueprint: ReduceBlueprint,
|
||||
#[comptime] config: ReduceOperationConfig,
|
||||
#[define(In)] _input_dtype: StorageType,
|
||||
#[define(Out)] _output_dtype: StorageType,
|
||||
#[define(Acc)] _acc_dtype: StorageType,
|
||||
) {
|
||||
multi_block_variables_init(&input.config, &mut output.global.variables);
|
||||
multi_block_variables_init(&output.config, &mut output.global.variables);
|
||||
|
||||
let (input, mut output) = init_tensors::<FusedReduceArgs, In, Out>(input, output);
|
||||
|
||||
reduce_kernel_virtual::<In, Out, Acc>(&input, &mut output, axis_reduce, blueprint, config);
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
use super::optimization::ReduceOptimizationTuneArg;
|
||||
use crate::{
|
||||
CubeFusionHandle,
|
||||
engine::trace::TuneOutput,
|
||||
tune::{TuneContext, TuneInput},
|
||||
};
|
||||
use burn_fusion::stream::Context;
|
||||
use cubecl::{
|
||||
AutotuneKey, CubeElement, CubeTuneId, Runtime,
|
||||
tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner},
|
||||
};
|
||||
use cubek::reduce::{
|
||||
launch::{RoutineStrategy, tune_key::ReduceAutotuneKey},
|
||||
routines::{BlueprintStrategy, cube::CubeStrategy, plane::PlaneStrategy, unit::UnitStrategy},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Autotune key for standard fused reduction operations.
|
||||
///
|
||||
/// Records metadata about the fusion graph (IO and ops) alongside
|
||||
/// the core reduction parameters to ensure stable kernel selection.
|
||||
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
|
||||
pub struct FusedReduceAutotuneKey {
|
||||
reduce_key: ReduceAutotuneKey,
|
||||
#[autotune(anchor)]
|
||||
fuse_num_reads: usize,
|
||||
#[autotune(anchor)]
|
||||
fuse_num_writes: usize,
|
||||
#[autotune(anchor)]
|
||||
fuse_num_ops: usize,
|
||||
}
|
||||
|
||||
/// Executes autotuning for fused reduction operations.
|
||||
///
|
||||
/// This tuner evaluates different hardware-specific strategies (Plane, Cube, Unit)
|
||||
/// and assigns priorities based on the `vector_count` of the reduction.
|
||||
pub fn fused_reduce_autotune<R: Runtime, BT: CubeElement>(
|
||||
arg: ReduceOptimizationTuneArg<R>,
|
||||
context: &mut Context<CubeFusionHandle<R>>,
|
||||
) {
|
||||
static TUNER: LocalTuner<FusedReduceAutotuneKey, CubeTuneId> = local_tuner!();
|
||||
|
||||
let tunables = TUNER.init(|| {
|
||||
const PRIORITY_MAX: i8 = 2;
|
||||
const PRIORITY_MIN: i8 = 1;
|
||||
|
||||
let mut set = TunableSet::new(create_key::<R>, input_gen::<R>);
|
||||
let group = TuneGroup::<FusedReduceAutotuneKey>::new("fused_reduce", |_key| PRIORITY_MAX);
|
||||
|
||||
// Fallback implementation for robustness.
|
||||
set = set.with(Tunable::new(
|
||||
"fused_reduce_fallback",
|
||||
tune_fallback::<R, BT>,
|
||||
));
|
||||
|
||||
// Define properties to categorize hardware strategies.
|
||||
enum ReduceProps {
|
||||
GreatWithLowReduceCount,
|
||||
GreatWithHighReduceCount,
|
||||
Balanced,
|
||||
}
|
||||
|
||||
let strategies = [
|
||||
(
|
||||
"fused_unit",
|
||||
RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
|
||||
ReduceProps::GreatWithHighReduceCount,
|
||||
),
|
||||
(
|
||||
"fused_plane",
|
||||
RoutineStrategy::Plane(BlueprintStrategy::Inferred(PlaneStrategy {
|
||||
independent: true,
|
||||
})),
|
||||
ReduceProps::Balanced,
|
||||
),
|
||||
(
|
||||
"fused_cube",
|
||||
RoutineStrategy::Cube(BlueprintStrategy::Inferred(CubeStrategy {
|
||||
// Two steps reduction doesn't work with fuse-on-write, we can't activate plane
|
||||
// when using the cube algo.
|
||||
use_planes: false,
|
||||
})),
|
||||
ReduceProps::GreatWithLowReduceCount,
|
||||
),
|
||||
];
|
||||
|
||||
for (name, strategy, props) in strategies {
|
||||
let tunable = Tunable::new(name, move |input| tune_reduce::<R, BT>(input, &strategy))
|
||||
.group(&group, move |key| match props {
|
||||
ReduceProps::GreatWithLowReduceCount => {
|
||||
if key.reduce_key.vector_count < 128 {
|
||||
PRIORITY_MAX
|
||||
} else {
|
||||
PRIORITY_MIN
|
||||
}
|
||||
}
|
||||
ReduceProps::GreatWithHighReduceCount => {
|
||||
if key.reduce_key.vector_count > 64 {
|
||||
PRIORITY_MAX
|
||||
} else {
|
||||
PRIORITY_MIN
|
||||
}
|
||||
}
|
||||
ReduceProps::Balanced => PRIORITY_MAX,
|
||||
});
|
||||
|
||||
set = set.with(tunable);
|
||||
}
|
||||
|
||||
set
|
||||
});
|
||||
|
||||
TUNER.execute(
|
||||
&CubeTuneId::new(&arg.info.client, &arg.info.device),
|
||||
&arg.info.client.clone(),
|
||||
tunables,
|
||||
TuneInput::new(context, arg),
|
||||
);
|
||||
}
|
||||
|
||||
/// Creates the autotune key by extracting tensor metadata and fusion block statistics.
|
||||
pub(crate) fn create_key<R: Runtime>(
|
||||
input: &TuneInput<R, ReduceOptimizationTuneArg<R>>,
|
||||
) -> FusedReduceAutotuneKey {
|
||||
let opt = input.optimization();
|
||||
let context = match input.context() {
|
||||
TuneContext::Original(context) => context,
|
||||
TuneContext::Fork(_) => panic!("Forked context not supported for key generation"),
|
||||
};
|
||||
|
||||
let input_tensor = context.tensors.get(&opt.info.reduce.op.input.id).unwrap();
|
||||
let out_tensor = context.tensors.get(&opt.info.reduce.op.out.id).unwrap();
|
||||
let acc = opt.info.reduce.acc.into_elem();
|
||||
|
||||
let key = ReduceAutotuneKey::generate(
|
||||
input_tensor.dtype.into(),
|
||||
out_tensor.dtype.into(),
|
||||
acc,
|
||||
&input_tensor.shape,
|
||||
opt.info.reduce.axis == input_tensor.shape.rank() - 1,
|
||||
opt.info.reduce.axis,
|
||||
);
|
||||
|
||||
// Assume the fusion contains at least a read and a write block.
|
||||
let read_block = &opt.info.trace.blocks[0];
|
||||
let write_block = &opt.info.trace.blocks[1];
|
||||
|
||||
FusedReduceAutotuneKey::new(
|
||||
key,
|
||||
read_block.reads.len() + write_block.reads.len(),
|
||||
read_block.writes.len() + write_block.writes.len(),
|
||||
read_block.ops.len() + write_block.ops.len(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Identity generator for tuning inputs.
|
||||
fn input_gen<R: Runtime>(
|
||||
_key: &FusedReduceAutotuneKey,
|
||||
input: &TuneInput<R, ReduceOptimizationTuneArg<R>>,
|
||||
) -> TuneInput<R, ReduceOptimizationTuneArg<R>> {
|
||||
input.clone()
|
||||
}
|
||||
|
||||
/// Executes a fused reduction optimization.
|
||||
fn tune_reduce<R: Runtime, BT: CubeElement>(
|
||||
input: TuneInput<R, ReduceOptimizationTuneArg<R>>,
|
||||
strategy: &RoutineStrategy,
|
||||
) -> Result<TuneOutput<R>, String> {
|
||||
let optimization = input.optimization();
|
||||
|
||||
match input.context() {
|
||||
TuneContext::Original(context) => {
|
||||
optimization.execute_fused::<BT>(context, strategy.clone())
|
||||
}
|
||||
TuneContext::Fork(mut context_owned) => {
|
||||
optimization.execute_fused::<BT>(&mut context_owned.as_context(), strategy.clone())
|
||||
}
|
||||
}
|
||||
.map_err(|e| format!("{e:?}"))
|
||||
}
|
||||
|
||||
/// Executes the fallback path for a reduction optimization.
|
||||
fn tune_fallback<R: Runtime, BT: CubeElement>(
|
||||
input: TuneInput<R, ReduceOptimizationTuneArg<R>>,
|
||||
) -> Result<TuneOutput<R>, String> {
|
||||
let optimization = input.optimization();
|
||||
|
||||
match input.context() {
|
||||
TuneContext::Original(context) => optimization.execute_fallback::<BT>(context),
|
||||
TuneContext::Fork(mut context_owned) => {
|
||||
optimization.execute_fallback::<BT>(&mut context_owned.as_context())
|
||||
}
|
||||
};
|
||||
|
||||
Ok(TuneOutput::UnChecked(std::marker::PhantomData))
|
||||
}
|
||||
@@ -0,0 +1,375 @@
|
||||
use crate::{
|
||||
engine::codegen::ir::FuseType,
|
||||
optim::{
|
||||
CubeOptimization,
|
||||
reduce::{ReduceFuser, ReduceFuserInfo, ReduceSettings},
|
||||
reduce_broadcasted::{
|
||||
ReduceBroadcastedOptimization, ReduceBroadcastedOptimizationInfo,
|
||||
fuser::{
|
||||
block::{ReduceBlockFuser, ReduceBlockFusionAnalysis, ReduceBroadcastedStatus},
|
||||
full::ReduceBroadcastedFullFuser,
|
||||
full_analyzer::FullFuserAnalyzer,
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
use burn_fusion::{FuserProperties, FuserStatus, OperationFuser};
|
||||
use burn_ir::OperationIr;
|
||||
use cubecl::Runtime;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Fuses element wise operations around a reduce operation.
|
||||
pub struct ReduceBroadcastedFuser<R: Runtime> {
|
||||
blocks: Vec<ReduceBlockFuser<R>>,
|
||||
fuser_default: ReduceFuser<R>,
|
||||
num_ops: usize,
|
||||
state: ReduceBroadcastedStatus,
|
||||
max_bindings: u32,
|
||||
bool_precision: FuseType,
|
||||
}
|
||||
|
||||
impl<R: Runtime> Clone for ReduceBroadcastedFuser<R> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
blocks: self.blocks.clone(),
|
||||
fuser_default: self.fuser_default.clone(),
|
||||
num_ops: self.num_ops,
|
||||
state: self.state.clone(),
|
||||
max_bindings: self.max_bindings,
|
||||
bool_precision: self.bool_precision,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> ReduceBroadcastedFuser<R> {
|
||||
pub fn new(device: R::Device, bool_precision: FuseType) -> Self {
|
||||
let fuser = ReduceFuser::new(device, bool_precision, ReduceSettings::Always);
|
||||
let max_bindings = fuser.fuser.max_bindings;
|
||||
let block = ReduceBlockFuser::new(fuser.clone());
|
||||
|
||||
Self {
|
||||
blocks: vec![block],
|
||||
fuser_default: fuser,
|
||||
num_ops: 0,
|
||||
state: ReduceBroadcastedStatus::Starting,
|
||||
max_bindings,
|
||||
bool_precision,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> OperationFuser<CubeOptimization<R>> for ReduceBroadcastedFuser<R> {
|
||||
fn fuse(&mut self, operation: &OperationIr) {
|
||||
if matches!(
|
||||
&self.state,
|
||||
ReduceBroadcastedStatus::Closed | ReduceBroadcastedStatus::Abort
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
let block = self.blocks.last_mut().unwrap();
|
||||
let analyze = block.analyze(operation, &self.state, &self.fuser_default);
|
||||
|
||||
let info = match analyze {
|
||||
ReduceBlockFusionAnalysis::Accept => {
|
||||
block.fuse(operation);
|
||||
self.num_ops += 1;
|
||||
block.fuser.reduce_info()
|
||||
}
|
||||
ReduceBlockFusionAnalysis::Refuse => {
|
||||
self.state = ReduceBroadcastedStatus::Closed;
|
||||
return;
|
||||
}
|
||||
ReduceBlockFusionAnalysis::NewBlockRequired => {
|
||||
let info = block.fuser.reduce_info();
|
||||
let mut block = ReduceBlockFuser::new(self.fuser_default.clone());
|
||||
block.fuse(operation);
|
||||
self.num_ops += 1;
|
||||
self.blocks.push(block);
|
||||
info
|
||||
}
|
||||
};
|
||||
|
||||
match info {
|
||||
ReduceFuserInfo::FusedReduce {
|
||||
shape_input_id,
|
||||
axis,
|
||||
} => {
|
||||
// Only support last axis for now.
|
||||
if axis != shape_input_id.len() - 1 {
|
||||
self.state = ReduceBroadcastedStatus::Abort;
|
||||
} else {
|
||||
self.state = ReduceBroadcastedStatus::Init {
|
||||
shape_id: shape_input_id,
|
||||
axis,
|
||||
};
|
||||
}
|
||||
}
|
||||
ReduceFuserInfo::FusedElemwise { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn finish(&mut self) -> CubeOptimization<R> {
|
||||
let analyzer = FullFuserAnalyzer::new(&self.blocks);
|
||||
let mut full =
|
||||
ReduceBroadcastedFullFuser::new(self.max_bindings, self.bool_precision, analyzer);
|
||||
let mut num_ops = 0;
|
||||
let fallbacks = self
|
||||
.blocks
|
||||
.iter_mut()
|
||||
.map(|block| block.finish(&mut num_ops, &mut full))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let broadcasted = Arc::new(full.finish());
|
||||
let info = Arc::new(ReduceBroadcastedOptimizationInfo {
|
||||
fallbacks,
|
||||
broadcasted,
|
||||
});
|
||||
CubeOptimization::ReduceBroadcasted(ReduceBroadcastedOptimization { info, num_ops })
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
let block = ReduceBlockFuser::new(self.fuser_default.clone());
|
||||
self.blocks = vec![block];
|
||||
self.num_ops = 0;
|
||||
self.state = ReduceBroadcastedStatus::Starting;
|
||||
}
|
||||
|
||||
fn status(&self) -> FuserStatus {
|
||||
match self.state {
|
||||
ReduceBroadcastedStatus::Closed | ReduceBroadcastedStatus::Abort => {
|
||||
return FuserStatus::Closed;
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
|
||||
let fuser = self.blocks.last().unwrap();
|
||||
fuser.fuser.status()
|
||||
}
|
||||
|
||||
fn properties(&self) -> FuserProperties {
|
||||
let ready = match self.state {
|
||||
ReduceBroadcastedStatus::Starting | ReduceBroadcastedStatus::Abort => false,
|
||||
ReduceBroadcastedStatus::Closed => {
|
||||
if self.blocks.len() == 1 {
|
||||
!self.blocks[0].is_elemwise()
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
_ => true,
|
||||
};
|
||||
let mut props = FuserProperties { score: 0, ready };
|
||||
for block in self.blocks.iter() {
|
||||
let p = block.properties();
|
||||
props.score += p.score;
|
||||
props.ready = p.ready && props.ready;
|
||||
}
|
||||
props
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.num_ops
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn OperationFuser<CubeOptimization<R>>> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn_ir::{
|
||||
BaseOperationIr, BinaryOpIr, CreationOpIr, ReduceDimOpIr, TensorId, TensorIr, TensorStatus,
|
||||
};
|
||||
use burn_std::{DType, Shape};
|
||||
|
||||
use super::*;
|
||||
|
||||
type Run = cubecl::TestRuntime;
|
||||
|
||||
#[test]
|
||||
fn reduce_broadcast_workflow_1() {
|
||||
let device: <Run as Runtime>::Device = Default::default();
|
||||
let mut fuser = ReduceBroadcastedFuser::<Run>::new(device, FuseType::I32);
|
||||
let (tensor1_out, tensor1) = tensor(0, &[1, 2], TensorStatus::ReadWrite);
|
||||
let (tensor2_out, tensor2) = tensor(1, &[1, 0], TensorStatus::ReadWrite);
|
||||
|
||||
fuser.fuse(&OperationIr::BaseFloat(BaseOperationIr::Ones(
|
||||
CreationOpIr { out: tensor1_out },
|
||||
)));
|
||||
fuser.fuse(&OperationIr::NumericFloat(
|
||||
DType::F32,
|
||||
burn_ir::NumericOperationIr::SumDim(ReduceDimOpIr {
|
||||
input: tensor1,
|
||||
out: tensor2_out,
|
||||
axis: 1,
|
||||
}),
|
||||
));
|
||||
|
||||
let status = fuser.status();
|
||||
assert_eq!(2, fuser.len());
|
||||
assert_eq!(status, FuserStatus::Open);
|
||||
assert_eq!(
|
||||
fuser.properties(),
|
||||
FuserProperties {
|
||||
score: 2,
|
||||
ready: true
|
||||
}
|
||||
);
|
||||
|
||||
// An existing tensor
|
||||
let (_tensor3_out, tensor3) = tensor(2, &[1, 0], TensorStatus::ReadWrite);
|
||||
// A new tensor
|
||||
let (tensor4_out, tensor4) = tensor(3, &[1, 0], TensorStatus::ReadWrite);
|
||||
fuser.fuse(&OperationIr::NumericFloat(
|
||||
DType::F32,
|
||||
burn_ir::NumericOperationIr::Add(BinaryOpIr {
|
||||
lhs: tensor2,
|
||||
rhs: tensor3,
|
||||
out: tensor4_out,
|
||||
}),
|
||||
));
|
||||
|
||||
let status = fuser.status();
|
||||
assert_eq!(3, fuser.len());
|
||||
assert_eq!(status, FuserStatus::Open);
|
||||
assert_eq!(
|
||||
fuser.properties(),
|
||||
FuserProperties {
|
||||
score: 3,
|
||||
ready: true
|
||||
}
|
||||
);
|
||||
|
||||
// An existing tensor
|
||||
let (_tensor5_out, tensor5) = tensor(4, &[1, 2], TensorStatus::ReadWrite);
|
||||
// A new tensor
|
||||
let (tensor6_out, tensor6) = tensor(5, &[1, 2], TensorStatus::ReadWrite);
|
||||
fuser.fuse(&OperationIr::NumericFloat(
|
||||
DType::F32,
|
||||
burn_ir::NumericOperationIr::Add(BinaryOpIr {
|
||||
lhs: tensor4,
|
||||
rhs: tensor5,
|
||||
out: tensor6_out,
|
||||
}),
|
||||
));
|
||||
|
||||
let status = fuser.status();
|
||||
assert_eq!(4, fuser.len());
|
||||
assert_eq!(status, FuserStatus::Open);
|
||||
assert_eq!(
|
||||
fuser.properties(),
|
||||
FuserProperties {
|
||||
score: 4,
|
||||
ready: true
|
||||
}
|
||||
);
|
||||
|
||||
let (tensor7_out, _tensor7) = tensor(6, &[1, 0], TensorStatus::ReadWrite);
|
||||
fuser.fuse(&OperationIr::NumericFloat(
|
||||
DType::F32,
|
||||
burn_ir::NumericOperationIr::SumDim(ReduceDimOpIr {
|
||||
input: tensor6,
|
||||
out: tensor7_out,
|
||||
axis: 1,
|
||||
}),
|
||||
));
|
||||
assert_eq!(5, fuser.len());
|
||||
assert_eq!(status, FuserStatus::Open);
|
||||
assert_eq!(
|
||||
fuser.properties(),
|
||||
FuserProperties {
|
||||
score: 5,
|
||||
ready: true
|
||||
}
|
||||
);
|
||||
|
||||
let _optimization = fuser.finish();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduce_broadcast_workflow_2() {
|
||||
let device: <Run as Runtime>::Device = Default::default();
|
||||
let mut fuser = ReduceBroadcastedFuser::<Run>::new(device, FuseType::I32);
|
||||
let (tensor1_out, tensor1) = tensor(0, &[1, 2], TensorStatus::ReadWrite);
|
||||
// An existing tensor
|
||||
let (_tensor2_out, mut tensor2) = tensor(2, &[1, 2], TensorStatus::ReadOnly);
|
||||
let (tensor3_out, tensor3) = tensor(3, &[1, 2], TensorStatus::ReadWrite);
|
||||
|
||||
// First reduce output
|
||||
let (tensor4_out, tensor4) = tensor(1, &[1, 0], TensorStatus::ReadWrite);
|
||||
|
||||
fuser.fuse(&OperationIr::BaseFloat(BaseOperationIr::Ones(
|
||||
CreationOpIr { out: tensor1_out },
|
||||
)));
|
||||
|
||||
fuser.fuse(&OperationIr::NumericFloat(
|
||||
DType::F32,
|
||||
burn_ir::NumericOperationIr::Add(BinaryOpIr {
|
||||
lhs: tensor1,
|
||||
rhs: tensor2.clone(),
|
||||
out: tensor3_out,
|
||||
}),
|
||||
));
|
||||
|
||||
fuser.fuse(&OperationIr::NumericFloat(
|
||||
DType::F32,
|
||||
burn_ir::NumericOperationIr::SumDim(ReduceDimOpIr {
|
||||
input: tensor3,
|
||||
out: tensor4_out,
|
||||
axis: 1,
|
||||
}),
|
||||
));
|
||||
|
||||
let status = fuser.status();
|
||||
assert_eq!(3, fuser.len());
|
||||
assert_eq!(status, FuserStatus::Open);
|
||||
assert_eq!(
|
||||
fuser.properties(),
|
||||
FuserProperties {
|
||||
score: 3,
|
||||
ready: true
|
||||
}
|
||||
);
|
||||
|
||||
// A new tensor
|
||||
let (tensor5_out, _tensor5) = tensor(5, &[1, 2], TensorStatus::ReadWrite);
|
||||
// Last time we use tensor2.
|
||||
tensor2.status = TensorStatus::ReadWrite;
|
||||
fuser.fuse(&OperationIr::NumericFloat(
|
||||
DType::F32,
|
||||
burn_ir::NumericOperationIr::Add(BinaryOpIr {
|
||||
lhs: tensor4,
|
||||
rhs: tensor2,
|
||||
out: tensor5_out,
|
||||
}),
|
||||
));
|
||||
|
||||
let status = fuser.status();
|
||||
assert_eq!(4, fuser.len());
|
||||
assert_eq!(status, FuserStatus::Open);
|
||||
assert_eq!(
|
||||
fuser.properties(),
|
||||
FuserProperties {
|
||||
score: 4,
|
||||
ready: true
|
||||
}
|
||||
);
|
||||
|
||||
let _optimization = fuser.finish();
|
||||
}
|
||||
|
||||
fn tensor(id: u64, shape: &[usize], status: TensorStatus) -> (TensorIr, TensorIr) {
|
||||
let tensor = TensorIr {
|
||||
id: TensorId::new(id),
|
||||
shape: Shape::from(shape),
|
||||
status: TensorStatus::NotInit,
|
||||
dtype: DType::F32,
|
||||
};
|
||||
let mut tensor_init = tensor.clone();
|
||||
tensor_init.status = status;
|
||||
|
||||
(tensor, tensor_init)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
use crate::optim::{
|
||||
CubeOptimization,
|
||||
elemwise::ElemwiseOptimization,
|
||||
reduce::{FusedReduce, ReduceFuser, ReduceFuserInfo},
|
||||
reduce_broadcasted::{ReduceBlockOptimInfo, fuser::full::ReduceBroadcastedFullFuser},
|
||||
};
|
||||
use burn_fusion::{FuserProperties, OperationFuser};
|
||||
use burn_ir::OperationIr;
|
||||
use burn_std::Shape;
|
||||
use cubecl::Runtime;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Responsible for fusing a single reduce block or elementwise block.
|
||||
///
|
||||
/// When the block kind is reduce, it supports fuse-on-read and fuse-on-write fusion.
|
||||
/// Broadcasting isn't supported; another block should handle it instead.
|
||||
pub struct ReduceBlockFuser<R: Runtime> {
|
||||
/// We use [ReduceFuser] for both elementwise and reduce blocks, keeping only the
|
||||
/// fuse-on-read trace if the block is tagged as elementwise.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// A single elementwise block can only exist at the end of a full [ReduceBlockFuser],
|
||||
/// otherwise the optimization will be included in the reduce fusion block.
|
||||
pub fuser: ReduceFuser<R>,
|
||||
pub(crate) ops: Vec<OperationIr>,
|
||||
pub(crate) kind: ReduceBlockKind,
|
||||
}
|
||||
|
||||
/// The current state of the fusion process.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ReduceBroadcastedStatus {
|
||||
/// Fusion is starting; no reduction has been fused yet.
|
||||
Starting,
|
||||
/// Fusion is initialized with at least one reduce operation.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Subsequent reduce operations must be compatible with the previous reduction to fuse.
|
||||
Init { shape_id: Shape, axis: usize },
|
||||
/// No more operations can be fused.
|
||||
Closed,
|
||||
/// Invalid axis.
|
||||
Abort,
|
||||
}
|
||||
|
||||
/// The [ReduceBlockFuser] capacity to accept an [OperationIr].
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum ReduceBlockFusionAnalysis {
|
||||
/// The operation can be fused; call [ReduceBlockFuser::fuse()].
|
||||
Accept,
|
||||
/// The operation cannot be fused; the optimization should close.
|
||||
Refuse,
|
||||
/// The operation can be fused, but requires a new block.
|
||||
NewBlockRequired,
|
||||
}
|
||||
|
||||
impl<R: Runtime> ReduceBlockFuser<R> {
|
||||
/// Creates a new block.
|
||||
pub fn new(fuser: ReduceFuser<R>) -> Self {
|
||||
Self {
|
||||
fuser: fuser.clone(),
|
||||
ops: Vec::new(),
|
||||
kind: ReduceBlockKind::Elemwise,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if this is an elementwise fuser.
|
||||
pub fn is_elemwise(&self) -> bool {
|
||||
matches!(self.kind, ReduceBlockKind::Elemwise)
|
||||
}
|
||||
|
||||
/// Analyzes if fusion is possible within this block.
|
||||
pub fn analyze(
|
||||
&self,
|
||||
op: &OperationIr,
|
||||
status: &ReduceBroadcastedStatus,
|
||||
default_node: &ReduceFuser<R>,
|
||||
) -> ReduceBlockFusionAnalysis {
|
||||
let mut fuser_try = self.fuser.clone();
|
||||
let before = fuser_try.len();
|
||||
fuser_try.fuse(op);
|
||||
let after = fuser_try.len();
|
||||
|
||||
if after > before {
|
||||
return ReduceBlockFusionAnalysis::Accept;
|
||||
}
|
||||
|
||||
// Can't create a new block if the previous one was not a reduction.
|
||||
if self.fuser.reduce.is_none() {
|
||||
return ReduceBlockFusionAnalysis::Refuse;
|
||||
}
|
||||
|
||||
let mut fuser_try = default_node.clone();
|
||||
let before = fuser_try.len();
|
||||
fuser_try.fuse(op);
|
||||
let after = fuser_try.len();
|
||||
|
||||
if after > before {
|
||||
let info = fuser_try.reduce_info();
|
||||
|
||||
return match (info, status) {
|
||||
(
|
||||
ReduceFuserInfo::FusedReduce {
|
||||
shape_input_id,
|
||||
axis,
|
||||
},
|
||||
ReduceBroadcastedStatus::Init {
|
||||
shape_id,
|
||||
axis: axis_init,
|
||||
},
|
||||
) => {
|
||||
if shape_id == &shape_input_id && axis_init == &axis {
|
||||
ReduceBlockFusionAnalysis::NewBlockRequired
|
||||
} else {
|
||||
ReduceBlockFusionAnalysis::Refuse
|
||||
}
|
||||
}
|
||||
(
|
||||
ReduceFuserInfo::FusedElemwise { shape_id },
|
||||
ReduceBroadcastedStatus::Init {
|
||||
shape_id: shape_init,
|
||||
..
|
||||
},
|
||||
) => {
|
||||
if &shape_id == shape_init {
|
||||
ReduceBlockFusionAnalysis::NewBlockRequired
|
||||
} else {
|
||||
ReduceBlockFusionAnalysis::Refuse
|
||||
}
|
||||
}
|
||||
_ => ReduceBlockFusionAnalysis::Refuse,
|
||||
};
|
||||
}
|
||||
|
||||
ReduceBlockFusionAnalysis::Refuse
|
||||
}
|
||||
|
||||
/// Fuses an operation within this block.
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// Ensure [Self::analyze()] is called before this function to confirm the operation is accepted.
|
||||
pub fn fuse(&mut self, op: &OperationIr) {
|
||||
self.fuser.fuse(op);
|
||||
self.ops.push(op.clone());
|
||||
|
||||
// Update the kind if a reduction is introduced to an elementwise block.
|
||||
if let (Some(reduce), ReduceBlockKind::Elemwise) = (&self.fuser.reduce, &self.kind) {
|
||||
self.kind = ReduceBlockKind::Reduce {
|
||||
ops_index: self.ops.len() - 1,
|
||||
reduce: Box::new(reduce.clone()),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the fuser properties.
|
||||
pub fn properties(&self) -> FuserProperties {
|
||||
let mut properties = self.fuser.properties();
|
||||
if let ReduceBlockKind::Elemwise = &self.kind {
|
||||
// Elementwise traces are always ready to run.
|
||||
properties.ready = true;
|
||||
}
|
||||
properties
|
||||
}
|
||||
|
||||
pub fn finish(
|
||||
&mut self,
|
||||
num_ops: &mut usize,
|
||||
full: &mut ReduceBroadcastedFullFuser,
|
||||
) -> ReduceBlockOptimInfo<R> {
|
||||
full.register(self);
|
||||
|
||||
match &self.kind {
|
||||
ReduceBlockKind::Elemwise => {
|
||||
let len = self.fuser.fuser_read_fallback.len();
|
||||
let device = self.fuser.device.clone();
|
||||
*num_ops += len;
|
||||
let trace = self.fuser.fuser_read_fallback.finish();
|
||||
let client = R::client(&device);
|
||||
let elementwise = ElemwiseOptimization::new(trace, client, device, len);
|
||||
ReduceBlockOptimInfo::Elemwise(Arc::new(elementwise))
|
||||
}
|
||||
ReduceBlockKind::Reduce { .. } => {
|
||||
*num_ops += self.fuser.len();
|
||||
let optim = self.fuser.finish();
|
||||
let info = match optim {
|
||||
CubeOptimization::Reduce(optim) => optim.info,
|
||||
_ => unreachable!("Expected Reduce optimization"),
|
||||
};
|
||||
ReduceBlockOptimInfo::Reduce(info)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ReduceBlockKind {
|
||||
Elemwise,
|
||||
Reduce {
|
||||
ops_index: usize,
|
||||
reduce: Box<FusedReduce>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<R: Runtime> Clone for ReduceBlockFuser<R> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
fuser: self.fuser.clone(),
|
||||
ops: self.ops.clone(),
|
||||
kind: self.kind.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
use crate::{
|
||||
engine::{
|
||||
codegen::ir::FuseType,
|
||||
fuser::TraceOperationFuser,
|
||||
settings::{FuseSettings, RefLayoutSetting, VectorizationSetting},
|
||||
},
|
||||
optim::{
|
||||
reduce::{FusedReduce, ReduceInstruction},
|
||||
reduce_broadcasted::{
|
||||
ReduceBroadcastedInfo,
|
||||
fuser::{
|
||||
block::{ReduceBlockFuser, ReduceBlockKind},
|
||||
full_analyzer::FullFuserAnalyzer,
|
||||
},
|
||||
launch::ReduceBroadcastedFuseBlock,
|
||||
},
|
||||
},
|
||||
};
|
||||
use burn_fusion::OperationFuser;
|
||||
use cubecl::Runtime;
|
||||
use cubek::reduce::components::instructions::ReduceOperationConfig;
|
||||
|
||||
/// Responsible for fusing a single trace for all operations involved in this optimization.
|
||||
pub struct ReduceBroadcastedFullFuser {
|
||||
pub(crate) fuser: TraceOperationFuser,
|
||||
analyzer: FullFuserAnalyzer,
|
||||
blocks: Vec<ReduceBlockKind>,
|
||||
settings_read: FuseSettings,
|
||||
settings_write: FuseSettings,
|
||||
}
|
||||
|
||||
impl ReduceBroadcastedFullFuser {
|
||||
/// Creates a new fuser with the given settings.
|
||||
pub fn new(max_bindings: u32, bool_precision: FuseType, analyzer: FullFuserAnalyzer) -> Self {
|
||||
let settings_read = FuseSettings {
|
||||
output_shape_updates: true,
|
||||
broadcast: true,
|
||||
inplace: false,
|
||||
ref_layout: RefLayoutSetting::OnlyContiguous,
|
||||
vectorization: VectorizationSetting::Activated,
|
||||
};
|
||||
let settings_write = FuseSettings {
|
||||
output_shape_updates: false,
|
||||
inplace: false,
|
||||
broadcast: false,
|
||||
ref_layout: RefLayoutSetting::OnlyContiguous,
|
||||
// Deactivated for now, but would be cool to support vectorization of the output.
|
||||
vectorization: VectorizationSetting::Deactivated,
|
||||
};
|
||||
let fuser = TraceOperationFuser::new(max_bindings, bool_precision, settings_read);
|
||||
|
||||
Self {
|
||||
fuser,
|
||||
blocks: Vec::new(),
|
||||
settings_write,
|
||||
settings_read,
|
||||
analyzer,
|
||||
}
|
||||
}
|
||||
|
||||
/// Finishes fusing all blocks.
|
||||
pub fn finish(mut self) -> ReduceBroadcastedInfo {
|
||||
let mut reduce_axis = 0;
|
||||
let mut blocks = Vec::new();
|
||||
|
||||
for block in self.blocks.iter() {
|
||||
match block {
|
||||
ReduceBlockKind::Elemwise => {}
|
||||
ReduceBlockKind::Reduce { reduce, .. } => {
|
||||
let config = match reduce.inst {
|
||||
ReduceInstruction::ArgMax => ReduceOperationConfig::ArgMax,
|
||||
ReduceInstruction::ArgMin => ReduceOperationConfig::ArgMin,
|
||||
ReduceInstruction::Prod => ReduceOperationConfig::Prod,
|
||||
ReduceInstruction::Mean => ReduceOperationConfig::Mean,
|
||||
ReduceInstruction::Sum => ReduceOperationConfig::Sum,
|
||||
ReduceInstruction::Max => ReduceOperationConfig::Max,
|
||||
ReduceInstruction::Min => ReduceOperationConfig::Min,
|
||||
ReduceInstruction::MaxAbs => ReduceOperationConfig::MaxAbs,
|
||||
};
|
||||
|
||||
let block = ReduceBroadcastedFuseBlock {
|
||||
op: config,
|
||||
input: reduce.input.clone(),
|
||||
output: reduce.output.clone(),
|
||||
};
|
||||
reduce_axis = reduce.axis;
|
||||
blocks.push(block);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let trace = self.fuser.finish();
|
||||
|
||||
ReduceBroadcastedInfo {
|
||||
blocks,
|
||||
trace,
|
||||
reduce_axis,
|
||||
}
|
||||
}
|
||||
|
||||
/// Registers a [ReduceBlockFuser] to build the trace.
|
||||
pub fn register<R: Runtime>(&mut self, block: &ReduceBlockFuser<R>) {
|
||||
// Helper to close previous blocks if necessary
|
||||
if !self.fuser.is_empty() {
|
||||
let mut settings = self.settings_read;
|
||||
settings.vectorization = VectorizationSetting::EqualThanPreviousBlock { block_pos: 0 };
|
||||
settings.ref_layout = RefLayoutSetting::SameAsBlock { block_pos: 0 };
|
||||
self.fuser.next_block([], settings, false);
|
||||
|
||||
let analysis = self.analyzer.retrieve_next();
|
||||
|
||||
for (tensor, block_pos) in analysis.inputs {
|
||||
self.fuser.block_local_input(&tensor, block_pos, false);
|
||||
}
|
||||
}
|
||||
|
||||
match &block.kind {
|
||||
ReduceBlockKind::Elemwise => {
|
||||
for op in &block.ops {
|
||||
self.fuser.fuse(op);
|
||||
}
|
||||
self.blocks.push(ReduceBlockKind::Elemwise);
|
||||
}
|
||||
ReduceBlockKind::Reduce { ops_index, reduce } => {
|
||||
for op in &block.ops[0..*ops_index] {
|
||||
self.fuser.fuse(op);
|
||||
}
|
||||
|
||||
let [input] = self
|
||||
.fuser
|
||||
.next_block([&reduce.op.input], self.settings_write, false);
|
||||
|
||||
let output = self.fuser.output_unhandled(&reduce.op.out);
|
||||
let analysis = self.analyzer.retrieve_next();
|
||||
|
||||
// Can be broadcasted so the generated buffer can be global.
|
||||
for (tensor, block_pos) in analysis.inputs {
|
||||
self.fuser.block_local_input(&tensor, block_pos, false);
|
||||
}
|
||||
|
||||
let fused_reduce = FusedReduce {
|
||||
input,
|
||||
output,
|
||||
acc: reduce.acc,
|
||||
axis: reduce.axis,
|
||||
op: reduce.op.clone(),
|
||||
use_planes: reduce.use_planes,
|
||||
shared: reduce.shared,
|
||||
inst: reduce.inst,
|
||||
};
|
||||
|
||||
self.blocks.push(ReduceBlockKind::Reduce {
|
||||
ops_index: *ops_index,
|
||||
reduce: Box::new(fused_reduce),
|
||||
});
|
||||
|
||||
for op in &block.ops[*ops_index + 1..block.ops.len()] {
|
||||
self.fuser.fuse(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
use super::block::ReduceBlockKind;
|
||||
use crate::optim::reduce_broadcasted::fuser::block::ReduceBlockFuser;
|
||||
use burn_ir::{TensorId, TensorIr};
|
||||
use cubecl::Runtime;
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FullFuserAnalyzer {
|
||||
// We need to know the block id of which we can reuse the read local input.
|
||||
analyses: Vec<Vec<(TensorIr, usize)>>,
|
||||
}
|
||||
|
||||
impl FullFuserAnalyzer {
|
||||
pub fn new<R: Runtime>(blocks: &[ReduceBlockFuser<R>]) -> Self {
|
||||
let mut state = AnalysisState::default();
|
||||
|
||||
for block in blocks.iter() {
|
||||
for (pos, op) in block.ops.iter().enumerate() {
|
||||
let potential_from_previous_blocks = op.inputs();
|
||||
let potential_to_next_blocks = op.outputs();
|
||||
|
||||
match &block.kind {
|
||||
ReduceBlockKind::Elemwise => {
|
||||
state.register(
|
||||
potential_from_previous_blocks,
|
||||
potential_to_next_blocks,
|
||||
BlockKind::Full,
|
||||
);
|
||||
}
|
||||
ReduceBlockKind::Reduce { ops_index, .. } => {
|
||||
if pos < *ops_index {
|
||||
state.register(
|
||||
potential_from_previous_blocks,
|
||||
potential_to_next_blocks,
|
||||
BlockKind::Full,
|
||||
);
|
||||
} else if pos > *ops_index {
|
||||
state.register(
|
||||
potential_from_previous_blocks,
|
||||
potential_to_next_blocks,
|
||||
BlockKind::Single,
|
||||
);
|
||||
} else {
|
||||
state.next_block();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
state.next_block();
|
||||
}
|
||||
|
||||
// First one is never called.
|
||||
state.analyses.remove(0);
|
||||
|
||||
Self {
|
||||
analyses: state.analyses,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn retrieve_next(&mut self) -> FullFuserAnalysis {
|
||||
let inputs = self.analyses.remove(0);
|
||||
FullFuserAnalysis { inputs }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FullFuserAnalysis {
|
||||
/// The tensor received from a previous block.
|
||||
pub inputs: Vec<(TensorIr, usize)>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct AnalysisState {
|
||||
/// That pool contains tensors that are available in the fuse-on-write part of a reduce, not
|
||||
/// broadcasted.
|
||||
available_from_previous_single: BTreeMap<TensorId, usize>,
|
||||
/// That pool contains tensors that are available in the fuse-on-read of a reduce and the
|
||||
/// element-wise broadcasted part
|
||||
available_from_previous_full: BTreeMap<TensorId, usize>,
|
||||
block_data: Vec<(TensorIr, usize)>,
|
||||
analyses: Vec<Vec<(TensorIr, usize)>>,
|
||||
current_full: Vec<TensorIr>,
|
||||
current_single: Vec<TensorIr>,
|
||||
}
|
||||
|
||||
enum BlockKind {
|
||||
Full,
|
||||
Single,
|
||||
}
|
||||
|
||||
impl AnalysisState {
|
||||
fn next_block(&mut self) {
|
||||
let block_pos = self.analyses.len();
|
||||
let data = core::mem::take(&mut self.block_data);
|
||||
self.analyses.push(data);
|
||||
|
||||
// Makes the current tensor reads available for the next block.
|
||||
for p in self.current_single.drain(..) {
|
||||
// We need to keep the earliest block position.
|
||||
self.available_from_previous_single
|
||||
.entry(p.id)
|
||||
.or_insert(block_pos);
|
||||
}
|
||||
for p in self.current_full.drain(..) {
|
||||
// We need to keep the earliest block position.
|
||||
self.available_from_previous_full
|
||||
.entry(p.id)
|
||||
.or_insert(block_pos);
|
||||
}
|
||||
}
|
||||
|
||||
fn register<'a>(
|
||||
&mut self,
|
||||
potential_from_previous_blocks: impl Iterator<Item = &'a TensorIr>,
|
||||
potential_to_next_blocks: impl Iterator<Item = &'a TensorIr>,
|
||||
kind: BlockKind,
|
||||
) {
|
||||
match kind {
|
||||
BlockKind::Full => {
|
||||
for potential in potential_from_previous_blocks {
|
||||
// We can't since it's not in the same scope.
|
||||
//
|
||||
// TODO: Find a way to merge multiple reduce loops.
|
||||
//
|
||||
// if let Some(block_pos) = self.available_from_previous_full.get(&potential.id) {
|
||||
// self.block_data.push((potential.clone(), *block_pos));
|
||||
// }
|
||||
|
||||
// We can since it's a broadcast.
|
||||
if let Some(block_pos) = self.available_from_previous_single.get(&potential.id)
|
||||
{
|
||||
self.block_data.push((potential.clone(), *block_pos));
|
||||
}
|
||||
|
||||
// Can reuse the read.
|
||||
self.current_full.push(potential.clone());
|
||||
}
|
||||
|
||||
for p in potential_to_next_blocks {
|
||||
self.current_full.push(p.clone());
|
||||
}
|
||||
}
|
||||
BlockKind::Single => {
|
||||
for potential in potential_from_previous_blocks {
|
||||
if let Some(block_pos) = self.available_from_previous_single.get(&potential.id)
|
||||
{
|
||||
self.block_data.push((potential.clone(), *block_pos));
|
||||
}
|
||||
// Can reuse the read.
|
||||
self.current_single.push(potential.clone());
|
||||
}
|
||||
|
||||
for p in potential_to_next_blocks {
|
||||
self.current_single.push(p.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
mod base;
|
||||
mod block;
|
||||
mod full;
|
||||
mod full_analyzer;
|
||||
|
||||
pub use base::*;
|
||||
@@ -0,0 +1,139 @@
|
||||
use crate::{
|
||||
engine::{
|
||||
codegen::ir::{FuseArg, FuseBlockConfig, GlobalArgsLaunch, RefLayout},
|
||||
launch::runner::{TraceRunner, Vectorization},
|
||||
},
|
||||
optim::reduce_broadcasted::unit::{
|
||||
ElemwiseFuseBlockLaunch, ReduceFuseBlockLaunch, reduce_kernel_broadcasted,
|
||||
},
|
||||
};
|
||||
use cubecl::{
|
||||
Runtime,
|
||||
ir::{ElemType, FloatKind, StorageType},
|
||||
prelude::*,
|
||||
server::LaunchError,
|
||||
};
|
||||
use cubek::reduce::{
|
||||
LineMode, ReduceDtypes,
|
||||
components::instructions::ReduceOperationConfig,
|
||||
launch::RoutineStrategy,
|
||||
routines::{
|
||||
BlueprintStrategy, GlobalReduceBlueprint, ReduceLineSettings, ReduceProblem, Routine,
|
||||
unit::{UnitRoutine, UnitStrategy},
|
||||
},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct ReduceBroadcastedFuseBlock {
|
||||
pub(crate) op: ReduceOperationConfig,
|
||||
pub(crate) input: FuseArg,
|
||||
pub(crate) output: FuseArg,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub struct FusedReduceBroadcastedLaunch<'a> {
|
||||
blocks: &'a Vec<ReduceBroadcastedFuseBlock>,
|
||||
reduce_axis: usize,
|
||||
// TODO: Support multiple strategies.
|
||||
_strategy: RoutineStrategy,
|
||||
}
|
||||
|
||||
impl<R: Runtime> Vectorization<R> for FusedReduceBroadcastedLaunch<'_> {}
|
||||
|
||||
impl<R: Runtime> TraceRunner<R> for FusedReduceBroadcastedLaunch<'_> {
|
||||
type Error = LaunchError;
|
||||
|
||||
fn run<'a>(
|
||||
&'a self,
|
||||
client: &'a ComputeClient<R>,
|
||||
inputs: GlobalArgsLaunch<'a, R>,
|
||||
outputs: GlobalArgsLaunch<'a, R>,
|
||||
configs: &'a [FuseBlockConfig],
|
||||
) -> Result<(), Self::Error> {
|
||||
let routine = UnitRoutine;
|
||||
let first_config = &configs[0];
|
||||
|
||||
let shape = match &first_config.ref_layout {
|
||||
RefLayout::Concrete(FuseArg::Output(..)) => {
|
||||
outputs.shape_ref(&first_config.ref_layout, first_config.rank)
|
||||
}
|
||||
_ => inputs.shape_ref(&first_config.ref_layout, first_config.rank),
|
||||
};
|
||||
|
||||
let vector_size = shape[self.reduce_axis];
|
||||
let vector_count = shape.iter().product::<usize>() / vector_size;
|
||||
let address_type = inputs
|
||||
.required_address_type()
|
||||
.max(outputs.required_address_type());
|
||||
|
||||
let (blueprint, settings) = routine
|
||||
.prepare::<R>(
|
||||
client,
|
||||
ReduceProblem {
|
||||
vector_size,
|
||||
vector_count,
|
||||
axis: self.reduce_axis,
|
||||
dtypes: ReduceDtypes {
|
||||
input: StorageType::Scalar(ElemType::Float(FloatKind::F32)),
|
||||
output: StorageType::Scalar(ElemType::Float(FloatKind::F32)),
|
||||
accumulation: StorageType::Scalar(ElemType::Float(FloatKind::F32)),
|
||||
},
|
||||
address_type,
|
||||
},
|
||||
ReduceLineSettings {
|
||||
line_mode: LineMode::Parallel,
|
||||
line_size_input: first_config.width,
|
||||
line_size_output: 1,
|
||||
},
|
||||
BlueprintStrategy::Inferred(UnitStrategy),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(blueprint.line_mode, LineMode::Parallel);
|
||||
|
||||
let mut blocks = SequenceArg::new();
|
||||
let mut index = 0;
|
||||
|
||||
for block in self.blocks {
|
||||
let arg = ReduceFuseBlockLaunch::new(
|
||||
block.op,
|
||||
configs[index].clone(),
|
||||
configs[index + 1].clone(),
|
||||
block.input.clone(),
|
||||
block.output.clone(),
|
||||
match blueprint.global {
|
||||
GlobalReduceBlueprint::Unit(bpt) => bpt,
|
||||
_ => panic!(),
|
||||
},
|
||||
);
|
||||
index += 2;
|
||||
blocks.push(arg);
|
||||
}
|
||||
|
||||
let block_end = match configs.len() > index {
|
||||
true => OptionArgs::Some(ElemwiseFuseBlockLaunch::new(
|
||||
configs.last().cloned().unwrap(),
|
||||
)),
|
||||
false => OptionArgs::None,
|
||||
};
|
||||
|
||||
// TODO: Ensure parallel is selected.
|
||||
|
||||
unsafe {
|
||||
reduce_kernel_broadcasted::launch_unchecked::<R>(
|
||||
client,
|
||||
settings.cube_count,
|
||||
settings.cube_dim,
|
||||
settings.address_type,
|
||||
inputs,
|
||||
outputs,
|
||||
ScalarArg::new(self.reduce_axis),
|
||||
blocks,
|
||||
block_end,
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
mod fuser;
|
||||
mod optimization;
|
||||
|
||||
pub(crate) mod launch;
|
||||
pub(crate) mod tune;
|
||||
pub(crate) mod unit;
|
||||
|
||||
pub use fuser::*;
|
||||
pub use optimization::*;
|
||||
@@ -0,0 +1,222 @@
|
||||
#[cfg(feature = "autotune")]
|
||||
use crate::optim::reduce::tune::fused_reduce_autotune;
|
||||
use crate::{
|
||||
CubeFusionHandle, FallbackOperation,
|
||||
engine::{
|
||||
launch::FuseTraceLauncher,
|
||||
trace::{FuseTrace, TraceError, TuneOutput},
|
||||
},
|
||||
optim::{
|
||||
elemwise::{ElemwiseOptimization, ElemwiseOptimizationState},
|
||||
reduce::{ReduceOptimizationInfo, ReduceOptimizationState, ReduceOptimizationTuneArg},
|
||||
reduce_broadcasted::{
|
||||
launch::{FusedReduceBroadcastedLaunch, ReduceBroadcastedFuseBlock},
|
||||
tune::fused_broadcasted_reduce_autotune,
|
||||
},
|
||||
},
|
||||
};
|
||||
use burn_fusion::stream::Context;
|
||||
use cubecl::{Runtime, prelude::*};
|
||||
use cubek::reduce::launch::RoutineStrategy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct ReduceBroadcastedOptimization<R: Runtime> {
|
||||
pub(crate) info: Arc<ReduceBroadcastedOptimizationInfo<R>>,
|
||||
pub(crate) num_ops: usize,
|
||||
}
|
||||
|
||||
pub(crate) struct ReduceBroadcastedOptimizationInfo<R: Runtime> {
|
||||
pub(crate) fallbacks: Vec<ReduceBlockOptimInfo<R>>,
|
||||
pub(crate) broadcasted: Arc<ReduceBroadcastedInfo>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub(crate) struct ReduceBroadcastedInfo {
|
||||
pub(crate) blocks: Vec<ReduceBroadcastedFuseBlock>,
|
||||
pub(crate) trace: FuseTrace,
|
||||
pub(crate) reduce_axis: usize,
|
||||
}
|
||||
|
||||
pub(crate) enum ReduceBlockOptimInfo<R: Runtime> {
|
||||
Reduce(Arc<ReduceOptimizationInfo<R>>),
|
||||
Elemwise(Arc<ElemwiseOptimization<R>>),
|
||||
}
|
||||
|
||||
impl<R: Runtime> ReduceBlockOptimInfo<R> {
|
||||
pub fn from_state(device: &R::Device, state: ReduceBlockState) -> Self {
|
||||
match state {
|
||||
ReduceBlockState::Reduce(state) => {
|
||||
Self::Reduce(Arc::new(ReduceOptimizationInfo::from_state(device, state)))
|
||||
}
|
||||
ReduceBlockState::Elemwise(state) => {
|
||||
Self::Elemwise(Arc::new(ElemwiseOptimization::from_state(device, state)))
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn to_state(&self) -> ReduceBlockState {
|
||||
match self {
|
||||
Self::Reduce(info) => ReduceBlockState::Reduce(info.to_state()),
|
||||
Self::Elemwise(info) => ReduceBlockState::Elemwise(info.to_state()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ReduceBroadcastedOptimizationTuneArg<R: Runtime> {
|
||||
pub(crate) fallbacks: Vec<ReduceBlockOptimArg<R>>,
|
||||
pub(crate) broadcasted: Arc<ReduceBroadcastedInfo>,
|
||||
pub(crate) client: ComputeClient<R>,
|
||||
pub(crate) device: R::Device,
|
||||
}
|
||||
|
||||
pub(crate) enum ReduceBlockOptimArg<R: Runtime> {
|
||||
Reduce(ReduceOptimizationTuneArg<R>),
|
||||
Elemwise(Arc<ElemwiseOptimization<R>>),
|
||||
}
|
||||
|
||||
impl<R: Runtime> ReduceBlockOptimArg<R> {
|
||||
pub fn execute_fallback<BT: CubeElement>(
|
||||
&self,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
) -> Option<TuneOutput<R>> {
|
||||
match self {
|
||||
ReduceBlockOptimArg::Reduce(reduce) => {
|
||||
#[cfg(feature = "autotune")]
|
||||
{
|
||||
fused_reduce_autotune::<R, BT>(reduce.clone(), context);
|
||||
None
|
||||
}
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
Some(reduce.execute_fallback::<BT>(context))
|
||||
}
|
||||
ReduceBlockOptimArg::Elemwise(elem) => {
|
||||
elem.execute::<BT>(context);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ReduceBroadcastedOptimizationState {
|
||||
fallbacks: Vec<ReduceBlockState>,
|
||||
broadcasted: ReduceBroadcastedInfo,
|
||||
num_ops: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[allow(clippy::large_enum_variant)] // Only for serialization.
|
||||
pub enum ReduceBlockState {
|
||||
Reduce(ReduceOptimizationState),
|
||||
Elemwise(ElemwiseOptimizationState),
|
||||
}
|
||||
|
||||
impl<R: Runtime> ReduceBroadcastedOptimizationTuneArg<R> {
|
||||
pub fn execute_fused<BT: CubeElement>(
|
||||
&self,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
strategy: RoutineStrategy,
|
||||
) -> Result<TuneOutput<R>, TraceError<String>> {
|
||||
let launch = FusedReduceBroadcastedLaunch::new(
|
||||
&self.broadcasted.blocks,
|
||||
self.broadcasted.reduce_axis,
|
||||
strategy,
|
||||
);
|
||||
let launcher = FuseTraceLauncher::new(&self.broadcasted.trace, &launch);
|
||||
|
||||
launcher
|
||||
.launch::<BT>(&self.client, &self.device, context)
|
||||
.map_err(|err| TraceError::RunnerError(format!("{:?}", err)))
|
||||
}
|
||||
|
||||
pub fn execute_fallback<BT: CubeElement>(
|
||||
&self,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
) {
|
||||
for fallback in self.fallbacks.iter() {
|
||||
fallback.execute_fallback::<BT>(context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
impl<R: Runtime> ReduceBroadcastedOptimization<R> {
|
||||
/// Execute the optimization.
|
||||
pub fn execute<BT: CubeElement>(
|
||||
&mut self,
|
||||
context: &mut Context<'_, CubeFusionHandle<R>>,
|
||||
fallback: impl Fn(usize) -> Box<dyn FallbackOperation<R>>,
|
||||
) {
|
||||
let mut current_index = 0;
|
||||
let mut client = None;
|
||||
let mut device = None;
|
||||
|
||||
let fallbacks = self
|
||||
.info
|
||||
.fallbacks
|
||||
.iter()
|
||||
.map(|info| {
|
||||
match info {
|
||||
ReduceBlockOptimInfo::Reduce(info) => {
|
||||
// The index of the fallback reduce is the number of ops fused as read.
|
||||
let fallback = fallback(current_index + info.len_read);
|
||||
client = Some(info.client.clone());
|
||||
device = Some(info.device.clone());
|
||||
let arg = ReduceOptimizationTuneArg {
|
||||
info: info.clone(),
|
||||
fallback: Arc::new(fallback),
|
||||
};
|
||||
current_index += info.len;
|
||||
ReduceBlockOptimArg::Reduce(arg)
|
||||
}
|
||||
ReduceBlockOptimInfo::Elemwise(op) => ReduceBlockOptimArg::Elemwise(op.clone()),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let arg = ReduceBroadcastedOptimizationTuneArg {
|
||||
fallbacks,
|
||||
client: client.unwrap(),
|
||||
device: device.unwrap(),
|
||||
broadcasted: self.info.broadcasted.clone(),
|
||||
};
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
fused_broadcasted_reduce_autotune::<R, BT>(arg, context);
|
||||
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
arg.execute_fallback::<BT>(context);
|
||||
}
|
||||
|
||||
pub fn to_state(&self) -> ReduceBroadcastedOptimizationState {
|
||||
ReduceBroadcastedOptimizationState {
|
||||
fallbacks: self
|
||||
.info
|
||||
.fallbacks
|
||||
.iter()
|
||||
.map(|info| info.to_state())
|
||||
.collect(),
|
||||
broadcasted: self.info.broadcasted.as_ref().clone(),
|
||||
num_ops: self.num_ops,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_state(device: &R::Device, state: ReduceBroadcastedOptimizationState) -> Self {
|
||||
Self {
|
||||
info: Arc::new(ReduceBroadcastedOptimizationInfo {
|
||||
fallbacks: state
|
||||
.fallbacks
|
||||
.into_iter()
|
||||
.map(|state| ReduceBlockOptimInfo::from_state(device, state))
|
||||
.collect(),
|
||||
broadcasted: Arc::new(state.broadcasted),
|
||||
}),
|
||||
num_ops: state.num_ops,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of output buffers added by fusion.
|
||||
pub fn num_ops_fused(&self) -> usize {
|
||||
self.num_ops
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,180 @@
|
||||
use super::optimization::ReduceBroadcastedOptimizationTuneArg;
|
||||
use crate::{
|
||||
CubeFusionHandle,
|
||||
engine::trace::TuneOutput,
|
||||
optim::{reduce::ReduceOptimizationInfo, reduce_broadcasted::ReduceBlockOptimArg},
|
||||
tune::{TuneContext, TuneInput},
|
||||
};
|
||||
use burn_fusion::stream::Context;
|
||||
use cubecl::{
|
||||
AutotuneKey, CubeElement, CubeTuneId, Runtime,
|
||||
tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner},
|
||||
};
|
||||
use cubek::reduce::{
|
||||
launch::{RoutineStrategy, tune_key::ReduceAutotuneKey},
|
||||
routines::{BlueprintStrategy, unit::UnitStrategy},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Autotune key for fused broadcasted reduction operations.
|
||||
///
|
||||
/// Captures the characteristics of the fusion (reads, writes, ops) to ensure
|
||||
/// the best kernel is selected for specific fused graph shapes.
|
||||
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
|
||||
pub struct FusedBroadcastedReduceAutotuneKey {
|
||||
reduce_key: ReduceAutotuneKey,
|
||||
#[autotune(anchor)]
|
||||
fuse_num_reads: usize,
|
||||
#[autotune(anchor)]
|
||||
fuse_num_writes: usize,
|
||||
#[autotune(anchor)]
|
||||
fuse_num_ops: usize,
|
||||
fuse_num_blocks: usize,
|
||||
}
|
||||
|
||||
/// Executes the autotuning process for fused reduction operations.
|
||||
///
|
||||
/// This function initializes a local tuner and attempts multiple strategies
|
||||
/// (fallback vs. unit strategy) to find the most efficient execution path.
|
||||
pub fn fused_broadcasted_reduce_autotune<R: Runtime, BT: CubeElement>(
|
||||
arg: ReduceBroadcastedOptimizationTuneArg<R>,
|
||||
context: &mut Context<CubeFusionHandle<R>>,
|
||||
) {
|
||||
static TUNER: LocalTuner<FusedBroadcastedReduceAutotuneKey, CubeTuneId> = local_tuner!();
|
||||
|
||||
let tunables = TUNER.init(|| {
|
||||
const PRIORITY_MAX: i8 = 2;
|
||||
let mut set = TunableSet::new(create_key::<R>, input_gen::<R>);
|
||||
|
||||
let group = TuneGroup::<FusedBroadcastedReduceAutotuneKey>::new(
|
||||
"fused_reduce_broadcasted",
|
||||
|_key| PRIORITY_MAX,
|
||||
);
|
||||
|
||||
// Standard fallback implementation - guaranteed to work.
|
||||
set = set.with(Tunable::new(
|
||||
"fused_reduce_broadcasted_fallback",
|
||||
tune_fallback::<R, BT>,
|
||||
));
|
||||
|
||||
// Specialized unit strategy for fused reductions.
|
||||
set = set.with(
|
||||
Tunable::new("fused_reduce_broadcasted_unit", move |input| {
|
||||
tune_reduce::<R, BT>(
|
||||
input,
|
||||
&RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
|
||||
)
|
||||
})
|
||||
.group(&group, |_| PRIORITY_MAX),
|
||||
);
|
||||
|
||||
set
|
||||
});
|
||||
|
||||
TUNER.execute(
|
||||
&CubeTuneId::new(&arg.client, &arg.device),
|
||||
&arg.client.clone(),
|
||||
tunables,
|
||||
TuneInput::new(context, arg),
|
||||
);
|
||||
}
|
||||
|
||||
/// Generates the autotune key based on the current optimization context and trace blocks.
|
||||
pub(crate) fn create_key<R: Runtime>(
|
||||
input: &TuneInput<R, ReduceBroadcastedOptimizationTuneArg<R>>,
|
||||
) -> FusedBroadcastedReduceAutotuneKey {
|
||||
let opt = input.optimization();
|
||||
let context = match input.context() {
|
||||
TuneContext::Original(context) => context,
|
||||
TuneContext::Fork(_) => unreachable!("Forked context not supported for key generation"),
|
||||
};
|
||||
|
||||
// The fusion must start with a reduction block to be valid here.
|
||||
let info = match &opt.fallbacks[0] {
|
||||
ReduceBlockOptimArg::Reduce(reduce) => &reduce.info,
|
||||
ReduceBlockOptimArg::Elemwise(_) => {
|
||||
unreachable!("Fusion must start with a reduction block")
|
||||
}
|
||||
};
|
||||
|
||||
let key = generate_reduce_autotune_key(info, context);
|
||||
|
||||
// Sum up complexity metrics across all blocks in the fused trace.
|
||||
let (mut num_reads, mut num_writes, mut num_ops) = (0, 0, 0);
|
||||
|
||||
for block in opt.broadcasted.trace.blocks.iter() {
|
||||
num_reads += block.reads.len();
|
||||
num_writes += block.writes.len();
|
||||
num_ops += block.ops.len();
|
||||
}
|
||||
|
||||
FusedBroadcastedReduceAutotuneKey::new(
|
||||
key,
|
||||
num_reads,
|
||||
num_writes,
|
||||
num_ops,
|
||||
info.trace.blocks.len(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Helper to generate the base reduction key (shapes, types, axes).
|
||||
fn generate_reduce_autotune_key<R: Runtime>(
|
||||
info: &ReduceOptimizationInfo<R>,
|
||||
context: &Context<CubeFusionHandle<R>>,
|
||||
) -> ReduceAutotuneKey {
|
||||
let input = context.tensors.get(&info.reduce.op.input.id).unwrap();
|
||||
let out = context.tensors.get(&info.reduce.op.out.id).unwrap();
|
||||
let acc = info.reduce.acc.into_elem();
|
||||
|
||||
ReduceAutotuneKey::generate(
|
||||
input.dtype.into(),
|
||||
out.dtype.into(),
|
||||
acc,
|
||||
&input.shape,
|
||||
info.reduce.axis == input.shape.rank() - 1, // Is it the last dimension?
|
||||
info.reduce.axis,
|
||||
)
|
||||
}
|
||||
|
||||
/// Simple input generator that clones the input for the tuner.
|
||||
fn input_gen<R: Runtime>(
|
||||
_key: &FusedBroadcastedReduceAutotuneKey,
|
||||
input: &TuneInput<R, ReduceBroadcastedOptimizationTuneArg<R>>,
|
||||
) -> TuneInput<R, ReduceBroadcastedOptimizationTuneArg<R>> {
|
||||
input.clone()
|
||||
}
|
||||
|
||||
/// Executes a fused reduction using a specific routine strategy.
|
||||
fn tune_reduce<R: Runtime, BT: CubeElement>(
|
||||
input: TuneInput<R, ReduceBroadcastedOptimizationTuneArg<R>>,
|
||||
strategy: &RoutineStrategy,
|
||||
) -> Result<TuneOutput<R>, String> {
|
||||
let optimization = input.optimization();
|
||||
|
||||
match input.context() {
|
||||
TuneContext::Original(context) => {
|
||||
optimization.execute_fused::<BT>(context, strategy.clone())
|
||||
}
|
||||
TuneContext::Fork(mut context_owned) => {
|
||||
optimization.execute_fused::<BT>(&mut context_owned.as_context(), strategy.clone())
|
||||
}
|
||||
}
|
||||
.map_err(|e| format!("{e:?}"))
|
||||
}
|
||||
|
||||
/// Executes the fallback implementation for the reduction.
|
||||
fn tune_fallback<R: Runtime, BT: CubeElement>(
|
||||
input: TuneInput<R, ReduceBroadcastedOptimizationTuneArg<R>>,
|
||||
) -> Result<TuneOutput<R>, String> {
|
||||
let optimization = input.optimization();
|
||||
|
||||
match input.context() {
|
||||
TuneContext::Original(context) => optimization.execute_fallback::<BT>(context),
|
||||
TuneContext::Fork(mut context_owned) => {
|
||||
optimization.execute_fallback::<BT>(&mut context_owned.as_context())
|
||||
}
|
||||
};
|
||||
|
||||
// Fallback is often used as a baseline, returning unchecked output.
|
||||
Ok(TuneOutput::UnChecked(std::marker::PhantomData))
|
||||
}
|
||||
@@ -0,0 +1,202 @@
|
||||
use crate::{
|
||||
engine::codegen::{
|
||||
ir::{FuseArg, FuseBlockConfig, FuseType, GlobalArgs, multi_block_variables_init},
|
||||
kernel::{fuse_on_write, init_locals},
|
||||
},
|
||||
optim::reduce::args::{FusedReduceArgs, FusedReduceInput, FusedReduceOutput},
|
||||
};
|
||||
use cubecl::{Runtime, prelude::*, std::tensor::r#virtual::VirtualTensor};
|
||||
use cubek::reduce::{
|
||||
LineMode, ReduceInstruction, ReducePrecision,
|
||||
components::{
|
||||
global::unit::GlobalFullUnitReduce,
|
||||
instructions::{ReduceOperation, ReduceOperationConfig},
|
||||
},
|
||||
init_tensors,
|
||||
routines::UnitReduceBlueprint,
|
||||
};
|
||||
|
||||
/// A configuration block for a reduction operation within a fused kernel.
|
||||
///
|
||||
/// This struct holds all the compile-time information needed to perform a
|
||||
/// reduction, including the operation type (Sum, Max, etc.) and the layout
|
||||
/// configuration for both input and output.
|
||||
#[derive(CubeType, CubeLaunch, Clone)]
|
||||
pub struct ReduceFuseBlock {
|
||||
#[cube(comptime)]
|
||||
op: ReduceOperationConfig,
|
||||
#[cube(comptime)]
|
||||
config_input: FuseBlockConfig,
|
||||
#[cube(comptime)]
|
||||
config_output: FuseBlockConfig,
|
||||
#[cube(comptime)]
|
||||
input: FuseArg,
|
||||
#[cube(comptime)]
|
||||
output: FuseArg,
|
||||
#[cube(comptime)]
|
||||
blueprint: UnitReduceBlueprint,
|
||||
}
|
||||
|
||||
/// A configuration block for an elementwise operation that follows a reduction.
|
||||
#[derive(CubeType, CubeLaunch, Clone)]
|
||||
pub struct ElemwiseFuseBlock {
|
||||
#[cube(comptime)]
|
||||
config: FuseBlockConfig,
|
||||
}
|
||||
|
||||
/// The entry point for a broadcasted reduction kernel.
|
||||
///
|
||||
/// This kernel initializes local variables for multiple reduction blocks and then
|
||||
/// executes the reduction sequence.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `inputs` - Global arguments containing input tensor handles.
|
||||
/// * `outputs` - Global arguments containing output tensor handles.
|
||||
/// * `reduce_axis` - The dimension along which the reduction is performed.
|
||||
/// * `blocks` - A sequence of reduction operations to execute.
|
||||
/// * `block_end` - An optional elementwise block to execute after reductions are complete.
|
||||
#[cube(launch_unchecked, address_type = "dynamic")]
|
||||
pub fn reduce_kernel_broadcasted(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
reduce_axis: usize,
|
||||
blocks: Sequence<ReduceFuseBlock>,
|
||||
block_end: Option<ElemwiseFuseBlock>,
|
||||
) {
|
||||
#[unroll]
|
||||
for i in 0..blocks.len() {
|
||||
let block = blocks.index(i);
|
||||
multi_block_variables_init(&block.config_input, &mut outputs.variables);
|
||||
multi_block_variables_init(&block.config_output, &mut outputs.variables);
|
||||
}
|
||||
|
||||
reduce_many(inputs, outputs, reduce_axis, blocks, block_end);
|
||||
}
|
||||
|
||||
const REDUCE_INPUT: u8 = 0;
|
||||
const REDUCE_ACC: u8 = 1;
|
||||
const REDUCE_OUT: u8 = 2;
|
||||
|
||||
type In = NumericExpand<REDUCE_INPUT>;
|
||||
type Acc = NumericExpand<REDUCE_ACC>;
|
||||
type Out = NumericExpand<REDUCE_OUT>;
|
||||
|
||||
/// Configures the precision polyfills for the reduction based on the block's `FuseType`.
|
||||
#[cube]
|
||||
fn set_polyfill_block(block: &ReduceFuseBlock) {
|
||||
let input_precision = comptime!(block.input.precision());
|
||||
let output_precision = comptime!(block.output.precision());
|
||||
let acc_precision = comptime!(match input_precision {
|
||||
FuseType::F64 => FuseType::F64,
|
||||
FuseType::F32 => FuseType::F32,
|
||||
FuseType::Flex32 => FuseType::F32,
|
||||
FuseType::F16 => FuseType::F32,
|
||||
FuseType::BF16 => FuseType::F32,
|
||||
FuseType::I64 => FuseType::I64,
|
||||
FuseType::I32 => FuseType::I32,
|
||||
FuseType::I16 => FuseType::I32,
|
||||
FuseType::I8 => FuseType::I32,
|
||||
FuseType::U64 => FuseType::U64,
|
||||
FuseType::U32 => FuseType::U32,
|
||||
FuseType::U16 => FuseType::U32,
|
||||
FuseType::U8 => FuseType::U32,
|
||||
FuseType::Bool => FuseType::I32,
|
||||
});
|
||||
|
||||
set_polyfill::<In>(comptime!(input_precision.into_type()));
|
||||
set_polyfill::<Out>(comptime!(output_precision.into_type()));
|
||||
set_polyfill::<Acc>(comptime!(acc_precision.into_type()));
|
||||
}
|
||||
|
||||
/// Internal logic for executing a sequence of reduction blocks followed by an optional
|
||||
/// trailing elementwise block.
|
||||
#[cube]
|
||||
#[allow(clippy::clone_on_copy)]
|
||||
fn reduce_many(
|
||||
inputs: &GlobalArgs,
|
||||
outputs: &mut GlobalArgs,
|
||||
reduce_axis: usize,
|
||||
blocks: Sequence<ReduceFuseBlock>,
|
||||
block_end: Option<ElemwiseFuseBlock>,
|
||||
) {
|
||||
let mut axis_size = 0;
|
||||
|
||||
#[unroll]
|
||||
for i in 0..blocks.len() {
|
||||
let block = blocks.index(i);
|
||||
let input = FusedReduceInput {
|
||||
global: inputs.clone(),
|
||||
config: comptime!(block.config_input.clone()),
|
||||
arg: comptime!(block.input.clone()),
|
||||
};
|
||||
let global = outputs.clone();
|
||||
let config = comptime!(block.config_output.clone());
|
||||
let arg = comptime!(block.output.clone());
|
||||
let mut output = FusedReduceOutput {
|
||||
global,
|
||||
config,
|
||||
arg,
|
||||
};
|
||||
|
||||
set_polyfill_block(block);
|
||||
let (input, mut output) = init_tensors::<FusedReduceArgs, In, Out>(&input, &mut output);
|
||||
|
||||
axis_size = reduce_step::<(In, Acc), Out, ReduceOperation>(
|
||||
&input,
|
||||
&mut output,
|
||||
reduce_axis,
|
||||
block.op,
|
||||
comptime!(block.blueprint.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(block) = block_end {
|
||||
let global_index = ABSOLUTE_POS;
|
||||
let width = comptime!(block.config.width as u32);
|
||||
let num_iter = axis_size / usize::cast_from(width);
|
||||
|
||||
for i in 0..num_iter {
|
||||
// Register block local inputs.
|
||||
let values = Registry::<FuseArg, Line<f32>>::new();
|
||||
let args = comptime![Vec::<FuseArg>::new()];
|
||||
let index = global_index * num_iter + i;
|
||||
let mut locals = init_locals(inputs, outputs, &block.config);
|
||||
|
||||
fuse_on_write::<f32>(
|
||||
inputs,
|
||||
outputs,
|
||||
&mut locals,
|
||||
index,
|
||||
values,
|
||||
args,
|
||||
&block.config.clone(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Executes a single reduction step using a specified instruction and blueprint.
|
||||
///
|
||||
/// Returns the size of the axis that was reduced.
|
||||
fn reduce_step<P: ReducePrecision, Out: Numeric, I: ReduceInstruction<P>>(
|
||||
input: &VirtualTensor<P::EI>,
|
||||
output: &mut VirtualTensor<Out, ReadWrite>,
|
||||
reduce_axis: usize,
|
||||
#[comptime] config: I::Config,
|
||||
#[comptime] blueprint: UnitReduceBlueprint,
|
||||
) -> usize {
|
||||
let inst = I::from_config(config);
|
||||
let axis_size = input.shape(reduce_axis);
|
||||
|
||||
GlobalFullUnitReduce::execute::<P, Out, I>(
|
||||
input,
|
||||
output,
|
||||
reduce_axis,
|
||||
&inst,
|
||||
LineMode::Parallel,
|
||||
comptime!(blueprint),
|
||||
);
|
||||
axis_size
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
use crate::CubeFusionHandle;
|
||||
use burn_fusion::stream::{Context, ContextOwned};
|
||||
use cubecl::Runtime;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Fusion context used when tuning kernels.
|
||||
///
|
||||
/// Either the original context is returned or a fork of the original.
|
||||
/// The fork is only given when performing autotuning, and not when actually performing the
|
||||
/// operation.
|
||||
pub enum TuneContext<'a, R: Runtime> {
|
||||
Original(&'a mut Context<'a, CubeFusionHandle<R>>),
|
||||
Fork(Box<ContextOwned<CubeFusionHandle<R>>>),
|
||||
}
|
||||
|
||||
/// Fusion input wrapper containing the context and the optimization.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// This should only be used with the [tuner](cubecl::tune::LocalTuner), since safety assumptions
|
||||
/// are made based on its behavior.
|
||||
pub struct TuneInput<R: Runtime, O> {
|
||||
context: UnsafeTuneContext<R>,
|
||||
optimization: Arc<O>,
|
||||
}
|
||||
|
||||
/// Unsafe wrapper around the context.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The wrapper removes the context lifetime.
|
||||
///
|
||||
/// For it to be correct, the context must not be used after the invocation of the
|
||||
/// [cubecl::tune::LocalTuner::execute] function. This is the case, since autotune functions are
|
||||
/// tuned using a cloned version of the input; therefore, a fork of the context will be used to find
|
||||
/// the best kernel to use, which can be async.
|
||||
enum UnsafeTuneContext<R: Runtime> {
|
||||
Original(*mut Context<'static, CubeFusionHandle<R>>),
|
||||
Fork(Box<ContextOwned<CubeFusionHandle<R>>>),
|
||||
}
|
||||
|
||||
unsafe impl<R: Runtime> Send for UnsafeTuneContext<R> {}
|
||||
unsafe impl<R: Runtime, O> Send for TuneInput<R, O> {}
|
||||
|
||||
impl<R: Runtime, O> TuneInput<R, O> {
|
||||
/// Create a new autotune input from the [context](Context) and an optimization.
|
||||
pub fn new(context: &mut Context<CubeFusionHandle<R>>, optimization: O) -> Self {
|
||||
let context = UnsafeTuneContext::new(context);
|
||||
|
||||
Self {
|
||||
context,
|
||||
optimization: Arc::new(optimization),
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieve the [autotune context](TuneContext) for the current input.
|
||||
pub fn context(&self) -> TuneContext<'static, R> {
|
||||
self.context.get()
|
||||
}
|
||||
|
||||
/// Retrieve the optimization for the current input.
|
||||
pub fn optimization(&self) -> &O {
|
||||
&self.optimization
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> UnsafeTuneContext<R> {
|
||||
fn new(context: &mut Context<'_, CubeFusionHandle<R>>) -> Self {
|
||||
let ptr = core::ptr::from_mut(context);
|
||||
|
||||
// It is necessary for the lifetime.
|
||||
#[allow(clippy::unnecessary_cast)]
|
||||
Self::Original(ptr as *mut Context<'static, _>)
|
||||
}
|
||||
|
||||
fn get(&self) -> TuneContext<'static, R> {
|
||||
match self {
|
||||
UnsafeTuneContext::Original(ptr) => {
|
||||
TuneContext::Original(unsafe { ptr.as_mut().unwrap() })
|
||||
}
|
||||
UnsafeTuneContext::Fork(context) => TuneContext::Fork(Box::new(context.fork())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime, O> Clone for TuneInput<R, O> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
context: self.context.clone(),
|
||||
optimization: self.optimization.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> Clone for UnsafeTuneContext<R> {
|
||||
fn clone(&self) -> Self {
|
||||
let context = match self {
|
||||
UnsafeTuneContext::Original(ptr) => {
|
||||
let context: &mut Context<'static, CubeFusionHandle<R>> =
|
||||
unsafe { ptr.as_mut().unwrap() };
|
||||
context.fork()
|
||||
}
|
||||
UnsafeTuneContext::Fork(context) => context.fork(),
|
||||
};
|
||||
UnsafeTuneContext::Fork(Box::new(context))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user