feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,48 @@
[package]
authors = [
"laggui <lagrange.guillaume.1@gmail.com>",
"nathanielsimard <nathaniel.simard.42@gmail.com>",
]
categories = ["science"]
description = "Multi-backend router decorator for the Burn framework"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "data"]
license.workspace = true
name = "burn-router"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-router"
documentation = "https://docs.rs/burn-router"
version.workspace = true
[lints]
workspace = true
[features]
default = ["std"]
std = ["burn-backend/std", "burn-std/std", "burn-ir/std"]
doc = ["default"]
tracing = [
"burn-backend/tracing",
"burn-ir/tracing",
"burn-std/tracing",
]
[dependencies]
burn-ir = { path = "../burn-ir", version = "=0.21.0-pre.2", default-features = false }
burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false }
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false }
hashbrown = { workspace = true }
spin = { workspace = true }
log = { workspace = true }
[dev-dependencies]
burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", default-features = false }
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" }
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", default-features = false, features = [
"std",
] }
[package.metadata.docs.rs]
features = ["doc"]
rustdoc-args = ["--cfg", "docsrs"]

View File

@@ -0,0 +1,3 @@
# Burn Router
A multi-backend extension that forwards the tensor operations to the appropriate backend.

View File

@@ -0,0 +1,75 @@
use super::{RouterTensor, RunnerChannel, RunnerClient, get_client};
use alloc::{format, string::String};
use burn_backend::{Backend, DType, ExecutionError, QTensorPrimitive, quantization::QuantScheme};
use core::marker::PhantomData;
/// A backend that forwards the tensor operations to the appropriate backend (given multiple backends).
pub struct BackendRouter<R: RunnerChannel> {
r: PhantomData<R>,
}
impl<R: RunnerChannel> core::fmt::Debug for BackendRouter<R> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("router"))
}
}
impl<R: RunnerChannel> Clone for BackendRouter<R> {
fn clone(&self) -> Self {
Self { r: PhantomData }
}
}
impl<R: RunnerChannel> Default for BackendRouter<R> {
fn default() -> Self {
Self { r: PhantomData }
}
}
impl<R: RunnerClient> QTensorPrimitive for RouterTensor<R> {
fn scheme(&self) -> &QuantScheme {
if let DType::QFloat(scheme) = &self.dtype {
scheme
} else {
// TODO: maybe `tensor.scheme()` should return an option
panic!("Expected quantized float dtype, got {:?}", self.dtype)
}
}
}
impl<R: RunnerChannel> Backend for BackendRouter<R> {
type Device = R::Device;
type FloatTensorPrimitive = RouterTensor<R::Client>;
type FloatElem = R::FloatElem;
type IntTensorPrimitive = RouterTensor<R::Client>;
type IntElem = R::IntElem;
type BoolTensorPrimitive = RouterTensor<R::Client>;
type BoolElem = R::BoolElem;
type QuantizedTensorPrimitive = RouterTensor<R::Client>;
fn name(device: &Self::Device) -> String {
format!("router<{}>", R::name(device))
}
fn seed(device: &Self::Device, seed: u64) {
let client = get_client::<R>(device);
client.seed(seed);
}
fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
let client = get_client::<R>(device);
client.sync()
}
fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
let client = get_client::<R>(device);
client.dtype_usage(dtype)
}
}

View File

@@ -0,0 +1,32 @@
use burn_backend::{Shape, backend::DeviceOps};
/// Allows tensors to be transferred between multiple backends.
pub trait MultiBackendBridge: Send + Sync + 'static {
/// The type that can be used to point to a tensor of any kind.
type TensorHandle;
/// Device type used by the backends.
type Device: DeviceOps;
/// Change the backend of the given float tensor.
fn change_backend_float(
tensor: Self::TensorHandle,
shape: Shape,
target_device: &Self::Device,
) -> Self::TensorHandle;
/// Change the backend of the given int tensor.
fn change_backend_int(
tensor: Self::TensorHandle,
shape: Shape,
target_device: &Self::Device,
) -> Self::TensorHandle;
/// Change the backend of the given bool tensor.
fn change_backend_bool(
tensor: Self::TensorHandle,
shape: Shape,
target_device: &Self::Device,
) -> Self::TensorHandle;
// TODO: change_backend_quantized
}

View File

@@ -0,0 +1,6 @@
use core::marker::PhantomData;
/// Simply transfers tensors between backends via the underlying [tensor data](burn_backend::TensorData).
pub struct ByteBridge<Backends> {
backends: PhantomData<Backends>,
}

View File

@@ -0,0 +1,5 @@
mod base;
mod byte;
pub use base::*;
pub use byte::*;

View File

@@ -0,0 +1,66 @@
use alloc::string::String;
use burn_backend::{DType, Element, Shape, backend::DeviceOps};
use burn_ir::TensorIr;
use crate::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
/// Type alias for `<Br as MultiBackendBridge>::TensorHandle`.
pub type TensorHandle<Br> = <Br as MultiBackendBridge>::TensorHandle;
/// Defines the connection channel and operations for a setup with multiple backend runner clients.
pub trait RunnerChannel: Clone + Send + Sync + 'static + Sized {
/// Device type.
type Device: DeviceOps;
/// A bridge that can transfer tensors between multiple backends.
type Bridge: MultiBackendBridge<Device = Self::Device>;
/// Client type.
type Client: RunnerClient<Device = Self::Device>;
/// Float element type.
type FloatElem: Element;
/// Int element type.
type IntElem: Element;
/// Bool element type.
type BoolElem: Element;
/// Name of the channel.
fn name(device: &Self::Device) -> String;
/// Initialize a new client for the given device.
fn init_client(device: &Self::Device) -> Self::Client;
/// Get the tensor handle corresponding to the [tensor representation](TensorIr).
fn get_tensor_handle(tensor: &TensorIr, client: &Self::Client) -> TensorHandle<Self::Bridge>;
/// Create a tensor with the given handle and shape.
fn register_tensor(
client: &Self::Client,
handle: TensorHandle<Self::Bridge>,
shape: Shape,
dtype: DType,
) -> RouterTensor<Self::Client>;
/// Change the tensor to a different client backend.
fn change_client_backend(
tensor: RouterTensor<Self::Client>,
device: &Self::Device, // target device
) -> RouterTensor<Self::Client> {
// Get tensor handle from current client
let original_client = tensor.client.clone();
let desc = tensor.into_ir();
let mut handle = Self::get_tensor_handle(&desc, &original_client);
if desc.dtype.is_float() {
handle = Self::Bridge::change_backend_float(handle, desc.shape.clone(), device);
} else if desc.dtype.is_int() {
handle = Self::Bridge::change_backend_int(handle, desc.shape.clone(), device);
} else if desc.dtype.is_bool() {
handle = Self::Bridge::change_backend_bool(handle, desc.shape.clone(), device);
} else {
unimplemented!()
}
// Register tensor handle on target client
let target_client = get_client::<Self>(device);
Self::register_tensor(&target_client, handle, desc.shape, desc.dtype)
}
}

