feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,84 @@
[package]
authors = [
"laggui <lagrange.guillaume.1@gmail.com>",
"nathanielsimard <nathaniel.simard.42@gmail.com>",
]
categories = ["science"]
description = "Backend dispatch for the Burn framework"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "data"]
license.workspace = true
name = "burn-dispatch"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-dispatch"
documentation = "https://docs.rs/burn-dispatch"
version.workspace = true
[lints]
workspace = true
[features]
default = [
"std",
"ndarray",
"burn-autodiff?/default",
"burn-cpu?/default",
"burn-cuda?/default",
"burn-ndarray?/default",
"burn-rocm?/default",
"burn-tch?/default",
"burn-wgpu?/default",
]
doc = ["default"]
std = [
"burn-backend/std",
"burn-std/std",
"burn-autodiff?/std",
"burn-cpu?/std",
"burn-cuda?/std",
"burn-ndarray?/std",
"burn-rocm?/std",
"burn-tch?/std",
"burn-wgpu?/std",
]
tracing = [
"burn-autodiff?/tracing",
"burn-cpu?/tracing",
"burn-cuda?/tracing",
"burn-ndarray?/tracing",
"burn-rocm?/tracing",
"burn-tch?/tracing",
"burn-wgpu?/tracing",
]
# Backends
cuda = ["burn-cuda"]
rocm = ["burn-rocm"]
ndarray = ["burn-ndarray"]
tch = ["burn-tch"]
vulkan = ["wgpu", "burn-wgpu/vulkan"]
webgpu = ["wgpu", "burn-wgpu/webgpu"]
metal = ["wgpu", "burn-wgpu/metal"]
wgpu = ["burn-wgpu"]
cpu = ["burn-cpu"]
autodiff = ["burn-autodiff"]
[dependencies]
burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false }
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false }
# Backends
burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-cpu = { path = "../burn-cpu", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-rocm = { path = "../burn-rocm", version = "=0.21.0-pre.2", optional = true, default-features = false }
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true, default-features = false }
# Op macros with `.as_$inner_kind()`
paste = { workspace = true }
[package.metadata.docs.rs]
features = ["doc"]
rustdoc-args = ["--cfg", "docsrs"]

View File

@@ -0,0 +1,3 @@
# Burn Backend Dispatch
A multi-backend dispatch that forwards the tensor operations to the appropriate backend.

View File

@@ -0,0 +1,35 @@
fn main() {
println!("cargo::rustc-check-cfg=cfg(wgpu_metal)");
println!("cargo::rustc-check-cfg=cfg(wgpu_vulkan)");
println!("cargo::rustc-check-cfg=cfg(wgpu_webgpu)");
// Detect which single wgpu backend is enabled
let metal = cfg!(feature = "metal");
let vulkan = cfg!(feature = "vulkan");
let webgpu = cfg!(feature = "webgpu");
let enabled = [(metal, "metal"), (vulkan, "vulkan"), (webgpu, "webgpu")]
.iter()
.filter(|x| x.0)
.map(|x| x.1)
.collect::<Vec<_>>();
// WGPU features are mutually exclusive, but we don't want to workspace to throw a compile error.
// In workspace builds with multiple features, we emit a warning and disable all WGPU backends.
if enabled.len() > 1 {
println!(
"cargo:warning=Only one WGPU backend can be enabled at once. Detected: [{}]. No WGPU backend will be available in this build. This is expected in workspace builds. For production, enable only one of: metal, vulkan, or webgpu.",
enabled.join(", ")
);
return;
}
if metal {
println!("cargo:rustc-cfg=wgpu_metal");
}
if vulkan {
println!("cargo:rustc-cfg=wgpu_vulkan");
}
if webgpu {
println!("cargo:rustc-cfg=wgpu_webgpu");
}
}

View File

@@ -0,0 +1,392 @@
use alloc::format;
use alloc::string::String;
use burn_backend::Backend;
use burn_backend::ExecutionError;
use burn_std::DType;
#[cfg(feature = "autodiff")]
use burn_autodiff::grads::Gradients;
#[cfg(feature = "autodiff")]
use burn_backend::AutodiffBackend;
use crate::backends::*;
use crate::{DispatchDevice, DispatchTensor};
/// The main execution backend in Burn.
///
/// [`Dispatch`] acts as a global backend that can manage multiple underlying
/// backends (e.g., `Cpu`, `Cuda`, `Wgpu`, `Metal`, etc.).
/// It is responsible for:
/// - Dispatching tensor operations to the appropriate backend.
/// - Managing cross-backend tensor transfers.
///
/// Essentially, [`Dispatch`] is the single entry point for executing tensor operations
/// in a backend-agnostic way. It allows Burn to provide a unified, global backend
/// for users while still leveraging multiple specialized backends under the hood.
///
/// # Example
///
/// ```ignore
/// use burn::Dispatch;
/// use burn::DispatchDevice;
///
/// // Select the device to execute operations on
/// let device = DispatchDevice::Cuda(Default::default());
///
/// // Create a tensor using the global backend
/// let t = Tensor::<Dispatch, 2>::zeros([128, 128], &device);
/// ```
#[derive(Debug, Default, Clone)]
pub struct Dispatch;
impl Backend for Dispatch {
type Device = DispatchDevice;
type FloatTensorPrimitive = DispatchTensor;
// TODO: either allow default dtype generic or remove associated types entirely?
type FloatElem = f32;
type IntTensorPrimitive = DispatchTensor;
type IntElem = i32;
type BoolTensorPrimitive = DispatchTensor;
type BoolElem = u8;
type QuantizedTensorPrimitive = DispatchTensor;
fn name(device: &Self::Device) -> String {
let inner = dispatch_device!(device, |device| B::name(device));
format!("dispatch<{inner}>")
}
fn seed(device: &Self::Device, seed: u64) {
dispatch_device!(device, |device| B::seed(device, seed))
}
fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
dispatch_device!(device, |device| B::sync(device))
}
fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
dispatch_device!(device, |device| B::dtype_usage(device, dtype))
}
fn ad_enabled(device: &Self::Device) -> bool {
match device {
#[cfg(feature = "autodiff")]
DispatchDevice::Autodiff(_) => true,
_ => false,
}
}
}
#[cfg(feature = "autodiff")]
impl AutodiffBackend for Dispatch {
type InnerBackend = Dispatch;
type Gradients = Gradients;
fn backward(tensor: DispatchTensor) -> Self::Gradients {
match tensor {
#[cfg(feature = "autodiff")]
DispatchTensor::Autodiff(tensor) => match *tensor {
#[cfg(feature = "cpu")]
DispatchTensor::Cpu(tensor) => tensor.autodiff().backward(),
#[cfg(feature = "cuda")]
DispatchTensor::Cuda(tensor) => tensor.autodiff().backward(),
#[cfg(wgpu_metal)]
DispatchTensor::Metal(tensor) => tensor.autodiff().backward(),
#[cfg(feature = "rocm")]
DispatchTensor::Rocm(tensor) => tensor.autodiff().backward(),
#[cfg(wgpu_vulkan)]
DispatchTensor::Vulkan(tensor) => tensor.autodiff().backward(),
#[cfg(wgpu_webgpu)]
DispatchTensor::WebGpu(tensor) => tensor.autodiff().backward(),
#[cfg(feature = "ndarray")]
DispatchTensor::NdArray(tensor) => tensor.autodiff().backward(),
DispatchTensor::Autodiff(_) => {
panic!("Autodiff should not wrap an autodiff tensor.")
}
},
_ => panic!("Requires autodiff tensor."),
}
}
fn grad(tensor: &DispatchTensor, grads: &Self::Gradients) -> Option<DispatchTensor> {
match &tensor {
#[cfg(feature = "autodiff")]
DispatchTensor::Autodiff(tensor) => match &**tensor {
#[cfg(feature = "cpu")]
DispatchTensor::Cpu(tensor) => tensor
.as_autodiff()
.grad(grads)
.map(|t| DispatchTensor::Cpu(crate::BackendTensor::Float(t))),
#[cfg(feature = "cuda")]
DispatchTensor::Cuda(tensor) => tensor
.as_autodiff()
.grad(grads)
.map(|t| DispatchTensor::Cuda(crate::BackendTensor::Float(t))),
#[cfg(wgpu_metal)]
DispatchTensor::Metal(tensor) => tensor
.as_autodiff()
.grad(grads)
.map(|t| DispatchTensor::Metal(crate::BackendTensor::Float(t))),
#[cfg(feature = "rocm")]
DispatchTensor::Rocm(tensor) => tensor
.as_autodiff()
.grad(grads)
.map(|t| DispatchTensor::Rocm(crate::BackendTensor::Float(t))),
#[cfg(wgpu_vulkan)]
DispatchTensor::Vulkan(tensor) => tensor
.as_autodiff()
.grad(grads)
.map(|t| DispatchTensor::Vulkan(crate::BackendTensor::Float(t))),
#[cfg(wgpu_webgpu)]
DispatchTensor::WebGpu(tensor) => tensor
.as_autodiff()
.grad(grads)
.map(|t| DispatchTensor::WebGpu(crate::BackendTensor::Float(t))),
#[cfg(feature = "ndarray")]
DispatchTensor::NdArray(tensor) => tensor
.as_autodiff()
.grad(grads)
.map(|t| DispatchTensor::NdArray(crate::BackendTensor::Float(t))),
DispatchTensor::Autodiff(_) => {
panic!("Autodiff should not wrap an autodiff tensor.")
}
},
_ => panic!("Requires autodiff tensor."),
}
}
fn grad_remove(tensor: &DispatchTensor, grads: &mut Self::Gradients) -> Option<DispatchTensor> {
match &tensor {
#[cfg(feature = "autodiff")]
DispatchTensor::Autodiff(tensor) => match &**tensor {
#[cfg(feature = "cpu")]
DispatchTensor::Cpu(tensor) => tensor
.as_autodiff()
.grad_remove(grads)
.map(|t| DispatchTensor::Cpu(crate::BackendTensor::Float(t))),
#[cfg(feature = "cuda")]
DispatchTensor::Cuda(tensor) => tensor
.as_autodiff()
.grad_remove(grads)
.map(|t| DispatchTensor::Cuda(crate::BackendTensor::Float(t))),
#[cfg(wgpu_metal)]
DispatchTensor::Metal(tensor) => tensor
.as_autodiff()
.grad_remove(grads)
.map(|t| DispatchTensor::Metal(crate::BackendTensor::Float(t))),
#[cfg(feature = "rocm")]
DispatchTensor::Rocm(tensor) => tensor
.as_autodiff()
.grad_remove(grads)
.map(|t| DispatchTensor::Rocm(crate::BackendTensor::Float(t))),
#[cfg(wgpu_vulkan)]
DispatchTensor::Vulkan(tensor) => tensor
.as_autodiff()
.grad_remove(grads)
.map(|t| DispatchTensor::Vulkan(crate::BackendTensor::Float(t))),
#[cfg(wgpu_webgpu)]
DispatchTensor::WebGpu(tensor) => tensor
.as_autodiff()
.grad_remove(grads)
.map(|t| DispatchTensor::WebGpu(crate::BackendTensor::Float(t))),
#[cfg(feature = "ndarray")]
DispatchTensor::NdArray(tensor) => tensor
.as_autodiff()
.grad_remove(grads)
.map(|t| DispatchTensor::NdArray(crate::BackendTensor::Float(t))),
DispatchTensor::Autodiff(_) => {
panic!("Autodiff should not wrap an autodiff tensor.")
}
},
_ => panic!("Requires autodiff tensor."),
}
}
fn grad_replace(tensor: &DispatchTensor, grads: &mut Self::Gradients, grad: DispatchTensor) {
match &tensor {
#[cfg(feature = "autodiff")]
DispatchTensor::Autodiff(tensor) => match (&**tensor, grad) {
#[cfg(feature = "cpu")]
(DispatchTensor::Cpu(tensor), DispatchTensor::Cpu(grad)) => {
tensor.as_autodiff().grad_replace(grads, grad.float())
}
#[cfg(feature = "cuda")]
(DispatchTensor::Cuda(tensor), DispatchTensor::Cuda(grad)) => {
tensor.as_autodiff().grad_replace(grads, grad.float())
}
#[cfg(wgpu_metal)]
(DispatchTensor::Metal(tensor), DispatchTensor::Metal(grad)) => {
tensor.as_autodiff().grad_replace(grads, grad.float())
}
#[cfg(feature = "rocm")]
(DispatchTensor::Rocm(tensor), DispatchTensor::Rocm(grad)) => {
tensor.as_autodiff().grad_replace(grads, grad.float())
}
#[cfg(wgpu_vulkan)]
(DispatchTensor::Vulkan(tensor), DispatchTensor::Vulkan(grad)) => {
tensor.as_autodiff().grad_replace(grads, grad.float())
}
#[cfg(wgpu_webgpu)]
(DispatchTensor::WebGpu(tensor), DispatchTensor::WebGpu(grad)) => {
tensor.as_autodiff().grad_replace(grads, grad.float())
}
#[cfg(feature = "ndarray")]
(DispatchTensor::NdArray(tensor), DispatchTensor::NdArray(grad)) => {
tensor.as_autodiff().grad_replace(grads, grad.float())
}
(DispatchTensor::Autodiff(_), _) => {
panic!("Autodiff should not wrap an autodiff tensor.")
}
(t, g) => panic!(
"The provided tensors are not on the same backend. Got backends {t:?} and {g:?}."
),
},
_ => panic!("Requires autodiff tensor."),
}
}
fn inner(tensor: DispatchTensor) -> DispatchTensor {
match tensor {
#[cfg(feature = "autodiff")]
DispatchTensor::Autodiff(tensor) => match *tensor {
#[cfg(feature = "cpu")]
DispatchTensor::Cpu(tensor) => {
DispatchTensor::Cpu(crate::BackendTensor::Float(tensor.autodiff().primitive))
}
#[cfg(feature = "cuda")]
DispatchTensor::Cuda(tensor) => {
DispatchTensor::Cuda(crate::BackendTensor::Float(tensor.autodiff().primitive))
}
#[cfg(wgpu_metal)]
DispatchTensor::Metal(tensor) => {
DispatchTensor::Metal(crate::BackendTensor::Float(tensor.autodiff().primitive))
}
#[cfg(feature = "rocm")]
DispatchTensor::Rocm(tensor) => {
DispatchTensor::Rocm(crate::BackendTensor::Float(tensor.autodiff().primitive))
}
#[cfg(wgpu_vulkan)]
DispatchTensor::Vulkan(tensor) => {
DispatchTensor::Vulkan(crate::BackendTensor::Float(tensor.autodiff().primitive))
}
#[cfg(wgpu_webgpu)]
DispatchTensor::WebGpu(tensor) => {
DispatchTensor::WebGpu(crate::BackendTensor::Float(tensor.autodiff().primitive))
}
#[cfg(feature = "ndarray")]
DispatchTensor::NdArray(tensor) => DispatchTensor::NdArray(
crate::BackendTensor::Float(tensor.autodiff().primitive),
),
DispatchTensor::Autodiff(_) => {
panic!("Autodiff should not wrap an autodiff tensor.")
}
},
_ => panic!("Requires autodiff tensor."),
}
}
fn int_inner(tensor: DispatchTensor) -> DispatchTensor {
tensor
}
fn bool_inner(tensor: DispatchTensor) -> DispatchTensor {
tensor
}
fn q_inner(tensor: DispatchTensor) -> DispatchTensor {
tensor
}
fn from_inner(tensor: DispatchTensor) -> DispatchTensor {
match tensor {
#[cfg(feature = "cpu")]
DispatchTensor::Cpu(tensor) => DispatchTensor::Autodiff(Box::new(DispatchTensor::Cpu(
crate::BackendTensor::Autodiff(Autodiff::<Cpu<f32>>::from_inner(tensor.float())),
))),
#[cfg(feature = "cuda")]
DispatchTensor::Cuda(tensor) => DispatchTensor::Autodiff(Box::new(
DispatchTensor::Cuda(crate::BackendTensor::Autodiff(
Autodiff::<Cuda<f32>>::from_inner(tensor.float()),
)),
)),
#[cfg(wgpu_metal)]
DispatchTensor::Metal(tensor) => DispatchTensor::Autodiff(Box::new(
DispatchTensor::Metal(crate::BackendTensor::Autodiff(
Autodiff::<Metal<f32>>::from_inner(tensor.float()),
)),
)),
#[cfg(feature = "rocm")]
DispatchTensor::Rocm(tensor) => DispatchTensor::Autodiff(Box::new(
DispatchTensor::Rocm(crate::BackendTensor::Autodiff(
Autodiff::<Rocm<f32>>::from_inner(tensor.float()),
)),
)),
#[cfg(wgpu_vulkan)]
DispatchTensor::Vulkan(tensor) => DispatchTensor::Autodiff(Box::new(
DispatchTensor::Vulkan(crate::BackendTensor::Autodiff(
Autodiff::<Vulkan<f32>>::from_inner(tensor.float()),
)),
)),
#[cfg(wgpu_webgpu)]
DispatchTensor::WebGpu(tensor) => DispatchTensor::Autodiff(Box::new(
DispatchTensor::WebGpu(crate::BackendTensor::Autodiff(
Autodiff::<WebGpu<f32>>::from_inner(tensor.float()),
)),
)),
#[cfg(feature = "ndarray")]
DispatchTensor::NdArray(tensor) => DispatchTensor::Autodiff(Box::new(
DispatchTensor::NdArray(crate::BackendTensor::Autodiff(
Autodiff::<NdArray<f32>>::from_inner(tensor.float()),
)),
)),
DispatchTensor::Autodiff(_) => {
panic!("Autodiff should not wrap an autodiff tensor.")
}
}
}
fn int_from_inner(tensor: DispatchTensor) -> DispatchTensor {
tensor
}
fn bool_from_inner(tensor: DispatchTensor) -> DispatchTensor {
tensor
}
fn q_from_inner(tensor: DispatchTensor) -> DispatchTensor {
tensor
}
}
impl DispatchTensor {
pub(crate) fn device(&self) -> DispatchDevice {
match self {
#[cfg(feature = "cpu")]
DispatchTensor::Cpu(tensor) => DispatchDevice::Cpu(tensor.device()),
#[cfg(feature = "cuda")]
DispatchTensor::Cuda(tensor) => DispatchDevice::Cuda(tensor.device()),
#[cfg(wgpu_metal)]
DispatchTensor::Metal(tensor) => DispatchDevice::Metal(tensor.device()),
#[cfg(feature = "rocm")]
DispatchTensor::Rocm(tensor) => DispatchDevice::Rocm(tensor.device()),
#[cfg(wgpu_vulkan)]
DispatchTensor::Vulkan(tensor) => DispatchDevice::Vulkan(tensor.device()),
#[cfg(wgpu_webgpu)]
DispatchTensor::WebGpu(tensor) => DispatchDevice::WebGpu(tensor.device()),
#[cfg(feature = "ndarray")]
DispatchTensor::NdArray(tensor) => DispatchDevice::NdArray(tensor.device()),
#[cfg(feature = "tch")]
DispatchTensor::LibTorch(tensor) => DispatchDevice::LibTorch(tensor.device()),
#[cfg(feature = "autodiff")]
DispatchTensor::Autodiff(tensor) => DispatchDevice::autodiff(tensor.device()),
}
}
}

