use burn_backend::{ Backend, Shape, tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}, }; /// A tensor representation containing a reference to a tensor resource with a given shape. #[derive(Clone)] pub struct TensorHandle { /// The type that can be used to point to a tensor of any kind. pub handle: H, /// The shape associated to the tensor. pub shape: Shape, } /// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor /// intermediate representation for compilation purpose or other... pub trait BackendIr: Backend { /// The type that can be used to point to a tensor of any kind. type Handle: Sync + Send + Clone; /// Convert a [handle](BackendIr::Handle) to a [float tensor](Backend::FloatTensorPrimitive). fn float_tensor(handle: TensorHandle) -> FloatTensor; /// Convert a [handle](BackendIr::Handle) to an [int tensor](Backend::IntTensorPrimitive). fn int_tensor(handle: TensorHandle) -> IntTensor; /// Convert a [handle](BackendIr::Handle) to a [bool tensor](Backend::BoolTensorPrimitive). fn bool_tensor(handle: TensorHandle) -> BoolTensor; /// Convert a [handle](BackendIr::Handle) to a [quantized tensor](Backend::QuantizedTensorPrimitive). fn quantized_tensor(handle: TensorHandle) -> QuantizedTensor; /// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](BackendIr::Handle). fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle; /// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](BackendIr::Handle). fn int_tensor_handle(tensor: IntTensor) -> Self::Handle; /// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](BackendIr::Handle). fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle; /// Convert a [quantized tensor](Backend::QuantizedTensorPrimitive) to a [handle](BackendIr::Handle). fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle; } /// Handle which points to a backend tensor primitive kind. #[derive(Clone, Debug)] pub enum HandleKind { /// Float tensor handle. Float(B::FloatTensorPrimitive), /// Int tensor handle. Int(B::IntTensorPrimitive), /// Bool tensor handle. Bool(B::BoolTensorPrimitive), /// Quantized tensor handle. Quantized(B::QuantizedTensorPrimitive), } impl HandleKind { /// Returns the handle kind name. pub fn name(&self) -> &str { match self { HandleKind::Float(_) => "float", HandleKind::Int(_) => "int", HandleKind::Bool(_) => "bool", HandleKind::Quantized(_) => "quantized", } } }