View File

@@ -0,0 +1,16 @@
use core::marker::PhantomData;
/// A local channel with direct connection to the backend runner clients.
pub struct DirectChannel<Backends, Bridge> {
backends: PhantomData<Backends>,
bridge: PhantomData<Bridge>,
}
impl<Backends, Bridge> Clone for DirectChannel<Backends, Bridge> {
fn clone(&self) -> Self {
Self {
backends: self.backends,
bridge: self.bridge,
}
}
}

View File

@@ -0,0 +1,5 @@
mod base;
mod direct;
pub use base::*;
pub use direct::*;

View File

@@ -0,0 +1,128 @@
use crate::{RouterTensor, RunnerChannel};
use alloc::boxed::Box;
use alloc::vec::Vec;
use burn_backend::{
DType, TensorData,
backend::{DeviceId, DeviceOps, ExecutionError},
};
use burn_ir::{OperationIr, TensorId, TensorIr};
use burn_std::future::DynFut;
use core::ops::DerefMut;
use hashbrown::HashMap;
use spin::Mutex;
/// Type alias for `<R as RunnerChannel>::Client`.
pub type Client<R> = <R as RunnerChannel>::Client;
pub(crate) static CLIENTS: RunnerClientLocator = RunnerClientLocator::new();
type Key = (core::any::TypeId, DeviceId);
/// Define how to interact with the runner.
pub trait RunnerClient: Clone + Send + Sync + Sized {
/// Device type.
type Device: DeviceOps;
/// Register a new tensor operation to be executed by the (runner) server.
fn register_op(&self, op: OperationIr);
/// Register a new tensor operation to be executed by the (runner) server.
///
/// Returns the new (uninitialized) output tensor(s) generated by the registered operation.
fn register(&self, op: OperationIr) -> Vec<RouterTensor<Self>> {
let out = op
.outputs()
.map(|output| {
RouterTensor::new(output.id, output.shape.clone(), output.dtype, self.clone())
})
.collect();
self.register_op(op);
out
}
/// Read the values contained by a tensor.
fn read_tensor_async(&self, tensor: TensorIr) -> DynFut<Result<TensorData, ExecutionError>>;
/// Sync the runner, ensure that all computations are finished.
fn sync(&self) -> Result<(), ExecutionError>;
/// Create a new (uninitialized) empty tensor and returns its corresponding [tensor id](TensorId).
fn create_empty_handle(&self) -> TensorId;
/// Create a new [RouterTensor] from the tensor data.
fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self>;
/// Get the current device used by all operations handled by this client.
fn device(&self) -> Self::Device;
/// Seed the runner.
fn seed(&self, seed: u64);
/// Returns the supported data type usage set
fn dtype_usage(&self, dtype: DType) -> burn_backend::DTypeUsageSet;
}
pub(crate) struct RunnerClientLocator {
clients: Mutex<Option<HashMap<Key, Box<dyn core::any::Any + Send>>>>,
}
/// Get the client for the given device
pub fn get_client<R: RunnerChannel>(device: &R::Device) -> Client<R> {
CLIENTS.client::<R>(device)
}
/// Initialize a new client for the given device.
///
/// If a (global) seed was previously set, the client seed is set.
fn new_client<R: RunnerChannel>(device: &R::Device) -> Client<R> {
R::init_client(device)
}
impl RunnerClientLocator {
/// Create a new client locator.
pub const fn new() -> Self {
Self {
clients: Mutex::new(None),
}
}
/// Get the runner client for the given device.
///
/// If a client isn't already initialized, it is created.
pub fn client<R: RunnerChannel + 'static>(&self, device: &R::Device) -> Client<R> {
let device_id = device.id();
let client_id = (core::any::TypeId::of::<R>(), device_id);
let mut clients = self.clients.lock();
if clients.is_none() {
let client = new_client::<R>(device);
Self::register_inner::<R>(client_id, client, &mut clients);
}
match clients.deref_mut() {
Some(clients) => match clients.get(&client_id) {
Some(client) => {
let client: &Client<R> = client.downcast_ref().unwrap();
client.clone()
}
None => {
let client = new_client::<R>(device);
let any = Box::new(client.clone());
clients.insert(client_id, any);
client
}
},
_ => unreachable!(),
}
}
fn register_inner<R: RunnerChannel + 'static>(
key: Key,
client: Client<R>,
clients: &mut Option<HashMap<Key, Box<dyn core::any::Any + Send>>>,
) {
if clients.is_none() {
*clients = Some(HashMap::new());
}
if let Some(clients) = clients {
if clients.contains_key(&key) {
panic!("Client already created for device {key:?}");
}
clients.insert(key, Box::new(client));
}
}
}

View File

@@ -0,0 +1,3 @@
mod base;
pub use base::*;

View File

@@ -0,0 +1,49 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![recursion_limit = "138"]
//! Burn multi-backend router.
mod backend;
mod bridge;
mod channel;
mod client;
mod ops;
mod runner;
mod tensor;
mod types;
pub use backend::*;
pub use bridge::*;
pub use channel::*;
pub use client::*;
pub use runner::*;
pub use tensor::*;
pub use types::*;
/// A local channel with a simple byte bridge between backends.
/// It transfers tensors between backends via the underlying [tensor data](burn_backend::TensorData).
pub type DirectByteChannel<Backends> = DirectChannel<Backends, ByteBridge<Backends>>;
/// Router backend.
///
/// # Example
///
/// ```ignore
/// type MyBackend = Router<(NdArray, Wgpu)>;
/// ```
pub type Router<Backends> = BackendRouter<DirectByteChannel<Backends>>;
extern crate alloc;
#[cfg(test)]
#[allow(unused)]
mod tests {
use crate::BackendRouter;
use crate::DirectByteChannel;
pub type TestBackend1 = burn_ndarray::NdArray<f32, i32>;
pub type TestBackend2 = burn_wgpu::Wgpu<f32, i32>;
pub type TestBackend = BackendRouter<DirectByteChannel<(TestBackend1, TestBackend2)>>;
}

View File

@@ -0,0 +1,4 @@
use crate::{BackendRouter, RunnerChannel};
use burn_backend::ops::ActivationOps;
impl<R: RunnerChannel> ActivationOps<Self> for BackendRouter<R> {}

View File