View File

@@ -0,0 +1,415 @@
use burn_backend::{DeviceId, DeviceOps};
use crate::backends::*;
/// Represents a device for the [`Dispatch`](crate::Dispatch).
///
/// Each variant corresponds to a backend that the [`Dispatch`](crate::Dispatch) can dispatch operations to.
///
/// # Example
///
/// ```ignore
/// use burn::DispatchDevice;
///
/// #[cfg(feature = "cpu")]
/// let cpu_device = DispatchDevice::Cpu(Default::default());
///
/// #[cfg(feature = "cuda")]
/// let cuda_device = DispatchDevice::Cuda(Default::default());
/// ```
#[derive(Clone, Eq)]
pub enum DispatchDevice {
/// The [CPU backend](Cpu) device.
#[cfg(feature = "cpu")]
Cpu(CpuDevice),
/// The [CUDA backend](Cuda) device.
#[cfg(feature = "cuda")]
Cuda(CudaDevice),
/// The [Metal backend](Metal) device (via WGPU runtime).
#[cfg(wgpu_metal)]
Metal(WgpuDevice),
/// The [ROCm backend](Rocm) device.
#[cfg(feature = "rocm")]
Rocm(RocmDevice),
/// The [Vulkan backend](Vulkan) device.
#[cfg(wgpu_vulkan)]
Vulkan(WgpuDevice),
/// The [WebGPU backend](WebGpu) device (via WGPU runtime).
#[cfg(wgpu_webgpu)]
WebGpu(WgpuDevice),
/// The [NdArray backend](NdArray) device (CPU-only).
#[cfg(feature = "ndarray")]
NdArray(NdArrayDevice),
/// The [LibTorch backend](LibTorch) device.
#[cfg(feature = "tch")]
LibTorch(LibTorchDevice),
/// The [autodiff enabled backend](Autodiff) device.
#[cfg(feature = "autodiff")]
Autodiff(AutodiffDevice),
}
#[cfg(feature = "autodiff")]
// This tuple struct mainly restricts users from creating Autodiff(Autodiff) devices.
/// A wrapper that enables automatic differentiation for a [`DispatchDevice`].
///
/// Use [`DispatchDevice::autodiff`] to construct this type.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AutodiffDevice(pub(crate) Box<DispatchDevice>);
// Useful for match in dispatch macros
impl core::ops::Deref for AutodiffDevice {
type Target = DispatchDevice;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl core::fmt::Debug for DispatchDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
#[cfg(feature = "cpu")]
Self::Cpu(device) => f.debug_tuple("Cpu").field(device).finish(),
#[cfg(feature = "cuda")]
Self::Cuda(device) => f.debug_tuple("Cuda").field(device).finish(),
#[cfg(wgpu_metal)]
Self::Metal(device) => f.debug_tuple("Metal").field(device).finish(),
#[cfg(feature = "rocm")]
Self::Rocm(device) => f.debug_tuple("Rocm").field(device).finish(),
#[cfg(wgpu_vulkan)]
Self::Vulkan(device) => f.debug_tuple("Vulkan").field(device).finish(),
#[cfg(wgpu_webgpu)]
Self::WebGpu(device) => f.debug_tuple("WebGpu").field(device).finish(),
#[cfg(feature = "ndarray")]
Self::NdArray(device) => f.debug_tuple("NdArray").field(device).finish(),
#[cfg(feature = "tch")]
Self::LibTorch(device) => f.debug_tuple("LibTorch").field(device).finish(),
#[cfg(feature = "autodiff")]
// Format without `AutodiffDevice` wrapper
Self::Autodiff(device) => f.debug_tuple("Autodiff").field(&device.0).finish(),
}
}
}
impl Default for DispatchDevice {
#[allow(unreachable_code)]
fn default() -> Self {
// TODO: which priority?
#[cfg(feature = "cpu")]
return Self::Cpu(CpuDevice);
#[cfg(feature = "cuda")]
return Self::Cuda(CudaDevice::default());
#[cfg(wgpu_metal)]
return Self::Metal(burn_wgpu::WgpuDevice::default());
#[cfg(feature = "rocm")]
return Self::Rocm(RocmDevice::default());
#[cfg(wgpu_vulkan)]
return Self::Vulkan(burn_wgpu::WgpuDevice::default());
#[cfg(wgpu_webgpu)]
return Self::WebGpu(burn_wgpu::WgpuDevice::default());
#[cfg(feature = "ndarray")]
return Self::NdArray(NdArrayDevice::default());
#[cfg(feature = "tch")]
return Self::LibTorch(LibTorchDevice::default());
}
}
impl PartialEq for DispatchDevice {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
// If both are Autodiff, compare the inner devices
#[cfg(feature = "autodiff")]
(DispatchDevice::Autodiff(a), DispatchDevice::Autodiff(b)) => a == b,
// If one is Autodiff, compare it to the raw device
#[cfg(feature = "autodiff")]
(DispatchDevice::Autodiff(a), b) => a.0.as_ref() == b,
#[cfg(feature = "autodiff")]
(a, DispatchDevice::Autodiff(b)) => a == b.0.as_ref(),
#[cfg(feature = "cpu")]
(Self::Cpu(a), Self::Cpu(b)) => a == b,
#[cfg(feature = "cuda")]
(Self::Cuda(a), Self::Cuda(b)) => a == b,
#[cfg(wgpu_metal)]
(Self::Metal(a), Self::Metal(b)) => a == b,
#[cfg(feature = "rocm")]
(Self::Rocm(a), Self::Rocm(b)) => a == b,
#[cfg(wgpu_vulkan)]
(Self::Vulkan(a), Self::Vulkan(b)) => a == b,
#[cfg(wgpu_webgpu)]
(Self::WebGpu(a), Self::WebGpu(b)) => a == b,
#[cfg(feature = "ndarray")]
(Self::NdArray(a), Self::NdArray(b)) => a == b,
#[cfg(feature = "tch")]
(Self::LibTorch(a), Self::LibTorch(b)) => a == b,
#[allow(unreachable_patterns)]
(_, _) => false,
}
}
}
/// Base multiplier to avoid type_id clashes between backends.
/// Limits the number of device types per backend, but this is a sensible limit.
const TYPE_ID_BASE: u16 = 10;
impl DispatchDevice {
#[cfg(feature = "autodiff")]
/// Creates a new [`DispatchDevice`] with [automatic differentiation](Autodiff) enabled.
pub fn autodiff(device: impl Into<DispatchDevice>) -> DispatchDevice {
let device = device.into();
DispatchDevice::Autodiff(AutodiffDevice(Box::new(device)))
}
/// Returns a unique number per variant to encode into type_id.
fn backend_id(&self) -> BackendId {
match self {
#[cfg(feature = "cpu")]
Self::Cpu(_) => BackendId::Cpu,
#[cfg(feature = "cuda")]
Self::Cuda(_) => BackendId::Cuda,
#[cfg(wgpu_metal)]
Self::Metal(_) => BackendId::Metal,
#[cfg(feature = "rocm")]
Self::Rocm(_) => BackendId::Rocm,
#[cfg(wgpu_vulkan)]
Self::Vulkan(_) => BackendId::Vulkan,
#[cfg(wgpu_webgpu)]
Self::WebGpu(_) => BackendId::WebGpu,
#[cfg(feature = "ndarray")]
Self::NdArray(_) => BackendId::NdArray,
#[cfg(feature = "tch")]
Self::LibTorch(_) => BackendId::LibTorch,
#[cfg(feature = "autodiff")]
Self::Autodiff(device) => device.0.backend_id(),
}
}
/// Encode variant ID and backend type ID into a unique `type_id`.
fn encode_type_id(&self, backend_type_id: u16) -> u16 {
u16::from(self.backend_id()) * TYPE_ID_BASE + backend_type_id
}
/// Decode an encoded `type_id` into variant ID and backend type ID.
fn decode_type_id(type_id: u16) -> (BackendId, u16) {
let variant = type_id / TYPE_ID_BASE;
let backend_type_id = type_id % TYPE_ID_BASE;
(
BackendId::try_from(variant).expect("Unknown DispatchDevice variant"),
backend_type_id,
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u16)]
enum BackendId {
#[cfg(feature = "cpu")]
Cpu = 0,
#[cfg(feature = "cuda")]
Cuda = 1,
#[cfg(wgpu_metal)]
Metal = 2,
#[cfg(feature = "rocm")]
Rocm = 3,
#[cfg(wgpu_vulkan)]
Vulkan = 4,
#[cfg(wgpu_webgpu)]
WebGpu = 5,
#[cfg(feature = "ndarray")]
NdArray = 6,
#[cfg(feature = "tch")]
LibTorch = 7,
}
impl From<BackendId> for u16 {
fn from(variant: BackendId) -> Self {
variant as u16
}
}
impl TryFrom<u16> for BackendId {
type Error = ();
fn try_from(value: u16) -> Result<Self, Self::Error> {
match value {
#[cfg(feature = "cpu")]
0 => Ok(Self::Cpu),
#[cfg(feature = "cuda")]
1 => Ok(Self::Cuda),
#[cfg(wgpu_metal)]
2 => Ok(Self::Metal),
#[cfg(feature = "rocm")]
3 => Ok(Self::Rocm),
#[cfg(wgpu_vulkan)]
4 => Ok(Self::Vulkan),
#[cfg(wgpu_webgpu)]
5 => Ok(Self::WebGpu),
#[cfg(feature = "ndarray")]
6 => Ok(Self::NdArray),
#[cfg(feature = "tch")]
7 => Ok(Self::LibTorch),
_ => Err(()),
}
}
}
impl DeviceOps for DispatchDevice {
fn inner(&self) -> &Self {
match self {
#[cfg(feature = "autodiff")]
DispatchDevice::Autodiff(device) => &device.0,
device => device,
}
}
}
impl burn_std::device::Device for DispatchDevice {
fn from_id(mut device_id: DeviceId) -> Self {
let (dispatch_id, backend_type_id) = Self::decode_type_id(device_id.type_id);
device_id.type_id = backend_type_id;
match dispatch_id {
#[cfg(feature = "cpu")]
BackendId::Cpu => Self::Cpu(CpuDevice::from_id(device_id)),
#[cfg(feature = "cuda")]
BackendId::Cuda => Self::Cuda(CudaDevice::from_id(device_id)),
#[cfg(wgpu_metal)]
BackendId::Metal => Self::Metal(WgpuDevice::from_id(device_id)),
#[cfg(feature = "rocm")]
BackendId::Rocm => Self::Rocm(RocmDevice::from_id(device_id)),
#[cfg(wgpu_vulkan)]
BackendId::Vulkan => Self::Vulkan(WgpuDevice::from_id(device_id)),
#[cfg(wgpu_webgpu)]
BackendId::WebGpu => Self::WebGpu(WgpuDevice::from_id(device_id)),
#[cfg(feature = "ndarray")]
BackendId::NdArray => Self::NdArray(NdArrayDevice::from_id(device_id)),
#[cfg(feature = "tch")]
BackendId::LibTorch => Self::LibTorch(LibTorchDevice::from_id(device_id)),
}
}
fn to_id(&self) -> DeviceId {
let mut device_id = match self {
#[cfg(feature = "cpu")]
Self::Cpu(device) => device.to_id(),
#[cfg(feature = "cuda")]
Self::Cuda(device) => device.to_id(),
#[cfg(wgpu_metal)]
Self::Metal(device) => device.to_id(),
#[cfg(feature = "rocm")]
Self::Rocm(device) => device.to_id(),
#[cfg(wgpu_vulkan)]
Self::Vulkan(device) => device.to_id(),
#[cfg(wgpu_webgpu)]
Self::WebGpu(device) => device.to_id(),
#[cfg(feature = "ndarray")]
Self::NdArray(device) => device.to_id(),
#[cfg(feature = "tch")]
Self::LibTorch(device) => device.to_id(),
#[cfg(feature = "autodiff")]
Self::Autodiff(device) => device.0.to_id(),
};
device_id.type_id = self.encode_type_id(device_id.type_id);
device_id
}
fn device_count(type_id: u16) -> usize {
let (dispatch_id, backend_type_id) = Self::decode_type_id(type_id);
match dispatch_id {
#[cfg(feature = "cpu")]
BackendId::Cpu => CpuDevice::device_count(backend_type_id),
#[cfg(feature = "cuda")]
BackendId::Cuda => CudaDevice::device_count(backend_type_id),
#[cfg(wgpu_metal)]
BackendId::Metal => WgpuDevice::device_count(backend_type_id),
#[cfg(feature = "rocm")]
BackendId::Rocm => RocmDevice::device_count(backend_type_id),
#[cfg(wgpu_vulkan)]
BackendId::Vulkan => WgpuDevice::device_count(backend_type_id),
#[cfg(wgpu_webgpu)]
BackendId::WebGpu => WgpuDevice::device_count(backend_type_id),
#[cfg(feature = "ndarray")]
BackendId::NdArray => NdArrayDevice::device_count(backend_type_id),
#[cfg(feature = "tch")]
BackendId::LibTorch => LibTorchDevice::device_count(backend_type_id),
}
}
}
#[cfg(feature = "cpu")]
impl From<CpuDevice> for DispatchDevice {
fn from(device: CpuDevice) -> Self {
DispatchDevice::Cpu(device)
}
}
#[cfg(feature = "cuda")]
impl From<CudaDevice> for DispatchDevice {
fn from(device: CudaDevice) -> Self {
DispatchDevice::Cuda(device)
}
}
#[cfg(wgpu_metal)]
impl From<WgpuDevice> for DispatchDevice {
fn from(device: WgpuDevice) -> Self {
DispatchDevice::Metal(device)
}
}
#[cfg(feature = "rocm")]
impl From<RocmDevice> for DispatchDevice {
fn from(device: RocmDevice) -> Self {
DispatchDevice::Rocm(device)
}
}
#[cfg(wgpu_vulkan)]
impl From<WgpuDevice> for DispatchDevice {
fn from(device: WgpuDevice) -> Self {
DispatchDevice::Vulkan(device)
}
}
#[cfg(wgpu_webgpu)]
impl From<WgpuDevice> for DispatchDevice {
fn from(device: WgpuDevice) -> Self {
DispatchDevice::WebGpu(device)
}
}
#[cfg(feature = "ndarray")]
impl From<NdArrayDevice> for DispatchDevice {
fn from(device: NdArrayDevice) -> Self {
DispatchDevice::NdArray(device)
}
}
#[cfg(feature = "tch")]
impl From<LibTorchDevice> for DispatchDevice {
fn from(device: LibTorchDevice) -> Self {
DispatchDevice::LibTorch(device)
}
}
#[cfg(feature = "tch")]
impl From<LibTorchDevice> for DispatchDevice {
fn from(device: LibTorchDevice) -> Self {
DispatchDevice::LibTorch(device)
}
}

View File

@@ -0,0 +1,90 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![recursion_limit = "138"]
//! Burn multi-backend dispatch.
//!
//! # Available Backends
//!
//! The dispatch backend supports the following variants, each enabled via cargo features:
//!
//! | Backend | Feature | Description |
//! |------------|------------|-------------|
//! | `Cpu` | `cpu` | Rust CPU backend (MLIR + LLVM) |
//! | `Cuda` | `cuda` | NVIDIA CUDA backend |
//! | `Metal` | `metal` | Apple Metal backend via `wgpu` (MSL) |
//! | `Rocm` | `rocm` | AMD ROCm backend |
//! | `Vulkan` | `vulkan` | Vulkan backend via `wgpu` (SPIR-V) |
//! | `WebGpu` | `webgpu` | WebGPU backend via `wgpu` (WGSL) |
//! | `NdArray` | `ndarray` | Pure Rust CPU backend using `ndarray` |
//! | `LibTorch` | `tch` | Libtorch backend via `tch` |
//! | `Autodiff` | `autodiff` | Autodiff-enabled backend (used in combination with any of the backends above) |
//!
//! **Note:** WGPU-based backends (`metal`, `vulkan`, `webgpu`) are mutually exclusive.
//! All other backends can be combined freely.
//!
//! ## WGPU Backend Exclusivity
//!
//! The WGPU-based backends (`metal`, `vulkan`, `webgpu`) are **mutually exclusive** due to
//! the current automatic compile, which can only select one target at a time.
//!
//! Enable only **one** of these features in your `Cargo.toml`:
//! - `metal`
//! - `vulkan`
//! - `webgpu`
//!
//! If multiple WGPU features are enabled, the build script will emit a warning and **disable all WGPU
//! backends** to prevent unintended behavior.
#[cfg(not(any(
feature = "cpu",
feature = "cuda",
wgpu_metal,
feature = "rocm",
wgpu_vulkan,
wgpu_webgpu,
feature = "ndarray",
feature = "tch",
)))]
compile_error!("At least one backend feature must be enabled.");
#[macro_use]
mod macros;
mod backend;
mod device;
mod ops;
mod tensor;
pub use backend::*;
pub use device::*;
pub use tensor::*;
extern crate alloc;
/// Backends and devices used.
pub(crate) mod backends {
#[cfg(feature = "autodiff")]
pub use burn_autodiff::Autodiff;
#[cfg(feature = "cpu")]
pub use burn_cpu::{Cpu, CpuDevice};
#[cfg(feature = "cuda")]
pub use burn_cuda::{Cuda, CudaDevice};
#[cfg(feature = "rocm")]
pub use burn_rocm::{Rocm, RocmDevice};
#[cfg(wgpu_metal)]
pub use burn_wgpu::Metal;
#[cfg(wgpu_vulkan)]
pub use burn_wgpu::Vulkan;
#[cfg(wgpu_webgpu)]
pub use burn_wgpu::WebGpu;
#[cfg(any(wgpu_metal, wgpu_vulkan, wgpu_webgpu))]
pub use burn_wgpu::WgpuDevice;
#[cfg(feature = "ndarray")]
pub use burn_ndarray::{NdArray, NdArrayDevice};
#[cfg(feature = "tch")]
pub use burn_tch::{LibTorch, LibTorchDevice};
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,50 @@
use burn_backend::{Scalar, ops::ActivationOps, tensor::FloatTensor};
use crate::Dispatch;
use crate::backends::*;
impl ActivationOps<Self> for Dispatch {
fn leaky_relu(tensor: FloatTensor<Self>, negative_slope: Scalar) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::leaky_relu(tensor, negative_slope) => Float)
}
fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::relu(tensor) => Float)
}
fn relu_backward(output: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {
binary_float!((output, float), (grad, float), |output, grad| B::relu_backward(output, grad) => Float)
}
fn gelu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::gelu(tensor) => Float)
}
fn prelu(tensor: FloatTensor<Self>, alpha: FloatTensor<Self>) -> FloatTensor<Self> {
binary_float!((tensor, float), (alpha, float), |tensor, alpha| B::prelu(tensor, alpha) => Float)
}
fn gelu_backward(x: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {
binary_float!((x, float), (grad, float), |x, grad| B::gelu_backward(x, grad) => Float)
}
fn sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::sigmoid(tensor) => Float)
}
fn sigmoid_backward(output: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {
binary_float!((output, float), (grad, float), |output, grad| B::sigmoid_backward(output, grad) => Float)
}
fn hard_sigmoid(tensor: FloatTensor<Self>, alpha: Scalar, beta: Scalar) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::hard_sigmoid(tensor, alpha, beta) => Float)
}
fn log_sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::log_sigmoid(tensor) => Float)
}
fn log_sigmoid_backward(x: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {
binary_float!((x, float), (grad, float), |x, grad| B::log_sigmoid_backward(x, grad) => Float)
}
}

