Files
RustyUI/crates/stable-diffusion-burn/burn-crates/burn-dispatch/src/tensor.rs
Ben_Kosytorz 3a67c0979c feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
2026-03-05 19:39:14 +01:00

275 lines
10 KiB
Rust

use burn_backend::{Backend, QTensorPrimitive, TensorMetadata};
use crate::backends::*;
#[cfg(feature = "autodiff")]
use burn_backend::tensor::FloatTensor;
// TODO: if we reduce the different associated types for float/int/bool/quantized tensor primitives down to a single
// `B::TensorPrimitive` we can simplify this.
/// Tensor which points to a backend tensor primitive kind.
#[derive(Clone, Debug)]
pub enum BackendTensor<B: Backend> {
/// Float tensor handle.
Float(B::FloatTensorPrimitive),
/// Int tensor handle.
Int(B::IntTensorPrimitive),
/// Bool tensor handle.
Bool(B::BoolTensorPrimitive),
/// Quantized tensor handle.
Quantized(B::QuantizedTensorPrimitive),
#[cfg(feature = "autodiff")]
/// Autodiff float tensor handle.
Autodiff(FloatTensor<Autodiff<B>>),
}
impl<B: Backend> BackendTensor<B> {
/// Returns the inner float tensor primitive.
pub(crate) fn float(self) -> B::FloatTensorPrimitive {
match self {
BackendTensor::Float(tensor) => tensor,
BackendTensor::Int(_) => panic!("Should be float, got int"),
BackendTensor::Bool(_) => panic!("Should be float, got bool"),
BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
}
}
/// Returns the inner float tensor primitive.
pub(crate) fn as_float(&self) -> &B::FloatTensorPrimitive {
match self {
BackendTensor::Float(tensor) => tensor,
BackendTensor::Int(_) => panic!("Should be float, got int"),
BackendTensor::Bool(_) => panic!("Should be float, got bool"),
BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
}
}
/// Returns the inner int tensor primitive.
pub(crate) fn int(self) -> B::IntTensorPrimitive {
match self {
BackendTensor::Int(tensor) => tensor,
BackendTensor::Float(_) => panic!("Should be int, got float"),
BackendTensor::Bool(_) => panic!("Should be int, got bool"),
BackendTensor::Quantized(_) => panic!("Should be int, got quantized"),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(_) => panic!("Should be int, got autodiff"),
}
}
/// Returns the inner bool tensor primitive.
pub(crate) fn bool(self) -> B::BoolTensorPrimitive {
match self {
BackendTensor::Bool(tensor) => tensor,
BackendTensor::Float(_) => panic!("Should be bool, got float"),
BackendTensor::Int(_) => panic!("Should be bool, got int"),
BackendTensor::Quantized(_) => panic!("Should be bool, got quantized"),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(_) => panic!("Should be bool, got autodiff"),
}
}
/// Returns the inner quantized tensor primitive.
pub(crate) fn quantized(self) -> B::QuantizedTensorPrimitive {
match self {
BackendTensor::Quantized(tensor) => tensor,
_ => unreachable!(),
}
}
#[cfg(feature = "autodiff")]
/// Returns the inner autodiff tensor primitive.
pub(crate) fn autodiff(self) -> FloatTensor<Autodiff<B>> {
match self {
BackendTensor::Autodiff(tensor) => tensor,
// NOTE: this is the panicking code reached in tensor.rs:74:18:
_ => unreachable!(),
}
}
#[cfg(feature = "autodiff")]
/// Returns the inner autodiff tensor primitive.
pub(crate) fn as_autodiff(&self) -> &FloatTensor<Autodiff<B>> {
match self {
BackendTensor::Autodiff(tensor) => tensor,
_ => unreachable!(),
}
}
#[cfg(feature = "autodiff")]
/// Returns the inner autodiff tensor primitive.
pub(crate) fn autodiff_inner(self) -> B::FloatTensorPrimitive {
match self {
BackendTensor::Autodiff(tensor) => tensor.primitive,
_ => unreachable!(),
}
}
/// Returns the backend device.
pub(crate) fn device(&self) -> B::Device {
match self {
BackendTensor::Float(tensor) => B::float_device(tensor),
BackendTensor::Int(tensor) => B::int_device(tensor),
BackendTensor::Bool(tensor) => B::bool_device(tensor),
BackendTensor::Quantized(tensor) => B::q_device(tensor),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(tensor) => B::float_device(&tensor.primitive),
}
}
}
impl<B: Backend> TensorMetadata for BackendTensor<B> {
fn dtype(&self) -> burn_std::DType {
match self {
BackendTensor::Float(tensor) => tensor.dtype(),
BackendTensor::Int(tensor) => tensor.dtype(),
BackendTensor::Bool(tensor) => tensor.dtype(),
BackendTensor::Quantized(tensor) => tensor.dtype(),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(tensor) => tensor.dtype(),
}
}
fn shape(&self) -> burn_std::Shape {
match self {
BackendTensor::Float(tensor) => tensor.shape(),
BackendTensor::Int(tensor) => tensor.shape(),
BackendTensor::Bool(tensor) => tensor.shape(),
BackendTensor::Quantized(tensor) => tensor.shape(),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(tensor) => tensor.shape(),
}
}
}
impl<B: Backend> QTensorPrimitive for BackendTensor<B> {
fn scheme(&self) -> &burn_std::QuantScheme {
match self {
BackendTensor::Quantized(tensor) => tensor.scheme(),
_ => panic!(
"Quantization scheme is not valid for dtype {:?}",
self.dtype(),
),
}
}
}
/// Dispatch tensor that can hold tensors from any enabled backend.
///
/// This enum wraps backend-specific tensor types, allowing runtime selection
/// of the backend to execute operations on.
#[derive(Clone, Debug)]
pub enum DispatchTensor {
/// The [CPU backend](Cpu) tensor.
#[cfg(feature = "cpu")]
Cpu(BackendTensor<Cpu>),
/// The [CUDA backend](Cuda) tensor.
#[cfg(feature = "cuda")]
Cuda(BackendTensor<Cuda>),
/// The [Metal backend](Metal) tensor.
#[cfg(wgpu_metal)]
Metal(BackendTensor<Metal>),
/// The [ROCm backend](Rocm) tensor.
#[cfg(feature = "rocm")]
Rocm(BackendTensor<Rocm>),
/// The [Vulkan backend](Vulkan) tensor.
#[cfg(wgpu_vulkan)]
Vulkan(BackendTensor<Vulkan>),
/// The [WebGPU backend](WebGpu) tensor.
#[cfg(wgpu_webgpu)]
WebGpu(BackendTensor<WebGpu>),
/// The [NdArray backend](NdArray) tensor.
#[cfg(feature = "ndarray")]
NdArray(BackendTensor<NdArray>),
/// The [LibTorch backend](LibTorch) tensor.
#[cfg(feature = "tch")]
LibTorch(BackendTensor<LibTorch>),
/// The [autodiff enabled backend](Autodiff) tensor.
#[cfg(feature = "autodiff")]
Autodiff(Box<DispatchTensor>),
}
impl TensorMetadata for DispatchTensor {
fn dtype(&self) -> burn_std::DType {
match self {
#[cfg(feature = "cpu")]
DispatchTensor::Cpu(tensor) => tensor.dtype(),
#[cfg(feature = "cuda")]
DispatchTensor::Cuda(tensor) => tensor.dtype(),
#[cfg(wgpu_metal)]
DispatchTensor::Metal(tensor) => tensor.dtype(),
#[cfg(feature = "rocm")]
DispatchTensor::Rocm(tensor) => tensor.dtype(),
#[cfg(wgpu_vulkan)]
DispatchTensor::Vulkan(tensor) => tensor.dtype(),
#[cfg(wgpu_webgpu)]
DispatchTensor::WebGpu(tensor) => tensor.dtype(),
#[cfg(feature = "ndarray")]
DispatchTensor::NdArray(tensor) => tensor.dtype(),
#[cfg(feature = "tch")]
DispatchTensor::LibTorch(tensor) => tensor.dtype(),
#[cfg(feature = "autodiff")]
DispatchTensor::Autodiff(tensor) => tensor.dtype(),
}
}
fn shape(&self) -> burn_std::Shape {
match self {
#[cfg(feature = "cpu")]
DispatchTensor::Cpu(tensor) => tensor.shape(),
#[cfg(feature = "cuda")]
DispatchTensor::Cuda(tensor) => tensor.shape(),
#[cfg(wgpu_metal)]
DispatchTensor::Metal(tensor) => tensor.shape(),
#[cfg(feature = "rocm")]
DispatchTensor::Rocm(tensor) => tensor.shape(),
#[cfg(wgpu_vulkan)]
DispatchTensor::Vulkan(tensor) => tensor.shape(),
#[cfg(wgpu_webgpu)]
DispatchTensor::WebGpu(tensor) => tensor.shape(),
#[cfg(feature = "ndarray")]
DispatchTensor::NdArray(tensor) => tensor.shape(),
#[cfg(feature = "tch")]
DispatchTensor::LibTorch(tensor) => tensor.shape(),
#[cfg(feature = "autodiff")]
DispatchTensor::Autodiff(tensor) => tensor.shape(),
}
}
}
impl QTensorPrimitive for DispatchTensor {
fn scheme(&self) -> &burn_std::QuantScheme {
match self {
#[cfg(feature = "cpu")]
DispatchTensor::Cpu(tensor) => tensor.scheme(),
#[cfg(feature = "cuda")]
DispatchTensor::Cuda(tensor) => tensor.scheme(),
#[cfg(wgpu_metal)]
DispatchTensor::Metal(tensor) => tensor.scheme(),
#[cfg(feature = "rocm")]
DispatchTensor::Rocm(tensor) => tensor.scheme(),
#[cfg(wgpu_vulkan)]
DispatchTensor::Vulkan(tensor) => tensor.scheme(),
#[cfg(wgpu_webgpu)]
DispatchTensor::WebGpu(tensor) => tensor.scheme(),
#[cfg(feature = "ndarray")]
DispatchTensor::NdArray(tensor) => tensor.scheme(),
#[cfg(feature = "tch")]
DispatchTensor::LibTorch(tensor) => tensor.scheme(),
#[cfg(feature = "autodiff")]
DispatchTensor::Autodiff(tensor) => tensor.scheme(),
}
}
}