@@ -0,0 +1,69 @@
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! binary_float_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
let rhs = $handles.get_float_tensor::<B>(&$desc.rhs);
let output = $ops(lhs, rhs);
$handles.register_float_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! binary_float_cmp_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
let rhs = $handles.get_float_tensor::<B>(&$desc.rhs);
let output = $ops(lhs, rhs);
$handles.register_bool_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! binary_int_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
let rhs = $handles.get_int_tensor::<B>(&$desc.rhs);
let output = $ops(lhs, rhs);
$handles.register_int_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! binary_int_cmp_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
let rhs = $handles.get_int_tensor::<B>(&$desc.rhs);
let output = $ops(lhs, rhs);
$handles.register_bool_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! binary_bool_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_bool_tensor::<B>(&$desc.lhs);
let rhs = $handles.get_bool_tensor::<B>(&$desc.rhs);
let output = $ops(lhs, rhs);
$handles.register_bool_tensor::<B>(&$desc.out.id, output);
}};
}

View File

@@ -0,0 +1,333 @@
use alloc::vec::Vec;
use burn_backend::backend::ExecutionError;
use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
use burn_backend::ops::BoolTensorOps;
use burn_backend::tensor::{
BoolTensor, Device, FloatElem, FloatTensor, IndexingUpdateOp, IntElem, IntTensor,
};
use burn_backend::{Element, Scalar, Shape, Slice, TensorData};
use burn_ir::{
BaseOperationIr, BinaryOpIr, BoolOperationIr, CastOpIr, CatOpIr, CreationOpIr, FlipOpIr,
GatherOpIr, InitOperationIr, MaskFillOpIr, MaskWhereOpIr, OperationIr, OperationOutput,
PermuteOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr, ShapeOpIr, SliceAssignOpIr, SliceOpIr,
SwapDimsOpIr, UnaryOpIr, UnfoldOpIr,
};
impl<R: RunnerChannel> BoolTensorOps<Self> for BackendRouter<R> {
fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
let client = get_client::<R>(device);
let desc =
CreationOpIr::create(shape, R::BoolElem::dtype(), || client.create_empty_handle());
client
.register(OperationIr::BaseBool(BaseOperationIr::Empty(desc)))
.output()
}
fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
let client = get_client::<R>(device);
let desc =
CreationOpIr::create(shape, R::BoolElem::dtype(), || client.create_empty_handle());
client
.register(OperationIr::BaseBool(BaseOperationIr::Zeros(desc)))
.output()
}
fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
let client = get_client::<R>(device);
let desc =
CreationOpIr::create(shape, R::BoolElem::dtype(), || client.create_empty_handle());
client
.register(OperationIr::BaseBool(BaseOperationIr::Ones(desc)))
.output()
}
async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
tensor.into_data().await
}
fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
let client = get_client::<R>(device);
let out = client.register_tensor_data(data);
let desc = InitOperationIr {
out: out.to_ir_out(),
};
// Call register op when output is already initialized
client.register_op(OperationIr::Init(desc));
out
}
fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
let client = tensor.client.clone();
let desc = CastOpIr::create(tensor.into_ir(), IntElem::<Self>::dtype(), || {
client.create_empty_handle()
});
client
.register(OperationIr::Bool(BoolOperationIr::IntoInt(desc)))
.output()
}
fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
let client = tensor.client.clone();
let desc = CastOpIr::create(tensor.into_ir(), FloatElem::<Self>::dtype(), || {
client.create_empty_handle()
});
client
.register(OperationIr::Bool(BoolOperationIr::IntoFloat(desc)))
.output()
}
fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
tensor.client.device()
}
fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
if &tensor.client.device() == device {
return tensor;
}
R::change_client_backend(tensor, device)
}
fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
let client = tensor.client.clone();
let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
client
.register(OperationIr::BaseBool(BaseOperationIr::Reshape(desc)))
.output()
}
fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
let client = tensor.client.clone();
let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
client.create_empty_handle()
});
client
.register(OperationIr::BaseBool(BaseOperationIr::Slice(desc)))
.output()
}
fn bool_slice_assign(
tensor: BoolTensor<Self>,
slices: &[burn_backend::Slice],
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
let client = tensor.client.clone();
let desc =
SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
client.create_empty_handle()
});
client
.register(OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc)))
.output()
}
fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
let client = lhs.client.clone();
let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
client.create_empty_handle()
});
client
.register(OperationIr::BaseBool(BaseOperationIr::Equal(desc)))
.output()
}
fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
let client = tensor.client.clone();
let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
client
.register(OperationIr::Bool(BoolOperationIr::Not(desc)))
.output()
}
fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
let client = lhs.client.clone();
let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
client.create_empty_handle()
});
client
.register(OperationIr::Bool(BoolOperationIr::And(desc)))
.output()
}
fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
let client = lhs.client.clone();
let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
client.create_empty_handle()
});
client
.register(OperationIr::Bool(BoolOperationIr::Or(desc)))
.output()
}
fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
let client = tensor.client.clone();
let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
client.create_empty_handle()
});
client
.register(OperationIr::BaseBool(BaseOperationIr::SwapDims(desc)))
.output()
}
fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
let client = tensor.client.clone();
let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
client.create_empty_handle()
});
client
.register(OperationIr::BaseBool(BaseOperationIr::Permute(desc)))
.output()
}
fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
let client = tensor.client.clone();
let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
client.create_empty_handle()
});
client
.register(OperationIr::BaseBool(BaseOperationIr::Flip(desc)))
.output()
}
fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
let client = tensor.client.clone();
let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
client
.register(OperationIr::BaseBool(BaseOperationIr::Expand(desc)))
.output()
}
fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
let client = tensors.first().unwrap().client.clone();
let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
client
.register(OperationIr::BaseBool(BaseOperationIr::Cat(desc)))
.output()
}
fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
let client = tensor.client.clone();
let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
client.create_empty_handle()
});
client
.register(OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc)))
.output()
}
fn bool_unfold(
tensor: BoolTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> BoolTensor<Self> {
let client = tensor.client.clone();
let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
client.create_empty_handle()
});
client
.register(OperationIr::BaseBool(BaseOperationIr::Unfold(desc)))
.output()
}
fn bool_mask_where(
tensor: BoolTensor<Self>,
mask: BoolTensor<Self>,
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
let client = tensor.client.clone();
let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
client.create_empty_handle()
});
client
.register(OperationIr::BaseBool(BaseOperationIr::MaskWhere(desc)))
.output()
}
fn bool_mask_fill(
tensor: BoolTensor<Self>,
mask: BoolTensor<Self>,
value: Scalar,
) -> BoolTensor<Self> {
let client = tensor.client.clone();
let value = value.into();
let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
client.create_empty_handle()
});
client
.register(OperationIr::BaseBool(BaseOperationIr::MaskFill(desc)))
.output()
}
fn bool_gather(
dim: usize,
tensor: BoolTensor<Self>,
indices: IntTensor<Self>,
) -> BoolTensor<Self> {
let client = tensor.client.clone();
let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
client.create_empty_handle()
});
client
.register(OperationIr::BaseBool(BaseOperationIr::Gather(desc)))
.output()
}
fn bool_scatter_or(
dim: usize,
tensor: BoolTensor<Self>,
indices: IntTensor<Self>,
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
let client = tensor.client.clone();
let desc = ScatterOpIr::create(
tensor.into_ir(),
dim,
indices.into_ir(),
value.into_ir(),
IndexingUpdateOp::Add,
|| client.create_empty_handle(),
);
client
.register(OperationIr::BaseBool(BaseOperationIr::Scatter(desc)))
.output()
}
fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
let client = lhs.client.clone();
let rhs = rhs.into();
let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
client.create_empty_handle()
});
client
.register(OperationIr::BaseBool(BaseOperationIr::EqualElem(desc)))
.output()
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,9 @@
mod activation;
mod binary;
mod bool_tensor;
mod int_tensor;
mod module;
mod qtensor;
mod tensor;
mod transaction;
mod unary;