View File

@@ -0,0 +1,222 @@
use burn_backend::{
ExecutionError, Scalar, TensorData,
ops::BoolTensorOps,
tensor::{BoolTensor, FloatTensor, IntTensor},
};
use burn_std::{Shape, Slice};
use crate::backends::*;
use crate::{Dispatch, DispatchDevice};
impl BoolTensorOps<Self> for Dispatch {
fn bool_empty(shape: Shape, device: &DispatchDevice) -> BoolTensor<Self> {
creation_op!(Bool, device, |device| B::bool_empty(shape, device))
}
fn bool_zeros(shape: Shape, device: &DispatchDevice) -> BoolTensor<Self> {
creation_op!(Bool, device, |device| B::bool_zeros(shape, device))
}
fn bool_ones(shape: Shape, device: &DispatchDevice) -> BoolTensor<Self> {
creation_op!(Bool, device, |device| B::bool_ones(shape, device))
}
async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
unary_op!(tensor, bool, |tensor| B::bool_into_data(tensor).await)
}
fn bool_from_data(data: TensorData, device: &DispatchDevice) -> BoolTensor<Self> {
creation_op!(Bool, device, |device| B::bool_from_data(data, device))
}
fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
unary_op!(tensor, bool, |tensor| B::bool_into_int(tensor) => Int)
}
fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
unary_op!(tensor, bool, |tensor| B::bool_into_float(tensor) => Float)
}
fn bool_device(tensor: &BoolTensor<Self>) -> DispatchDevice {
tensor.device()
}
fn bool_to_device(tensor: BoolTensor<Self>, device: &DispatchDevice) -> BoolTensor<Self> {
to_device!(
Bool,
bool,
tensor,
device,
bool_to_device,
|inner, device| {
let data =
burn_backend::read_sync(B1::bool_into_data(inner)).expect("Should read data");
B2::bool_from_data(data, device)
}
)
}
fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
unary_op!(tensor, bool, |tensor| B::bool_reshape(tensor, shape) => Bool)
}
fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
unary_op!(tensor, bool, |tensor| B::bool_slice(tensor, slices) => Bool)
}
fn bool_slice_assign(
tensor: BoolTensor<Self>,
slices: &[Slice],
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
binary_op!((tensor, bool), (value, bool), |tensor, value| B::bool_slice_assign(tensor, slices, value) => Bool)
}
fn bool_mask_where(
tensor: BoolTensor<Self>,
mask: BoolTensor<Self>,
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
multi_op!(
inputs[(tensor, bool), (mask, bool), (value, bool)], => Bool,
B::bool_mask_where(tensor, mask, value)
)
}
fn bool_mask_fill(
tensor: BoolTensor<Self>,
mask: BoolTensor<Self>,
value: Scalar,
) -> BoolTensor<Self> {
binary_op!((tensor, bool), (mask, bool), |tensor, mask| B::bool_mask_fill(tensor, mask, value) => Bool)
}
fn bool_gather(
dim: usize,
tensor: BoolTensor<Self>,
indices: IntTensor<Self>,
) -> BoolTensor<Self> {
binary_op!((tensor, bool), (indices, int), |tensor, indices| B::bool_gather(dim, tensor, indices) => Bool)
}
fn bool_scatter_or(
dim: usize,
tensor: BoolTensor<Self>,
indices: IntTensor<Self>,
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
multi_op!(
inputs[(tensor, bool), (indices, int), (value, bool)], => Bool,
B::bool_scatter_or(dim, tensor, indices, value)
)
}
fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_equal(lhs, rhs) => Bool)
}
fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
unary_op!(lhs, bool, |lhs| B::bool_equal_elem(lhs, rhs) => Bool)
}
fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
unary_op!(tensor, bool, |tensor| B::bool_not(tensor) => Bool)
}
fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_and(lhs, rhs) => Bool)
}
fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_or(lhs, rhs) => Bool)
}
fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
unary_op!(tensor, bool, |tensor| B::bool_swap_dims(tensor, dim1, dim2) => Bool)
}
fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
unary_op!(tensor, bool, |tensor| B::bool_permute(tensor, axes) => Bool)
}
fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
unary_op!(tensor, bool, |tensor| B::bool_flip(tensor, axes) => Bool)
}
fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
unary_op!(tensor, bool, |tensor| B::bool_expand(tensor, shape) => Bool)
}
fn bool_unfold(
tensor: BoolTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> BoolTensor<Self> {
unary_op!(tensor, bool, |tensor| B::bool_unfold(tensor, dim, size, step) => Bool)
}
fn bool_select(
tensor: BoolTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
) -> BoolTensor<Self> {
binary_op!((tensor, bool), (indices, int), |tensor, indices| B::bool_select(tensor, dim, indices) => Bool)
}
fn bool_select_or(
tensor: BoolTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
multi_op!(
inputs[(tensor, bool), (indices, int), (value, bool)], => Bool,
B::bool_select_or(tensor, dim, indices, value)
)
}
fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
unary_op!(tensor, bool, |tensor| B::bool_repeat_dim(tensor, dim, times) => Bool)
}
fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
vec_op!(tensors, bool, |tensors| B::bool_cat(tensors, dim) => Bool)
}
fn bool_not_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_not_equal(lhs, rhs) => Bool)
}
fn bool_not_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
unary_op!(lhs, bool, |lhs| B::bool_not_equal_elem(lhs, rhs) => Bool)
}
fn bool_xor(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_xor(lhs, rhs) => Bool)
}
fn bool_transpose(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
unary_op!(tensor, bool, |tensor| B::bool_transpose(tensor) => Bool)
}
fn bool_any(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
unary_op!(tensor, bool, |tensor| B::bool_any(tensor) => Bool)
}
fn bool_any_dim(tensor: BoolTensor<Self>, dim: usize) -> BoolTensor<Self> {
unary_op!(tensor, bool, |tensor| B::bool_any_dim(tensor, dim) => Bool)
}
fn bool_all(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
unary_op!(tensor, bool, |tensor| B::bool_all(tensor) => Bool)
}
fn bool_all_dim(tensor: BoolTensor<Self>, dim: usize) -> BoolTensor<Self> {
unary_op!(tensor, bool, |tensor| B::bool_all_dim(tensor, dim) => Bool)
}
async fn bool_argwhere(tensor: BoolTensor<Self>) -> IntTensor<Self> {
unary_op!(tensor, bool, |tensor| B::bool_argwhere(tensor).await => Int)
}
}

