feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
@@ -0,0 +1,84 @@
|
||||
[package]
|
||||
authors = [
|
||||
"laggui <lagrange.guillaume.1@gmail.com>",
|
||||
"nathanielsimard <nathaniel.simard.42@gmail.com>",
|
||||
]
|
||||
categories = ["science"]
|
||||
description = "Backend dispatch for the Burn framework"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "data"]
|
||||
license.workspace = true
|
||||
name = "burn-dispatch"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-dispatch"
|
||||
documentation = "https://docs.rs/burn-dispatch"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = [
|
||||
"std",
|
||||
"ndarray",
|
||||
"burn-autodiff?/default",
|
||||
"burn-cpu?/default",
|
||||
"burn-cuda?/default",
|
||||
"burn-ndarray?/default",
|
||||
"burn-rocm?/default",
|
||||
"burn-tch?/default",
|
||||
"burn-wgpu?/default",
|
||||
]
|
||||
doc = ["default"]
|
||||
std = [
|
||||
"burn-backend/std",
|
||||
"burn-std/std",
|
||||
"burn-autodiff?/std",
|
||||
"burn-cpu?/std",
|
||||
"burn-cuda?/std",
|
||||
"burn-ndarray?/std",
|
||||
"burn-rocm?/std",
|
||||
"burn-tch?/std",
|
||||
"burn-wgpu?/std",
|
||||
]
|
||||
tracing = [
|
||||
"burn-autodiff?/tracing",
|
||||
"burn-cpu?/tracing",
|
||||
"burn-cuda?/tracing",
|
||||
"burn-ndarray?/tracing",
|
||||
"burn-rocm?/tracing",
|
||||
"burn-tch?/tracing",
|
||||
"burn-wgpu?/tracing",
|
||||
]
|
||||
|
||||
# Backends
|
||||
cuda = ["burn-cuda"]
|
||||
rocm = ["burn-rocm"]
|
||||
ndarray = ["burn-ndarray"]
|
||||
tch = ["burn-tch"]
|
||||
vulkan = ["wgpu", "burn-wgpu/vulkan"]
|
||||
webgpu = ["wgpu", "burn-wgpu/webgpu"]
|
||||
metal = ["wgpu", "burn-wgpu/metal"]
|
||||
wgpu = ["burn-wgpu"]
|
||||
cpu = ["burn-cpu"]
|
||||
autodiff = ["burn-autodiff"]
|
||||
|
||||
[dependencies]
|
||||
burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false }
|
||||
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false }
|
||||
|
||||
# Backends
|
||||
burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-cpu = { path = "../burn-cpu", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-rocm = { path = "../burn-rocm", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true, default-features = false }
|
||||
|
||||
# Op macros with `.as_$inner_kind()`
|
||||
paste = { workspace = true }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
@@ -0,0 +1,3 @@
|
||||
# Burn Backend Dispatch
|
||||
|
||||
A multi-backend dispatch that forwards the tensor operations to the appropriate backend.
|
||||
@@ -0,0 +1,35 @@
|
||||
fn main() {
|
||||
println!("cargo::rustc-check-cfg=cfg(wgpu_metal)");
|
||||
println!("cargo::rustc-check-cfg=cfg(wgpu_vulkan)");
|
||||
println!("cargo::rustc-check-cfg=cfg(wgpu_webgpu)");
|
||||
|
||||
// Detect which single wgpu backend is enabled
|
||||
let metal = cfg!(feature = "metal");
|
||||
let vulkan = cfg!(feature = "vulkan");
|
||||
let webgpu = cfg!(feature = "webgpu");
|
||||
let enabled = [(metal, "metal"), (vulkan, "vulkan"), (webgpu, "webgpu")]
|
||||
.iter()
|
||||
.filter(|x| x.0)
|
||||
.map(|x| x.1)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// WGPU features are mutually exclusive, but we don't want to workspace to throw a compile error.
|
||||
// In workspace builds with multiple features, we emit a warning and disable all WGPU backends.
|
||||
if enabled.len() > 1 {
|
||||
println!(
|
||||
"cargo:warning=Only one WGPU backend can be enabled at once. Detected: [{}]. No WGPU backend will be available in this build. This is expected in workspace builds. For production, enable only one of: metal, vulkan, or webgpu.",
|
||||
enabled.join(", ")
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if metal {
|
||||
println!("cargo:rustc-cfg=wgpu_metal");
|
||||
}
|
||||
if vulkan {
|
||||
println!("cargo:rustc-cfg=wgpu_vulkan");
|
||||
}
|
||||
if webgpu {
|
||||
println!("cargo:rustc-cfg=wgpu_webgpu");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,392 @@
|
||||
use alloc::format;
|
||||
use alloc::string::String;
|
||||
|
||||
use burn_backend::Backend;
|
||||
use burn_backend::ExecutionError;
|
||||
use burn_std::DType;
|
||||
|
||||
#[cfg(feature = "autodiff")]
|
||||
use burn_autodiff::grads::Gradients;
|
||||
#[cfg(feature = "autodiff")]
|
||||
use burn_backend::AutodiffBackend;
|
||||
|
||||
use crate::backends::*;
|
||||
use crate::{DispatchDevice, DispatchTensor};
|
||||
|
||||
/// The main execution backend in Burn.
|
||||
///
|
||||
/// [`Dispatch`] acts as a global backend that can manage multiple underlying
|
||||
/// backends (e.g., `Cpu`, `Cuda`, `Wgpu`, `Metal`, etc.).
|
||||
/// It is responsible for:
|
||||
/// - Dispatching tensor operations to the appropriate backend.
|
||||
/// - Managing cross-backend tensor transfers.
|
||||
///
|
||||
/// Essentially, [`Dispatch`] is the single entry point for executing tensor operations
|
||||
/// in a backend-agnostic way. It allows Burn to provide a unified, global backend
|
||||
/// for users while still leveraging multiple specialized backends under the hood.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// use burn::Dispatch;
|
||||
/// use burn::DispatchDevice;
|
||||
///
|
||||
/// // Select the device to execute operations on
|
||||
/// let device = DispatchDevice::Cuda(Default::default());
|
||||
///
|
||||
/// // Create a tensor using the global backend
|
||||
/// let t = Tensor::<Dispatch, 2>::zeros([128, 128], &device);
|
||||
/// ```
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct Dispatch;
|
||||
|
||||
impl Backend for Dispatch {
|
||||
type Device = DispatchDevice;
|
||||
|
||||
type FloatTensorPrimitive = DispatchTensor;
|
||||
|
||||
// TODO: either allow default dtype generic or remove associated types entirely?
|
||||
type FloatElem = f32;
|
||||
|
||||
type IntTensorPrimitive = DispatchTensor;
|
||||
|
||||
type IntElem = i32;
|
||||
|
||||
type BoolTensorPrimitive = DispatchTensor;
|
||||
|
||||
type BoolElem = u8;
|
||||
|
||||
type QuantizedTensorPrimitive = DispatchTensor;
|
||||
|
||||
fn name(device: &Self::Device) -> String {
|
||||
let inner = dispatch_device!(device, |device| B::name(device));
|
||||
format!("dispatch<{inner}>")
|
||||
}
|
||||
|
||||
fn seed(device: &Self::Device, seed: u64) {
|
||||
dispatch_device!(device, |device| B::seed(device, seed))
|
||||
}
|
||||
|
||||
fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
|
||||
dispatch_device!(device, |device| B::sync(device))
|
||||
}
|
||||
|
||||
fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
|
||||
dispatch_device!(device, |device| B::dtype_usage(device, dtype))
|
||||
}
|
||||
|
||||
fn ad_enabled(device: &Self::Device) -> bool {
|
||||
match device {
|
||||
#[cfg(feature = "autodiff")]
|
||||
DispatchDevice::Autodiff(_) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "autodiff")]
|
||||
impl AutodiffBackend for Dispatch {
|
||||
type InnerBackend = Dispatch;
|
||||
|
||||
type Gradients = Gradients;
|
||||
|
||||
fn backward(tensor: DispatchTensor) -> Self::Gradients {
|
||||
match tensor {
|
||||
#[cfg(feature = "autodiff")]
|
||||
DispatchTensor::Autodiff(tensor) => match *tensor {
|
||||
#[cfg(feature = "cpu")]
|
||||
DispatchTensor::Cpu(tensor) => tensor.autodiff().backward(),
|
||||
#[cfg(feature = "cuda")]
|
||||
DispatchTensor::Cuda(tensor) => tensor.autodiff().backward(),
|
||||
#[cfg(wgpu_metal)]
|
||||
DispatchTensor::Metal(tensor) => tensor.autodiff().backward(),
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchTensor::Rocm(tensor) => tensor.autodiff().backward(),
|
||||
#[cfg(wgpu_vulkan)]
|
||||
DispatchTensor::Vulkan(tensor) => tensor.autodiff().backward(),
|
||||
#[cfg(wgpu_webgpu)]
|
||||
DispatchTensor::WebGpu(tensor) => tensor.autodiff().backward(),
|
||||
#[cfg(feature = "ndarray")]
|
||||
DispatchTensor::NdArray(tensor) => tensor.autodiff().backward(),
|
||||
DispatchTensor::Autodiff(_) => {
|
||||
panic!("Autodiff should not wrap an autodiff tensor.")
|
||||
}
|
||||
},
|
||||
_ => panic!("Requires autodiff tensor."),
|
||||
}
|
||||
}
|
||||
|
||||
fn grad(tensor: &DispatchTensor, grads: &Self::Gradients) -> Option<DispatchTensor> {
|
||||
match &tensor {
|
||||
#[cfg(feature = "autodiff")]
|
||||
DispatchTensor::Autodiff(tensor) => match &**tensor {
|
||||
#[cfg(feature = "cpu")]
|
||||
DispatchTensor::Cpu(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad(grads)
|
||||
.map(|t| DispatchTensor::Cpu(crate::BackendTensor::Float(t))),
|
||||
#[cfg(feature = "cuda")]
|
||||
DispatchTensor::Cuda(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad(grads)
|
||||
.map(|t| DispatchTensor::Cuda(crate::BackendTensor::Float(t))),
|
||||
#[cfg(wgpu_metal)]
|
||||
DispatchTensor::Metal(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad(grads)
|
||||
.map(|t| DispatchTensor::Metal(crate::BackendTensor::Float(t))),
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchTensor::Rocm(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad(grads)
|
||||
.map(|t| DispatchTensor::Rocm(crate::BackendTensor::Float(t))),
|
||||
#[cfg(wgpu_vulkan)]
|
||||
DispatchTensor::Vulkan(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad(grads)
|
||||
.map(|t| DispatchTensor::Vulkan(crate::BackendTensor::Float(t))),
|
||||
#[cfg(wgpu_webgpu)]
|
||||
DispatchTensor::WebGpu(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad(grads)
|
||||
.map(|t| DispatchTensor::WebGpu(crate::BackendTensor::Float(t))),
|
||||
#[cfg(feature = "ndarray")]
|
||||
DispatchTensor::NdArray(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad(grads)
|
||||
.map(|t| DispatchTensor::NdArray(crate::BackendTensor::Float(t))),
|
||||
DispatchTensor::Autodiff(_) => {
|
||||
panic!("Autodiff should not wrap an autodiff tensor.")
|
||||
}
|
||||
},
|
||||
_ => panic!("Requires autodiff tensor."),
|
||||
}
|
||||
}
|
||||
|
||||
fn grad_remove(tensor: &DispatchTensor, grads: &mut Self::Gradients) -> Option<DispatchTensor> {
|
||||
match &tensor {
|
||||
#[cfg(feature = "autodiff")]
|
||||
DispatchTensor::Autodiff(tensor) => match &**tensor {
|
||||
#[cfg(feature = "cpu")]
|
||||
DispatchTensor::Cpu(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad_remove(grads)
|
||||
.map(|t| DispatchTensor::Cpu(crate::BackendTensor::Float(t))),
|
||||
#[cfg(feature = "cuda")]
|
||||
DispatchTensor::Cuda(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad_remove(grads)
|
||||
.map(|t| DispatchTensor::Cuda(crate::BackendTensor::Float(t))),
|
||||
#[cfg(wgpu_metal)]
|
||||
DispatchTensor::Metal(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad_remove(grads)
|
||||
.map(|t| DispatchTensor::Metal(crate::BackendTensor::Float(t))),
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchTensor::Rocm(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad_remove(grads)
|
||||
.map(|t| DispatchTensor::Rocm(crate::BackendTensor::Float(t))),
|
||||
#[cfg(wgpu_vulkan)]
|
||||
DispatchTensor::Vulkan(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad_remove(grads)
|
||||
.map(|t| DispatchTensor::Vulkan(crate::BackendTensor::Float(t))),
|
||||
#[cfg(wgpu_webgpu)]
|
||||
DispatchTensor::WebGpu(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad_remove(grads)
|
||||
.map(|t| DispatchTensor::WebGpu(crate::BackendTensor::Float(t))),
|
||||
#[cfg(feature = "ndarray")]
|
||||
DispatchTensor::NdArray(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad_remove(grads)
|
||||
.map(|t| DispatchTensor::NdArray(crate::BackendTensor::Float(t))),
|
||||
DispatchTensor::Autodiff(_) => {
|
||||
panic!("Autodiff should not wrap an autodiff tensor.")
|
||||
}
|
||||
},
|
||||
_ => panic!("Requires autodiff tensor."),
|
||||
}
|
||||
}
|
||||
|
||||
fn grad_replace(tensor: &DispatchTensor, grads: &mut Self::Gradients, grad: DispatchTensor) {
|
||||
match &tensor {
|
||||
#[cfg(feature = "autodiff")]
|
||||
DispatchTensor::Autodiff(tensor) => match (&**tensor, grad) {
|
||||
#[cfg(feature = "cpu")]
|
||||
(DispatchTensor::Cpu(tensor), DispatchTensor::Cpu(grad)) => {
|
||||
tensor.as_autodiff().grad_replace(grads, grad.float())
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
(DispatchTensor::Cuda(tensor), DispatchTensor::Cuda(grad)) => {
|
||||
tensor.as_autodiff().grad_replace(grads, grad.float())
|
||||
}
|
||||
#[cfg(wgpu_metal)]
|
||||
(DispatchTensor::Metal(tensor), DispatchTensor::Metal(grad)) => {
|
||||
tensor.as_autodiff().grad_replace(grads, grad.float())
|
||||
}
|
||||
#[cfg(feature = "rocm")]
|
||||
(DispatchTensor::Rocm(tensor), DispatchTensor::Rocm(grad)) => {
|
||||
tensor.as_autodiff().grad_replace(grads, grad.float())
|
||||
}
|
||||
#[cfg(wgpu_vulkan)]
|
||||
(DispatchTensor::Vulkan(tensor), DispatchTensor::Vulkan(grad)) => {
|
||||
tensor.as_autodiff().grad_replace(grads, grad.float())
|
||||
}
|
||||
#[cfg(wgpu_webgpu)]
|
||||
(DispatchTensor::WebGpu(tensor), DispatchTensor::WebGpu(grad)) => {
|
||||
tensor.as_autodiff().grad_replace(grads, grad.float())
|
||||
}
|
||||
#[cfg(feature = "ndarray")]
|
||||
(DispatchTensor::NdArray(tensor), DispatchTensor::NdArray(grad)) => {
|
||||
tensor.as_autodiff().grad_replace(grads, grad.float())
|
||||
}
|
||||
(DispatchTensor::Autodiff(_), _) => {
|
||||
panic!("Autodiff should not wrap an autodiff tensor.")
|
||||
}
|
||||
(t, g) => panic!(
|
||||
"The provided tensors are not on the same backend. Got backends {t:?} and {g:?}."
|
||||
),
|
||||
},
|
||||
_ => panic!("Requires autodiff tensor."),
|
||||
}
|
||||
}
|
||||
|
||||
fn inner(tensor: DispatchTensor) -> DispatchTensor {
|
||||
match tensor {
|
||||
#[cfg(feature = "autodiff")]
|
||||
DispatchTensor::Autodiff(tensor) => match *tensor {
|
||||
#[cfg(feature = "cpu")]
|
||||
DispatchTensor::Cpu(tensor) => {
|
||||
DispatchTensor::Cpu(crate::BackendTensor::Float(tensor.autodiff().primitive))
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
DispatchTensor::Cuda(tensor) => {
|
||||
DispatchTensor::Cuda(crate::BackendTensor::Float(tensor.autodiff().primitive))
|
||||
}
|
||||
#[cfg(wgpu_metal)]
|
||||
DispatchTensor::Metal(tensor) => {
|
||||
DispatchTensor::Metal(crate::BackendTensor::Float(tensor.autodiff().primitive))
|
||||
}
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchTensor::Rocm(tensor) => {
|
||||
DispatchTensor::Rocm(crate::BackendTensor::Float(tensor.autodiff().primitive))
|
||||
}
|
||||
#[cfg(wgpu_vulkan)]
|
||||
DispatchTensor::Vulkan(tensor) => {
|
||||
DispatchTensor::Vulkan(crate::BackendTensor::Float(tensor.autodiff().primitive))
|
||||
}
|
||||
#[cfg(wgpu_webgpu)]
|
||||
DispatchTensor::WebGpu(tensor) => {
|
||||
DispatchTensor::WebGpu(crate::BackendTensor::Float(tensor.autodiff().primitive))
|
||||
}
|
||||
#[cfg(feature = "ndarray")]
|
||||
DispatchTensor::NdArray(tensor) => DispatchTensor::NdArray(
|
||||
crate::BackendTensor::Float(tensor.autodiff().primitive),
|
||||
),
|
||||
DispatchTensor::Autodiff(_) => {
|
||||
panic!("Autodiff should not wrap an autodiff tensor.")
|
||||
}
|
||||
},
|
||||
_ => panic!("Requires autodiff tensor."),
|
||||
}
|
||||
}
|
||||
|
||||
fn int_inner(tensor: DispatchTensor) -> DispatchTensor {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn bool_inner(tensor: DispatchTensor) -> DispatchTensor {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn q_inner(tensor: DispatchTensor) -> DispatchTensor {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn from_inner(tensor: DispatchTensor) -> DispatchTensor {
|
||||
match tensor {
|
||||
#[cfg(feature = "cpu")]
|
||||
DispatchTensor::Cpu(tensor) => DispatchTensor::Autodiff(Box::new(DispatchTensor::Cpu(
|
||||
crate::BackendTensor::Autodiff(Autodiff::<Cpu<f32>>::from_inner(tensor.float())),
|
||||
))),
|
||||
#[cfg(feature = "cuda")]
|
||||
DispatchTensor::Cuda(tensor) => DispatchTensor::Autodiff(Box::new(
|
||||
DispatchTensor::Cuda(crate::BackendTensor::Autodiff(
|
||||
Autodiff::<Cuda<f32>>::from_inner(tensor.float()),
|
||||
)),
|
||||
)),
|
||||
#[cfg(wgpu_metal)]
|
||||
DispatchTensor::Metal(tensor) => DispatchTensor::Autodiff(Box::new(
|
||||
DispatchTensor::Metal(crate::BackendTensor::Autodiff(
|
||||
Autodiff::<Metal<f32>>::from_inner(tensor.float()),
|
||||
)),
|
||||
)),
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchTensor::Rocm(tensor) => DispatchTensor::Autodiff(Box::new(
|
||||
DispatchTensor::Rocm(crate::BackendTensor::Autodiff(
|
||||
Autodiff::<Rocm<f32>>::from_inner(tensor.float()),
|
||||
)),
|
||||
)),
|
||||
#[cfg(wgpu_vulkan)]
|
||||
DispatchTensor::Vulkan(tensor) => DispatchTensor::Autodiff(Box::new(
|
||||
DispatchTensor::Vulkan(crate::BackendTensor::Autodiff(
|
||||
Autodiff::<Vulkan<f32>>::from_inner(tensor.float()),
|
||||
)),
|
||||
)),
|
||||
#[cfg(wgpu_webgpu)]
|
||||
DispatchTensor::WebGpu(tensor) => DispatchTensor::Autodiff(Box::new(
|
||||
DispatchTensor::WebGpu(crate::BackendTensor::Autodiff(
|
||||
Autodiff::<WebGpu<f32>>::from_inner(tensor.float()),
|
||||
)),
|
||||
)),
|
||||
#[cfg(feature = "ndarray")]
|
||||
DispatchTensor::NdArray(tensor) => DispatchTensor::Autodiff(Box::new(
|
||||
DispatchTensor::NdArray(crate::BackendTensor::Autodiff(
|
||||
Autodiff::<NdArray<f32>>::from_inner(tensor.float()),
|
||||
)),
|
||||
)),
|
||||
DispatchTensor::Autodiff(_) => {
|
||||
panic!("Autodiff should not wrap an autodiff tensor.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn int_from_inner(tensor: DispatchTensor) -> DispatchTensor {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn bool_from_inner(tensor: DispatchTensor) -> DispatchTensor {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn q_from_inner(tensor: DispatchTensor) -> DispatchTensor {
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
impl DispatchTensor {
|
||||
pub(crate) fn device(&self) -> DispatchDevice {
|
||||
match self {
|
||||
#[cfg(feature = "cpu")]
|
||||
DispatchTensor::Cpu(tensor) => DispatchDevice::Cpu(tensor.device()),
|
||||
#[cfg(feature = "cuda")]
|
||||
DispatchTensor::Cuda(tensor) => DispatchDevice::Cuda(tensor.device()),
|
||||
#[cfg(wgpu_metal)]
|
||||
DispatchTensor::Metal(tensor) => DispatchDevice::Metal(tensor.device()),
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchTensor::Rocm(tensor) => DispatchDevice::Rocm(tensor.device()),
|
||||
#[cfg(wgpu_vulkan)]
|
||||
DispatchTensor::Vulkan(tensor) => DispatchDevice::Vulkan(tensor.device()),
|
||||
#[cfg(wgpu_webgpu)]
|
||||
DispatchTensor::WebGpu(tensor) => DispatchDevice::WebGpu(tensor.device()),
|
||||
#[cfg(feature = "ndarray")]
|
||||
DispatchTensor::NdArray(tensor) => DispatchDevice::NdArray(tensor.device()),
|
||||
#[cfg(feature = "tch")]
|
||||
DispatchTensor::LibTorch(tensor) => DispatchDevice::LibTorch(tensor.device()),
|
||||
#[cfg(feature = "autodiff")]
|
||||
DispatchTensor::Autodiff(tensor) => DispatchDevice::autodiff(tensor.device()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,415 @@
|
||||
use burn_backend::{DeviceId, DeviceOps};
|
||||
|
||||
use crate::backends::*;
|
||||
|
||||
/// Represents a device for the [`Dispatch`](crate::Dispatch).
|
||||
///
|
||||
/// Each variant corresponds to a backend that the [`Dispatch`](crate::Dispatch) can dispatch operations to.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// use burn::DispatchDevice;
|
||||
///
|
||||
/// #[cfg(feature = "cpu")]
|
||||
/// let cpu_device = DispatchDevice::Cpu(Default::default());
|
||||
///
|
||||
/// #[cfg(feature = "cuda")]
|
||||
/// let cuda_device = DispatchDevice::Cuda(Default::default());
|
||||
/// ```
|
||||
#[derive(Clone, Eq)]
|
||||
pub enum DispatchDevice {
|
||||
/// The [CPU backend](Cpu) device.
|
||||
#[cfg(feature = "cpu")]
|
||||
Cpu(CpuDevice),
|
||||
|
||||
/// The [CUDA backend](Cuda) device.
|
||||
#[cfg(feature = "cuda")]
|
||||
Cuda(CudaDevice),
|
||||
|
||||
/// The [Metal backend](Metal) device (via WGPU runtime).
|
||||
#[cfg(wgpu_metal)]
|
||||
Metal(WgpuDevice),
|
||||
|
||||
/// The [ROCm backend](Rocm) device.
|
||||
#[cfg(feature = "rocm")]
|
||||
Rocm(RocmDevice),
|
||||
|
||||
/// The [Vulkan backend](Vulkan) device.
|
||||
#[cfg(wgpu_vulkan)]
|
||||
Vulkan(WgpuDevice),
|
||||
|
||||
/// The [WebGPU backend](WebGpu) device (via WGPU runtime).
|
||||
#[cfg(wgpu_webgpu)]
|
||||
WebGpu(WgpuDevice),
|
||||
|
||||
/// The [NdArray backend](NdArray) device (CPU-only).
|
||||
#[cfg(feature = "ndarray")]
|
||||
NdArray(NdArrayDevice),
|
||||
|
||||
/// The [LibTorch backend](LibTorch) device.
|
||||
#[cfg(feature = "tch")]
|
||||
LibTorch(LibTorchDevice),
|
||||
|
||||
/// The [autodiff enabled backend](Autodiff) device.
|
||||
#[cfg(feature = "autodiff")]
|
||||
Autodiff(AutodiffDevice),
|
||||
}
|
||||
|
||||
#[cfg(feature = "autodiff")]
|
||||
// This tuple struct mainly restricts users from creating Autodiff(Autodiff) devices.
|
||||
/// A wrapper that enables automatic differentiation for a [`DispatchDevice`].
|
||||
///
|
||||
/// Use [`DispatchDevice::autodiff`] to construct this type.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct AutodiffDevice(pub(crate) Box<DispatchDevice>);
|
||||
|
||||
// Useful for match in dispatch macros
|
||||
impl core::ops::Deref for AutodiffDevice {
|
||||
type Target = DispatchDevice;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl core::fmt::Debug for DispatchDevice {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
#[cfg(feature = "cpu")]
|
||||
Self::Cpu(device) => f.debug_tuple("Cpu").field(device).finish(),
|
||||
#[cfg(feature = "cuda")]
|
||||
Self::Cuda(device) => f.debug_tuple("Cuda").field(device).finish(),
|
||||
#[cfg(wgpu_metal)]
|
||||
Self::Metal(device) => f.debug_tuple("Metal").field(device).finish(),
|
||||
#[cfg(feature = "rocm")]
|
||||
Self::Rocm(device) => f.debug_tuple("Rocm").field(device).finish(),
|
||||
#[cfg(wgpu_vulkan)]
|
||||
Self::Vulkan(device) => f.debug_tuple("Vulkan").field(device).finish(),
|
||||
#[cfg(wgpu_webgpu)]
|
||||
Self::WebGpu(device) => f.debug_tuple("WebGpu").field(device).finish(),
|
||||
#[cfg(feature = "ndarray")]
|
||||
Self::NdArray(device) => f.debug_tuple("NdArray").field(device).finish(),
|
||||
#[cfg(feature = "tch")]
|
||||
Self::LibTorch(device) => f.debug_tuple("LibTorch").field(device).finish(),
|
||||
#[cfg(feature = "autodiff")]
|
||||
// Format without `AutodiffDevice` wrapper
|
||||
Self::Autodiff(device) => f.debug_tuple("Autodiff").field(&device.0).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DispatchDevice {
|
||||
#[allow(unreachable_code)]
|
||||
fn default() -> Self {
|
||||
// TODO: which priority?
|
||||
|
||||
#[cfg(feature = "cpu")]
|
||||
return Self::Cpu(CpuDevice);
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
return Self::Cuda(CudaDevice::default());
|
||||
|
||||
#[cfg(wgpu_metal)]
|
||||
return Self::Metal(burn_wgpu::WgpuDevice::default());
|
||||
|
||||
#[cfg(feature = "rocm")]
|
||||
return Self::Rocm(RocmDevice::default());
|
||||
|
||||
#[cfg(wgpu_vulkan)]
|
||||
return Self::Vulkan(burn_wgpu::WgpuDevice::default());
|
||||
|
||||
#[cfg(wgpu_webgpu)]
|
||||
return Self::WebGpu(burn_wgpu::WgpuDevice::default());
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
return Self::NdArray(NdArrayDevice::default());
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
return Self::LibTorch(LibTorchDevice::default());
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for DispatchDevice {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
// If both are Autodiff, compare the inner devices
|
||||
#[cfg(feature = "autodiff")]
|
||||
(DispatchDevice::Autodiff(a), DispatchDevice::Autodiff(b)) => a == b,
|
||||
// If one is Autodiff, compare it to the raw device
|
||||
#[cfg(feature = "autodiff")]
|
||||
(DispatchDevice::Autodiff(a), b) => a.0.as_ref() == b,
|
||||
#[cfg(feature = "autodiff")]
|
||||
(a, DispatchDevice::Autodiff(b)) => a == b.0.as_ref(),
|
||||
#[cfg(feature = "cpu")]
|
||||
(Self::Cpu(a), Self::Cpu(b)) => a == b,
|
||||
#[cfg(feature = "cuda")]
|
||||
(Self::Cuda(a), Self::Cuda(b)) => a == b,
|
||||
#[cfg(wgpu_metal)]
|
||||
(Self::Metal(a), Self::Metal(b)) => a == b,
|
||||
#[cfg(feature = "rocm")]
|
||||
(Self::Rocm(a), Self::Rocm(b)) => a == b,
|
||||
#[cfg(wgpu_vulkan)]
|
||||
(Self::Vulkan(a), Self::Vulkan(b)) => a == b,
|
||||
#[cfg(wgpu_webgpu)]
|
||||
(Self::WebGpu(a), Self::WebGpu(b)) => a == b,
|
||||
#[cfg(feature = "ndarray")]
|
||||
(Self::NdArray(a), Self::NdArray(b)) => a == b,
|
||||
#[cfg(feature = "tch")]
|
||||
(Self::LibTorch(a), Self::LibTorch(b)) => a == b,
|
||||
#[allow(unreachable_patterns)]
|
||||
(_, _) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Base multiplier to avoid type_id clashes between backends.
|
||||
/// Limits the number of device types per backend, but this is a sensible limit.
|
||||
const TYPE_ID_BASE: u16 = 10;
|
||||
|
||||
impl DispatchDevice {
|
||||
#[cfg(feature = "autodiff")]
|
||||
/// Creates a new [`DispatchDevice`] with [automatic differentiation](Autodiff) enabled.
|
||||
pub fn autodiff(device: impl Into<DispatchDevice>) -> DispatchDevice {
|
||||
let device = device.into();
|
||||
DispatchDevice::Autodiff(AutodiffDevice(Box::new(device)))
|
||||
}
|
||||
|
||||
/// Returns a unique number per variant to encode into type_id.
|
||||
fn backend_id(&self) -> BackendId {
|
||||
match self {
|
||||
#[cfg(feature = "cpu")]
|
||||
Self::Cpu(_) => BackendId::Cpu,
|
||||
#[cfg(feature = "cuda")]
|
||||
Self::Cuda(_) => BackendId::Cuda,
|
||||
#[cfg(wgpu_metal)]
|
||||
Self::Metal(_) => BackendId::Metal,
|
||||
#[cfg(feature = "rocm")]
|
||||
Self::Rocm(_) => BackendId::Rocm,
|
||||
#[cfg(wgpu_vulkan)]
|
||||
Self::Vulkan(_) => BackendId::Vulkan,
|
||||
#[cfg(wgpu_webgpu)]
|
||||
Self::WebGpu(_) => BackendId::WebGpu,
|
||||
#[cfg(feature = "ndarray")]
|
||||
Self::NdArray(_) => BackendId::NdArray,
|
||||
#[cfg(feature = "tch")]
|
||||
Self::LibTorch(_) => BackendId::LibTorch,
|
||||
#[cfg(feature = "autodiff")]
|
||||
Self::Autodiff(device) => device.0.backend_id(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode variant ID and backend type ID into a unique `type_id`.
|
||||
fn encode_type_id(&self, backend_type_id: u16) -> u16 {
|
||||
u16::from(self.backend_id()) * TYPE_ID_BASE + backend_type_id
|
||||
}
|
||||
|
||||
/// Decode an encoded `type_id` into variant ID and backend type ID.
|
||||
fn decode_type_id(type_id: u16) -> (BackendId, u16) {
|
||||
let variant = type_id / TYPE_ID_BASE;
|
||||
let backend_type_id = type_id % TYPE_ID_BASE;
|
||||
(
|
||||
BackendId::try_from(variant).expect("Unknown DispatchDevice variant"),
|
||||
backend_type_id,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u16)]
|
||||
enum BackendId {
|
||||
#[cfg(feature = "cpu")]
|
||||
Cpu = 0,
|
||||
#[cfg(feature = "cuda")]
|
||||
Cuda = 1,
|
||||
#[cfg(wgpu_metal)]
|
||||
Metal = 2,
|
||||
#[cfg(feature = "rocm")]
|
||||
Rocm = 3,
|
||||
#[cfg(wgpu_vulkan)]
|
||||
Vulkan = 4,
|
||||
#[cfg(wgpu_webgpu)]
|
||||
WebGpu = 5,
|
||||
#[cfg(feature = "ndarray")]
|
||||
NdArray = 6,
|
||||
#[cfg(feature = "tch")]
|
||||
LibTorch = 7,
|
||||
}
|
||||
|
||||
impl From<BackendId> for u16 {
|
||||
fn from(variant: BackendId) -> Self {
|
||||
variant as u16
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<u16> for BackendId {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(value: u16) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
#[cfg(feature = "cpu")]
|
||||
0 => Ok(Self::Cpu),
|
||||
#[cfg(feature = "cuda")]
|
||||
1 => Ok(Self::Cuda),
|
||||
#[cfg(wgpu_metal)]
|
||||
2 => Ok(Self::Metal),
|
||||
#[cfg(feature = "rocm")]
|
||||
3 => Ok(Self::Rocm),
|
||||
#[cfg(wgpu_vulkan)]
|
||||
4 => Ok(Self::Vulkan),
|
||||
#[cfg(wgpu_webgpu)]
|
||||
5 => Ok(Self::WebGpu),
|
||||
#[cfg(feature = "ndarray")]
|
||||
6 => Ok(Self::NdArray),
|
||||
#[cfg(feature = "tch")]
|
||||
7 => Ok(Self::LibTorch),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DeviceOps for DispatchDevice {
|
||||
fn inner(&self) -> &Self {
|
||||
match self {
|
||||
#[cfg(feature = "autodiff")]
|
||||
DispatchDevice::Autodiff(device) => &device.0,
|
||||
device => device,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl burn_std::device::Device for DispatchDevice {
|
||||
fn from_id(mut device_id: DeviceId) -> Self {
|
||||
let (dispatch_id, backend_type_id) = Self::decode_type_id(device_id.type_id);
|
||||
device_id.type_id = backend_type_id;
|
||||
|
||||
match dispatch_id {
|
||||
#[cfg(feature = "cpu")]
|
||||
BackendId::Cpu => Self::Cpu(CpuDevice::from_id(device_id)),
|
||||
#[cfg(feature = "cuda")]
|
||||
BackendId::Cuda => Self::Cuda(CudaDevice::from_id(device_id)),
|
||||
#[cfg(wgpu_metal)]
|
||||
BackendId::Metal => Self::Metal(WgpuDevice::from_id(device_id)),
|
||||
#[cfg(feature = "rocm")]
|
||||
BackendId::Rocm => Self::Rocm(RocmDevice::from_id(device_id)),
|
||||
#[cfg(wgpu_vulkan)]
|
||||
BackendId::Vulkan => Self::Vulkan(WgpuDevice::from_id(device_id)),
|
||||
#[cfg(wgpu_webgpu)]
|
||||
BackendId::WebGpu => Self::WebGpu(WgpuDevice::from_id(device_id)),
|
||||
#[cfg(feature = "ndarray")]
|
||||
BackendId::NdArray => Self::NdArray(NdArrayDevice::from_id(device_id)),
|
||||
#[cfg(feature = "tch")]
|
||||
BackendId::LibTorch => Self::LibTorch(LibTorchDevice::from_id(device_id)),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_id(&self) -> DeviceId {
|
||||
let mut device_id = match self {
|
||||
#[cfg(feature = "cpu")]
|
||||
Self::Cpu(device) => device.to_id(),
|
||||
#[cfg(feature = "cuda")]
|
||||
Self::Cuda(device) => device.to_id(),
|
||||
#[cfg(wgpu_metal)]
|
||||
Self::Metal(device) => device.to_id(),
|
||||
#[cfg(feature = "rocm")]
|
||||
Self::Rocm(device) => device.to_id(),
|
||||
#[cfg(wgpu_vulkan)]
|
||||
Self::Vulkan(device) => device.to_id(),
|
||||
#[cfg(wgpu_webgpu)]
|
||||
Self::WebGpu(device) => device.to_id(),
|
||||
#[cfg(feature = "ndarray")]
|
||||
Self::NdArray(device) => device.to_id(),
|
||||
#[cfg(feature = "tch")]
|
||||
Self::LibTorch(device) => device.to_id(),
|
||||
#[cfg(feature = "autodiff")]
|
||||
Self::Autodiff(device) => device.0.to_id(),
|
||||
};
|
||||
device_id.type_id = self.encode_type_id(device_id.type_id);
|
||||
device_id
|
||||
}
|
||||
|
||||
fn device_count(type_id: u16) -> usize {
|
||||
let (dispatch_id, backend_type_id) = Self::decode_type_id(type_id);
|
||||
match dispatch_id {
|
||||
#[cfg(feature = "cpu")]
|
||||
BackendId::Cpu => CpuDevice::device_count(backend_type_id),
|
||||
#[cfg(feature = "cuda")]
|
||||
BackendId::Cuda => CudaDevice::device_count(backend_type_id),
|
||||
#[cfg(wgpu_metal)]
|
||||
BackendId::Metal => WgpuDevice::device_count(backend_type_id),
|
||||
#[cfg(feature = "rocm")]
|
||||
BackendId::Rocm => RocmDevice::device_count(backend_type_id),
|
||||
#[cfg(wgpu_vulkan)]
|
||||
BackendId::Vulkan => WgpuDevice::device_count(backend_type_id),
|
||||
#[cfg(wgpu_webgpu)]
|
||||
BackendId::WebGpu => WgpuDevice::device_count(backend_type_id),
|
||||
#[cfg(feature = "ndarray")]
|
||||
BackendId::NdArray => NdArrayDevice::device_count(backend_type_id),
|
||||
#[cfg(feature = "tch")]
|
||||
BackendId::LibTorch => LibTorchDevice::device_count(backend_type_id),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cpu")]
|
||||
impl From<CpuDevice> for DispatchDevice {
|
||||
fn from(device: CpuDevice) -> Self {
|
||||
DispatchDevice::Cpu(device)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
impl From<CudaDevice> for DispatchDevice {
|
||||
fn from(device: CudaDevice) -> Self {
|
||||
DispatchDevice::Cuda(device)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(wgpu_metal)]
|
||||
impl From<WgpuDevice> for DispatchDevice {
|
||||
fn from(device: WgpuDevice) -> Self {
|
||||
DispatchDevice::Metal(device)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "rocm")]
|
||||
impl From<RocmDevice> for DispatchDevice {
|
||||
fn from(device: RocmDevice) -> Self {
|
||||
DispatchDevice::Rocm(device)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(wgpu_vulkan)]
|
||||
impl From<WgpuDevice> for DispatchDevice {
|
||||
fn from(device: WgpuDevice) -> Self {
|
||||
DispatchDevice::Vulkan(device)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(wgpu_webgpu)]
|
||||
impl From<WgpuDevice> for DispatchDevice {
|
||||
fn from(device: WgpuDevice) -> Self {
|
||||
DispatchDevice::WebGpu(device)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
impl From<NdArrayDevice> for DispatchDevice {
|
||||
fn from(device: NdArrayDevice) -> Self {
|
||||
DispatchDevice::NdArray(device)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
impl From<LibTorchDevice> for DispatchDevice {
|
||||
fn from(device: LibTorchDevice) -> Self {
|
||||
DispatchDevice::LibTorch(device)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
impl From<LibTorchDevice> for DispatchDevice {
|
||||
fn from(device: LibTorchDevice) -> Self {
|
||||
DispatchDevice::LibTorch(device)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
#![recursion_limit = "138"]
|
||||
|
||||
//! Burn multi-backend dispatch.
|
||||
//!
|
||||
//! # Available Backends
|
||||
//!
|
||||
//! The dispatch backend supports the following variants, each enabled via cargo features:
|
||||
//!
|
||||
//! | Backend | Feature | Description |
|
||||
//! |------------|------------|-------------|
|
||||
//! | `Cpu` | `cpu` | Rust CPU backend (MLIR + LLVM) |
|
||||
//! | `Cuda` | `cuda` | NVIDIA CUDA backend |
|
||||
//! | `Metal` | `metal` | Apple Metal backend via `wgpu` (MSL) |
|
||||
//! | `Rocm` | `rocm` | AMD ROCm backend |
|
||||
//! | `Vulkan` | `vulkan` | Vulkan backend via `wgpu` (SPIR-V) |
|
||||
//! | `WebGpu` | `webgpu` | WebGPU backend via `wgpu` (WGSL) |
|
||||
//! | `NdArray` | `ndarray` | Pure Rust CPU backend using `ndarray` |
|
||||
//! | `LibTorch` | `tch` | Libtorch backend via `tch` |
|
||||
//! | `Autodiff` | `autodiff` | Autodiff-enabled backend (used in combination with any of the backends above) |
|
||||
//!
|
||||
//! **Note:** WGPU-based backends (`metal`, `vulkan`, `webgpu`) are mutually exclusive.
|
||||
//! All other backends can be combined freely.
|
||||
//!
|
||||
//! ## WGPU Backend Exclusivity
|
||||
//!
|
||||
//! The WGPU-based backends (`metal`, `vulkan`, `webgpu`) are **mutually exclusive** due to
|
||||
//! the current automatic compile, which can only select one target at a time.
|
||||
//!
|
||||
//! Enable only **one** of these features in your `Cargo.toml`:
|
||||
//! - `metal`
|
||||
//! - `vulkan`
|
||||
//! - `webgpu`
|
||||
//!
|
||||
//! If multiple WGPU features are enabled, the build script will emit a warning and **disable all WGPU
|
||||
//! backends** to prevent unintended behavior.
|
||||
|
||||
#[cfg(not(any(
|
||||
feature = "cpu",
|
||||
feature = "cuda",
|
||||
wgpu_metal,
|
||||
feature = "rocm",
|
||||
wgpu_vulkan,
|
||||
wgpu_webgpu,
|
||||
feature = "ndarray",
|
||||
feature = "tch",
|
||||
)))]
|
||||
compile_error!("At least one backend feature must be enabled.");
|
||||
|
||||
#[macro_use]
|
||||
mod macros;
|
||||
|
||||
mod backend;
|
||||
mod device;
|
||||
mod ops;
|
||||
mod tensor;
|
||||
|
||||
pub use backend::*;
|
||||
pub use device::*;
|
||||
pub use tensor::*;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
/// Backends and devices used.
|
||||
pub(crate) mod backends {
|
||||
#[cfg(feature = "autodiff")]
|
||||
pub use burn_autodiff::Autodiff;
|
||||
|
||||
#[cfg(feature = "cpu")]
|
||||
pub use burn_cpu::{Cpu, CpuDevice};
|
||||
#[cfg(feature = "cuda")]
|
||||
pub use burn_cuda::{Cuda, CudaDevice};
|
||||
#[cfg(feature = "rocm")]
|
||||
pub use burn_rocm::{Rocm, RocmDevice};
|
||||
#[cfg(wgpu_metal)]
|
||||
pub use burn_wgpu::Metal;
|
||||
#[cfg(wgpu_vulkan)]
|
||||
pub use burn_wgpu::Vulkan;
|
||||
#[cfg(wgpu_webgpu)]
|
||||
pub use burn_wgpu::WebGpu;
|
||||
#[cfg(any(wgpu_metal, wgpu_vulkan, wgpu_webgpu))]
|
||||
pub use burn_wgpu::WgpuDevice;
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
pub use burn_ndarray::{NdArray, NdArrayDevice};
|
||||
#[cfg(feature = "tch")]
|
||||
pub use burn_tch::{LibTorch, LibTorchDevice};
|
||||
}
|
||||
1197
crates/stable-diffusion-burn/burn-crates/burn-dispatch/src/macros.rs
Normal file
1197
crates/stable-diffusion-burn/burn-crates/burn-dispatch/src/macros.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,50 @@
|
||||
use burn_backend::{Scalar, ops::ActivationOps, tensor::FloatTensor};
|
||||
|
||||
use crate::Dispatch;
|
||||
use crate::backends::*;
|
||||
|
||||
impl ActivationOps<Self> for Dispatch {
|
||||
fn leaky_relu(tensor: FloatTensor<Self>, negative_slope: Scalar) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::leaky_relu(tensor, negative_slope) => Float)
|
||||
}
|
||||
|
||||
fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::relu(tensor) => Float)
|
||||
}
|
||||
|
||||
fn relu_backward(output: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
binary_float!((output, float), (grad, float), |output, grad| B::relu_backward(output, grad) => Float)
|
||||
}
|
||||
|
||||
fn gelu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::gelu(tensor) => Float)
|
||||
}
|
||||
|
||||
fn prelu(tensor: FloatTensor<Self>, alpha: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
binary_float!((tensor, float), (alpha, float), |tensor, alpha| B::prelu(tensor, alpha) => Float)
|
||||
}
|
||||
|
||||
fn gelu_backward(x: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
binary_float!((x, float), (grad, float), |x, grad| B::gelu_backward(x, grad) => Float)
|
||||
}
|
||||
|
||||
fn sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::sigmoid(tensor) => Float)
|
||||
}
|
||||
|
||||
fn sigmoid_backward(output: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
binary_float!((output, float), (grad, float), |output, grad| B::sigmoid_backward(output, grad) => Float)
|
||||
}
|
||||
|
||||
fn hard_sigmoid(tensor: FloatTensor<Self>, alpha: Scalar, beta: Scalar) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::hard_sigmoid(tensor, alpha, beta) => Float)
|
||||
}
|
||||
|
||||
fn log_sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::log_sigmoid(tensor) => Float)
|
||||
}
|
||||
|
||||
fn log_sigmoid_backward(x: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
binary_float!((x, float), (grad, float), |x, grad| B::log_sigmoid_backward(x, grad) => Float)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
use burn_backend::{
|
||||
ExecutionError, Scalar, TensorData,
|
||||
ops::BoolTensorOps,
|
||||
tensor::{BoolTensor, FloatTensor, IntTensor},
|
||||
};
|
||||
use burn_std::{Shape, Slice};
|
||||
|
||||
use crate::backends::*;
|
||||
use crate::{Dispatch, DispatchDevice};
|
||||
|
||||
impl BoolTensorOps<Self> for Dispatch {
|
||||
fn bool_empty(shape: Shape, device: &DispatchDevice) -> BoolTensor<Self> {
|
||||
creation_op!(Bool, device, |device| B::bool_empty(shape, device))
|
||||
}
|
||||
|
||||
fn bool_zeros(shape: Shape, device: &DispatchDevice) -> BoolTensor<Self> {
|
||||
creation_op!(Bool, device, |device| B::bool_zeros(shape, device))
|
||||
}
|
||||
|
||||
fn bool_ones(shape: Shape, device: &DispatchDevice) -> BoolTensor<Self> {
|
||||
creation_op!(Bool, device, |device| B::bool_ones(shape, device))
|
||||
}
|
||||
|
||||
async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
|
||||
unary_op!(tensor, bool, |tensor| B::bool_into_data(tensor).await)
|
||||
}
|
||||
|
||||
fn bool_from_data(data: TensorData, device: &DispatchDevice) -> BoolTensor<Self> {
|
||||
creation_op!(Bool, device, |device| B::bool_from_data(data, device))
|
||||
}
|
||||
|
||||
fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
|
||||
unary_op!(tensor, bool, |tensor| B::bool_into_int(tensor) => Int)
|
||||
}
|
||||
|
||||
fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_op!(tensor, bool, |tensor| B::bool_into_float(tensor) => Float)
|
||||
}
|
||||
|
||||
fn bool_device(tensor: &BoolTensor<Self>) -> DispatchDevice {
|
||||
tensor.device()
|
||||
}
|
||||
|
||||
fn bool_to_device(tensor: BoolTensor<Self>, device: &DispatchDevice) -> BoolTensor<Self> {
|
||||
to_device!(
|
||||
Bool,
|
||||
bool,
|
||||
tensor,
|
||||
device,
|
||||
bool_to_device,
|
||||
|inner, device| {
|
||||
let data =
|
||||
burn_backend::read_sync(B1::bool_into_data(inner)).expect("Should read data");
|
||||
B2::bool_from_data(data, device)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
|
||||
unary_op!(tensor, bool, |tensor| B::bool_reshape(tensor, shape) => Bool)
|
||||
}
|
||||
|
||||
fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
|
||||
unary_op!(tensor, bool, |tensor| B::bool_slice(tensor, slices) => Bool)
|
||||
}
|
||||
|
||||
fn bool_slice_assign(
|
||||
tensor: BoolTensor<Self>,
|
||||
slices: &[Slice],
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
binary_op!((tensor, bool), (value, bool), |tensor, value| B::bool_slice_assign(tensor, slices, value) => Bool)
|
||||
}
|
||||
|
||||
fn bool_mask_where(
|
||||
tensor: BoolTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(tensor, bool), (mask, bool), (value, bool)], => Bool,
|
||||
B::bool_mask_where(tensor, mask, value)
|
||||
)
|
||||
}
|
||||
|
||||
fn bool_mask_fill(
|
||||
tensor: BoolTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: Scalar,
|
||||
) -> BoolTensor<Self> {
|
||||
binary_op!((tensor, bool), (mask, bool), |tensor, mask| B::bool_mask_fill(tensor, mask, value) => Bool)
|
||||
}
|
||||
|
||||
fn bool_gather(
|
||||
dim: usize,
|
||||
tensor: BoolTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
binary_op!((tensor, bool), (indices, int), |tensor, indices| B::bool_gather(dim, tensor, indices) => Bool)
|
||||
}
|
||||
|
||||
fn bool_scatter_or(
|
||||
dim: usize,
|
||||
tensor: BoolTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(tensor, bool), (indices, int), (value, bool)], => Bool,
|
||||
B::bool_scatter_or(dim, tensor, indices, value)
|
||||
)
|
||||
}
|
||||
|
||||
fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_equal(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
unary_op!(lhs, bool, |lhs| B::bool_equal_elem(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
unary_op!(tensor, bool, |tensor| B::bool_not(tensor) => Bool)
|
||||
}
|
||||
|
||||
fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_and(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_or(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
|
||||
unary_op!(tensor, bool, |tensor| B::bool_swap_dims(tensor, dim1, dim2) => Bool)
|
||||
}
|
||||
|
||||
fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
|
||||
unary_op!(tensor, bool, |tensor| B::bool_permute(tensor, axes) => Bool)
|
||||
}
|
||||
|
||||
fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
|
||||
unary_op!(tensor, bool, |tensor| B::bool_flip(tensor, axes) => Bool)
|
||||
}
|
||||
|
||||
fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
|
||||
unary_op!(tensor, bool, |tensor| B::bool_expand(tensor, shape) => Bool)
|
||||
}
|
||||
|
||||
fn bool_unfold(
|
||||
tensor: BoolTensor<Self>,
|
||||
dim: usize,
|
||||
size: usize,
|
||||
step: usize,
|
||||
) -> BoolTensor<Self> {
|
||||
unary_op!(tensor, bool, |tensor| B::bool_unfold(tensor, dim, size, step) => Bool)
|
||||
}
|
||||
|
||||
fn bool_select(
|
||||
tensor: BoolTensor<Self>,
|
||||
dim: usize,
|
||||
indices: IntTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
binary_op!((tensor, bool), (indices, int), |tensor, indices| B::bool_select(tensor, dim, indices) => Bool)
|
||||
}
|
||||
|
||||
fn bool_select_or(
|
||||
tensor: BoolTensor<Self>,
|
||||
dim: usize,
|
||||
indices: IntTensor<Self>,
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(tensor, bool), (indices, int), (value, bool)], => Bool,
|
||||
B::bool_select_or(tensor, dim, indices, value)
|
||||
)
|
||||
}
|
||||
|
||||
fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
|
||||
unary_op!(tensor, bool, |tensor| B::bool_repeat_dim(tensor, dim, times) => Bool)
|
||||
}
|
||||
|
||||
fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
|
||||
vec_op!(tensors, bool, |tensors| B::bool_cat(tensors, dim) => Bool)
|
||||
}
|
||||
|
||||
fn bool_not_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_not_equal(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn bool_not_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
unary_op!(lhs, bool, |lhs| B::bool_not_equal_elem(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn bool_xor(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_xor(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn bool_transpose(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
unary_op!(tensor, bool, |tensor| B::bool_transpose(tensor) => Bool)
|
||||
}
|
||||
|
||||
fn bool_any(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
unary_op!(tensor, bool, |tensor| B::bool_any(tensor) => Bool)
|
||||
}
|
||||
|
||||
fn bool_any_dim(tensor: BoolTensor<Self>, dim: usize) -> BoolTensor<Self> {
|
||||
unary_op!(tensor, bool, |tensor| B::bool_any_dim(tensor, dim) => Bool)
|
||||
}
|
||||
|
||||
fn bool_all(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
unary_op!(tensor, bool, |tensor| B::bool_all(tensor) => Bool)
|
||||
}
|
||||
|
||||
fn bool_all_dim(tensor: BoolTensor<Self>, dim: usize) -> BoolTensor<Self> {
|
||||
unary_op!(tensor, bool, |tensor| B::bool_all_dim(tensor, dim) => Bool)
|
||||
}
|
||||
|
||||
async fn bool_argwhere(tensor: BoolTensor<Self>) -> IntTensor<Self> {
|
||||
unary_op!(tensor, bool, |tensor| B::bool_argwhere(tensor).await => Int)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,503 @@
|
||||
use burn_backend::{
|
||||
ExecutionError, Scalar, TensorData,
|
||||
ops::IntTensorOps,
|
||||
tensor::{BoolTensor, FloatTensor, IntTensor},
|
||||
};
|
||||
use burn_std::{IntDType, Shape, Slice};
|
||||
|
||||
use crate::backends::*;
|
||||
use crate::{Dispatch, DispatchDevice};
|
||||
|
||||
impl IntTensorOps<Self> for Dispatch {
|
||||
fn int_empty(shape: Shape, device: &DispatchDevice, dtype: IntDType) -> IntTensor<Self> {
|
||||
creation_op!(Int, device, |device| B::int_empty(shape, device, dtype))
|
||||
}
|
||||
|
||||
async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {
|
||||
unary_op!(tensor, int, |tensor| B::int_into_data(tensor).await)
|
||||
}
|
||||
|
||||
fn int_from_data(data: TensorData, device: &DispatchDevice) -> IntTensor<Self> {
|
||||
creation_op!(Int, device, |device| B::int_from_data(data, device))
|
||||
}
|
||||
|
||||
fn int_device(tensor: &IntTensor<Self>) -> DispatchDevice {
|
||||
tensor.device()
|
||||
}
|
||||
|
||||
fn int_to_device(tensor: IntTensor<Self>, device: &DispatchDevice) -> IntTensor<Self> {
|
||||
to_device!(Int, int, tensor, device, int_to_device, |inner, device| {
|
||||
let data = burn_backend::read_sync(B1::int_into_data(inner)).expect("Should read data");
|
||||
B2::int_from_data(data, device)
|
||||
})
|
||||
}
|
||||
|
||||
fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_reshape(tensor, shape) => Int)
|
||||
}
|
||||
|
||||
fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_slice(tensor, slices) => Int)
|
||||
}
|
||||
|
||||
fn int_slice_assign(
|
||||
tensor: IntTensor<Self>,
|
||||
slices: &[Slice],
|
||||
value: IntTensor<Self>,
|
||||
) -> IntTensor<Self> {
|
||||
binary_op!((tensor, int), (value, int), |tensor, value| B::int_slice_assign(tensor, slices, value) => Int)
|
||||
}
|
||||
|
||||
fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_into_float(tensor) => Float)
|
||||
}
|
||||
|
||||
fn int_mask_where(
|
||||
tensor: IntTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: IntTensor<Self>,
|
||||
) -> IntTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(tensor, int), (mask, bool), (value, int)], => Int,
|
||||
B::int_mask_where(tensor, mask, value)
|
||||
)
|
||||
}
|
||||
|
||||
fn int_mask_fill(
|
||||
tensor: IntTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: Scalar,
|
||||
) -> IntTensor<Self> {
|
||||
binary_op!((tensor, int), (mask, bool), |tensor, mask| B::int_mask_fill(tensor, mask, value) => Int)
|
||||
}
|
||||
|
||||
fn int_gather(
|
||||
dim: usize,
|
||||
tensor: IntTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> IntTensor<Self> {
|
||||
binary_op!((tensor, int), (indices, int), |tensor, indices| B::int_gather(dim, tensor, indices) => Int)
|
||||
}
|
||||
|
||||
fn int_scatter_add(
|
||||
dim: usize,
|
||||
tensor: IntTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
value: IntTensor<Self>,
|
||||
) -> IntTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(tensor, int), (indices, int), (value, int)], => Int,
|
||||
B::int_scatter_add(dim, tensor, indices, value)
|
||||
)
|
||||
}
|
||||
|
||||
fn int_select(
|
||||
tensor: IntTensor<Self>,
|
||||
dim: usize,
|
||||
indices: IntTensor<Self>,
|
||||
) -> IntTensor<Self> {
|
||||
binary_op!((tensor, int), (indices, int), |tensor, indices| B::int_select(tensor, dim, indices) => Int)
|
||||
}
|
||||
|
||||
fn int_select_add(
|
||||
tensor: IntTensor<Self>,
|
||||
dim: usize,
|
||||
indices: IntTensor<Self>,
|
||||
value: IntTensor<Self>,
|
||||
) -> IntTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(tensor, int), (indices, int), (value, int)], => Int,
|
||||
B::int_select_add(tensor, dim, indices, value)
|
||||
)
|
||||
}
|
||||
|
||||
fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_equal(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn int_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
unary_op!(lhs, int, |lhs| B::int_equal_elem(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_greater(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn int_greater_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
unary_op!(lhs, int, |lhs| B::int_greater_elem(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_greater_equal(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
unary_op!(lhs, int, |lhs| B::int_greater_equal_elem(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_lower(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn int_lower_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
unary_op!(lhs, int, |lhs| B::int_lower_elem(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_lower_equal(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
unary_op!(lhs, int, |lhs| B::int_lower_equal_elem(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_add(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn int_add_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
unary_op!(lhs, int, |lhs| B::int_add_scalar(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_sub(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn int_sub_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
unary_op!(lhs, int, |lhs| B::int_sub_scalar(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_mul(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn int_mul_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
unary_op!(lhs, int, |lhs| B::int_mul_scalar(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_div(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn int_div_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
unary_op!(lhs, int, |lhs| B::int_div_scalar(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_remainder(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
unary_op!(lhs, int, |lhs| B::int_remainder_scalar(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_matmul(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_sum(tensor) => Int)
|
||||
}
|
||||
|
||||
fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_sum_dim(tensor, dim) => Int)
|
||||
}
|
||||
|
||||
fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_prod(tensor) => Int)
|
||||
}
|
||||
|
||||
fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_prod_dim(tensor, dim) => Int)
|
||||
}
|
||||
|
||||
fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_mean_dim(tensor, dim) => Int)
|
||||
}
|
||||
|
||||
fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_cumsum(tensor, dim) => Int)
|
||||
}
|
||||
|
||||
fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_cumprod(tensor, dim) => Int)
|
||||
}
|
||||
|
||||
fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_cummin(tensor, dim) => Int)
|
||||
}
|
||||
|
||||
fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_cummax(tensor, dim) => Int)
|
||||
}
|
||||
|
||||
fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_argmax(tensor, dim) => Int)
|
||||
}
|
||||
|
||||
fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_argmin(tensor, dim) => Int)
|
||||
}
|
||||
|
||||
fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_abs(tensor) => Int)
|
||||
}
|
||||
|
||||
fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_swap_dims(tensor, dim1, dim2) => Int)
|
||||
}
|
||||
|
||||
fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_permute(tensor, axes) => Int)
|
||||
}
|
||||
|
||||
fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_flip(tensor, axes) => Int)
|
||||
}
|
||||
|
||||
fn int_random(
|
||||
shape: Shape,
|
||||
distribution: burn_backend::Distribution,
|
||||
device: &DispatchDevice,
|
||||
) -> IntTensor<Self> {
|
||||
creation_op!(Int, device, |device| {
|
||||
B::int_random(shape, distribution, device)
|
||||
})
|
||||
}
|
||||
|
||||
fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_expand(tensor, shape) => Int)
|
||||
}
|
||||
|
||||
fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_and(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
unary_op!(lhs, int, |lhs| B::bitwise_and_scalar(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_or(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
unary_op!(lhs, int, |lhs| B::bitwise_or_scalar(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_xor(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
unary_op!(lhs, int, |lhs| B::bitwise_xor_scalar(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::bitwise_not(tensor) => Int)
|
||||
}
|
||||
|
||||
fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_left_shift(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
unary_op!(lhs, int, |lhs| B::bitwise_left_shift_scalar(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_right_shift(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
unary_op!(lhs, int, |lhs| B::bitwise_right_shift_scalar(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_cast(tensor, dtype) => Int)
|
||||
}
|
||||
|
||||
fn int_unfold(
|
||||
tensor: IntTensor<Self>,
|
||||
dim: usize,
|
||||
size: usize,
|
||||
step: usize,
|
||||
) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_unfold(tensor, dim, size, step) => Int)
|
||||
}
|
||||
|
||||
fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_repeat_dim(tensor, dim, times) => Int)
|
||||
}
|
||||
|
||||
fn int_cat(tensors: Vec<IntTensor<Self>>, dim: usize) -> IntTensor<Self> {
|
||||
vec_op!(tensors, int, |tensors| B::int_cat(tensors, dim) => Int)
|
||||
}
|
||||
|
||||
fn int_not_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_not_equal(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn int_not_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
unary_op!(lhs, int, |lhs| B::int_not_equal_elem(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn int_powi(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_powi(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn int_powf(lhs: IntTensor<Self>, rhs: FloatTensor<Self>) -> IntTensor<Self> {
|
||||
binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_powf(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn int_powi_scalar_impl(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
unary_op!(lhs, int, |lhs| B::int_powi_scalar_impl(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn int_powf_scalar_impl(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
unary_op!(lhs, int, |lhs| B::int_powf_scalar_impl(lhs, rhs) => Int)
|
||||
}
|
||||
|
||||
fn int_clamp_min(tensor: IntTensor<Self>, min: Scalar) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_clamp_min(tensor, min) => Int)
|
||||
}
|
||||
|
||||
fn int_clamp_max(tensor: IntTensor<Self>, max: Scalar) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_clamp_max(tensor, max) => Int)
|
||||
}
|
||||
|
||||
fn int_clamp(tensor: IntTensor<Self>, min: Scalar, max: Scalar) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_clamp(tensor, min, max) => Int)
|
||||
}
|
||||
|
||||
fn int_neg(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_neg(tensor) => Int)
|
||||
}
|
||||
|
||||
fn int_zeros(shape: Shape, device: &DispatchDevice, dtype: IntDType) -> IntTensor<Self> {
|
||||
creation_op!(Int, device, |device| B::int_zeros(shape, device, dtype))
|
||||
}
|
||||
|
||||
fn int_ones(shape: Shape, device: &DispatchDevice, dtype: IntDType) -> IntTensor<Self> {
|
||||
creation_op!(Int, device, |device| B::int_ones(shape, device, dtype))
|
||||
}
|
||||
|
||||
fn int_full(
|
||||
shape: Shape,
|
||||
fill_value: Scalar,
|
||||
device: &DispatchDevice,
|
||||
dtype: IntDType,
|
||||
) -> IntTensor<Self> {
|
||||
creation_op!(Int, device, |device| B::int_full(
|
||||
shape, fill_value, device, dtype
|
||||
))
|
||||
}
|
||||
|
||||
fn int_mean(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_mean(tensor) => Int)
|
||||
}
|
||||
|
||||
fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_max(tensor) => Int)
|
||||
}
|
||||
|
||||
fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_max_dim(tensor, dim) => Int)
|
||||
}
|
||||
|
||||
fn int_max_dim_with_indices(
|
||||
tensor: IntTensor<Self>,
|
||||
dim: usize,
|
||||
) -> (IntTensor<Self>, IntTensor<Self>) {
|
||||
multi_op!(
|
||||
inputs[(tensor, int)],
|
||||
outputs[(out, Int), (indices, Int)],
|
||||
B::int_max_dim_with_indices(tensor, dim)
|
||||
)
|
||||
}
|
||||
|
||||
fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_max_abs(tensor) => Int)
|
||||
}
|
||||
|
||||
fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_max_abs_dim(tensor, dim) => Int)
|
||||
}
|
||||
|
||||
fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_min(tensor) => Int)
|
||||
}
|
||||
|
||||
fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_min_dim(tensor, dim) => Int)
|
||||
}
|
||||
|
||||
fn int_min_dim_with_indices(
|
||||
tensor: IntTensor<Self>,
|
||||
dim: usize,
|
||||
) -> (IntTensor<Self>, IntTensor<Self>) {
|
||||
multi_op!(
|
||||
inputs[(tensor, int)],
|
||||
outputs[(out, Int), (indices, Int)],
|
||||
B::int_min_dim_with_indices(tensor, dim)
|
||||
)
|
||||
}
|
||||
|
||||
fn int_transpose(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_transpose(tensor) => Int)
|
||||
}
|
||||
|
||||
fn int_arange_step(
|
||||
range: std::ops::Range<i64>,
|
||||
step: usize,
|
||||
device: &DispatchDevice,
|
||||
) -> IntTensor<Self> {
|
||||
creation_op!(Int, device, |device| B::int_arange_step(
|
||||
range, step, device
|
||||
))
|
||||
}
|
||||
|
||||
fn int_arange(range: std::ops::Range<i64>, device: &DispatchDevice) -> IntTensor<Self> {
|
||||
creation_op!(Int, device, |device| B::int_arange(range, device))
|
||||
}
|
||||
|
||||
fn int_any(tensor: IntTensor<Self>) -> BoolTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_any(tensor) => Bool)
|
||||
}
|
||||
|
||||
fn int_any_dim(tensor: IntTensor<Self>, dim: usize) -> BoolTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_any_dim(tensor, dim) => Bool)
|
||||
}
|
||||
|
||||
fn int_all(tensor: IntTensor<Self>) -> BoolTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_all(tensor) => Bool)
|
||||
}
|
||||
|
||||
fn int_all_dim(tensor: IntTensor<Self>, dim: usize) -> BoolTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_all_dim(tensor, dim) => Bool)
|
||||
}
|
||||
|
||||
fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_sign(tensor) => Int)
|
||||
}
|
||||
|
||||
fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_sort(tensor, dim, descending) => Int)
|
||||
}
|
||||
|
||||
fn int_sort_with_indices(
|
||||
tensor: IntTensor<Self>,
|
||||
dim: usize,
|
||||
descending: bool,
|
||||
) -> (IntTensor<Self>, IntTensor<Self>) {
|
||||
multi_op!(
|
||||
inputs[(tensor, int)],
|
||||
outputs[(out, Int), (indices, Int)],
|
||||
B::int_sort_with_indices(tensor, dim, descending)
|
||||
)
|
||||
}
|
||||
|
||||
fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
|
||||
unary_op!(tensor, int, |tensor| B::int_argsort(tensor, dim, descending) => Int)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
mod activation;
|
||||
mod bool_tensor;
|
||||
mod int_tensor;
|
||||
mod module;
|
||||
mod qtensor;
|
||||
mod tensor;
|
||||
mod transaction;
|
||||
@@ -0,0 +1,628 @@
|
||||
use burn_backend::{
|
||||
ops::{
|
||||
DeformConv2dBackward, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward,
|
||||
MaxPool2dWithIndices, ModuleOps,
|
||||
},
|
||||
tensor::{FloatTensor, IntTensor},
|
||||
};
|
||||
|
||||
use crate::Dispatch;
|
||||
use crate::backends::*;
|
||||
|
||||
impl ModuleOps<Self> for Dispatch {
|
||||
fn conv2d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: burn_backend::ops::ConvOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (weight, float)],
|
||||
opt_inputs[(bias, float)],
|
||||
=> Float,
|
||||
B::conv2d(x, weight, bias, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn deform_conv2d(
|
||||
x: FloatTensor<Self>,
|
||||
offset: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
mask: Option<FloatTensor<Self>>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: burn_backend::ops::DeformConvOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (offset, float), (weight, float)],
|
||||
opt_inputs[(mask, float), (bias, float)],
|
||||
=> Float,
|
||||
B::deform_conv2d(x, offset, weight, mask, bias, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn deform_conv2d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
offset: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
mask: Option<FloatTensor<Self>>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: burn_backend::ops::DeformConvOptions<2>,
|
||||
) -> DeformConv2dBackward<Self> {
|
||||
let (x_grad, offset_grad, weight_grad, mask_grad, bias_grad) = multi_op!(
|
||||
inputs[(x, float), (offset, float), (weight, float), (output_grad, float)],
|
||||
opt_inputs[(mask, float), (bias, float)],
|
||||
outputs[(x_grad, Float), (offset_grad, Float), (weight_grad, Float)],
|
||||
opt_outputs[mask_grad, bias_grad],
|
||||
{
|
||||
let res = B::deform_conv2d_backward(x, offset, weight, mask, bias, output_grad, options);
|
||||
(res.x_grad, res.offset_grad, res.weight_grad, res.mask_grad, res.bias_grad)
|
||||
}
|
||||
);
|
||||
DeformConv2dBackward::new(x_grad, offset_grad, weight_grad, mask_grad, bias_grad)
|
||||
}
|
||||
|
||||
fn conv3d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: burn_backend::ops::ConvOptions<3>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (weight, float)],
|
||||
opt_inputs[(bias, float)],
|
||||
=> Float,
|
||||
B::conv3d(x, weight, bias, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: burn_backend::ops::ConvTransposeOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (weight, float)],
|
||||
opt_inputs[(bias, float)],
|
||||
=> Float,
|
||||
B::conv_transpose2d(x, weight, bias, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv_transpose3d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: burn_backend::ops::ConvTransposeOptions<3>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (weight, float)],
|
||||
opt_inputs[(bias, float)],
|
||||
=> Float,
|
||||
B::conv_transpose3d(x, weight, bias, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn avg_pool2d(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
ceil_mode: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(inputs[(x, float)],
|
||||
=> Float,
|
||||
B::avg_pool2d(x, kernel_size, stride, padding, count_include_pad, ceil_mode)
|
||||
)
|
||||
}
|
||||
|
||||
fn avg_pool2d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
ceil_mode: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (grad, float)],
|
||||
=> Float,
|
||||
B::avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad, ceil_mode)
|
||||
)
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float)],
|
||||
=> Float,
|
||||
B::adaptive_avg_pool2d(x, output_size)
|
||||
)
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (grad, float)],
|
||||
=> Float,
|
||||
B::adaptive_avg_pool2d_backward(x, grad)
|
||||
)
|
||||
}
|
||||
|
||||
fn max_pool2d(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float)],
|
||||
=> Float,
|
||||
B::max_pool2d(x, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indices(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
) -> MaxPool2dWithIndices<Self> {
|
||||
let (out, indices) = multi_op!(
|
||||
inputs[(x, float)],
|
||||
outputs[(out, Float), (indices, Int)],
|
||||
{
|
||||
let res = B::max_pool2d_with_indices(x, kernel_size, stride, padding, dilation, ceil_mode);
|
||||
(res.output, res.indices)
|
||||
}
|
||||
);
|
||||
MaxPool2dWithIndices::new(out, indices)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indices_backward(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
output_grad: FloatTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> MaxPool2dBackward<Self> {
|
||||
let x_grad = multi_op!(
|
||||
inputs[(x, float), (output_grad, float), (indices, int)],
|
||||
=> Float,
|
||||
{
|
||||
let res = B::max_pool2d_with_indices_backward(x, kernel_size, stride, padding, dilation, ceil_mode, output_grad, indices);
|
||||
res.x_grad
|
||||
}
|
||||
);
|
||||
MaxPool2dBackward::new(x_grad)
|
||||
}
|
||||
|
||||
fn interpolate(
|
||||
x: FloatTensor<Self>,
|
||||
output_size: [usize; 2],
|
||||
options: burn_backend::ops::InterpolateOptions,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float)],
|
||||
=> Float,
|
||||
B::interpolate(x, output_size, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn interpolate_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
output_size: [usize; 2],
|
||||
options: burn_backend::ops::InterpolateOptions,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (grad, float)],
|
||||
=> Float,
|
||||
B::interpolate_backward(x, grad, output_size, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn embedding(weights: FloatTensor<Self>, indices: IntTensor<Self>) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(weights, float), (indices, int)],
|
||||
=> Float,
|
||||
B::embedding(weights, indices)
|
||||
)
|
||||
}
|
||||
|
||||
fn embedding_backward(
|
||||
weights: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(weights, float), (output_grad, float), (indices, int)],
|
||||
=> Float,
|
||||
B::embedding_backward(weights, output_grad, indices)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: burn_backend::ops::ConvOptions<1>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (weight, float)],
|
||||
opt_inputs[(bias, float)],
|
||||
=> Float,
|
||||
B::conv1d(x, weight, bias, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv1d_x_backward(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: burn_backend::ops::ConvOptions<1>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (weight, float), (output_grad, float)],
|
||||
=> Float,
|
||||
B::conv1d_x_backward(x, weight, output_grad, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv1d_weight_backward(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: burn_backend::ops::ConvOptions<1>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (weight, float), (output_grad, float)],
|
||||
=> Float,
|
||||
B::conv1d_weight_backward(x, weight, output_grad, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv1d_bias_backward(
|
||||
x: FloatTensor<Self>,
|
||||
bias: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (bias, float), (output_grad, float)],
|
||||
=> Float,
|
||||
B::conv1d_bias_backward(x, bias, output_grad)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv2d_x_backward(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: burn_backend::ops::ConvOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (weight, float), (output_grad, float)],
|
||||
=> Float,
|
||||
B::conv2d_x_backward(x, weight, output_grad, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv2d_weight_backward(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: burn_backend::ops::ConvOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (weight, float), (output_grad, float)],
|
||||
=> Float,
|
||||
B::conv2d_weight_backward(x, weight, output_grad, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv2d_bias_backward(
|
||||
x: FloatTensor<Self>,
|
||||
bias: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (bias, float), (output_grad, float)],
|
||||
=> Float,
|
||||
B::conv2d_bias_backward(x, bias, output_grad)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv3d_x_backward(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: burn_backend::ops::ConvOptions<3>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (weight, float), (output_grad, float)],
|
||||
=> Float,
|
||||
B::conv3d_x_backward(x, weight, output_grad, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv3d_weight_backward(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: burn_backend::ops::ConvOptions<3>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (weight, float), (output_grad, float)],
|
||||
=> Float,
|
||||
B::conv3d_weight_backward(x, weight, output_grad, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv3d_bias_backward(
|
||||
x: FloatTensor<Self>,
|
||||
bias: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (bias, float), (output_grad, float)],
|
||||
=> Float,
|
||||
B::conv3d_bias_backward(x, bias, output_grad)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: burn_backend::ops::ConvTransposeOptions<1>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (weight, float)],
|
||||
opt_inputs[(bias, float)],
|
||||
=> Float,
|
||||
B::conv_transpose1d(x, weight, bias, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv_transpose1d_x_backward(
|
||||
weight: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: burn_backend::ops::ConvTransposeOptions<1>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(weight, float), (output_grad, float)],
|
||||
=> Float,
|
||||
B::conv_transpose1d_x_backward(weight, output_grad, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv_transpose1d_weight_backward(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: burn_backend::ops::ConvTransposeOptions<1>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (weight, float), (output_grad, float)],
|
||||
=> Float,
|
||||
B::conv_transpose1d_weight_backward(x, weight, output_grad, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv_transpose1d_bias_backward(
|
||||
x: FloatTensor<Self>,
|
||||
bias: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (bias, float), (output_grad, float)],
|
||||
=> Float,
|
||||
B::conv_transpose1d_bias_backward(x, bias, output_grad)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv_transpose2d_x_backward(
|
||||
weight: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: burn_backend::ops::ConvTransposeOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(weight, float), (output_grad, float)],
|
||||
=> Float,
|
||||
B::conv_transpose2d_x_backward(weight, output_grad, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv_transpose2d_weight_backward(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: burn_backend::ops::ConvTransposeOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (weight, float), (output_grad, float)],
|
||||
=> Float,
|
||||
B::conv_transpose2d_weight_backward(x, weight, output_grad, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv_transpose2d_bias_backward(
|
||||
x: FloatTensor<Self>,
|
||||
bias: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (bias, float), (output_grad, float)],
|
||||
=> Float,
|
||||
B::conv_transpose2d_bias_backward(x, bias, output_grad)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv_transpose3d_x_backward(
|
||||
weight: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: burn_backend::ops::ConvTransposeOptions<3>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(weight, float), (output_grad, float)],
|
||||
=> Float,
|
||||
B::conv_transpose3d_x_backward(weight, output_grad, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv_transpose3d_weight_backward(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: burn_backend::ops::ConvTransposeOptions<3>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (weight, float), (output_grad, float)],
|
||||
=> Float,
|
||||
B::conv_transpose3d_weight_backward(x, weight, output_grad, options)
|
||||
)
|
||||
}
|
||||
|
||||
fn conv_transpose3d_bias_backward(
|
||||
x: FloatTensor<Self>,
|
||||
bias: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (bias, float), (output_grad, float)],
|
||||
=> Float,
|
||||
B::conv_transpose3d_bias_backward(x, bias, output_grad)
|
||||
)
|
||||
}
|
||||
|
||||
fn unfold4d(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
options: burn_backend::ops::UnfoldOptions,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(inputs[(x, float)], => Float, B::unfold4d(x, kernel_size, options))
|
||||
}
|
||||
|
||||
fn avg_pool1d(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
count_include_pad: bool,
|
||||
ceil_mode: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(inputs[(x, float)], => Float,
|
||||
B::avg_pool1d(x, kernel_size, stride, padding, count_include_pad, ceil_mode)
|
||||
)
|
||||
}
|
||||
|
||||
fn avg_pool1d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
count_include_pad: bool,
|
||||
ceil_mode: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (grad, float)],
|
||||
=> Float,
|
||||
B::avg_pool1d_backward(x, grad, kernel_size, stride, padding, count_include_pad, ceil_mode)
|
||||
)
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {
|
||||
multi_op!(inputs[(x, float)], => Float, B::adaptive_avg_pool1d(x, output_size))
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool1d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(x, float), (grad, float)],
|
||||
=> Float,
|
||||
B::adaptive_avg_pool1d_backward(x, grad)
|
||||
)
|
||||
}
|
||||
|
||||
fn max_pool1d(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
ceil_mode: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(inputs[(x, float)], => Float,
|
||||
B::max_pool1d(x, kernel_size, stride, padding, dilation, ceil_mode))
|
||||
}
|
||||
|
||||
fn max_pool1d_with_indices(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
ceil_mode: bool,
|
||||
) -> MaxPool1dWithIndices<Self> {
|
||||
let (out, indices) = multi_op!(
|
||||
inputs[(x, float)],
|
||||
outputs[(out, Float), (indices, Int)],
|
||||
{
|
||||
let res = B::max_pool1d_with_indices(x, kernel_size, stride, padding, dilation, ceil_mode);
|
||||
(res.output, res.indices)
|
||||
}
|
||||
);
|
||||
MaxPool1dWithIndices::new(out, indices)
|
||||
}
|
||||
|
||||
fn max_pool1d_with_indices_backward(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
ceil_mode: bool,
|
||||
output_grad: FloatTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> MaxPool1dBackward<Self> {
|
||||
let x_grad = multi_op!(
|
||||
inputs[(x, float), (output_grad, float), (indices, int)],
|
||||
=> Float,
|
||||
{
|
||||
let res = B::max_pool1d_with_indices_backward(x, kernel_size, stride, padding, dilation, ceil_mode, output_grad, indices);
|
||||
res.x_grad
|
||||
}
|
||||
);
|
||||
MaxPool1dBackward::new(x_grad)
|
||||
}
|
||||
|
||||
fn attention(
|
||||
query: FloatTensor<Self>,
|
||||
key: FloatTensor<Self>,
|
||||
value: FloatTensor<Self>,
|
||||
mask: Option<burn_backend::tensor::BoolTensor<Self>>,
|
||||
attn_bias: Option<FloatTensor<Self>>,
|
||||
options: burn_backend::ops::AttentionModuleOptions,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(query, float), (key, float), (value, float)],
|
||||
opt_inputs[(mask, bool), (attn_bias, float)],
|
||||
=> Float,
|
||||
B::attention(query, key, value, mask, attn_bias, options)
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,212 @@
|
||||
use burn_backend::{
|
||||
ExecutionError, QTensorPrimitive, TensorData, TensorPrimitive,
|
||||
ops::QTensorOps,
|
||||
quantization::QuantizationParametersPrimitive,
|
||||
tensor::{FloatTensor, IntTensor, QuantizedTensor},
|
||||
};
|
||||
use burn_std::{QuantPropagation, Shape, Slice};
|
||||
|
||||
use crate::backends::*;
|
||||
use crate::{Dispatch, DispatchDevice};
|
||||
|
||||
impl QTensorOps<Self> for Dispatch {
|
||||
fn q_from_data(data: TensorData, device: &DispatchDevice) -> QuantizedTensor<Self> {
|
||||
creation_op!(Quantized, device, |device| B::q_from_data(data, device))
|
||||
}
|
||||
|
||||
fn quantize(
|
||||
tensor: FloatTensor<Self>,
|
||||
scheme: &burn_std::QuantScheme,
|
||||
qparams: QuantizationParametersPrimitive<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
binary_op!(
|
||||
(tensor, float),
|
||||
(qparams.scales, float),
|
||||
|tensor, scales| {
|
||||
B::quantize(tensor, scheme, QuantizationParametersPrimitive { scales })
|
||||
} => Quantized
|
||||
)
|
||||
}
|
||||
|
||||
fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_op!(tensor, quantized, |tensor| B::dequantize(tensor) => Float)
|
||||
}
|
||||
|
||||
fn q_device(tensor: &QuantizedTensor<Self>) -> DispatchDevice {
|
||||
tensor.device()
|
||||
}
|
||||
|
||||
fn q_to_device(
|
||||
tensor: QuantizedTensor<Self>,
|
||||
device: &DispatchDevice,
|
||||
) -> QuantizedTensor<Self> {
|
||||
to_device!(
|
||||
Quantized,
|
||||
quantized,
|
||||
tensor,
|
||||
device,
|
||||
q_to_device,
|
||||
|inner, device| {
|
||||
let data =
|
||||
burn_backend::read_sync(B1::q_into_data(inner)).expect("Should read data");
|
||||
B2::q_from_data(data, device)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
|
||||
unary_op!(tensor, quantized, |tensor| B::q_reshape(tensor, shape) => Quantized)
|
||||
}
|
||||
|
||||
async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
|
||||
unary_op!(tensor, quantized, |tensor| B::q_into_data(tensor).await)
|
||||
}
|
||||
|
||||
fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
|
||||
unary_op!(tensor, quantized, |tensor| B::q_expand(tensor, shape) => Quantized)
|
||||
}
|
||||
|
||||
fn q_swap_dims(
|
||||
tensor: QuantizedTensor<Self>,
|
||||
dim1: usize,
|
||||
dim2: usize,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unary_op!(tensor, quantized, |tensor| B::q_swap_dims(tensor, dim1, dim2) => Quantized)
|
||||
}
|
||||
|
||||
fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
|
||||
unary_op!(tensor, quantized, |tensor| B::q_permute(tensor, axes) => Quantized)
|
||||
}
|
||||
|
||||
fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
|
||||
unary_op!(tensor, quantized, |tensor| B::q_flip(tensor, axes) => Quantized)
|
||||
}
|
||||
|
||||
fn q_select(
|
||||
tensor: QuantizedTensor<Self>,
|
||||
dim: usize,
|
||||
indices: IntTensor<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
binary_op!(
|
||||
(tensor, quantized),
|
||||
(indices, int),
|
||||
|tensor, indices| B::q_select(tensor, dim, indices) => Quantized
|
||||
)
|
||||
}
|
||||
|
||||
fn q_slice(tensor: QuantizedTensor<Self>, slices: &[Slice]) -> QuantizedTensor<Self> {
|
||||
unary_op!(tensor, quantized, |tensor| B::q_slice(tensor, slices) => Quantized)
|
||||
}
|
||||
|
||||
fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {
|
||||
// TODO: this would be much cleaner if we consolidated tensor primitive types
|
||||
match (lhs, rhs) {
|
||||
(TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
|
||||
if matches!(lhs.propagation(), QuantPropagation::Propagate) {
|
||||
let out = binary_op!(
|
||||
(lhs, quantized),
|
||||
(rhs, quantized),
|
||||
|lhs, rhs| {
|
||||
if let TensorPrimitive::QFloat(out) = B::q_matmul(
|
||||
TensorPrimitive::QFloat(lhs),
|
||||
TensorPrimitive::QFloat(rhs),
|
||||
) {
|
||||
out
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
} => Quantized
|
||||
);
|
||||
TensorPrimitive::QFloat(out)
|
||||
} else {
|
||||
let out = binary_op!(
|
||||
(lhs, quantized),
|
||||
(rhs, quantized),
|
||||
|lhs, rhs| {
|
||||
if let TensorPrimitive::Float(out) = B::q_matmul(
|
||||
TensorPrimitive::QFloat(lhs),
|
||||
TensorPrimitive::QFloat(rhs),
|
||||
) {
|
||||
out
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
} => Float
|
||||
);
|
||||
TensorPrimitive::Float(out)
|
||||
}
|
||||
}
|
||||
(TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
|
||||
if matches!(rhs.propagation(), QuantPropagation::Propagate) {
|
||||
let out = binary_op!(
|
||||
(lhs, float),
|
||||
(rhs, quantized),
|
||||
|lhs, rhs| {
|
||||
if let TensorPrimitive::QFloat(out) = B::q_matmul(
|
||||
TensorPrimitive::Float(lhs),
|
||||
TensorPrimitive::QFloat(rhs),
|
||||
) {
|
||||
out
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
} => Quantized
|
||||
);
|
||||
TensorPrimitive::QFloat(out)
|
||||
} else {
|
||||
let out = binary_op!(
|
||||
(lhs, float),
|
||||
(rhs, quantized),
|
||||
|lhs, rhs| {
|
||||
if let TensorPrimitive::Float(out) = B::q_matmul(
|
||||
TensorPrimitive::Float(lhs),
|
||||
TensorPrimitive::QFloat(rhs),
|
||||
) {
|
||||
out
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
} => Float
|
||||
);
|
||||
TensorPrimitive::Float(out)
|
||||
}
|
||||
}
|
||||
(TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
|
||||
if matches!(lhs.propagation(), QuantPropagation::Propagate) {
|
||||
let out = binary_op!(
|
||||
(lhs, quantized),
|
||||
(rhs, float),
|
||||
|lhs, rhs| {
|
||||
if let TensorPrimitive::QFloat(out) = B::q_matmul(
|
||||
TensorPrimitive::QFloat(lhs),
|
||||
TensorPrimitive::Float(rhs),
|
||||
) {
|
||||
out
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
} => Quantized
|
||||
);
|
||||
TensorPrimitive::QFloat(out)
|
||||
} else {
|
||||
let out = binary_op!(
|
||||
(lhs, quantized),
|
||||
(rhs, float),
|
||||
|lhs, rhs| {
|
||||
if let TensorPrimitive::Float(out) = B::q_matmul(
|
||||
TensorPrimitive::QFloat(lhs),
|
||||
TensorPrimitive::Float(rhs),
|
||||
) {
|
||||
out
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
} => Float
|
||||
);
|
||||
TensorPrimitive::Float(out)
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,594 @@
|
||||
use burn_backend::{
|
||||
ExecutionError, Scalar, TensorData,
|
||||
ops::FloatTensorOps,
|
||||
tensor::{BoolTensor, FloatTensor, IntTensor},
|
||||
};
|
||||
use burn_std::{FloatDType, Shape, Slice};
|
||||
|
||||
use crate::backends::*;
|
||||
use crate::{Dispatch, DispatchDevice};
|
||||
|
||||
// TODO: remove backend default elem type genericsnow that we have per-device defaults
|
||||
// https://github.com/tracel-ai/burn/issues/3642
|
||||
|
||||
impl FloatTensorOps<Self> for Dispatch {
|
||||
fn float_from_data(
|
||||
data: burn_backend::TensorData,
|
||||
device: &DispatchDevice,
|
||||
) -> FloatTensor<Self> {
|
||||
creation_op!(Float, device, |device| B::float_from_data(data, device))
|
||||
}
|
||||
|
||||
fn float_random(
|
||||
shape: Shape,
|
||||
distribution: burn_backend::Distribution,
|
||||
device: &DispatchDevice,
|
||||
) -> FloatTensor<Self> {
|
||||
creation_op!(Float, device, |device| {
|
||||
B::float_random(shape, distribution, device)
|
||||
})
|
||||
}
|
||||
|
||||
async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
|
||||
unary_float!(tensor, float, |tensor| B::float_into_data(tensor).await)
|
||||
}
|
||||
|
||||
fn float_device(tensor: &FloatTensor<Self>) -> DispatchDevice {
|
||||
tensor.device()
|
||||
}
|
||||
|
||||
fn float_to_device(tensor: FloatTensor<Self>, device: &DispatchDevice) -> FloatTensor<Self> {
|
||||
float_to_device!(
|
||||
Float,
|
||||
float,
|
||||
tensor,
|
||||
device,
|
||||
float_to_device,
|
||||
|inner, device| {
|
||||
let data =
|
||||
burn_backend::read_sync(B1::float_into_data(inner)).expect("Should read data");
|
||||
B2::float_from_data(data, device)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_into_int(tensor) => Int)
|
||||
}
|
||||
|
||||
fn float_empty(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
|
||||
creation_op!(Float, device, |device| B::float_empty(shape, device, dtype))
|
||||
}
|
||||
|
||||
fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_add(lhs, rhs) => Float)
|
||||
}
|
||||
|
||||
fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
|
||||
unary_float!(lhs, float, |lhs| B::float_add_scalar(lhs, rhs) => Float)
|
||||
}
|
||||
|
||||
fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_sub(lhs, rhs) => Float)
|
||||
}
|
||||
|
||||
fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
|
||||
unary_float!(lhs, float, |lhs| B::float_sub_scalar(lhs, rhs) => Float)
|
||||
}
|
||||
|
||||
fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_mul(lhs, rhs) => Float)
|
||||
}
|
||||
|
||||
fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
|
||||
unary_float!(lhs, float, |lhs| B::float_mul_scalar(lhs, rhs) => Float)
|
||||
}
|
||||
|
||||
fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_div(lhs, rhs) => Float)
|
||||
}
|
||||
|
||||
fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
|
||||
unary_float!(lhs, float, |lhs| B::float_div_scalar(lhs, rhs) => Float)
|
||||
}
|
||||
|
||||
fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_remainder(lhs, rhs) => Float)
|
||||
}
|
||||
|
||||
fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
|
||||
unary_float!(lhs, float, |lhs| B::float_remainder_scalar(lhs, rhs) => Float)
|
||||
}
|
||||
|
||||
fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_matmul(lhs, rhs) => Float)
|
||||
}
|
||||
|
||||
fn float_cross(
|
||||
lhs: FloatTensor<Self>,
|
||||
rhs: FloatTensor<Self>,
|
||||
dim: usize,
|
||||
) -> FloatTensor<Self> {
|
||||
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_cross(lhs, rhs, dim) => Float)
|
||||
}
|
||||
|
||||
fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_recip(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_swap_dims(tensor, dim1, dim2) => Float)
|
||||
}
|
||||
|
||||
fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_permute(tensor, axes) => Float)
|
||||
}
|
||||
|
||||
fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_flip(tensor, axes) => Float)
|
||||
}
|
||||
|
||||
fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_reshape(tensor, shape) => Float)
|
||||
}
|
||||
|
||||
fn float_gather(
|
||||
dim: usize,
|
||||
tensor: FloatTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_gather(dim, tensor, indices) => Float)
|
||||
}
|
||||
|
||||
fn float_scatter_add(
|
||||
dim: usize,
|
||||
tensor: FloatTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
value: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(tensor, float), (indices, int), (value, float)], => Float,
|
||||
B::float_scatter_add(dim, tensor, indices, value)
|
||||
)
|
||||
}
|
||||
|
||||
fn float_select(
|
||||
tensor: FloatTensor<Self>,
|
||||
dim: usize,
|
||||
indices: IntTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_select(tensor, dim, indices) => Float)
|
||||
}
|
||||
|
||||
fn float_select_add(
|
||||
tensor: FloatTensor<Self>,
|
||||
dim: usize,
|
||||
indices: IntTensor<Self>,
|
||||
value: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(tensor, float), (indices, int), (value, float)], => Float,
|
||||
B::float_select_add(tensor, dim, indices, value)
|
||||
)
|
||||
}
|
||||
|
||||
fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_slice(tensor, slices) => Float)
|
||||
}
|
||||
|
||||
fn float_slice_assign(
|
||||
tensor: FloatTensor<Self>,
|
||||
slices: &[Slice],
|
||||
value: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
binary_float!((tensor, float), (value, float), |tensor, value| B::float_slice_assign(tensor, slices, value) => Float)
|
||||
}
|
||||
|
||||
fn float_mask_where(
|
||||
tensor: FloatTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
multi_op!(
|
||||
inputs[(tensor, float), (mask, bool), (value, float)], => Float,
|
||||
B::float_mask_where(tensor, mask, value)
|
||||
)
|
||||
}
|
||||
|
||||
fn float_mask_fill(
|
||||
tensor: FloatTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: Scalar,
|
||||
) -> FloatTensor<Self> {
|
||||
binary_float!((tensor, float), (mask, bool), |tensor, mask| B::float_mask_fill(tensor, mask, value) => Float)
|
||||
}
|
||||
|
||||
fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
|
||||
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_equal(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
unary_float!(lhs, float, |lhs| B::float_equal_elem(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
|
||||
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
unary_float!(lhs, float, |lhs| B::float_greater_elem(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
|
||||
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater_equal(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
unary_float!(lhs, float, |lhs| B::float_greater_equal_elem(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
|
||||
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
unary_float!(lhs, float, |lhs| B::float_lower_elem(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
|
||||
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower_equal(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
unary_float!(lhs, float, |lhs| B::float_lower_equal_elem(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_sum(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_sum_dim(tensor, dim) => Float)
|
||||
}
|
||||
|
||||
fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_mean_dim(tensor, dim) => Float)
|
||||
}
|
||||
|
||||
fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_cumsum(tensor, dim) => Float)
|
||||
}
|
||||
|
||||
fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_cumprod(tensor, dim) => Float)
|
||||
}
|
||||
|
||||
fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_cummin(tensor, dim) => Float)
|
||||
}
|
||||
|
||||
fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_cummax(tensor, dim) => Float)
|
||||
}
|
||||
|
||||
fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_cast(tensor, dtype) => Float)
|
||||
}
|
||||
|
||||
fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_exp(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_log(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_log1p(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_powf(lhs, rhs) => Float)
|
||||
}
|
||||
|
||||
fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_powf_scalar_impl(tensor, value) => Float)
|
||||
}
|
||||
|
||||
fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_sqrt(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_abs(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_cos(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_sin(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_tan(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_cosh(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_sinh(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_tanh(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_acos(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_acosh(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_asin(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_asinh(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_atan(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_atanh(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_atan2(lhs, rhs) => Float)
|
||||
}
|
||||
|
||||
fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_round(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_floor(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_ceil(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_trunc(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_erf(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_argmax(tensor, dim) => Int)
|
||||
}
|
||||
|
||||
fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_argmin(tensor, dim) => Int)
|
||||
}
|
||||
|
||||
fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_expand(tensor, shape) => Float)
|
||||
}
|
||||
|
||||
fn float_unfold(
|
||||
tensor: FloatTensor<Self>,
|
||||
dim: usize,
|
||||
size: usize,
|
||||
step: usize,
|
||||
) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| {
|
||||
B::float_unfold(tensor, dim, size, step)
|
||||
} => Float)
|
||||
}
|
||||
|
||||
fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_detach(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_set_require_grad(tensor: FloatTensor<Self>, require_grad: bool) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_set_require_grad(tensor, require_grad) => Float)
|
||||
}
|
||||
|
||||
fn float_is_require_grad(tensor: &FloatTensor<Self>) -> bool {
|
||||
unary_float!(ref tensor, float, |tensor| B::float_is_require_grad(tensor))
|
||||
}
|
||||
|
||||
// Default implementation
|
||||
fn float_zeros(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
|
||||
creation_op!(Float, device, |device| B::float_zeros(shape, device, dtype))
|
||||
}
|
||||
|
||||
fn float_ones(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
|
||||
creation_op!(Float, device, |device| B::float_ones(shape, device, dtype))
|
||||
}
|
||||
|
||||
fn float_full(
|
||||
shape: Shape,
|
||||
fill_value: Scalar,
|
||||
device: &DispatchDevice,
|
||||
dtype: FloatDType,
|
||||
) -> FloatTensor<Self> {
|
||||
creation_op!(Float, device, |device| B::float_full(
|
||||
shape, fill_value, device, dtype
|
||||
))
|
||||
}
|
||||
|
||||
fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_repeat_dim(tensor, dim, times) => Float)
|
||||
}
|
||||
|
||||
fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_clamp_min(tensor, min) => Float)
|
||||
}
|
||||
|
||||
fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_clamp_max(tensor, max) => Float)
|
||||
}
|
||||
|
||||
fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_clamp(tensor, min, max) => Float)
|
||||
}
|
||||
|
||||
fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_neg(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_transpose(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_transpose(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_not_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
|
||||
binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_not_equal(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn float_not_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
unary_float!(lhs, float, |lhs| B::float_not_equal_elem(lhs, rhs) => Bool)
|
||||
}
|
||||
|
||||
fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_prod(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_prod_dim(tensor, dim) => Float)
|
||||
}
|
||||
|
||||
fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_mean(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_powi(lhs: FloatTensor<Self>, rhs: IntTensor<Self>) -> FloatTensor<Self> {
|
||||
binary_float!((lhs, float), (rhs, int), |lhs, rhs| B::float_powi(lhs, rhs) => Float)
|
||||
}
|
||||
|
||||
fn float_powi_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
|
||||
unary_float!(lhs, float, |lhs| B::float_powi_scalar_impl(lhs, rhs) => Float)
|
||||
}
|
||||
|
||||
fn float_powf_scalar(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_powf_scalar(tensor, value) => Float)
|
||||
}
|
||||
|
||||
fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
|
||||
vec_op!(tensors, float, |tensors| B::float_cat(tensors, dim) => Float)
|
||||
}
|
||||
|
||||
fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_max(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_max_dim(tensor, dim) => Float)
|
||||
}
|
||||
|
||||
fn float_max_dim_with_indices(
|
||||
tensor: FloatTensor<Self>,
|
||||
dim: usize,
|
||||
) -> (FloatTensor<Self>, IntTensor<Self>) {
|
||||
multi_op!(
|
||||
inputs[(tensor, float)],
|
||||
outputs[(out, Float), (indices, Int)],
|
||||
B::float_max_dim_with_indices(tensor, dim)
|
||||
)
|
||||
}
|
||||
|
||||
fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_min(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_min_dim(tensor, dim) => Float)
|
||||
}
|
||||
|
||||
fn float_min_dim_with_indices(
|
||||
tensor: FloatTensor<Self>,
|
||||
dim: usize,
|
||||
) -> (FloatTensor<Self>, IntTensor<Self>) {
|
||||
multi_op!(
|
||||
inputs[(tensor, float)],
|
||||
outputs[(out, Float), (indices, Int)],
|
||||
B::float_min_dim_with_indices(tensor, dim)
|
||||
)
|
||||
}
|
||||
|
||||
fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_max_abs(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_max_abs_dim(tensor, dim) => Float)
|
||||
}
|
||||
|
||||
fn float_any(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_any(tensor) => Bool)
|
||||
}
|
||||
|
||||
fn float_any_dim(tensor: FloatTensor<Self>, dim: usize) -> BoolTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_any_dim(tensor, dim) => Bool)
|
||||
}
|
||||
|
||||
fn float_all(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_all(tensor) => Bool)
|
||||
}
|
||||
|
||||
fn float_all_dim(tensor: FloatTensor<Self>, dim: usize) -> BoolTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_all_dim(tensor, dim) => Bool)
|
||||
}
|
||||
|
||||
fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_sign(tensor) => Float)
|
||||
}
|
||||
|
||||
fn float_sort(tensor: FloatTensor<Self>, dim: usize, descending: bool) -> FloatTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_sort(tensor, dim, descending) => Float)
|
||||
}
|
||||
|
||||
fn float_sort_with_indices(
|
||||
tensor: FloatTensor<Self>,
|
||||
dim: usize,
|
||||
descending: bool,
|
||||
) -> (FloatTensor<Self>, IntTensor<Self>) {
|
||||
multi_op!(
|
||||
inputs[(tensor, float)],
|
||||
outputs[(out, Float), (indices, Int)],
|
||||
B::float_sort_with_indices(tensor, dim, descending)
|
||||
)
|
||||
}
|
||||
|
||||
fn float_argsort(tensor: FloatTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_argsort(tensor, dim, descending) => Int)
|
||||
}
|
||||
|
||||
fn float_grid_sample_2d(
|
||||
tensor: FloatTensor<Self>,
|
||||
grid: FloatTensor<Self>,
|
||||
options: burn_backend::ops::GridSampleOptions,
|
||||
) -> FloatTensor<Self> {
|
||||
binary_float!((tensor, float), (grid, float), |tensor, grid| B::float_grid_sample_2d(tensor, grid, options) => Float)
|
||||
}
|
||||
|
||||
fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_is_nan(tensor) => Bool)
|
||||
}
|
||||
|
||||
fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
|
||||
unary_float!(tensor, float, |tensor| B::float_is_inf(tensor) => Bool)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
use burn_backend::{
|
||||
ExecutionError,
|
||||
ops::{TransactionOps, TransactionPrimitive, TransactionPrimitiveData},
|
||||
};
|
||||
|
||||
use crate::Dispatch;
|
||||
use crate::backends::*;
|
||||
|
||||
impl TransactionOps<Self> for Dispatch {
|
||||
async fn tr_execute(
|
||||
transaction: TransactionPrimitive<Self>,
|
||||
) -> Result<TransactionPrimitiveData, ExecutionError> {
|
||||
let first_tensor = transaction
|
||||
.read_floats
|
||||
.first()
|
||||
.or(transaction.read_ints.first())
|
||||
.or(transaction.read_bools.first());
|
||||
|
||||
match first_tensor {
|
||||
Some(tensor) => {
|
||||
transaction_op!(transaction, tensor)
|
||||
}
|
||||
None => Ok(TransactionPrimitiveData::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,274 @@
|
||||
use burn_backend::{Backend, QTensorPrimitive, TensorMetadata};
|
||||
|
||||
use crate::backends::*;
|
||||
|
||||
#[cfg(feature = "autodiff")]
|
||||
use burn_backend::tensor::FloatTensor;
|
||||
|
||||
// TODO: if we reduce the different associated types for float/int/bool/quantized tensor primitives down to a single
|
||||
// `B::TensorPrimitive` we can simplify this.
|
||||
|
||||
/// Tensor which points to a backend tensor primitive kind.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum BackendTensor<B: Backend> {
|
||||
/// Float tensor handle.
|
||||
Float(B::FloatTensorPrimitive),
|
||||
/// Int tensor handle.
|
||||
Int(B::IntTensorPrimitive),
|
||||
/// Bool tensor handle.
|
||||
Bool(B::BoolTensorPrimitive),
|
||||
/// Quantized tensor handle.
|
||||
Quantized(B::QuantizedTensorPrimitive),
|
||||
#[cfg(feature = "autodiff")]
|
||||
/// Autodiff float tensor handle.
|
||||
Autodiff(FloatTensor<Autodiff<B>>),
|
||||
}
|
||||
|
||||
impl<B: Backend> BackendTensor<B> {
|
||||
/// Returns the inner float tensor primitive.
|
||||
pub(crate) fn float(self) -> B::FloatTensorPrimitive {
|
||||
match self {
|
||||
BackendTensor::Float(tensor) => tensor,
|
||||
BackendTensor::Int(_) => panic!("Should be float, got int"),
|
||||
BackendTensor::Bool(_) => panic!("Should be float, got bool"),
|
||||
BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
|
||||
#[cfg(feature = "autodiff")]
|
||||
BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
|
||||
}
|
||||
}
|
||||
/// Returns the inner float tensor primitive.
|
||||
pub(crate) fn as_float(&self) -> &B::FloatTensorPrimitive {
|
||||
match self {
|
||||
BackendTensor::Float(tensor) => tensor,
|
||||
BackendTensor::Int(_) => panic!("Should be float, got int"),
|
||||
BackendTensor::Bool(_) => panic!("Should be float, got bool"),
|
||||
BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
|
||||
#[cfg(feature = "autodiff")]
|
||||
BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the inner int tensor primitive.
|
||||
pub(crate) fn int(self) -> B::IntTensorPrimitive {
|
||||
match self {
|
||||
BackendTensor::Int(tensor) => tensor,
|
||||
BackendTensor::Float(_) => panic!("Should be int, got float"),
|
||||
BackendTensor::Bool(_) => panic!("Should be int, got bool"),
|
||||
BackendTensor::Quantized(_) => panic!("Should be int, got quantized"),
|
||||
#[cfg(feature = "autodiff")]
|
||||
BackendTensor::Autodiff(_) => panic!("Should be int, got autodiff"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the inner bool tensor primitive.
|
||||
pub(crate) fn bool(self) -> B::BoolTensorPrimitive {
|
||||
match self {
|
||||
BackendTensor::Bool(tensor) => tensor,
|
||||
BackendTensor::Float(_) => panic!("Should be bool, got float"),
|
||||
BackendTensor::Int(_) => panic!("Should be bool, got int"),
|
||||
BackendTensor::Quantized(_) => panic!("Should be bool, got quantized"),
|
||||
#[cfg(feature = "autodiff")]
|
||||
BackendTensor::Autodiff(_) => panic!("Should be bool, got autodiff"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the inner quantized tensor primitive.
|
||||
pub(crate) fn quantized(self) -> B::QuantizedTensorPrimitive {
|
||||
match self {
|
||||
BackendTensor::Quantized(tensor) => tensor,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "autodiff")]
|
||||
/// Returns the inner autodiff tensor primitive.
|
||||
pub(crate) fn autodiff(self) -> FloatTensor<Autodiff<B>> {
|
||||
match self {
|
||||
BackendTensor::Autodiff(tensor) => tensor,
|
||||
// NOTE: this is the panicking code reached in tensor.rs:74:18:
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "autodiff")]
|
||||
/// Returns the inner autodiff tensor primitive.
|
||||
pub(crate) fn as_autodiff(&self) -> &FloatTensor<Autodiff<B>> {
|
||||
match self {
|
||||
BackendTensor::Autodiff(tensor) => tensor,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "autodiff")]
|
||||
/// Returns the inner autodiff tensor primitive.
|
||||
pub(crate) fn autodiff_inner(self) -> B::FloatTensorPrimitive {
|
||||
match self {
|
||||
BackendTensor::Autodiff(tensor) => tensor.primitive,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the backend device.
|
||||
pub(crate) fn device(&self) -> B::Device {
|
||||
match self {
|
||||
BackendTensor::Float(tensor) => B::float_device(tensor),
|
||||
BackendTensor::Int(tensor) => B::int_device(tensor),
|
||||
BackendTensor::Bool(tensor) => B::bool_device(tensor),
|
||||
BackendTensor::Quantized(tensor) => B::q_device(tensor),
|
||||
#[cfg(feature = "autodiff")]
|
||||
BackendTensor::Autodiff(tensor) => B::float_device(&tensor.primitive),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> TensorMetadata for BackendTensor<B> {
|
||||
fn dtype(&self) -> burn_std::DType {
|
||||
match self {
|
||||
BackendTensor::Float(tensor) => tensor.dtype(),
|
||||
BackendTensor::Int(tensor) => tensor.dtype(),
|
||||
BackendTensor::Bool(tensor) => tensor.dtype(),
|
||||
BackendTensor::Quantized(tensor) => tensor.dtype(),
|
||||
#[cfg(feature = "autodiff")]
|
||||
BackendTensor::Autodiff(tensor) => tensor.dtype(),
|
||||
}
|
||||
}
|
||||
|
||||
fn shape(&self) -> burn_std::Shape {
|
||||
match self {
|
||||
BackendTensor::Float(tensor) => tensor.shape(),
|
||||
BackendTensor::Int(tensor) => tensor.shape(),
|
||||
BackendTensor::Bool(tensor) => tensor.shape(),
|
||||
BackendTensor::Quantized(tensor) => tensor.shape(),
|
||||
#[cfg(feature = "autodiff")]
|
||||
BackendTensor::Autodiff(tensor) => tensor.shape(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> QTensorPrimitive for BackendTensor<B> {
|
||||
fn scheme(&self) -> &burn_std::QuantScheme {
|
||||
match self {
|
||||
BackendTensor::Quantized(tensor) => tensor.scheme(),
|
||||
_ => panic!(
|
||||
"Quantization scheme is not valid for dtype {:?}",
|
||||
self.dtype(),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Dispatch tensor that can hold tensors from any enabled backend.
|
||||
///
|
||||
/// This enum wraps backend-specific tensor types, allowing runtime selection
|
||||
/// of the backend to execute operations on.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum DispatchTensor {
|
||||
/// The [CPU backend](Cpu) tensor.
|
||||
#[cfg(feature = "cpu")]
|
||||
Cpu(BackendTensor<Cpu>),
|
||||
|
||||
/// The [CUDA backend](Cuda) tensor.
|
||||
#[cfg(feature = "cuda")]
|
||||
Cuda(BackendTensor<Cuda>),
|
||||
|
||||
/// The [Metal backend](Metal) tensor.
|
||||
#[cfg(wgpu_metal)]
|
||||
Metal(BackendTensor<Metal>),
|
||||
|
||||
/// The [ROCm backend](Rocm) tensor.
|
||||
#[cfg(feature = "rocm")]
|
||||
Rocm(BackendTensor<Rocm>),
|
||||
|
||||
/// The [Vulkan backend](Vulkan) tensor.
|
||||
#[cfg(wgpu_vulkan)]
|
||||
Vulkan(BackendTensor<Vulkan>),
|
||||
|
||||
/// The [WebGPU backend](WebGpu) tensor.
|
||||
#[cfg(wgpu_webgpu)]
|
||||
WebGpu(BackendTensor<WebGpu>),
|
||||
|
||||
/// The [NdArray backend](NdArray) tensor.
|
||||
#[cfg(feature = "ndarray")]
|
||||
NdArray(BackendTensor<NdArray>),
|
||||
|
||||
/// The [LibTorch backend](LibTorch) tensor.
|
||||
#[cfg(feature = "tch")]
|
||||
LibTorch(BackendTensor<LibTorch>),
|
||||
|
||||
/// The [autodiff enabled backend](Autodiff) tensor.
|
||||
#[cfg(feature = "autodiff")]
|
||||
Autodiff(Box<DispatchTensor>),
|
||||
}
|
||||
|
||||
impl TensorMetadata for DispatchTensor {
|
||||
fn dtype(&self) -> burn_std::DType {
|
||||
match self {
|
||||
#[cfg(feature = "cpu")]
|
||||
DispatchTensor::Cpu(tensor) => tensor.dtype(),
|
||||
#[cfg(feature = "cuda")]
|
||||
DispatchTensor::Cuda(tensor) => tensor.dtype(),
|
||||
#[cfg(wgpu_metal)]
|
||||
DispatchTensor::Metal(tensor) => tensor.dtype(),
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchTensor::Rocm(tensor) => tensor.dtype(),
|
||||
#[cfg(wgpu_vulkan)]
|
||||
DispatchTensor::Vulkan(tensor) => tensor.dtype(),
|
||||
#[cfg(wgpu_webgpu)]
|
||||
DispatchTensor::WebGpu(tensor) => tensor.dtype(),
|
||||
#[cfg(feature = "ndarray")]
|
||||
DispatchTensor::NdArray(tensor) => tensor.dtype(),
|
||||
#[cfg(feature = "tch")]
|
||||
DispatchTensor::LibTorch(tensor) => tensor.dtype(),
|
||||
#[cfg(feature = "autodiff")]
|
||||
DispatchTensor::Autodiff(tensor) => tensor.dtype(),
|
||||
}
|
||||
}
|
||||
|
||||
fn shape(&self) -> burn_std::Shape {
|
||||
match self {
|
||||
#[cfg(feature = "cpu")]
|
||||
DispatchTensor::Cpu(tensor) => tensor.shape(),
|
||||
#[cfg(feature = "cuda")]
|
||||
DispatchTensor::Cuda(tensor) => tensor.shape(),
|
||||
#[cfg(wgpu_metal)]
|
||||
DispatchTensor::Metal(tensor) => tensor.shape(),
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchTensor::Rocm(tensor) => tensor.shape(),
|
||||
#[cfg(wgpu_vulkan)]
|
||||
DispatchTensor::Vulkan(tensor) => tensor.shape(),
|
||||
#[cfg(wgpu_webgpu)]
|
||||
DispatchTensor::WebGpu(tensor) => tensor.shape(),
|
||||
#[cfg(feature = "ndarray")]
|
||||
DispatchTensor::NdArray(tensor) => tensor.shape(),
|
||||
#[cfg(feature = "tch")]
|
||||
DispatchTensor::LibTorch(tensor) => tensor.shape(),
|
||||
#[cfg(feature = "autodiff")]
|
||||
DispatchTensor::Autodiff(tensor) => tensor.shape(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QTensorPrimitive for DispatchTensor {
|
||||
fn scheme(&self) -> &burn_std::QuantScheme {
|
||||
match self {
|
||||
#[cfg(feature = "cpu")]
|
||||
DispatchTensor::Cpu(tensor) => tensor.scheme(),
|
||||
#[cfg(feature = "cuda")]
|
||||
DispatchTensor::Cuda(tensor) => tensor.scheme(),
|
||||
#[cfg(wgpu_metal)]
|
||||
DispatchTensor::Metal(tensor) => tensor.scheme(),
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchTensor::Rocm(tensor) => tensor.scheme(),
|
||||
#[cfg(wgpu_vulkan)]
|
||||
DispatchTensor::Vulkan(tensor) => tensor.scheme(),
|
||||
#[cfg(wgpu_webgpu)]
|
||||
DispatchTensor::WebGpu(tensor) => tensor.scheme(),
|
||||
#[cfg(feature = "ndarray")]
|
||||
DispatchTensor::NdArray(tensor) => tensor.scheme(),
|
||||
#[cfg(feature = "tch")]
|
||||
DispatchTensor::LibTorch(tensor) => tensor.scheme(),
|
||||
#[cfg(feature = "autodiff")]
|
||||
DispatchTensor::Autodiff(tensor) => tensor.scheme(),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user