use core::sync::atomic::{AtomicU32, Ordering}; use alloc::format; use alloc::{sync::Arc, vec::Vec}; use super::RunnerClient; use burn_backend::{DType, Shape, TensorData, TensorMetadata, backend::ExecutionError}; use burn_ir::{TensorId, TensorIr, TensorStatus}; /// Tensor primitive for the [router backend](crate::BackendRouter). pub struct RouterTensor { pub(crate) id: TensorId, pub(crate) shape: Shape, pub(crate) dtype: DType, /// The client that has this tensor pub client: C, pub(crate) count: Arc, } impl TensorMetadata for RouterTensor { fn dtype(&self) -> DType { self.dtype } fn shape(&self) -> Shape { self.shape.clone() } fn rank(&self) -> usize { self.shape.num_dims() } } impl RouterTensor { /// Create a new router tensor. pub fn new(id: TensorId, shape: Shape, dtype: DType, client: C) -> Self { Self { id, shape, dtype, client, count: Arc::new(AtomicU32::new(1)), } } pub(crate) async fn into_data(self) -> Result { self.client.clone().read_tensor_async(self.into_ir()).await } /// Get the ir for this tensor pub fn into_ir(mut self) -> TensorIr { let count = self.count.load(Ordering::Relaxed); let status = self.status(count); let mut shape_out = Shape::from(Vec::::new()); core::mem::swap(&mut self.shape, &mut shape_out); if let TensorStatus::ReadWrite = status { // Avoids an unwanted drop on the same thread. // // Since `drop` is called after `into_ir`, we must not register a drop if the tensor // was consumed with a `ReadWrite` status. self.count.fetch_add(1, Ordering::Relaxed); } TensorIr { status, shape: shape_out, id: self.id, dtype: self.dtype, } } pub(crate) fn to_ir_out(&self) -> TensorIr { TensorIr { status: TensorStatus::NotInit, shape: self.shape.clone(), id: self.id, dtype: self.dtype, } } pub(crate) fn status(&self, count: u32) -> TensorStatus { if count <= 1 { TensorStatus::ReadWrite } else { TensorStatus::ReadOnly } } } impl core::fmt::Debug for RouterTensor { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str( format!( "{{ id: {:?}, shape: {:?}, dtype: {:?}, device: {:?} }}", self.id, self.shape, self.dtype, self.client.device().clone(), ) .as_str(), ) } } impl Clone for RouterTensor { fn clone(&self) -> Self { self.count.fetch_add(1, Ordering::Relaxed); Self { id: self.id, shape: self.shape.clone(), client: self.client.clone(), dtype: self.dtype, count: self.count.clone(), } } } impl Drop for RouterTensor { fn drop(&mut self) { let count = self.count.fetch_sub(1, Ordering::Relaxed); match self.status(count) { TensorStatus::ReadWrite => { let id = self.id; let mut shape = Shape::from(Vec::::new()); core::mem::swap(&mut shape, &mut self.shape); let ir = TensorIr { id, shape, status: TensorStatus::ReadWrite, dtype: self.dtype, }; self.client.register_op(burn_ir::OperationIr::Drop(ir)); } TensorStatus::ReadOnly => {} TensorStatus::NotInit => {} } } }