View File

@@ -0,0 +1,503 @@
use burn_backend::{
ExecutionError, Scalar, TensorData,
ops::IntTensorOps,
tensor::{BoolTensor, FloatTensor, IntTensor},
};
use burn_std::{IntDType, Shape, Slice};
use crate::backends::*;
use crate::{Dispatch, DispatchDevice};
impl IntTensorOps<Self> for Dispatch {
fn int_empty(shape: Shape, device: &DispatchDevice, dtype: IntDType) -> IntTensor<Self> {
creation_op!(Int, device, |device| B::int_empty(shape, device, dtype))
}
async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {
unary_op!(tensor, int, |tensor| B::int_into_data(tensor).await)
}
fn int_from_data(data: TensorData, device: &DispatchDevice) -> IntTensor<Self> {
creation_op!(Int, device, |device| B::int_from_data(data, device))
}
fn int_device(tensor: &IntTensor<Self>) -> DispatchDevice {
tensor.device()
}
fn int_to_device(tensor: IntTensor<Self>, device: &DispatchDevice) -> IntTensor<Self> {
to_device!(Int, int, tensor, device, int_to_device, |inner, device| {
let data = burn_backend::read_sync(B1::int_into_data(inner)).expect("Should read data");
B2::int_from_data(data, device)
})
}
fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_reshape(tensor, shape) => Int)
}
fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_slice(tensor, slices) => Int)
}
fn int_slice_assign(
tensor: IntTensor<Self>,
slices: &[Slice],
value: IntTensor<Self>,
) -> IntTensor<Self> {
binary_op!((tensor, int), (value, int), |tensor, value| B::int_slice_assign(tensor, slices, value) => Int)
}
fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_into_float(tensor) => Float)
}
fn int_mask_where(
tensor: IntTensor<Self>,
mask: BoolTensor<Self>,
value: IntTensor<Self>,
) -> IntTensor<Self> {
multi_op!(
inputs[(tensor, int), (mask, bool), (value, int)], => Int,
B::int_mask_where(tensor, mask, value)
)
}
fn int_mask_fill(
tensor: IntTensor<Self>,
mask: BoolTensor<Self>,
value: Scalar,
) -> IntTensor<Self> {
binary_op!((tensor, int), (mask, bool), |tensor, mask| B::int_mask_fill(tensor, mask, value) => Int)
}
fn int_gather(
dim: usize,
tensor: IntTensor<Self>,
indices: IntTensor<Self>,
) -> IntTensor<Self> {
binary_op!((tensor, int), (indices, int), |tensor, indices| B::int_gather(dim, tensor, indices) => Int)
}
fn int_scatter_add(
dim: usize,
tensor: IntTensor<Self>,
indices: IntTensor<Self>,
value: IntTensor<Self>,
) -> IntTensor<Self> {
multi_op!(
inputs[(tensor, int), (indices, int), (value, int)], => Int,
B::int_scatter_add(dim, tensor, indices, value)
)
}
fn int_select(
tensor: IntTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
) -> IntTensor<Self> {
binary_op!((tensor, int), (indices, int), |tensor, indices| B::int_select(tensor, dim, indices) => Int)
}
fn int_select_add(
tensor: IntTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
value: IntTensor<Self>,
) -> IntTensor<Self> {
multi_op!(
inputs[(tensor, int), (indices, int), (value, int)], => Int,
B::int_select_add(tensor, dim, indices, value)
)
}
fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_equal(lhs, rhs) => Bool)
}
fn int_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
unary_op!(lhs, int, |lhs| B::int_equal_elem(lhs, rhs) => Bool)
}
fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_greater(lhs, rhs) => Bool)
}
fn int_greater_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
unary_op!(lhs, int, |lhs| B::int_greater_elem(lhs, rhs) => Bool)
}
fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_greater_equal(lhs, rhs) => Bool)
}
fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
unary_op!(lhs, int, |lhs| B::int_greater_equal_elem(lhs, rhs) => Bool)
}
fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_lower(lhs, rhs) => Bool)
}
fn int_lower_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
unary_op!(lhs, int, |lhs| B::int_lower_elem(lhs, rhs) => Bool)
}
fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_lower_equal(lhs, rhs) => Bool)
}
fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
unary_op!(lhs, int, |lhs| B::int_lower_equal_elem(lhs, rhs) => Bool)
}
fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_add(lhs, rhs) => Int)
}
fn int_add_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
unary_op!(lhs, int, |lhs| B::int_add_scalar(lhs, rhs) => Int)
}
fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_sub(lhs, rhs) => Int)
}
fn int_sub_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
unary_op!(lhs, int, |lhs| B::int_sub_scalar(lhs, rhs) => Int)
}
fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_mul(lhs, rhs) => Int)
}
fn int_mul_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
unary_op!(lhs, int, |lhs| B::int_mul_scalar(lhs, rhs) => Int)
}
fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_div(lhs, rhs) => Int)
}
fn int_div_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
unary_op!(lhs, int, |lhs| B::int_div_scalar(lhs, rhs) => Int)
}
fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_remainder(lhs, rhs) => Int)
}
fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
unary_op!(lhs, int, |lhs| B::int_remainder_scalar(lhs, rhs) => Int)
}
fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_matmul(lhs, rhs) => Int)
}
fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_sum(tensor) => Int)
}
fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_sum_dim(tensor, dim) => Int)
}
fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_prod(tensor) => Int)
}
fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_prod_dim(tensor, dim) => Int)
}
fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_mean_dim(tensor, dim) => Int)
}
fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_cumsum(tensor, dim) => Int)
}
fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_cumprod(tensor, dim) => Int)
}
fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_cummin(tensor, dim) => Int)
}
fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_cummax(tensor, dim) => Int)
}
fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_argmax(tensor, dim) => Int)
}
fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_argmin(tensor, dim) => Int)
}
fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_abs(tensor) => Int)
}
fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_swap_dims(tensor, dim1, dim2) => Int)
}
fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_permute(tensor, axes) => Int)
}
fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_flip(tensor, axes) => Int)
}
fn int_random(
shape: Shape,
distribution: burn_backend::Distribution,
device: &DispatchDevice,
) -> IntTensor<Self> {
creation_op!(Int, device, |device| {
B::int_random(shape, distribution, device)
})
}
fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_expand(tensor, shape) => Int)
}
fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_and(lhs, rhs) => Int)
}
fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
unary_op!(lhs, int, |lhs| B::bitwise_and_scalar(lhs, rhs) => Int)
}
fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_or(lhs, rhs) => Int)
}
fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
unary_op!(lhs, int, |lhs| B::bitwise_or_scalar(lhs, rhs) => Int)
}
fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_xor(lhs, rhs) => Int)
}
fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
unary_op!(lhs, int, |lhs| B::bitwise_xor_scalar(lhs, rhs) => Int)
}
fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::bitwise_not(tensor) => Int)
}
fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_left_shift(lhs, rhs) => Int)
}
fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
unary_op!(lhs, int, |lhs| B::bitwise_left_shift_scalar(lhs, rhs) => Int)
}
fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_right_shift(lhs, rhs) => Int)
}
fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
unary_op!(lhs, int, |lhs| B::bitwise_right_shift_scalar(lhs, rhs) => Int)
}
fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_cast(tensor, dtype) => Int)
}
fn int_unfold(
tensor: IntTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_unfold(tensor, dim, size, step) => Int)
}
fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_repeat_dim(tensor, dim, times) => Int)
}
fn int_cat(tensors: Vec<IntTensor<Self>>, dim: usize) -> IntTensor<Self> {
vec_op!(tensors, int, |tensors| B::int_cat(tensors, dim) => Int)
}
fn int_not_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_not_equal(lhs, rhs) => Bool)
}
fn int_not_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
unary_op!(lhs, int, |lhs| B::int_not_equal_elem(lhs, rhs) => Bool)
}
fn int_powi(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_powi(lhs, rhs) => Int)
}
fn int_powf(lhs: IntTensor<Self>, rhs: FloatTensor<Self>) -> IntTensor<Self> {
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_powf(lhs, rhs) => Int)
}
fn int_powi_scalar_impl(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
unary_op!(lhs, int, |lhs| B::int_powi_scalar_impl(lhs, rhs) => Int)
}
fn int_powf_scalar_impl(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
unary_op!(lhs, int, |lhs| B::int_powf_scalar_impl(lhs, rhs) => Int)
}
fn int_clamp_min(tensor: IntTensor<Self>, min: Scalar) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_clamp_min(tensor, min) => Int)
}
fn int_clamp_max(tensor: IntTensor<Self>, max: Scalar) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_clamp_max(tensor, max) => Int)
}
fn int_clamp(tensor: IntTensor<Self>, min: Scalar, max: Scalar) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_clamp(tensor, min, max) => Int)
}
fn int_neg(tensor: IntTensor<Self>) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_neg(tensor) => Int)
}
fn int_zeros(shape: Shape, device: &DispatchDevice, dtype: IntDType) -> IntTensor<Self> {
creation_op!(Int, device, |device| B::int_zeros(shape, device, dtype))
}
fn int_ones(shape: Shape, device: &DispatchDevice, dtype: IntDType) -> IntTensor<Self> {
creation_op!(Int, device, |device| B::int_ones(shape, device, dtype))
}
fn int_full(
shape: Shape,
fill_value: Scalar,
device: &DispatchDevice,
dtype: IntDType,
) -> IntTensor<Self> {
creation_op!(Int, device, |device| B::int_full(
shape, fill_value, device, dtype
))
}
fn int_mean(tensor: IntTensor<Self>) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_mean(tensor) => Int)
}
fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_max(tensor) => Int)
}
fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_max_dim(tensor, dim) => Int)
}
fn int_max_dim_with_indices(
tensor: IntTensor<Self>,
dim: usize,
) -> (IntTensor<Self>, IntTensor<Self>) {
multi_op!(
inputs[(tensor, int)],
outputs[(out, Int), (indices, Int)],
B::int_max_dim_with_indices(tensor, dim)
)
}
fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_max_abs(tensor) => Int)
}
fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_max_abs_dim(tensor, dim) => Int)
}
fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_min(tensor) => Int)
}
fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_min_dim(tensor, dim) => Int)
}
fn int_min_dim_with_indices(
tensor: IntTensor<Self>,
dim: usize,
) -> (IntTensor<Self>, IntTensor<Self>) {
multi_op!(
inputs[(tensor, int)],
outputs[(out, Int), (indices, Int)],
B::int_min_dim_with_indices(tensor, dim)
)
}
fn int_transpose(tensor: IntTensor<Self>) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_transpose(tensor) => Int)
}
fn int_arange_step(
range: std::ops::Range<i64>,
step: usize,
device: &DispatchDevice,
) -> IntTensor<Self> {
creation_op!(Int, device, |device| B::int_arange_step(
range, step, device
))
}
fn int_arange(range: std::ops::Range<i64>, device: &DispatchDevice) -> IntTensor<Self> {
creation_op!(Int, device, |device| B::int_arange(range, device))
}
fn int_any(tensor: IntTensor<Self>) -> BoolTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_any(tensor) => Bool)
}
fn int_any_dim(tensor: IntTensor<Self>, dim: usize) -> BoolTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_any_dim(tensor, dim) => Bool)
}
fn int_all(tensor: IntTensor<Self>) -> BoolTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_all(tensor) => Bool)
}
fn int_all_dim(tensor: IntTensor<Self>, dim: usize) -> BoolTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_all_dim(tensor, dim) => Bool)
}
fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_sign(tensor) => Int)
}
fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_sort(tensor, dim, descending) => Int)
}
fn int_sort_with_indices(
tensor: IntTensor<Self>,
dim: usize,
descending: bool,
) -> (IntTensor<Self>, IntTensor<Self>) {
multi_op!(
inputs[(tensor, int)],
outputs[(out, Int), (indices, Int)],
B::int_sort_with_indices(tensor, dim, descending)
)
}
fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
unary_op!(tensor, int, |tensor| B::int_argsort(tensor, dim, descending) => Int)
}
}

