- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
143 lines
3.9 KiB
Rust
143 lines
3.9 KiB
Rust
use core::sync::atomic::{AtomicU32, Ordering};
|
|
|
|
use alloc::format;
|
|
use alloc::{sync::Arc, vec::Vec};
|
|
|
|
use super::RunnerClient;
|
|
use burn_backend::{DType, Shape, TensorData, TensorMetadata, backend::ExecutionError};
|
|
use burn_ir::{TensorId, TensorIr, TensorStatus};
|
|
|
|
/// Tensor primitive for the [router backend](crate::BackendRouter).
|
|
pub struct RouterTensor<C: RunnerClient> {
|
|
pub(crate) id: TensorId,
|
|
pub(crate) shape: Shape,
|
|
pub(crate) dtype: DType,
|
|
/// The client that has this tensor
|
|
pub client: C,
|
|
pub(crate) count: Arc<AtomicU32>,
|
|
}
|
|
|
|
impl<C: RunnerClient> TensorMetadata for RouterTensor<C> {
|
|
fn dtype(&self) -> DType {
|
|
self.dtype
|
|
}
|
|
|
|
fn shape(&self) -> Shape {
|
|
self.shape.clone()
|
|
}
|
|
|
|
fn rank(&self) -> usize {
|
|
self.shape.num_dims()
|
|
}
|
|
}
|
|
|
|
impl<C: RunnerClient> RouterTensor<C> {
|
|
/// Create a new router tensor.
|
|
pub fn new(id: TensorId, shape: Shape, dtype: DType, client: C) -> Self {
|
|
Self {
|
|
id,
|
|
shape,
|
|
dtype,
|
|
client,
|
|
count: Arc::new(AtomicU32::new(1)),
|
|
}
|
|
}
|
|
|
|
pub(crate) async fn into_data(self) -> Result<TensorData, ExecutionError> {
|
|
self.client.clone().read_tensor_async(self.into_ir()).await
|
|
}
|
|
|
|
/// Get the ir for this tensor
|
|
pub fn into_ir(mut self) -> TensorIr {
|
|
let count = self.count.load(Ordering::Relaxed);
|
|
let status = self.status(count);
|
|
let mut shape_out = Shape::from(Vec::<usize>::new());
|
|
core::mem::swap(&mut self.shape, &mut shape_out);
|
|
|
|
if let TensorStatus::ReadWrite = status {
|
|
// Avoids an unwanted drop on the same thread.
|
|
//
|
|
// Since `drop` is called after `into_ir`, we must not register a drop if the tensor
|
|
// was consumed with a `ReadWrite` status.
|
|
self.count.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
|
|
TensorIr {
|
|
status,
|
|
shape: shape_out,
|
|
id: self.id,
|
|
dtype: self.dtype,
|
|
}
|
|
}
|
|
|
|
pub(crate) fn to_ir_out(&self) -> TensorIr {
|
|
TensorIr {
|
|
status: TensorStatus::NotInit,
|
|
shape: self.shape.clone(),
|
|
id: self.id,
|
|
dtype: self.dtype,
|
|
}
|
|
}
|
|
|
|
pub(crate) fn status(&self, count: u32) -> TensorStatus {
|
|
if count <= 1 {
|
|
TensorStatus::ReadWrite
|
|
} else {
|
|
TensorStatus::ReadOnly
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<C: RunnerClient> core::fmt::Debug for RouterTensor<C> {
|
|
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
|
f.write_str(
|
|
format!(
|
|
"{{ id: {:?}, shape: {:?}, dtype: {:?}, device: {:?} }}",
|
|
self.id,
|
|
self.shape,
|
|
self.dtype,
|
|
self.client.device().clone(),
|
|
)
|
|
.as_str(),
|
|
)
|
|
}
|
|
}
|
|
|
|
impl<C: RunnerClient> Clone for RouterTensor<C> {
|
|
fn clone(&self) -> Self {
|
|
self.count.fetch_add(1, Ordering::Relaxed);
|
|
|
|
Self {
|
|
id: self.id,
|
|
shape: self.shape.clone(),
|
|
client: self.client.clone(),
|
|
dtype: self.dtype,
|
|
count: self.count.clone(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<C: RunnerClient> Drop for RouterTensor<C> {
|
|
fn drop(&mut self) {
|
|
let count = self.count.fetch_sub(1, Ordering::Relaxed);
|
|
|
|
match self.status(count) {
|
|
TensorStatus::ReadWrite => {
|
|
let id = self.id;
|
|
let mut shape = Shape::from(Vec::<usize>::new());
|
|
core::mem::swap(&mut shape, &mut self.shape);
|
|
|
|
let ir = TensorIr {
|
|
id,
|
|
shape,
|
|
status: TensorStatus::ReadWrite,
|
|
dtype: self.dtype,
|
|
};
|
|
self.client.register_op(burn_ir::OperationIr::Drop(ir));
|
|
}
|
|
TensorStatus::ReadOnly => {}
|
|
TensorStatus::NotInit => {}
|
|
}
|
|
}
|
|
}
|