use crate::BoolElement; use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, kernel, tensor::CubeTensor}; use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}; use burn_backend::{DType, Shape}; use burn_cubecl_fusion::optim::reduce::ReduceSettings; use burn_cubecl_fusion::optim::reduce_broadcasted::ReduceBroadcastedFuser; use burn_cubecl_fusion::{ CubeFusionHandle, FallbackOperation, optim::{ CubeOptimization, CubeOptimizationState, elemwise::{ElementWiseFuser, ElemwiseOptimization}, matmul::{MatmulFuser, MatmulOptimization}, reduce::{ReduceFuser, ReduceOptimization}, reduce_broadcasted::ReduceBroadcastedOptimization, }, }; use burn_fusion::{ FusionBackend, FusionRuntime, stream::{Operation, OrderedExecution}, }; use burn_ir::{BackendIr, TensorHandle}; use burn_std::Metadata; use core::marker::PhantomData; use std::sync::Arc; impl burn_fusion::Optimization> for CubeOptimization where R: CubeRuntime, BT: BoolElement, { fn execute( &mut self, context: &mut burn_fusion::stream::Context< '_, as FusionRuntime>::FusionHandle, >, execution: &OrderedExecution>, ) { match self { Self::ElementWise(op) => op.execute::(context), Self::Matmul(op) => op.execute::(context, |index| { let operation = execution.operation_within_optimization(index); Box::new(FallbackOperationWrapper::new(operation)) }), Self::Reduce(op) => op.execute::(context, |index| { let operation = execution.operation_within_optimization(index); Box::new(FallbackOperationWrapper::new(operation)) }), Self::ReduceBroadcasted(op) => op.execute::(context, |index| { let operation = execution.operation_within_optimization(index); Box::new(FallbackOperationWrapper::new(operation)) }), } } fn to_state(&self) -> CubeOptimizationState { self.to_opt_state() } fn from_state(device: &R::Device, state: CubeOptimizationState) -> Self { match state { CubeOptimizationState::ElementWise(state) => { Self::ElementWise(ElemwiseOptimization::from_state(device, state)) } CubeOptimizationState::Matmul(state) => { Self::Matmul(MatmulOptimization::from_state(device, state)) } CubeOptimizationState::Reduce(state) => { Self::Reduce(ReduceOptimization::from_state(device, state)) } CubeOptimizationState::ReduceBroadcasted(state) => { Self::ReduceBroadcasted(ReduceBroadcastedOptimization::from_state(device, state)) } } } } struct FallbackOperationWrapper { operation: O, } impl FallbackOperationWrapper { fn new(op: O) -> Self { Self { operation: op } } } impl FallbackOperation for FallbackOperationWrapper>>> { fn run(&self, context: &mut burn_fusion::stream::Context<'_, CubeFusionHandle>) { self.operation.as_ref().execute(context.handles); } } impl BackendIr for CubeBackend { type Handle = CubeFusionHandle; fn float_tensor(handle: TensorHandle) -> FloatTensor { into_tensor(handle.handle, handle.shape) } fn int_tensor(handle: TensorHandle) -> IntTensor { into_tensor(handle.handle, handle.shape) } fn bool_tensor(handle: TensorHandle) -> BoolTensor { into_tensor(handle.handle, handle.shape) } fn quantized_tensor(handle: TensorHandle) -> QuantizedTensor { into_tensor(handle.handle, handle.shape) } fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle { tensor.into() } fn int_tensor_handle(tensor: IntTensor) -> Self::Handle { tensor.into() } fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle { tensor.into() } fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle { tensor.into() } } impl FusionRuntime for FusionCubeRuntime { type OptimizationState = CubeOptimizationState; type Optimization = CubeOptimization; type FusionHandle = CubeFusionHandle; type FusionDevice = R::CubeDevice; type BoolRepr = BT; fn fusers(device: R::Device) -> Vec>> { vec![ Box::new(ElementWiseFuser::new( device.clone(), BT::as_type_native_unchecked().into(), )), Box::new(MatmulFuser::new( device.clone(), BT::as_type_native_unchecked().into(), )), Box::new(ReduceFuser::new( device.clone(), BT::as_type_native_unchecked().into(), ReduceSettings::Always, )), Box::new(ReduceBroadcastedFuser::new( device.clone(), BT::as_type_native_unchecked().into(), )), ] } } /// Fusion runtime for JIT runtimes. #[derive(Debug)] pub struct FusionCubeRuntime { _b: PhantomData, _bool: PhantomData, } impl FusionBackend for CubeBackend { type FusionRuntime = FusionCubeRuntime; type FullPrecisionBackend = CubeBackend; fn cast_float(tensor: FloatTensor, dtype: DType) -> Self::Handle { kernel::cast(tensor, dtype).into() } } fn into_tensor(handle: CubeFusionHandle, shape: Shape) -> CubeTensor { CubeTensor { client: handle.client, handle: handle.handle, device: handle.device, meta: Box::new(Metadata::new(shape, handle.strides)), dtype: handle.dtype, qparams: handle.qparams, } } impl From> for CubeFusionHandle { fn from(value: CubeTensor) -> Self { Self { client: value.client, handle: value.handle, device: value.device, strides: value.meta.strides, dtype: value.dtype, qparams: value.qparams, } } }