View File

@@ -0,0 +1,7 @@
mod activation;
mod bool_tensor;
mod int_tensor;
mod module;
mod qtensor;
mod tensor;
mod transaction;

View File

@@ -0,0 +1,628 @@
use burn_backend::{
ops::{
DeformConv2dBackward, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward,
MaxPool2dWithIndices, ModuleOps,
},
tensor::{FloatTensor, IntTensor},
};
use crate::Dispatch;
use crate::backends::*;
impl ModuleOps<Self> for Dispatch {
fn conv2d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: burn_backend::ops::ConvOptions<2>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (weight, float)],
opt_inputs[(bias, float)],
=> Float,
B::conv2d(x, weight, bias, options)
)
}
fn deform_conv2d(
x: FloatTensor<Self>,
offset: FloatTensor<Self>,
weight: FloatTensor<Self>,
mask: Option<FloatTensor<Self>>,
bias: Option<FloatTensor<Self>>,
options: burn_backend::ops::DeformConvOptions<2>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (offset, float), (weight, float)],
opt_inputs[(mask, float), (bias, float)],
=> Float,
B::deform_conv2d(x, offset, weight, mask, bias, options)
)
}
fn deform_conv2d_backward(
x: FloatTensor<Self>,
offset: FloatTensor<Self>,
weight: FloatTensor<Self>,
mask: Option<FloatTensor<Self>>,
bias: Option<FloatTensor<Self>>,
output_grad: FloatTensor<Self>,
options: burn_backend::ops::DeformConvOptions<2>,
) -> DeformConv2dBackward<Self> {
let (x_grad, offset_grad, weight_grad, mask_grad, bias_grad) = multi_op!(
inputs[(x, float), (offset, float), (weight, float), (output_grad, float)],
opt_inputs[(mask, float), (bias, float)],
outputs[(x_grad, Float), (offset_grad, Float), (weight_grad, Float)],
opt_outputs[mask_grad, bias_grad],
{
let res = B::deform_conv2d_backward(x, offset, weight, mask, bias, output_grad, options);
(res.x_grad, res.offset_grad, res.weight_grad, res.mask_grad, res.bias_grad)
}
);
DeformConv2dBackward::new(x_grad, offset_grad, weight_grad, mask_grad, bias_grad)
}
fn conv3d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: burn_backend::ops::ConvOptions<3>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (weight, float)],
opt_inputs[(bias, float)],
=> Float,
B::conv3d(x, weight, bias, options)
)
}
fn conv_transpose2d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: burn_backend::ops::ConvTransposeOptions<2>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (weight, float)],
opt_inputs[(bias, float)],
=> Float,
B::conv_transpose2d(x, weight, bias, options)
)
}
fn conv_transpose3d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: burn_backend::ops::ConvTransposeOptions<3>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (weight, float)],
opt_inputs[(bias, float)],
=> Float,
B::conv_transpose3d(x, weight, bias, options)
)
}
fn avg_pool2d(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
ceil_mode: bool,
) -> FloatTensor<Self> {
multi_op!(inputs[(x, float)],
=> Float,
B::avg_pool2d(x, kernel_size, stride, padding, count_include_pad, ceil_mode)
)
}
fn avg_pool2d_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
ceil_mode: bool,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (grad, float)],
=> Float,
B::avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad, ceil_mode)
)
}
fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float)],
=> Float,
B::adaptive_avg_pool2d(x, output_size)
)
}
fn adaptive_avg_pool2d_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (grad, float)],
=> Float,
B::adaptive_avg_pool2d_backward(x, grad)
)
}
fn max_pool2d(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float)],
=> Float,
B::max_pool2d(x, kernel_size, stride, padding, dilation, ceil_mode)
)
}
fn max_pool2d_with_indices(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
) -> MaxPool2dWithIndices<Self> {
let (out, indices) = multi_op!(
inputs[(x, float)],
outputs[(out, Float), (indices, Int)],
{
let res = B::max_pool2d_with_indices(x, kernel_size, stride, padding, dilation, ceil_mode);
(res.output, res.indices)
}
);
MaxPool2dWithIndices::new(out, indices)
}
fn max_pool2d_with_indices_backward(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
output_grad: FloatTensor<Self>,
indices: IntTensor<Self>,
) -> MaxPool2dBackward<Self> {
let x_grad = multi_op!(
inputs[(x, float), (output_grad, float), (indices, int)],
=> Float,
{
let res = B::max_pool2d_with_indices_backward(x, kernel_size, stride, padding, dilation, ceil_mode, output_grad, indices);
res.x_grad
}
);
MaxPool2dBackward::new(x_grad)
}
fn interpolate(
x: FloatTensor<Self>,
output_size: [usize; 2],
options: burn_backend::ops::InterpolateOptions,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float)],
=> Float,
B::interpolate(x, output_size, options)
)
}
fn interpolate_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
output_size: [usize; 2],
options: burn_backend::ops::InterpolateOptions,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (grad, float)],
=> Float,
B::interpolate_backward(x, grad, output_size, options)
)
}
fn embedding(weights: FloatTensor<Self>, indices: IntTensor<Self>) -> FloatTensor<Self> {
multi_op!(
inputs[(weights, float), (indices, int)],
=> Float,
B::embedding(weights, indices)
)
}
fn embedding_backward(
weights: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
indices: IntTensor<Self>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(weights, float), (output_grad, float), (indices, int)],
=> Float,
B::embedding_backward(weights, output_grad, indices)
)
}
fn conv1d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: burn_backend::ops::ConvOptions<1>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (weight, float)],
opt_inputs[(bias, float)],
=> Float,
B::conv1d(x, weight, bias, options)
)
}
fn conv1d_x_backward(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
options: burn_backend::ops::ConvOptions<1>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (weight, float), (output_grad, float)],
=> Float,
B::conv1d_x_backward(x, weight, output_grad, options)
)
}
fn conv1d_weight_backward(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
options: burn_backend::ops::ConvOptions<1>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (weight, float), (output_grad, float)],
=> Float,
B::conv1d_weight_backward(x, weight, output_grad, options)
)
}
fn conv1d_bias_backward(
x: FloatTensor<Self>,
bias: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (bias, float), (output_grad, float)],
=> Float,
B::conv1d_bias_backward(x, bias, output_grad)
)
}
fn conv2d_x_backward(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
options: burn_backend::ops::ConvOptions<2>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (weight, float), (output_grad, float)],
=> Float,
B::conv2d_x_backward(x, weight, output_grad, options)
)
}
fn conv2d_weight_backward(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
options: burn_backend::ops::ConvOptions<2>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (weight, float), (output_grad, float)],
=> Float,
B::conv2d_weight_backward(x, weight, output_grad, options)
)
}
fn conv2d_bias_backward(
x: FloatTensor<Self>,
bias: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (bias, float), (output_grad, float)],
=> Float,
B::conv2d_bias_backward(x, bias, output_grad)
)
}
fn conv3d_x_backward(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
options: burn_backend::ops::ConvOptions<3>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (weight, float), (output_grad, float)],
=> Float,
B::conv3d_x_backward(x, weight, output_grad, options)
)
}
fn conv3d_weight_backward(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
options: burn_backend::ops::ConvOptions<3>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (weight, float), (output_grad, float)],
=> Float,
B::conv3d_weight_backward(x, weight, output_grad, options)
)
}
fn conv3d_bias_backward(
x: FloatTensor<Self>,
bias: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (bias, float), (output_grad, float)],
=> Float,
B::conv3d_bias_backward(x, bias, output_grad)
)
}
fn conv_transpose1d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: burn_backend::ops::ConvTransposeOptions<1>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (weight, float)],
opt_inputs[(bias, float)],
=> Float,
B::conv_transpose1d(x, weight, bias, options)
)
}
fn conv_transpose1d_x_backward(
weight: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
options: burn_backend::ops::ConvTransposeOptions<1>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(weight, float), (output_grad, float)],
=> Float,
B::conv_transpose1d_x_backward(weight, output_grad, options)
)
}
fn conv_transpose1d_weight_backward(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
options: burn_backend::ops::ConvTransposeOptions<1>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (weight, float), (output_grad, float)],
=> Float,
B::conv_transpose1d_weight_backward(x, weight, output_grad, options)
)
}
fn conv_transpose1d_bias_backward(
x: FloatTensor<Self>,
bias: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (bias, float), (output_grad, float)],
=> Float,
B::conv_transpose1d_bias_backward(x, bias, output_grad)
)
}
fn conv_transpose2d_x_backward(
weight: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
options: burn_backend::ops::ConvTransposeOptions<2>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(weight, float), (output_grad, float)],
=> Float,
B::conv_transpose2d_x_backward(weight, output_grad, options)
)
}
fn conv_transpose2d_weight_backward(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
options: burn_backend::ops::ConvTransposeOptions<2>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (weight, float), (output_grad, float)],
=> Float,
B::conv_transpose2d_weight_backward(x, weight, output_grad, options)
)
}
fn conv_transpose2d_bias_backward(
x: FloatTensor<Self>,
bias: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (bias, float), (output_grad, float)],
=> Float,
B::conv_transpose2d_bias_backward(x, bias, output_grad)
)
}
fn conv_transpose3d_x_backward(
weight: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
options: burn_backend::ops::ConvTransposeOptions<3>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(weight, float), (output_grad, float)],
=> Float,
B::conv_transpose3d_x_backward(weight, output_grad, options)
)
}
fn conv_transpose3d_weight_backward(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
options: burn_backend::ops::ConvTransposeOptions<3>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (weight, float), (output_grad, float)],
=> Float,
B::conv_transpose3d_weight_backward(x, weight, output_grad, options)
)
}
fn conv_transpose3d_bias_backward(
x: FloatTensor<Self>,
bias: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (bias, float), (output_grad, float)],
=> Float,
B::conv_transpose3d_bias_backward(x, bias, output_grad)
)
}
fn unfold4d(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
options: burn_backend::ops::UnfoldOptions,
) -> FloatTensor<Self> {
multi_op!(inputs[(x, float)], => Float, B::unfold4d(x, kernel_size, options))
}
fn avg_pool1d(
x: FloatTensor<Self>,
kernel_size: usize,
stride: usize,
padding: usize,
count_include_pad: bool,
ceil_mode: bool,
) -> FloatTensor<Self> {
multi_op!(inputs[(x, float)], => Float,
B::avg_pool1d(x, kernel_size, stride, padding, count_include_pad, ceil_mode)
)
}
fn avg_pool1d_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
kernel_size: usize,
stride: usize,
padding: usize,
count_include_pad: bool,
ceil_mode: bool,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (grad, float)],
=> Float,
B::avg_pool1d_backward(x, grad, kernel_size, stride, padding, count_include_pad, ceil_mode)
)
}
fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {
multi_op!(inputs[(x, float)], => Float, B::adaptive_avg_pool1d(x, output_size))
}
fn adaptive_avg_pool1d_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(x, float), (grad, float)],
=> Float,
B::adaptive_avg_pool1d_backward(x, grad)
)
}
fn max_pool1d(
x: FloatTensor<Self>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
ceil_mode: bool,
) -> FloatTensor<Self> {
multi_op!(inputs[(x, float)], => Float,
B::max_pool1d(x, kernel_size, stride, padding, dilation, ceil_mode))
}
fn max_pool1d_with_indices(
x: FloatTensor<Self>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
ceil_mode: bool,
) -> MaxPool1dWithIndices<Self> {
let (out, indices) = multi_op!(
inputs[(x, float)],
outputs[(out, Float), (indices, Int)],
{
let res = B::max_pool1d_with_indices(x, kernel_size, stride, padding, dilation, ceil_mode);
(res.output, res.indices)
}
);
MaxPool1dWithIndices::new(out, indices)
}
fn max_pool1d_with_indices_backward(
x: FloatTensor<Self>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
ceil_mode: bool,
output_grad: FloatTensor<Self>,
indices: IntTensor<Self>,
) -> MaxPool1dBackward<Self> {
let x_grad = multi_op!(
inputs[(x, float), (output_grad, float), (indices, int)],
=> Float,
{
let res = B::max_pool1d_with_indices_backward(x, kernel_size, stride, padding, dilation, ceil_mode, output_grad, indices);
res.x_grad
}
);
MaxPool1dBackward::new(x_grad)
}
fn attention(
query: FloatTensor<Self>,
key: FloatTensor<Self>,
value: FloatTensor<Self>,
mask: Option<burn_backend::tensor::BoolTensor<Self>>,
attn_bias: Option<FloatTensor<Self>>,
options: burn_backend::ops::AttentionModuleOptions,
) -> FloatTensor<Self> {
multi_op!(
inputs[(query, float), (key, float), (value, float)],
opt_inputs[(mask, bool), (attn_bias, float)],
=> Float,
B::attention(query, key, value, mask, attn_bias, options)
)
}
}