View File

@@ -0,0 +1,796 @@
use alloc::boxed::Box;
use burn_backend::Element;
use burn_backend::ops::{
AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward,
DeformConvOptions, InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices,
MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
};
use burn_backend::tensor::{BoolTensor, FloatTensor, IntElem, IntTensor};
use burn_ir::*;
use crate::{BackendRouter, RunnerChannel, RunnerClient};
impl<R: RunnerChannel> ModuleOps<Self> for BackendRouter<R> {
fn conv1d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvOptions<1>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = Conv1dOpIr::create(
x.into_ir(),
weight.into_ir(),
bias.map(|bias| bias.into_ir()),
options.into(),
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::Conv1d(desc)))
.output()
}
fn conv1d_x_backward(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
options: ConvOptions<1>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = Conv1dXBackwardOpIr::create(
x.into_ir(),
weight.into_ir(),
output_grad.into_ir(),
options.into(),
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::Conv1dXBackward(
desc,
)))
.output()
}
fn conv1d_weight_backward(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
options: ConvOptions<1>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = Conv1dWeightBackwardOpIr::create(
x.into_ir(),
weight.into_ir(),
output_grad.into_ir(),
options.into(),
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(
ModuleOperationIr::Conv1dWeightBackward(desc),
))
.output()
}
fn conv1d_bias_backward(
x: FloatTensor<Self>,
bias: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = Conv1dBiasBackwardOpIr::create(
x.into_ir(),
bias.into_ir(),
output_grad.into_ir(),
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::Conv1dBiasBackward(
desc,
)))
.output()
}
fn conv2d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvOptions<2>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = Conv2dOpIr::create(
x.into_ir(),
weight.into_ir(),
bias.map(|bias| bias.into_ir()),
options.into(),
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::Conv2d(desc)))
.output()
}
fn conv2d_x_backward(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
options: ConvOptions<2>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = Conv2dXBackwardOpIr::create(
x.into_ir(),
weight.into_ir(),
output_grad.into_ir(),
options.into(),
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::Conv2dXBackward(
desc,
)))
.output()
}
fn conv2d_weight_backward(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
options: ConvOptions<2>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = Conv2dWeightBackwardOpIr::create(
x.into_ir(),
weight.into_ir(),
output_grad.into_ir(),
options.into(),
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(
ModuleOperationIr::Conv2dWeightBackward(desc),
))
.output()
}
fn conv2d_bias_backward(
x: FloatTensor<Self>,
bias: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = Conv2dBiasBackwardOpIr::create(
x.into_ir(),
bias.into_ir(),
output_grad.into_ir(),
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::Conv2dBiasBackward(
desc,
)))
.output()
}
fn conv3d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvOptions<3>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = Conv3dOpIr::create(
x.into_ir(),
weight.into_ir(),
bias.map(|bias| bias.into_ir()),
options.into(),
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::Conv3d(desc)))
.output()
}
fn conv3d_x_backward(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
options: ConvOptions<3>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = Conv3dXBackwardOpIr::create(
x.into_ir(),
weight.into_ir(),
output_grad.into_ir(),
options.into(),
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::Conv3dXBackward(
desc,
)))
.output()
}
fn conv3d_weight_backward(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
options: ConvOptions<3>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = Conv3dWeightBackwardOpIr::create(
x.into_ir(),
weight.into_ir(),
output_grad.into_ir(),
options.into(),
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(
ModuleOperationIr::Conv3dWeightBackward(desc),
))
.output()
}
fn conv3d_bias_backward(
x: FloatTensor<Self>,
bias: FloatTensor<Self>,
output_grad: FloatTensor<Self>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = Conv3dBiasBackwardOpIr::create(
x.into_ir(),
bias.into_ir(),
output_grad.into_ir(),
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::Conv3dBiasBackward(
desc,
)))
.output()
}
fn conv_transpose1d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = ConvTranspose1dOpIr::create(
x.into_ir(),
weight.into_ir(),
bias.map(|bias| bias.into_ir()),
options.into(),
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::ConvTranspose1d(
desc,
)))
.output()
}
fn conv_transpose2d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvTransposeOptions<2>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = ConvTranspose2dOpIr::create(
x.into_ir(),
weight.into_ir(),
bias.map(|bias| bias.into_ir()),
options.into(),
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::ConvTranspose2d(
desc,
)))
.output()
}
fn conv_transpose3d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvTransposeOptions<3>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = ConvTranspose3dOpIr::create(
x.into_ir(),
weight.into_ir(),
bias.map(|bias| bias.into_ir()),
options.into(),
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::ConvTranspose3d(
desc,
)))
.output()
}
fn avg_pool1d(
x: FloatTensor<Self>,
kernel_size: usize,
stride: usize,
padding: usize,
count_include_pad: bool,
ceil_mode: bool,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = AvgPool1dOpIr::create(
x.into_ir(),
kernel_size,
stride,
padding,
count_include_pad,
ceil_mode,
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::AvgPool1d(desc)))
.output()
}
fn avg_pool2d(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
ceil_mode: bool,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = AvgPool2dOpIr::create(
x.into_ir(),
kernel_size,
stride,
padding,
count_include_pad,
ceil_mode,
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::AvgPool2d(desc)))
.output()
}
fn avg_pool1d_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
kernel_size: usize,
stride: usize,
padding: usize,
count_include_pad: bool,
ceil_mode: bool,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = AvgPool1dBackwardOpIr::create(
x.into_ir(),
grad.into_ir(),
kernel_size,
stride,
padding,
count_include_pad,
ceil_mode,
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::AvgPool1dBackward(
desc,
)))
.output()
}
fn avg_pool2d_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
ceil_mode: bool,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = AvgPool2dBackwardOpIr::create(
x.into_ir(),
grad.into_ir(),
kernel_size,
stride,
padding,
count_include_pad,
ceil_mode,
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::AvgPool2dBackward(
desc,
)))
.output()
}
fn max_pool1d(
x: FloatTensor<Self>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
ceil_mode: bool,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = MaxPool1dOpIr::create(
x.into_ir(),
kernel_size,
stride,
padding,
dilation,
ceil_mode,
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::MaxPool1d(desc)))
.output()
}
fn max_pool2d(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = MaxPool2dOpIr::create(
x.into_ir(),
kernel_size,
stride,
padding,
dilation,
ceil_mode,
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::MaxPool2d(desc)))
.output()
}
fn max_pool1d_with_indices(
x: FloatTensor<Self>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
ceil_mode: bool,
) -> MaxPool1dWithIndices<Self> {
let client = x.client.clone();
let desc = MaxPool1dWithIndicesOpIr::create(
x.into_ir(),
kernel_size,
stride,
padding,
dilation,
ceil_mode,
IntElem::<Self>::dtype(),
|| client.create_empty_handle(),
);
let [out, out_indices] = client
.register(OperationIr::Module(
ModuleOperationIr::MaxPool1dWithIndices(desc),
))
.outputs();
MaxPool1dWithIndices::new(out, out_indices)
}
fn max_pool2d_with_indices(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
) -> MaxPool2dWithIndices<Self> {
let client = x.client.clone();
let desc = MaxPool2dWithIndicesOpIr::create(
x.into_ir(),
kernel_size,
stride,
padding,
dilation,
ceil_mode,
IntElem::<Self>::dtype(),
|| client.create_empty_handle(),
);
let [out, out_indices] = client
.register(OperationIr::Module(
ModuleOperationIr::MaxPool2dWithIndices(desc),
))
.outputs();
MaxPool2dWithIndices::new(out, out_indices)
}
fn max_pool1d_with_indices_backward(
x: FloatTensor<Self>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
ceil_mode: bool,
output_grad: FloatTensor<Self>,
indices: IntTensor<Self>,
) -> MaxPool1dBackward<Self> {
let client = x.client.clone();
let desc = MaxPool1dWithIndicesBackwardOpIr::create(
x.into_ir(),
output_grad.into_ir(),
indices.into_ir(),
kernel_size,
stride,
padding,
dilation,
ceil_mode,
|| client.create_empty_handle(),
);
let out = client
.register(OperationIr::Module(
ModuleOperationIr::MaxPool1dWithIndicesBackward(desc),
))
.output();
MaxPool1dBackward::new(out)
}
fn max_pool2d_with_indices_backward(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
output_grad: FloatTensor<Self>,
indices: IntTensor<Self>,
) -> MaxPool2dBackward<Self> {
let client = x.client.clone();
let desc = MaxPool2dWithIndicesBackwardOpIr::create(
x.into_ir(),
output_grad.into_ir(),
indices.into_ir(),
kernel_size,
stride,
padding,
dilation,
ceil_mode,
|| client.create_empty_handle(),
);
let out = client
.register(OperationIr::Module(
ModuleOperationIr::MaxPool2dWithIndicesBackward(desc),
))
.output();
MaxPool2dBackward::new(out)
}
fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = AdaptiveAvgPool1dOpIr::create(x.into_ir(), output_size, || {
client.create_empty_handle()
});
client
.register(OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1d(
desc,
)))
.output()
}
fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = AdaptiveAvgPool2dOpIr::create(x.into_ir(), output_size, || {
client.create_empty_handle()
});
client
.register(OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2d(
desc,
)))
.output()
}
fn adaptive_avg_pool1d_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = AdaptiveAvgPool1dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
client.create_empty_handle()
});
client
.register(OperationIr::Module(
ModuleOperationIr::AdaptiveAvgPool1dBackward(desc),
))
.output()
}
fn adaptive_avg_pool2d_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = AdaptiveAvgPool2dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
client.create_empty_handle()
});
client
.register(OperationIr::Module(
ModuleOperationIr::AdaptiveAvgPool2dBackward(desc),
))
.output()
}
fn interpolate(
x: FloatTensor<Self>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = InterpolateOpIr::create(x.into_ir(), output_size, options.into(), || {
client.create_empty_handle()
});
client
.register(OperationIr::Module(ModuleOperationIr::Interpolate(desc)))
.output()
}
fn interpolate_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = InterpolateBackwardOpIr::create(
x.into_ir(),
grad.into_ir(),
output_size,
options.into(),
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::InterpolateBackward(
desc,
)))
.output()
}
fn deform_conv2d(
x: FloatTensor<Self>,
offset: FloatTensor<Self>,
weight: FloatTensor<Self>,
mask: Option<FloatTensor<Self>>,
bias: Option<FloatTensor<Self>>,
options: DeformConvOptions<2>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let desc = DeformConv2dOpIr::create(
x.into_ir(),
offset.into_ir(),
weight.into_ir(),
mask.map(|mask| mask.into_ir()),
bias.map(|bias| bias.into_ir()),
options.into(),
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::DeformableConv2d(
Box::new(desc),
)))
.output()
}
fn deform_conv2d_backward(
x: FloatTensor<Self>,
offset: FloatTensor<Self>,
weight: FloatTensor<Self>,
mask: Option<FloatTensor<Self>>,
bias: Option<FloatTensor<Self>>,
output_grad: FloatTensor<Self>,
options: DeformConvOptions<2>,
) -> DeformConv2dBackward<Self> {
let client = x.client.clone();
let has_bias = bias.is_some();
let has_mask = mask.is_some();
let desc = DeformConv2dBackwardOpIr::create(
x.into_ir(),
offset.into_ir(),
weight.into_ir(),
mask.map(|mask| mask.into_ir()),
bias.map(|bias| bias.into_ir()),
output_grad.into_ir(),
options.into(),
|| client.create_empty_handle(),
);
let mut outputs = client
.register(OperationIr::Module(
ModuleOperationIr::DeformableConv2dBackward(Box::new(desc)),
))
.into_iter();
// When the number of outputs is variable, the order is important
let input_grad = outputs.next().unwrap();
let offset_grad = outputs.next().unwrap();
let weight_grad = outputs.next().unwrap();
let mask_grad = has_mask.then(|| outputs.next().unwrap());
let bias_grad = has_bias.then(|| outputs.next().unwrap());
DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad)
}
fn attention(
query: FloatTensor<Self>,
key: FloatTensor<Self>,
value: FloatTensor<Self>,
mask: Option<BoolTensor<Self>>,
attn_bias: Option<FloatTensor<Self>>,
options: AttentionModuleOptions,
) -> FloatTensor<Self> {
let client = query.client.clone();
let desc = AttentionOpIr::create(
query.into_ir(),
key.into_ir(),
value.into_ir(),
mask.map(|m: BoolTensor<Self>| m.into_ir()),
attn_bias.map(|ab| ab.into_ir()),
options.into(),
|| client.create_empty_handle(),
);
client
.register(OperationIr::Module(ModuleOperationIr::Attention(desc)))
.output()
}
}