View File

@@ -0,0 +1,212 @@
use burn_backend::{
ExecutionError, QTensorPrimitive, TensorData, TensorPrimitive,
ops::QTensorOps,
quantization::QuantizationParametersPrimitive,
tensor::{FloatTensor, IntTensor, QuantizedTensor},
};
use burn_std::{QuantPropagation, Shape, Slice};
use crate::backends::*;
use crate::{Dispatch, DispatchDevice};
impl QTensorOps<Self> for Dispatch {
fn q_from_data(data: TensorData, device: &DispatchDevice) -> QuantizedTensor<Self> {
creation_op!(Quantized, device, |device| B::q_from_data(data, device))
}
fn quantize(
tensor: FloatTensor<Self>,
scheme: &burn_std::QuantScheme,
qparams: QuantizationParametersPrimitive<Self>,
) -> QuantizedTensor<Self> {
binary_op!(
(tensor, float),
(qparams.scales, float),
|tensor, scales| {
B::quantize(tensor, scheme, QuantizationParametersPrimitive { scales })
} => Quantized
)
}
fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
unary_op!(tensor, quantized, |tensor| B::dequantize(tensor) => Float)
}
fn q_device(tensor: &QuantizedTensor<Self>) -> DispatchDevice {
tensor.device()
}
fn q_to_device(
tensor: QuantizedTensor<Self>,
device: &DispatchDevice,
) -> QuantizedTensor<Self> {
to_device!(
Quantized,
quantized,
tensor,
device,
q_to_device,
|inner, device| {
let data =
burn_backend::read_sync(B1::q_into_data(inner)).expect("Should read data");
B2::q_from_data(data, device)
}
)
}
fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
unary_op!(tensor, quantized, |tensor| B::q_reshape(tensor, shape) => Quantized)
}
async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
unary_op!(tensor, quantized, |tensor| B::q_into_data(tensor).await)
}
fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
unary_op!(tensor, quantized, |tensor| B::q_expand(tensor, shape) => Quantized)
}
fn q_swap_dims(
tensor: QuantizedTensor<Self>,
dim1: usize,
dim2: usize,
) -> QuantizedTensor<Self> {
unary_op!(tensor, quantized, |tensor| B::q_swap_dims(tensor, dim1, dim2) => Quantized)
}
fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
unary_op!(tensor, quantized, |tensor| B::q_permute(tensor, axes) => Quantized)
}
fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
unary_op!(tensor, quantized, |tensor| B::q_flip(tensor, axes) => Quantized)
}
fn q_select(
tensor: QuantizedTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
) -> QuantizedTensor<Self> {
binary_op!(
(tensor, quantized),
(indices, int),
|tensor, indices| B::q_select(tensor, dim, indices) => Quantized
)
}
fn q_slice(tensor: QuantizedTensor<Self>, slices: &[Slice]) -> QuantizedTensor<Self> {
unary_op!(tensor, quantized, |tensor| B::q_slice(tensor, slices) => Quantized)
}
fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {
// TODO: this would be much cleaner if we consolidated tensor primitive types
match (lhs, rhs) {
(TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
if matches!(lhs.propagation(), QuantPropagation::Propagate) {
let out = binary_op!(
(lhs, quantized),
(rhs, quantized),
|lhs, rhs| {
if let TensorPrimitive::QFloat(out) = B::q_matmul(
TensorPrimitive::QFloat(lhs),
TensorPrimitive::QFloat(rhs),
) {
out
} else {
unreachable!()
}
} => Quantized
);
TensorPrimitive::QFloat(out)
} else {
let out = binary_op!(
(lhs, quantized),
(rhs, quantized),
|lhs, rhs| {
if let TensorPrimitive::Float(out) = B::q_matmul(
TensorPrimitive::QFloat(lhs),
TensorPrimitive::QFloat(rhs),
) {
out
} else {
unreachable!()
}
} => Float
);
TensorPrimitive::Float(out)
}
}
(TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
if matches!(rhs.propagation(), QuantPropagation::Propagate) {
let out = binary_op!(
(lhs, float),
(rhs, quantized),
|lhs, rhs| {
if let TensorPrimitive::QFloat(out) = B::q_matmul(
TensorPrimitive::Float(lhs),
TensorPrimitive::QFloat(rhs),
) {
out
} else {
unreachable!()
}
} => Quantized
);
TensorPrimitive::QFloat(out)
} else {
let out = binary_op!(
(lhs, float),
(rhs, quantized),
|lhs, rhs| {
if let TensorPrimitive::Float(out) = B::q_matmul(
TensorPrimitive::Float(lhs),
TensorPrimitive::QFloat(rhs),
) {
out
} else {
unreachable!()
}
} => Float
);
TensorPrimitive::Float(out)
}
}
(TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
if matches!(lhs.propagation(), QuantPropagation::Propagate) {
let out = binary_op!(
(lhs, quantized),
(rhs, float),
|lhs, rhs| {
if let TensorPrimitive::QFloat(out) = B::q_matmul(
TensorPrimitive::QFloat(lhs),
TensorPrimitive::Float(rhs),
) {
out
} else {
unreachable!()
}
} => Quantized
);
TensorPrimitive::QFloat(out)
} else {
let out = binary_op!(
(lhs, quantized),
(rhs, float),
|lhs, rhs| {
if let TensorPrimitive::Float(out) = B::q_matmul(
TensorPrimitive::QFloat(lhs),
TensorPrimitive::Float(rhs),
) {
out
} else {
unreachable!()
}
} => Float
);
TensorPrimitive::Float(out)
}
}
_ => unreachable!(),
}
}
}

View File

@@ -0,0 +1,594 @@
use burn_backend::{
ExecutionError, Scalar, TensorData,
ops::FloatTensorOps,
tensor::{BoolTensor, FloatTensor, IntTensor},
};
use burn_std::{FloatDType, Shape, Slice};
use crate::backends::*;
use crate::{Dispatch, DispatchDevice};
// TODO: remove backend default elem type genericsnow that we have per-device defaults
// https://github.com/tracel-ai/burn/issues/3642
impl FloatTensorOps<Self> for Dispatch {
fn float_from_data(
data: burn_backend::TensorData,
device: &DispatchDevice,
) -> FloatTensor<Self> {
creation_op!(Float, device, |device| B::float_from_data(data, device))
}
fn float_random(
shape: Shape,
distribution: burn_backend::Distribution,
device: &DispatchDevice,
) -> FloatTensor<Self> {
creation_op!(Float, device, |device| {
B::float_random(shape, distribution, device)
})
}
async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
unary_float!(tensor, float, |tensor| B::float_into_data(tensor).await)
}
fn float_device(tensor: &FloatTensor<Self>) -> DispatchDevice {
tensor.device()
}
fn float_to_device(tensor: FloatTensor<Self>, device: &DispatchDevice) -> FloatTensor<Self> {
float_to_device!(
Float,
float,
tensor,
device,
float_to_device,
|inner, device| {
let data =
burn_backend::read_sync(B1::float_into_data(inner)).expect("Should read data");
B2::float_from_data(data, device)
}
)
}
fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_into_int(tensor) => Int)
}
fn float_empty(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
creation_op!(Float, device, |device| B::float_empty(shape, device, dtype))
}
fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_add(lhs, rhs) => Float)
}
fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
unary_float!(lhs, float, |lhs| B::float_add_scalar(lhs, rhs) => Float)
}
fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_sub(lhs, rhs) => Float)
}
fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
unary_float!(lhs, float, |lhs| B::float_sub_scalar(lhs, rhs) => Float)
}
fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_mul(lhs, rhs) => Float)
}
fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
unary_float!(lhs, float, |lhs| B::float_mul_scalar(lhs, rhs) => Float)
}
fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_div(lhs, rhs) => Float)
}
fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
unary_float!(lhs, float, |lhs| B::float_div_scalar(lhs, rhs) => Float)
}
fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_remainder(lhs, rhs) => Float)
}
fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
unary_float!(lhs, float, |lhs| B::float_remainder_scalar(lhs, rhs) => Float)
}
fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_matmul(lhs, rhs) => Float)
}
fn float_cross(
lhs: FloatTensor<Self>,
rhs: FloatTensor<Self>,
dim: usize,
) -> FloatTensor<Self> {
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_cross(lhs, rhs, dim) => Float)
}
fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_recip(tensor) => Float)
}
fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_swap_dims(tensor, dim1, dim2) => Float)
}
fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_permute(tensor, axes) => Float)
}
fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_flip(tensor, axes) => Float)
}
fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_reshape(tensor, shape) => Float)
}
fn float_gather(
dim: usize,
tensor: FloatTensor<Self>,
indices: IntTensor<Self>,
) -> FloatTensor<Self> {
binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_gather(dim, tensor, indices) => Float)
}
fn float_scatter_add(
dim: usize,
tensor: FloatTensor<Self>,
indices: IntTensor<Self>,
value: FloatTensor<Self>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(tensor, float), (indices, int), (value, float)], => Float,
B::float_scatter_add(dim, tensor, indices, value)
)
}
fn float_select(
tensor: FloatTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
) -> FloatTensor<Self> {
binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_select(tensor, dim, indices) => Float)
}
fn float_select_add(
tensor: FloatTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
value: FloatTensor<Self>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(tensor, float), (indices, int), (value, float)], => Float,
B::float_select_add(tensor, dim, indices, value)
)
}
fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_slice(tensor, slices) => Float)
}
fn float_slice_assign(
tensor: FloatTensor<Self>,
slices: &[Slice],
value: FloatTensor<Self>,
) -> FloatTensor<Self> {
binary_float!((tensor, float), (value, float), |tensor, value| B::float_slice_assign(tensor, slices, value) => Float)
}
fn float_mask_where(
tensor: FloatTensor<Self>,
mask: BoolTensor<Self>,
value: FloatTensor<Self>,
) -> FloatTensor<Self> {
multi_op!(
inputs[(tensor, float), (mask, bool), (value, float)], => Float,
B::float_mask_where(tensor, mask, value)
)
}
fn float_mask_fill(
tensor: FloatTensor<Self>,
mask: BoolTensor<Self>,
value: Scalar,
) -> FloatTensor<Self> {
binary_float!((tensor, float), (mask, bool), |tensor, mask| B::float_mask_fill(tensor, mask, value) => Float)
}
fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_equal(lhs, rhs) => Bool)
}
fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
unary_float!(lhs, float, |lhs| B::float_equal_elem(lhs, rhs) => Bool)
}
fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater(lhs, rhs) => Bool)
}
fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
unary_float!(lhs, float, |lhs| B::float_greater_elem(lhs, rhs) => Bool)
}
fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater_equal(lhs, rhs) => Bool)
}
fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
unary_float!(lhs, float, |lhs| B::float_greater_equal_elem(lhs, rhs) => Bool)
}
fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower(lhs, rhs) => Bool)
}
fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
unary_float!(lhs, float, |lhs| B::float_lower_elem(lhs, rhs) => Bool)
}
fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower_equal(lhs, rhs) => Bool)
}
fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
unary_float!(lhs, float, |lhs| B::float_lower_equal_elem(lhs, rhs) => Bool)
}
fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_sum(tensor) => Float)
}
fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_sum_dim(tensor, dim) => Float)
}
fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_mean_dim(tensor, dim) => Float)
}
fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_cumsum(tensor, dim) => Float)
}
fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_cumprod(tensor, dim) => Float)
}
fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_cummin(tensor, dim) => Float)
}
fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_cummax(tensor, dim) => Float)
}
fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_cast(tensor, dtype) => Float)
}
fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_exp(tensor) => Float)
}
fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_log(tensor) => Float)
}
fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_log1p(tensor) => Float)
}
fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_powf(lhs, rhs) => Float)
}
fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_powf_scalar_impl(tensor, value) => Float)
}
fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_sqrt(tensor) => Float)
}
fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_abs(tensor) => Float)
}
fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_cos(tensor) => Float)
}
fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_sin(tensor) => Float)
}
fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_tan(tensor) => Float)
}
fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_cosh(tensor) => Float)
}
fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_sinh(tensor) => Float)
}
fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_tanh(tensor) => Float)
}
fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_acos(tensor) => Float)
}
fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_acosh(tensor) => Float)
}
fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_asin(tensor) => Float)
}
fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_asinh(tensor) => Float)
}
fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_atan(tensor) => Float)
}
fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_atanh(tensor) => Float)
}
fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_atan2(lhs, rhs) => Float)
}
fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_round(tensor) => Float)
}
fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_floor(tensor) => Float)
}
fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_ceil(tensor) => Float)
}
fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_trunc(tensor) => Float)
}
fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_erf(tensor) => Float)
}
fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_argmax(tensor, dim) => Int)
}
fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_argmin(tensor, dim) => Int)
}
fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_expand(tensor, shape) => Float)
}
fn float_unfold(
tensor: FloatTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| {
B::float_unfold(tensor, dim, size, step)
} => Float)
}
fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_detach(tensor) => Float)
}
fn float_set_require_grad(tensor: FloatTensor<Self>, require_grad: bool) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_set_require_grad(tensor, require_grad) => Float)
}
fn float_is_require_grad(tensor: &FloatTensor<Self>) -> bool {
unary_float!(ref tensor, float, |tensor| B::float_is_require_grad(tensor))
}
// Default implementation
fn float_zeros(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
creation_op!(Float, device, |device| B::float_zeros(shape, device, dtype))
}
fn float_ones(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
creation_op!(Float, device, |device| B::float_ones(shape, device, dtype))
}
fn float_full(
shape: Shape,
fill_value: Scalar,
device: &DispatchDevice,
dtype: FloatDType,
) -> FloatTensor<Self> {
creation_op!(Float, device, |device| B::float_full(
shape, fill_value, device, dtype
))
}
fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_repeat_dim(tensor, dim, times) => Float)
}
fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_clamp_min(tensor, min) => Float)
}
fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_clamp_max(tensor, max) => Float)
}
fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_clamp(tensor, min, max) => Float)
}
fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_neg(tensor) => Float)
}
fn float_transpose(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_transpose(tensor) => Float)
}
fn float_not_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_not_equal(lhs, rhs) => Bool)
}
fn float_not_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
unary_float!(lhs, float, |lhs| B::float_not_equal_elem(lhs, rhs) => Bool)
}
fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_prod(tensor) => Float)
}
fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_prod_dim(tensor, dim) => Float)
}
fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_mean(tensor) => Float)
}
fn float_powi(lhs: FloatTensor<Self>, rhs: IntTensor<Self>) -> FloatTensor<Self> {
binary_float!((lhs, float), (rhs, int), |lhs, rhs| B::float_powi(lhs, rhs) => Float)
}
fn float_powi_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
unary_float!(lhs, float, |lhs| B::float_powi_scalar_impl(lhs, rhs) => Float)
}
fn float_powf_scalar(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_powf_scalar(tensor, value) => Float)
}
fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
vec_op!(tensors, float, |tensors| B::float_cat(tensors, dim) => Float)
}
fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_max(tensor) => Float)
}
fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_max_dim(tensor, dim) => Float)
}
fn float_max_dim_with_indices(
tensor: FloatTensor<Self>,
dim: usize,
) -> (FloatTensor<Self>, IntTensor<Self>) {
multi_op!(
inputs[(tensor, float)],
outputs[(out, Float), (indices, Int)],
B::float_max_dim_with_indices(tensor, dim)
)
}
fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_min(tensor) => Float)
}
fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_min_dim(tensor, dim) => Float)
}
fn float_min_dim_with_indices(
tensor: FloatTensor<Self>,
dim: usize,
) -> (FloatTensor<Self>, IntTensor<Self>) {
multi_op!(
inputs[(tensor, float)],
outputs[(out, Float), (indices, Int)],
B::float_min_dim_with_indices(tensor, dim)
)
}
fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_max_abs(tensor) => Float)
}
fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_max_abs_dim(tensor, dim) => Float)
}
fn float_any(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_any(tensor) => Bool)
}
fn float_any_dim(tensor: FloatTensor<Self>, dim: usize) -> BoolTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_any_dim(tensor, dim) => Bool)
}
fn float_all(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_all(tensor) => Bool)
}
fn float_all_dim(tensor: FloatTensor<Self>, dim: usize) -> BoolTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_all_dim(tensor, dim) => Bool)
}
fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_sign(tensor) => Float)
}
fn float_sort(tensor: FloatTensor<Self>, dim: usize, descending: bool) -> FloatTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_sort(tensor, dim, descending) => Float)
}
fn float_sort_with_indices(
tensor: FloatTensor<Self>,
dim: usize,
descending: bool,
) -> (FloatTensor<Self>, IntTensor<Self>) {
multi_op!(
inputs[(tensor, float)],
outputs[(out, Float), (indices, Int)],
B::float_sort_with_indices(tensor, dim, descending)
)
}
fn float_argsort(tensor: FloatTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_argsort(tensor, dim, descending) => Int)
}
fn float_grid_sample_2d(
tensor: FloatTensor<Self>,
grid: FloatTensor<Self>,
options: burn_backend::ops::GridSampleOptions,
) -> FloatTensor<Self> {
binary_float!((tensor, float), (grid, float), |tensor, grid| B::float_grid_sample_2d(tensor, grid, options) => Float)
}
fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_is_nan(tensor) => Bool)
}
fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
unary_float!(tensor, float, |tensor| B::float_is_inf(tensor) => Bool)
}
}