View File

@@ -0,0 +1,92 @@
use burn_backend::{
ExecutionError, Shape, Slice, TensorData,
ops::QTensorOps,
quantization::{QuantScheme, QuantizationParametersPrimitive},
tensor::{Device, FloatTensor, IntTensor, QuantizedTensor},
};
use crate::{BackendRouter, RunnerChannel};
impl<R: RunnerChannel> QTensorOps<Self> for BackendRouter<R> {
fn q_from_data(_data: TensorData, _device: &Device<Self>) -> QuantizedTensor<Self> {
unimplemented!()
}
fn quantize(
_tensor: FloatTensor<Self>,
_scheme: &QuantScheme,
_qparams: QuantizationParametersPrimitive<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn quantize_dynamic(
_tensor: FloatTensor<Self>,
_scheme: &QuantScheme,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
unimplemented!()
}
fn q_device(_tensor: &QuantizedTensor<Self>) -> Device<Self> {
unimplemented!()
}
fn q_to_device(
_tensor: QuantizedTensor<Self>,
_device: &Device<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_reshape(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
unimplemented!()
}
async fn q_into_data(_tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
unimplemented!()
}
fn q_swap_dims(
_tensor: QuantizedTensor<Self>,
_dim1: usize,
_dim2: usize,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_gather(
_dim: usize,
_tensor: QuantizedTensor<Self>,
_indices: IntTensor<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_select(
_tensor: QuantizedTensor<Self>,
_dim: usize,
_indices: IntTensor<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_slice(_tensor: QuantizedTensor<Self>, _slices: &[Slice]) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
unimplemented!()
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,5 @@
use burn_backend::ops::TransactionOps;
use crate::{BackendRouter, RunnerChannel};
impl<R: RunnerChannel> TransactionOps<Self> for BackendRouter<R> {}

View File

@@ -0,0 +1,155 @@
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_float_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
let output = $ops(lhs, $desc.rhs.into());
$handles.register_float_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_float_dim_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
let output = $ops(lhs, $desc.rhs);
$handles.register_float_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! reduce_float_dim_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let input = $handles.get_float_tensor::<B>(&$desc.input);
let output = $ops(input, $desc.axis);
$handles.register_float_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! reduce_float2int_dim_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let input = $handles.get_float_tensor::<B>(&$desc.input);
let output = $ops(input, $desc.axis);
$handles.register_int_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! reduce_int_dim_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let input = $handles.get_int_tensor::<B>(&$desc.input);
let output = $ops(input, $desc.axis);
$handles.register_int_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_float2int_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
let output = $ops(lhs, $desc.rhs);
$handles.register_int_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_float_cmp_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
let output = $ops(lhs, $desc.rhs.into());
$handles.register_bool_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! unary_float_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_float_tensor::<B>(&$desc.input);
let output = $ops(lhs);
$handles.register_float_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_int_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
let output = $ops(lhs, $desc.rhs.into());
$handles.register_int_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_int_dim_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
let output = $ops(lhs, $desc.rhs);
$handles.register_int_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_int_cmp_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
let output = $ops(lhs, $desc.rhs.into());
$handles.register_bool_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! unary_int_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_int_tensor::<B>(&$desc.input);
let output = $ops(lhs);
$handles.register_int_tensor::<B>(&$desc.out.id, output);
}};
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,142 @@
use core::sync::atomic::{AtomicU32, Ordering};
use alloc::format;
use alloc::{sync::Arc, vec::Vec};
use super::RunnerClient;
use burn_backend::{DType, Shape, TensorData, TensorMetadata, backend::ExecutionError};
use burn_ir::{TensorId, TensorIr, TensorStatus};
/// Tensor primitive for the [router backend](crate::BackendRouter).
pub struct RouterTensor<C: RunnerClient> {
pub(crate) id: TensorId,
pub(crate) shape: Shape,
pub(crate) dtype: DType,
/// The client that has this tensor
pub client: C,
pub(crate) count: Arc<AtomicU32>,
}
impl<C: RunnerClient> TensorMetadata for RouterTensor<C> {
fn dtype(&self) -> DType {
self.dtype
}
fn shape(&self) -> Shape {
self.shape.clone()
}
fn rank(&self) -> usize {
self.shape.num_dims()
}
}
impl<C: RunnerClient> RouterTensor<C> {
/// Create a new router tensor.
pub fn new(id: TensorId, shape: Shape, dtype: DType, client: C) -> Self {
Self {
id,
shape,
dtype,
client,
count: Arc::new(AtomicU32::new(1)),
}
}
pub(crate) async fn into_data(self) -> Result<TensorData, ExecutionError> {
self.client.clone().read_tensor_async(self.into_ir()).await
}
/// Get the ir for this tensor
pub fn into_ir(mut self) -> TensorIr {
let count = self.count.load(Ordering::Relaxed);
let status = self.status(count);
let mut shape_out = Shape::from(Vec::<usize>::new());
core::mem::swap(&mut self.shape, &mut shape_out);
if let TensorStatus::ReadWrite = status {
// Avoids an unwanted drop on the same thread.
//
// Since `drop` is called after `into_ir`, we must not register a drop if the tensor
// was consumed with a `ReadWrite` status.
self.count.fetch_add(1, Ordering::Relaxed);
}
TensorIr {
status,
shape: shape_out,
id: self.id,
dtype: self.dtype,
}
}
pub(crate) fn to_ir_out(&self) -> TensorIr {
TensorIr {
status: TensorStatus::NotInit,
shape: self.shape.clone(),
id: self.id,
dtype: self.dtype,
}
}
pub(crate) fn status(&self, count: u32) -> TensorStatus {
if count <= 1 {
TensorStatus::ReadWrite
} else {
TensorStatus::ReadOnly
}
}
}
impl<C: RunnerClient> core::fmt::Debug for RouterTensor<C> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(
format!(
"{{ id: {:?}, shape: {:?}, dtype: {:?}, device: {:?} }}",
self.id,
self.shape,
self.dtype,
self.client.device().clone(),
)
.as_str(),
)
}
}
impl<C: RunnerClient> Clone for RouterTensor<C> {
fn clone(&self) -> Self {
self.count.fetch_add(1, Ordering::Relaxed);
Self {
id: self.id,
shape: self.shape.clone(),
client: self.client.clone(),
dtype: self.dtype,
count: self.count.clone(),
}
}
}
impl<C: RunnerClient> Drop for RouterTensor<C> {
fn drop(&mut self) {
let count = self.count.fetch_sub(1, Ordering::Relaxed);
match self.status(count) {
TensorStatus::ReadWrite => {
let id = self.id;
let mut shape = Shape::from(Vec::<usize>::new());
core::mem::swap(&mut shape, &mut self.shape);
let ir = TensorIr {
id,
shape,
status: TensorStatus::ReadWrite,
dtype: self.dtype,
};
self.client.register_op(burn_ir::OperationIr::Drop(ir));
}
TensorStatus::ReadOnly => {}
TensorStatus::NotInit => {}
}
}
}

View File

@@ -0,0 +1,386 @@
use alloc::format;
use alloc::string::String;
use burn_backend::{
DType, Shape, TensorData,
backend::{Backend, DeviceId, DeviceOps, ExecutionError},
try_read_sync,
};
use burn_ir::{BackendIr, OperationIr, TensorHandle, TensorId, TensorIr};
use burn_std::future::DynFut;
use crate::{
ByteBridge, DirectChannel, MultiBackendBridge, RouterTensor, Runner, RunnerChannel,
RunnerClient,
};
/// Implement multi backend types, with enums having one variant per backend.
macro_rules! impl_multi_backend_types {
// Match the default backend and at least one other backend, with rest being optional
($module_name:ident, $DefaultBackend:ident, $($OtherBackend:ident),+) => {
/// Module containing the essential types for multi-backend operations.
///
/// - `Handle`: the type used to point to a tensor (defined for all backends).
/// - `MultiRunnerClient`: a client for multiple runners (each responsible to execute tensor operations on a given backend).
/// - `DirectChannel`: a local channel with direct connection to the backend runner clients.
/// - `ByteBridge`: a simple multi-backend bridge that transfers tensors via the underlying [tensor data](burn_backend::TensorData).
///
/// Each enum type is defined with backend identifiers as variant names (e.g., `B1` and `B2` for dual backends).
pub mod $module_name {
use super::*;
/// The type that can be used to point to a tensor of any kind.
/// Each backend has its own variant.
pub enum Handle<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> {
#[allow(missing_docs)]
$DefaultBackend($DefaultBackend::Handle),
$(
#[allow(missing_docs)]
$OtherBackend($OtherBackend::Handle),
)+
}
/// The device type used by a backend.
/// Each backend has its own variant.
#[derive(Clone, Debug)]
pub enum MultiDevice<$DefaultBackend: Backend, $($OtherBackend: Backend),+> {
#[allow(missing_docs)]
$DefaultBackend($DefaultBackend::Device),
$(
#[allow(missing_docs)]
$OtherBackend($OtherBackend::Device),
)+
}
impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> PartialEq for MultiDevice<$DefaultBackend, $($OtherBackend),+> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::$DefaultBackend(lhs), Self::$DefaultBackend(rhs)) => lhs == rhs,
$(
(Self::$OtherBackend(lhs), Self::$OtherBackend(rhs)) => lhs == rhs,
)+
_ => false,
}
}
}
// Default implementation always returns the first backend's device
impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> Default for MultiDevice<$DefaultBackend, $($OtherBackend),+> {
fn default() -> Self {
Self::$DefaultBackend($DefaultBackend::Device::default())
}
}
impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> burn_std::device::Device for MultiDevice<$DefaultBackend, $($OtherBackend),+> {
fn from_id(_device_id: DeviceId) -> Self {
// TODO: Should be fix with the new router backend.
Default::default()
}
fn to_id(&self) -> DeviceId {
match self {
Self::$DefaultBackend(device) => device.id(),
$(
Self::$OtherBackend(device) => device.id(),
)+
}
}
fn device_count(_type_id: u16) -> usize {
1
}
}
impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> DeviceOps for MultiDevice<$DefaultBackend, $($OtherBackend),+> {}
/// A local client with multiple runners (each responsible to execute tensor operations on a given backend).
#[derive(Clone)]
pub enum MultiRunnerClient<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> {
#[allow(missing_docs)]
$DefaultBackend(Runner<$DefaultBackend>),
$(
#[allow(missing_docs)]
$OtherBackend(Runner<$OtherBackend>),
)+
}
impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> RunnerClient for MultiRunnerClient<$DefaultBackend, $($OtherBackend),+>
{
type Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>;
fn register_op(&self, op: OperationIr) {
match self {
Self::$DefaultBackend(runner) => runner.register_op(op),
$(
Self::$OtherBackend(runner) => runner.register_op(op),
)+
}
}
fn read_tensor_async(&self, tensor: TensorIr) -> DynFut<Result<TensorData, ExecutionError>> {
match self {
Self::$DefaultBackend(runner) => runner.read_tensor_async(tensor),
$(
Self::$OtherBackend(runner) => runner.read_tensor_async(tensor),
)+
}
}
fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
match self {
Self::$DefaultBackend(runner) => {
let desc = runner.register_tensor_data_desc(data);
RouterTensor::new(desc.id, desc.shape, desc.dtype, self.clone())
}
$(
Self::$OtherBackend(runner) => {
let desc = runner.register_tensor_data_desc(data);
RouterTensor::new(desc.id, desc.shape, desc.dtype, self.clone())
}
)+
}
}
fn device(&self) -> Self::Device {
match self {
Self::$DefaultBackend(runner) => MultiDevice::$DefaultBackend(runner.device()),
$(
Self::$OtherBackend(runner) => MultiDevice::$OtherBackend(runner.device()),
)+
}
}
fn sync(&self) -> Result<(), ExecutionError> {
match self {
Self::$DefaultBackend(runner) => runner.sync(),
$(
Self::$OtherBackend(runner) => runner.sync(),
)+
}
}
fn seed(&self, seed: u64) {
match self {
Self::$DefaultBackend(runner) => runner.seed(seed),
$(
Self::$OtherBackend(runner) => runner.seed(seed),
)+
}
}
fn create_empty_handle(&self) -> TensorId {
match self {
Self::$DefaultBackend(runner) => runner.create_empty_handle(),
$(
Self::$OtherBackend(runner) => runner.create_empty_handle(),
)+
}
}
fn dtype_usage(&self, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {
match self {
Self::$DefaultBackend(runner) => runner.dtype_usage(dtype),
$(
Self::$OtherBackend(runner) => runner.dtype_usage(dtype),
)+
}
}
}
impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+, Br> RunnerChannel for DirectChannel<($DefaultBackend, $($OtherBackend),+), Br>
where
Br: MultiBackendBridge<TensorHandle = Handle<$DefaultBackend, $($OtherBackend),+>, Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>>,
{
type Device = Br::Device;
type Bridge = Br;
type FloatElem = $DefaultBackend::FloatElem;
type IntElem = $DefaultBackend::IntElem;
type BoolElem = $DefaultBackend::BoolElem;
type Client = MultiRunnerClient<$DefaultBackend, $($OtherBackend),+>;
fn init_client(device: &Self::Device) -> Self::Client {
match device {
MultiDevice::$DefaultBackend(device) => MultiRunnerClient::$DefaultBackend(Runner::new(device.clone())),
$(
MultiDevice::$OtherBackend(device) => MultiRunnerClient::$OtherBackend(Runner::new(device.clone())),
)+
}
}
fn get_tensor_handle(
tensor: &TensorIr,
client: &Self::Client,
) -> <Self::Bridge as MultiBackendBridge>::TensorHandle {
match client {
MultiRunnerClient::$DefaultBackend(runner) => Handle::$DefaultBackend(runner.get_tensor_handle(tensor)),
$(
MultiRunnerClient::$OtherBackend(runner) => Handle::$OtherBackend(runner.get_tensor_handle(tensor)),
)+
}
}
fn register_tensor(
client: &Self::Client,
handle: <Self::Bridge as MultiBackendBridge>::TensorHandle,
shape: Shape,
dtype: DType,
) -> RouterTensor<Self::Client> {
match client {
MultiRunnerClient::$DefaultBackend(runner) => match handle {
Handle::$DefaultBackend(handle) => runner.register_tensor(handle, shape, dtype, client.clone()),
_ => unreachable!("Can't register tensor handle for another backend."),
},
$(
MultiRunnerClient::$OtherBackend(runner) => match handle {
Handle::$OtherBackend(handle) => runner.register_tensor(handle, shape, dtype, client.clone()),
_ => unreachable!("Can't register tensor handle for another backend."),
},
)+
}
}
fn name(_device: &Self::Device) -> String {
let mut name = format!("{}", $DefaultBackend::name(&<$DefaultBackend::Device as Default>::default()));
$(
name.push_str(&format!(", {}", $OtherBackend::name(&<$OtherBackend::Device as Default>::default())));
)+
format!("direct<({})>", name)
}
}
impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> MultiBackendBridge for ByteBridge<($DefaultBackend, $($OtherBackend),+)> {
type TensorHandle = Handle<$DefaultBackend, $($OtherBackend),+>;
type Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>;
fn change_backend_float(
tensor: Self::TensorHandle,
shape: Shape,
target_device: &Self::Device,
) -> Self::TensorHandle {
multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+)
}
fn change_backend_int(
tensor: Self::TensorHandle,
shape: Shape,
target_device: &Self::Device,
) -> Self::TensorHandle {
multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+)
}
fn change_backend_bool(
tensor: Self::TensorHandle,
shape: Shape,
target_device: &Self::Device,
) -> Self::TensorHandle {
multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+)
}
}
}
};
}
macro_rules! bridge {
($Backend:ident, $handle:expr, $device:expr, $shape:expr) => {{
// Bridge for the same backend
let tensor = $Backend::float_tensor(TensorHandle {
handle: $handle,
shape: $shape,
});
let tensor = $Backend::float_to_device(tensor, $device);
let handle = $Backend::float_tensor_handle(tensor);
Handle::$Backend(handle)
}};
($BackendA:ident, $BackendB:ident, $handle:expr, $device:expr, $shape:expr) => {{
// Byte bridge between two backends
let tensor = $BackendA::float_tensor(TensorHandle { handle: $handle, shape: $shape });
let data = try_read_sync($BackendA::float_into_data(tensor)).unwrap().expect(
"Failed to read tensor data synchronously. This can happen on platforms that don't support blocking futures like WASM."
);
let tensor = $BackendB::float_from_data(data, $device);
let handle = $BackendB::float_tensor_handle(tensor);
Handle::$BackendB(handle)
}};
}
macro_rules! multi_backend_match {
($shape:expr, ($handle:expr, $device:expr) : $DefaultBackend:ident, $($OtherBackend:ident),+) => {
multi_backend_match! (
@step
$shape,
($handle, $device);
{
(Handle::$DefaultBackend(handle), MultiDevice::$DefaultBackend(device)) => bridge!($DefaultBackend, handle, device, $shape),
$(
(Handle::$DefaultBackend(handle), MultiDevice::$OtherBackend(device)) => bridge!($DefaultBackend, $OtherBackend, handle, device, $shape),
(Handle::$OtherBackend(handle), MultiDevice::$DefaultBackend(device)) => bridge!($OtherBackend, $DefaultBackend, handle, device, $shape),
(Handle::$OtherBackend(handle), MultiDevice::$OtherBackend(device)) => bridge!($OtherBackend, handle, device, $shape),
)+
};
$($OtherBackend),+
)
};
(@step
$shape:expr,
$pats:tt;
{ $($arms:tt)* };
$BackendA:ident,
$($OtherBackend:ident),+
) => {
multi_backend_match! (
@step
$shape,
$pats;
{
$($arms)*
$(
(Handle::$BackendA(handle), MultiDevice::$OtherBackend(device)) => bridge!($BackendA, $OtherBackend, handle, device, $shape),
(Handle::$OtherBackend(handle), MultiDevice::$BackendA(device)) => bridge!($OtherBackend, $BackendA, handle, device, $shape),
)*
};
$($OtherBackend),*
)
};
(@step
$shape:expr,
($handle:expr, $device:expr);
{ $($arms:tt)* };
$($BackendA:ident)?
) => {
match ($handle, $device) {
$($arms)*
}
};
}
// Implement multi-backend types and byte bridge for up to 4 backends
impl_multi_backend_types!(duo, B1, B2);
impl_multi_backend_types!(trio, B1, B2, B3);
impl_multi_backend_types!(quad, B1, B2, B3, B4);
#[cfg(not(target_os = "windows"))] // cannot find a wgpu adapter on windows CI
#[cfg(test)]
mod tests {
use burn_tensor::{Tensor, backend::Backend};
use super::*;
use crate::tests::{TestBackend, TestBackend1, TestBackend2};
#[test]
fn should_support_dual_byte_bridge() {
let device1 = duo::MultiDevice::B1(<TestBackend1 as Backend>::Device::default());
let device2 = duo::MultiDevice::B2(<TestBackend2 as Backend>::Device::default());
let tensor1 = Tensor::<TestBackend, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device1);
let tensor2 = Tensor::<TestBackend, 1>::from_floats([5.0, 6.0, 7.0, 8.0], &device2);
let tensor1_2 = tensor1.clone().to_device(&device2);
tensor1.into_data().assert_eq(&tensor1_2.into_data(), true);
let tensor2_1 = tensor2.clone().to_device(&device1);
tensor2.into_data().assert_eq(&tensor2_1.into_data(), true);
}
}