View File

@@ -0,0 +1,26 @@
use burn_backend::{
ExecutionError,
ops::{TransactionOps, TransactionPrimitive, TransactionPrimitiveData},
};
use crate::Dispatch;
use crate::backends::*;
impl TransactionOps<Self> for Dispatch {
async fn tr_execute(
transaction: TransactionPrimitive<Self>,
) -> Result<TransactionPrimitiveData, ExecutionError> {
let first_tensor = transaction
.read_floats
.first()
.or(transaction.read_ints.first())
.or(transaction.read_bools.first());
match first_tensor {
Some(tensor) => {
transaction_op!(transaction, tensor)
}
None => Ok(TransactionPrimitiveData::default()),
}
}
}

View File

@@ -0,0 +1,274 @@
use burn_backend::{Backend, QTensorPrimitive, TensorMetadata};
use crate::backends::*;
#[cfg(feature = "autodiff")]
use burn_backend::tensor::FloatTensor;
// TODO: if we reduce the different associated types for float/int/bool/quantized tensor primitives down to a single
// `B::TensorPrimitive` we can simplify this.
/// Tensor which points to a backend tensor primitive kind.
#[derive(Clone, Debug)]
pub enum BackendTensor<B: Backend> {
/// Float tensor handle.
Float(B::FloatTensorPrimitive),
/// Int tensor handle.
Int(B::IntTensorPrimitive),
/// Bool tensor handle.
Bool(B::BoolTensorPrimitive),
/// Quantized tensor handle.
Quantized(B::QuantizedTensorPrimitive),
#[cfg(feature = "autodiff")]
/// Autodiff float tensor handle.
Autodiff(FloatTensor<Autodiff<B>>),
}
impl<B: Backend> BackendTensor<B> {
/// Returns the inner float tensor primitive.
pub(crate) fn float(self) -> B::FloatTensorPrimitive {
match self {
BackendTensor::Float(tensor) => tensor,
BackendTensor::Int(_) => panic!("Should be float, got int"),
BackendTensor::Bool(_) => panic!("Should be float, got bool"),
BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
}
}
/// Returns the inner float tensor primitive.
pub(crate) fn as_float(&self) -> &B::FloatTensorPrimitive {
match self {
BackendTensor::Float(tensor) => tensor,
BackendTensor::Int(_) => panic!("Should be float, got int"),
BackendTensor::Bool(_) => panic!("Should be float, got bool"),
BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
}
}
/// Returns the inner int tensor primitive.
pub(crate) fn int(self) -> B::IntTensorPrimitive {
match self {
BackendTensor::Int(tensor) => tensor,
BackendTensor::Float(_) => panic!("Should be int, got float"),
BackendTensor::Bool(_) => panic!("Should be int, got bool"),
BackendTensor::Quantized(_) => panic!("Should be int, got quantized"),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(_) => panic!("Should be int, got autodiff"),
}
}
/// Returns the inner bool tensor primitive.
pub(crate) fn bool(self) -> B::BoolTensorPrimitive {
match self {
BackendTensor::Bool(tensor) => tensor,
BackendTensor::Float(_) => panic!("Should be bool, got float"),
BackendTensor::Int(_) => panic!("Should be bool, got int"),
BackendTensor::Quantized(_) => panic!("Should be bool, got quantized"),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(_) => panic!("Should be bool, got autodiff"),
}
}
/// Returns the inner quantized tensor primitive.
pub(crate) fn quantized(self) -> B::QuantizedTensorPrimitive {
match self {
BackendTensor::Quantized(tensor) => tensor,
_ => unreachable!(),
}
}
#[cfg(feature = "autodiff")]
/// Returns the inner autodiff tensor primitive.
pub(crate) fn autodiff(self) -> FloatTensor<Autodiff<B>> {
match self {
BackendTensor::Autodiff(tensor) => tensor,
// NOTE: this is the panicking code reached in tensor.rs:74:18:
_ => unreachable!(),
}
}
#[cfg(feature = "autodiff")]
/// Returns the inner autodiff tensor primitive.
pub(crate) fn as_autodiff(&self) -> &FloatTensor<Autodiff<B>> {
match self {
BackendTensor::Autodiff(tensor) => tensor,
_ => unreachable!(),
}
}
#[cfg(feature = "autodiff")]
/// Returns the inner autodiff tensor primitive.
pub(crate) fn autodiff_inner(self) -> B::FloatTensorPrimitive {
match self {
BackendTensor::Autodiff(tensor) => tensor.primitive,
_ => unreachable!(),
}
}
/// Returns the backend device.
pub(crate) fn device(&self) -> B::Device {
match self {
BackendTensor::Float(tensor) => B::float_device(tensor),
BackendTensor::Int(tensor) => B::int_device(tensor),
BackendTensor::Bool(tensor) => B::bool_device(tensor),
BackendTensor::Quantized(tensor) => B::q_device(tensor),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(tensor) => B::float_device(&tensor.primitive),
}
}
}
impl<B: Backend> TensorMetadata for BackendTensor<B> {
fn dtype(&self) -> burn_std::DType {
match self {
BackendTensor::Float(tensor) => tensor.dtype(),
BackendTensor::Int(tensor) => tensor.dtype(),
BackendTensor::Bool(tensor) => tensor.dtype(),
BackendTensor::Quantized(tensor) => tensor.dtype(),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(tensor) => tensor.dtype(),
}
}
fn shape(&self) -> burn_std::Shape {
match self {
BackendTensor::Float(tensor) => tensor.shape(),
BackendTensor::Int(tensor) => tensor.shape(),
BackendTensor::Bool(tensor) => tensor.shape(),
BackendTensor::Quantized(tensor) => tensor.shape(),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(tensor) => tensor.shape(),
}
}
}
impl<B: Backend> QTensorPrimitive for BackendTensor<B> {
fn scheme(&self) -> &burn_std::QuantScheme {
match self {
BackendTensor::Quantized(tensor) => tensor.scheme(),
_ => panic!(
"Quantization scheme is not valid for dtype {:?}",
self.dtype(),
),
}
}
}
/// Dispatch tensor that can hold tensors from any enabled backend.
///
/// This enum wraps backend-specific tensor types, allowing runtime selection
/// of the backend to execute operations on.
#[derive(Clone, Debug)]
pub enum DispatchTensor {
/// The [CPU backend](Cpu) tensor.
#[cfg(feature = "cpu")]
Cpu(BackendTensor<Cpu>),
/// The [CUDA backend](Cuda) tensor.
#[cfg(feature = "cuda")]
Cuda(BackendTensor<Cuda>),
/// The [Metal backend](Metal) tensor.
#[cfg(wgpu_metal)]
Metal(BackendTensor<Metal>),
/// The [ROCm backend](Rocm) tensor.
#[cfg(feature = "rocm")]
Rocm(BackendTensor<Rocm>),
/// The [Vulkan backend](Vulkan) tensor.
#[cfg(wgpu_vulkan)]
Vulkan(BackendTensor<Vulkan>),
/// The [WebGPU backend](WebGpu) tensor.
#[cfg(wgpu_webgpu)]
WebGpu(BackendTensor<WebGpu>),
/// The [NdArray backend](NdArray) tensor.
#[cfg(feature = "ndarray")]
NdArray(BackendTensor<NdArray>),
/// The [LibTorch backend](LibTorch) tensor.
#[cfg(feature = "tch")]
LibTorch(BackendTensor<LibTorch>),
/// The [autodiff enabled backend](Autodiff) tensor.
#[cfg(feature = "autodiff")]
Autodiff(Box<DispatchTensor>),
}
impl TensorMetadata for DispatchTensor {
fn dtype(&self) -> burn_std::DType {
match self {
#[cfg(feature = "cpu")]
DispatchTensor::Cpu(tensor) => tensor.dtype(),
#[cfg(feature = "cuda")]
DispatchTensor::Cuda(tensor) => tensor.dtype(),
#[cfg(wgpu_metal)]
DispatchTensor::Metal(tensor) => tensor.dtype(),
#[cfg(feature = "rocm")]
DispatchTensor::Rocm(tensor) => tensor.dtype(),
#[cfg(wgpu_vulkan)]
DispatchTensor::Vulkan(tensor) => tensor.dtype(),
#[cfg(wgpu_webgpu)]
DispatchTensor::WebGpu(tensor) => tensor.dtype(),
#[cfg(feature = "ndarray")]
DispatchTensor::NdArray(tensor) => tensor.dtype(),
#[cfg(feature = "tch")]
DispatchTensor::LibTorch(tensor) => tensor.dtype(),
#[cfg(feature = "autodiff")]
DispatchTensor::Autodiff(tensor) => tensor.dtype(),
}
}
fn shape(&self) -> burn_std::Shape {
match self {
#[cfg(feature = "cpu")]
DispatchTensor::Cpu(tensor) => tensor.shape(),
#[cfg(feature = "cuda")]
DispatchTensor::Cuda(tensor) => tensor.shape(),
#[cfg(wgpu_metal)]
DispatchTensor::Metal(tensor) => tensor.shape(),
#[cfg(feature = "rocm")]
DispatchTensor::Rocm(tensor) => tensor.shape(),
#[cfg(wgpu_vulkan)]
DispatchTensor::Vulkan(tensor) => tensor.shape(),
#[cfg(wgpu_webgpu)]
DispatchTensor::WebGpu(tensor) => tensor.shape(),
#[cfg(feature = "ndarray")]
DispatchTensor::NdArray(tensor) => tensor.shape(),
#[cfg(feature = "tch")]
DispatchTensor::LibTorch(tensor) => tensor.shape(),
#[cfg(feature = "autodiff")]
DispatchTensor::Autodiff(tensor) => tensor.shape(),
}
}
}
impl QTensorPrimitive for DispatchTensor {
fn scheme(&self) -> &burn_std::QuantScheme {
match self {
#[cfg(feature = "cpu")]
DispatchTensor::Cpu(tensor) => tensor.scheme(),
#[cfg(feature = "cuda")]
DispatchTensor::Cuda(tensor) => tensor.scheme(),
#[cfg(wgpu_metal)]
DispatchTensor::Metal(tensor) => tensor.scheme(),
#[cfg(feature = "rocm")]
DispatchTensor::Rocm(tensor) => tensor.scheme(),
#[cfg(wgpu_vulkan)]
DispatchTensor::Vulkan(tensor) => tensor.scheme(),
#[cfg(wgpu_webgpu)]
DispatchTensor::WebGpu(tensor) => tensor.scheme(),
#[cfg(feature = "ndarray")]
DispatchTensor::NdArray(tensor) => tensor.scheme(),
#[cfg(feature = "tch")]
DispatchTensor::LibTorch(tensor) => tensor.scheme(),
#[cfg(feature = "autodiff")]
DispatchTensor::Autodiff(tensor) => tensor.scheme(),
}
}
}