feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
@@ -0,0 +1,48 @@
|
||||
[package]
|
||||
authors = [
|
||||
"laggui <lagrange.guillaume.1@gmail.com>",
|
||||
"nathanielsimard <nathaniel.simard.42@gmail.com>",
|
||||
]
|
||||
categories = ["science"]
|
||||
description = "Multi-backend router decorator for the Burn framework"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "data"]
|
||||
license.workspace = true
|
||||
name = "burn-router"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-router"
|
||||
documentation = "https://docs.rs/burn-router"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = ["std"]
|
||||
std = ["burn-backend/std", "burn-std/std", "burn-ir/std"]
|
||||
doc = ["default"]
|
||||
tracing = [
|
||||
"burn-backend/tracing",
|
||||
"burn-ir/tracing",
|
||||
"burn-std/tracing",
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
burn-ir = { path = "../burn-ir", version = "=0.21.0-pre.2", default-features = false }
|
||||
burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false }
|
||||
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false }
|
||||
hashbrown = { workspace = true }
|
||||
spin = { workspace = true }
|
||||
log = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", default-features = false }
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" }
|
||||
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", default-features = false, features = [
|
||||
"std",
|
||||
] }
|
||||
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
@@ -0,0 +1,3 @@
|
||||
# Burn Router
|
||||
|
||||
A multi-backend extension that forwards the tensor operations to the appropriate backend.
|
||||
@@ -0,0 +1,75 @@
|
||||
use super::{RouterTensor, RunnerChannel, RunnerClient, get_client};
|
||||
use alloc::{format, string::String};
|
||||
use burn_backend::{Backend, DType, ExecutionError, QTensorPrimitive, quantization::QuantScheme};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
/// A backend that forwards the tensor operations to the appropriate backend (given multiple backends).
|
||||
pub struct BackendRouter<R: RunnerChannel> {
|
||||
r: PhantomData<R>,
|
||||
}
|
||||
|
||||
impl<R: RunnerChannel> core::fmt::Debug for BackendRouter<R> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_fmt(format_args!("router"))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RunnerChannel> Clone for BackendRouter<R> {
|
||||
fn clone(&self) -> Self {
|
||||
Self { r: PhantomData }
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RunnerChannel> Default for BackendRouter<R> {
|
||||
fn default() -> Self {
|
||||
Self { r: PhantomData }
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RunnerClient> QTensorPrimitive for RouterTensor<R> {
|
||||
fn scheme(&self) -> &QuantScheme {
|
||||
if let DType::QFloat(scheme) = &self.dtype {
|
||||
scheme
|
||||
} else {
|
||||
// TODO: maybe `tensor.scheme()` should return an option
|
||||
panic!("Expected quantized float dtype, got {:?}", self.dtype)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RunnerChannel> Backend for BackendRouter<R> {
|
||||
type Device = R::Device;
|
||||
|
||||
type FloatTensorPrimitive = RouterTensor<R::Client>;
|
||||
|
||||
type FloatElem = R::FloatElem;
|
||||
|
||||
type IntTensorPrimitive = RouterTensor<R::Client>;
|
||||
|
||||
type IntElem = R::IntElem;
|
||||
|
||||
type BoolTensorPrimitive = RouterTensor<R::Client>;
|
||||
|
||||
type BoolElem = R::BoolElem;
|
||||
|
||||
type QuantizedTensorPrimitive = RouterTensor<R::Client>;
|
||||
|
||||
fn name(device: &Self::Device) -> String {
|
||||
format!("router<{}>", R::name(device))
|
||||
}
|
||||
|
||||
fn seed(device: &Self::Device, seed: u64) {
|
||||
let client = get_client::<R>(device);
|
||||
client.seed(seed);
|
||||
}
|
||||
|
||||
fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
|
||||
let client = get_client::<R>(device);
|
||||
client.sync()
|
||||
}
|
||||
|
||||
fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
|
||||
let client = get_client::<R>(device);
|
||||
client.dtype_usage(dtype)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
use burn_backend::{Shape, backend::DeviceOps};
|
||||
|
||||
/// Allows tensors to be transferred between multiple backends.
|
||||
pub trait MultiBackendBridge: Send + Sync + 'static {
|
||||
/// The type that can be used to point to a tensor of any kind.
|
||||
type TensorHandle;
|
||||
/// Device type used by the backends.
|
||||
type Device: DeviceOps;
|
||||
|
||||
/// Change the backend of the given float tensor.
|
||||
fn change_backend_float(
|
||||
tensor: Self::TensorHandle,
|
||||
shape: Shape,
|
||||
target_device: &Self::Device,
|
||||
) -> Self::TensorHandle;
|
||||
|
||||
/// Change the backend of the given int tensor.
|
||||
fn change_backend_int(
|
||||
tensor: Self::TensorHandle,
|
||||
shape: Shape,
|
||||
target_device: &Self::Device,
|
||||
) -> Self::TensorHandle;
|
||||
|
||||
/// Change the backend of the given bool tensor.
|
||||
fn change_backend_bool(
|
||||
tensor: Self::TensorHandle,
|
||||
shape: Shape,
|
||||
target_device: &Self::Device,
|
||||
) -> Self::TensorHandle;
|
||||
|
||||
// TODO: change_backend_quantized
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
use core::marker::PhantomData;
|
||||
|
||||
/// Simply transfers tensors between backends via the underlying [tensor data](burn_backend::TensorData).
|
||||
pub struct ByteBridge<Backends> {
|
||||
backends: PhantomData<Backends>,
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
mod base;
|
||||
mod byte;
|
||||
|
||||
pub use base::*;
|
||||
pub use byte::*;
|
||||
@@ -0,0 +1,66 @@
|
||||
use alloc::string::String;
|
||||
use burn_backend::{DType, Element, Shape, backend::DeviceOps};
|
||||
use burn_ir::TensorIr;
|
||||
|
||||
use crate::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
|
||||
|
||||
/// Type alias for `<Br as MultiBackendBridge>::TensorHandle`.
|
||||
pub type TensorHandle<Br> = <Br as MultiBackendBridge>::TensorHandle;
|
||||
|
||||
/// Defines the connection channel and operations for a setup with multiple backend runner clients.
|
||||
pub trait RunnerChannel: Clone + Send + Sync + 'static + Sized {
|
||||
/// Device type.
|
||||
type Device: DeviceOps;
|
||||
/// A bridge that can transfer tensors between multiple backends.
|
||||
type Bridge: MultiBackendBridge<Device = Self::Device>;
|
||||
/// Client type.
|
||||
type Client: RunnerClient<Device = Self::Device>;
|
||||
/// Float element type.
|
||||
type FloatElem: Element;
|
||||
/// Int element type.
|
||||
type IntElem: Element;
|
||||
/// Bool element type.
|
||||
type BoolElem: Element;
|
||||
|
||||
/// Name of the channel.
|
||||
fn name(device: &Self::Device) -> String;
|
||||
|
||||
/// Initialize a new client for the given device.
|
||||
fn init_client(device: &Self::Device) -> Self::Client;
|
||||
|
||||
/// Get the tensor handle corresponding to the [tensor representation](TensorIr).
|
||||
fn get_tensor_handle(tensor: &TensorIr, client: &Self::Client) -> TensorHandle<Self::Bridge>;
|
||||
|
||||
/// Create a tensor with the given handle and shape.
|
||||
fn register_tensor(
|
||||
client: &Self::Client,
|
||||
handle: TensorHandle<Self::Bridge>,
|
||||
shape: Shape,
|
||||
dtype: DType,
|
||||
) -> RouterTensor<Self::Client>;
|
||||
|
||||
/// Change the tensor to a different client backend.
|
||||
fn change_client_backend(
|
||||
tensor: RouterTensor<Self::Client>,
|
||||
device: &Self::Device, // target device
|
||||
) -> RouterTensor<Self::Client> {
|
||||
// Get tensor handle from current client
|
||||
let original_client = tensor.client.clone();
|
||||
let desc = tensor.into_ir();
|
||||
let mut handle = Self::get_tensor_handle(&desc, &original_client);
|
||||
|
||||
if desc.dtype.is_float() {
|
||||
handle = Self::Bridge::change_backend_float(handle, desc.shape.clone(), device);
|
||||
} else if desc.dtype.is_int() {
|
||||
handle = Self::Bridge::change_backend_int(handle, desc.shape.clone(), device);
|
||||
} else if desc.dtype.is_bool() {
|
||||
handle = Self::Bridge::change_backend_bool(handle, desc.shape.clone(), device);
|
||||
} else {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
// Register tensor handle on target client
|
||||
let target_client = get_client::<Self>(device);
|
||||
Self::register_tensor(&target_client, handle, desc.shape, desc.dtype)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
use core::marker::PhantomData;
|
||||
|
||||
/// A local channel with direct connection to the backend runner clients.
|
||||
pub struct DirectChannel<Backends, Bridge> {
|
||||
backends: PhantomData<Backends>,
|
||||
bridge: PhantomData<Bridge>,
|
||||
}
|
||||
|
||||
impl<Backends, Bridge> Clone for DirectChannel<Backends, Bridge> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
backends: self.backends,
|
||||
bridge: self.bridge,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
mod base;
|
||||
mod direct;
|
||||
|
||||
pub use base::*;
|
||||
pub use direct::*;
|
||||
@@ -0,0 +1,128 @@
|
||||
use crate::{RouterTensor, RunnerChannel};
|
||||
use alloc::boxed::Box;
|
||||
use alloc::vec::Vec;
|
||||
use burn_backend::{
|
||||
DType, TensorData,
|
||||
backend::{DeviceId, DeviceOps, ExecutionError},
|
||||
};
|
||||
use burn_ir::{OperationIr, TensorId, TensorIr};
|
||||
use burn_std::future::DynFut;
|
||||
use core::ops::DerefMut;
|
||||
use hashbrown::HashMap;
|
||||
use spin::Mutex;
|
||||
|
||||
/// Type alias for `<R as RunnerChannel>::Client`.
|
||||
pub type Client<R> = <R as RunnerChannel>::Client;
|
||||
pub(crate) static CLIENTS: RunnerClientLocator = RunnerClientLocator::new();
|
||||
|
||||
type Key = (core::any::TypeId, DeviceId);
|
||||
|
||||
/// Define how to interact with the runner.
|
||||
pub trait RunnerClient: Clone + Send + Sync + Sized {
|
||||
/// Device type.
|
||||
type Device: DeviceOps;
|
||||
|
||||
/// Register a new tensor operation to be executed by the (runner) server.
|
||||
fn register_op(&self, op: OperationIr);
|
||||
/// Register a new tensor operation to be executed by the (runner) server.
|
||||
///
|
||||
/// Returns the new (uninitialized) output tensor(s) generated by the registered operation.
|
||||
fn register(&self, op: OperationIr) -> Vec<RouterTensor<Self>> {
|
||||
let out = op
|
||||
.outputs()
|
||||
.map(|output| {
|
||||
RouterTensor::new(output.id, output.shape.clone(), output.dtype, self.clone())
|
||||
})
|
||||
.collect();
|
||||
self.register_op(op);
|
||||
|
||||
out
|
||||
}
|
||||
/// Read the values contained by a tensor.
|
||||
fn read_tensor_async(&self, tensor: TensorIr) -> DynFut<Result<TensorData, ExecutionError>>;
|
||||
/// Sync the runner, ensure that all computations are finished.
|
||||
fn sync(&self) -> Result<(), ExecutionError>;
|
||||
/// Create a new (uninitialized) empty tensor and returns its corresponding [tensor id](TensorId).
|
||||
fn create_empty_handle(&self) -> TensorId;
|
||||
/// Create a new [RouterTensor] from the tensor data.
|
||||
fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self>;
|
||||
/// Get the current device used by all operations handled by this client.
|
||||
fn device(&self) -> Self::Device;
|
||||
/// Seed the runner.
|
||||
fn seed(&self, seed: u64);
|
||||
/// Returns the supported data type usage set
|
||||
fn dtype_usage(&self, dtype: DType) -> burn_backend::DTypeUsageSet;
|
||||
}
|
||||
|
||||
pub(crate) struct RunnerClientLocator {
|
||||
clients: Mutex<Option<HashMap<Key, Box<dyn core::any::Any + Send>>>>,
|
||||
}
|
||||
|
||||
/// Get the client for the given device
|
||||
pub fn get_client<R: RunnerChannel>(device: &R::Device) -> Client<R> {
|
||||
CLIENTS.client::<R>(device)
|
||||
}
|
||||
|
||||
/// Initialize a new client for the given device.
|
||||
///
|
||||
/// If a (global) seed was previously set, the client seed is set.
|
||||
fn new_client<R: RunnerChannel>(device: &R::Device) -> Client<R> {
|
||||
R::init_client(device)
|
||||
}
|
||||
|
||||
impl RunnerClientLocator {
|
||||
/// Create a new client locator.
|
||||
pub const fn new() -> Self {
|
||||
Self {
|
||||
clients: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the runner client for the given device.
|
||||
///
|
||||
/// If a client isn't already initialized, it is created.
|
||||
pub fn client<R: RunnerChannel + 'static>(&self, device: &R::Device) -> Client<R> {
|
||||
let device_id = device.id();
|
||||
let client_id = (core::any::TypeId::of::<R>(), device_id);
|
||||
let mut clients = self.clients.lock();
|
||||
|
||||
if clients.is_none() {
|
||||
let client = new_client::<R>(device);
|
||||
Self::register_inner::<R>(client_id, client, &mut clients);
|
||||
}
|
||||
|
||||
match clients.deref_mut() {
|
||||
Some(clients) => match clients.get(&client_id) {
|
||||
Some(client) => {
|
||||
let client: &Client<R> = client.downcast_ref().unwrap();
|
||||
client.clone()
|
||||
}
|
||||
None => {
|
||||
let client = new_client::<R>(device);
|
||||
let any = Box::new(client.clone());
|
||||
clients.insert(client_id, any);
|
||||
client
|
||||
}
|
||||
},
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn register_inner<R: RunnerChannel + 'static>(
|
||||
key: Key,
|
||||
client: Client<R>,
|
||||
clients: &mut Option<HashMap<Key, Box<dyn core::any::Any + Send>>>,
|
||||
) {
|
||||
if clients.is_none() {
|
||||
*clients = Some(HashMap::new());
|
||||
}
|
||||
|
||||
if let Some(clients) = clients {
|
||||
if clients.contains_key(&key) {
|
||||
panic!("Client already created for device {key:?}");
|
||||
}
|
||||
|
||||
clients.insert(key, Box::new(client));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
mod base;
|
||||
|
||||
pub use base::*;
|
||||
@@ -0,0 +1,49 @@
|
||||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
#![recursion_limit = "138"]
|
||||
|
||||
//! Burn multi-backend router.
|
||||
|
||||
mod backend;
|
||||
mod bridge;
|
||||
mod channel;
|
||||
mod client;
|
||||
mod ops;
|
||||
mod runner;
|
||||
mod tensor;
|
||||
mod types;
|
||||
|
||||
pub use backend::*;
|
||||
pub use bridge::*;
|
||||
pub use channel::*;
|
||||
pub use client::*;
|
||||
pub use runner::*;
|
||||
pub use tensor::*;
|
||||
pub use types::*;
|
||||
|
||||
/// A local channel with a simple byte bridge between backends.
|
||||
/// It transfers tensors between backends via the underlying [tensor data](burn_backend::TensorData).
|
||||
pub type DirectByteChannel<Backends> = DirectChannel<Backends, ByteBridge<Backends>>;
|
||||
|
||||
/// Router backend.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// type MyBackend = Router<(NdArray, Wgpu)>;
|
||||
/// ```
|
||||
pub type Router<Backends> = BackendRouter<DirectByteChannel<Backends>>;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(unused)]
|
||||
mod tests {
|
||||
use crate::BackendRouter;
|
||||
use crate::DirectByteChannel;
|
||||
|
||||
pub type TestBackend1 = burn_ndarray::NdArray<f32, i32>;
|
||||
pub type TestBackend2 = burn_wgpu::Wgpu<f32, i32>;
|
||||
pub type TestBackend = BackendRouter<DirectByteChannel<(TestBackend1, TestBackend2)>>;
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
use crate::{BackendRouter, RunnerChannel};
|
||||
use burn_backend::ops::ActivationOps;
|
||||
|
||||
impl<R: RunnerChannel> ActivationOps<Self> for BackendRouter<R> {}
|
||||
@@ -0,0 +1,69 @@
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! binary_float_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
|
||||
let rhs = $handles.get_float_tensor::<B>(&$desc.rhs);
|
||||
let output = $ops(lhs, rhs);
|
||||
|
||||
$handles.register_float_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! binary_float_cmp_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
|
||||
let rhs = $handles.get_float_tensor::<B>(&$desc.rhs);
|
||||
let output = $ops(lhs, rhs);
|
||||
|
||||
$handles.register_bool_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! binary_int_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
|
||||
let rhs = $handles.get_int_tensor::<B>(&$desc.rhs);
|
||||
let output = $ops(lhs, rhs);
|
||||
|
||||
$handles.register_int_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! binary_int_cmp_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
|
||||
let rhs = $handles.get_int_tensor::<B>(&$desc.rhs);
|
||||
let output = $ops(lhs, rhs);
|
||||
|
||||
$handles.register_bool_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! binary_bool_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_bool_tensor::<B>(&$desc.lhs);
|
||||
let rhs = $handles.get_bool_tensor::<B>(&$desc.rhs);
|
||||
let output = $ops(lhs, rhs);
|
||||
|
||||
$handles.register_bool_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
@@ -0,0 +1,333 @@
|
||||
use alloc::vec::Vec;
|
||||
use burn_backend::backend::ExecutionError;
|
||||
|
||||
use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
|
||||
use burn_backend::ops::BoolTensorOps;
|
||||
use burn_backend::tensor::{
|
||||
BoolTensor, Device, FloatElem, FloatTensor, IndexingUpdateOp, IntElem, IntTensor,
|
||||
};
|
||||
use burn_backend::{Element, Scalar, Shape, Slice, TensorData};
|
||||
use burn_ir::{
|
||||
BaseOperationIr, BinaryOpIr, BoolOperationIr, CastOpIr, CatOpIr, CreationOpIr, FlipOpIr,
|
||||
GatherOpIr, InitOperationIr, MaskFillOpIr, MaskWhereOpIr, OperationIr, OperationOutput,
|
||||
PermuteOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr, ShapeOpIr, SliceAssignOpIr, SliceOpIr,
|
||||
SwapDimsOpIr, UnaryOpIr, UnfoldOpIr,
|
||||
};
|
||||
|
||||
impl<R: RunnerChannel> BoolTensorOps<Self> for BackendRouter<R> {
|
||||
fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
|
||||
let client = get_client::<R>(device);
|
||||
let desc =
|
||||
CreationOpIr::create(shape, R::BoolElem::dtype(), || client.create_empty_handle());
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::Empty(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
|
||||
let client = get_client::<R>(device);
|
||||
let desc =
|
||||
CreationOpIr::create(shape, R::BoolElem::dtype(), || client.create_empty_handle());
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::Zeros(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
|
||||
let client = get_client::<R>(device);
|
||||
let desc =
|
||||
CreationOpIr::create(shape, R::BoolElem::dtype(), || client.create_empty_handle());
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::Ones(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
|
||||
tensor.into_data().await
|
||||
}
|
||||
|
||||
fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
|
||||
let client = get_client::<R>(device);
|
||||
let out = client.register_tensor_data(data);
|
||||
let desc = InitOperationIr {
|
||||
out: out.to_ir_out(),
|
||||
};
|
||||
|
||||
// Call register op when output is already initialized
|
||||
client.register_op(OperationIr::Init(desc));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let desc = CastOpIr::create(tensor.into_ir(), IntElem::<Self>::dtype(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::Bool(BoolOperationIr::IntoInt(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let desc = CastOpIr::create(tensor.into_ir(), FloatElem::<Self>::dtype(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::Bool(BoolOperationIr::IntoFloat(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
|
||||
tensor.client.device()
|
||||
}
|
||||
|
||||
fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
|
||||
if &tensor.client.device() == device {
|
||||
return tensor;
|
||||
}
|
||||
R::change_client_backend(tensor, device)
|
||||
}
|
||||
|
||||
fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::Reshape(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::Slice(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_slice_assign(
|
||||
tensor: BoolTensor<Self>,
|
||||
slices: &[burn_backend::Slice],
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let desc =
|
||||
SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
let client = lhs.client.clone();
|
||||
let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::Equal(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
|
||||
|
||||
client
|
||||
.register(OperationIr::Bool(BoolOperationIr::Not(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
let client = lhs.client.clone();
|
||||
let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::Bool(BoolOperationIr::And(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
let client = lhs.client.clone();
|
||||
let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::Bool(BoolOperationIr::Or(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::SwapDims(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::Permute(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::Flip(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::Expand(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
|
||||
let client = tensors.first().unwrap().client.clone();
|
||||
let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
|
||||
let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::Cat(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_unfold(
|
||||
tensor: BoolTensor<Self>,
|
||||
dim: usize,
|
||||
size: usize,
|
||||
step: usize,
|
||||
) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::Unfold(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_mask_where(
|
||||
tensor: BoolTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::MaskWhere(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_mask_fill(
|
||||
tensor: BoolTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: Scalar,
|
||||
) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let value = value.into();
|
||||
let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::MaskFill(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_gather(
|
||||
dim: usize,
|
||||
tensor: BoolTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::Gather(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_scatter_or(
|
||||
dim: usize,
|
||||
tensor: BoolTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let desc = ScatterOpIr::create(
|
||||
tensor.into_ir(),
|
||||
dim,
|
||||
indices.into_ir(),
|
||||
value.into_ir(),
|
||||
IndexingUpdateOp::Add,
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::Scatter(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
let client = lhs.client.clone();
|
||||
let rhs = rhs.into();
|
||||
let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::BaseBool(BaseOperationIr::EqualElem(desc)))
|
||||
.output()
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,9 @@
|
||||
mod activation;
|
||||
mod binary;
|
||||
mod bool_tensor;
|
||||
mod int_tensor;
|
||||
mod module;
|
||||
mod qtensor;
|
||||
mod tensor;
|
||||
mod transaction;
|
||||
mod unary;
|
||||
@@ -0,0 +1,796 @@
|
||||
use alloc::boxed::Box;
|
||||
|
||||
use burn_backend::Element;
|
||||
use burn_backend::ops::{
|
||||
AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward,
|
||||
DeformConvOptions, InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices,
|
||||
MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
|
||||
};
|
||||
use burn_backend::tensor::{BoolTensor, FloatTensor, IntElem, IntTensor};
|
||||
use burn_ir::*;
|
||||
|
||||
use crate::{BackendRouter, RunnerChannel, RunnerClient};
|
||||
|
||||
impl<R: RunnerChannel> ModuleOps<Self> for BackendRouter<R> {
|
||||
fn conv1d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvOptions<1>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = Conv1dOpIr::create(
|
||||
x.into_ir(),
|
||||
weight.into_ir(),
|
||||
bias.map(|bias| bias.into_ir()),
|
||||
options.into(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::Conv1d(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn conv1d_x_backward(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: ConvOptions<1>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = Conv1dXBackwardOpIr::create(
|
||||
x.into_ir(),
|
||||
weight.into_ir(),
|
||||
output_grad.into_ir(),
|
||||
options.into(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::Conv1dXBackward(
|
||||
desc,
|
||||
)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn conv1d_weight_backward(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: ConvOptions<1>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = Conv1dWeightBackwardOpIr::create(
|
||||
x.into_ir(),
|
||||
weight.into_ir(),
|
||||
output_grad.into_ir(),
|
||||
options.into(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(
|
||||
ModuleOperationIr::Conv1dWeightBackward(desc),
|
||||
))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn conv1d_bias_backward(
|
||||
x: FloatTensor<Self>,
|
||||
bias: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = Conv1dBiasBackwardOpIr::create(
|
||||
x.into_ir(),
|
||||
bias.into_ir(),
|
||||
output_grad.into_ir(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::Conv1dBiasBackward(
|
||||
desc,
|
||||
)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = Conv2dOpIr::create(
|
||||
x.into_ir(),
|
||||
weight.into_ir(),
|
||||
bias.map(|bias| bias.into_ir()),
|
||||
options.into(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::Conv2d(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn conv2d_x_backward(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: ConvOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = Conv2dXBackwardOpIr::create(
|
||||
x.into_ir(),
|
||||
weight.into_ir(),
|
||||
output_grad.into_ir(),
|
||||
options.into(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::Conv2dXBackward(
|
||||
desc,
|
||||
)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn conv2d_weight_backward(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: ConvOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = Conv2dWeightBackwardOpIr::create(
|
||||
x.into_ir(),
|
||||
weight.into_ir(),
|
||||
output_grad.into_ir(),
|
||||
options.into(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(
|
||||
ModuleOperationIr::Conv2dWeightBackward(desc),
|
||||
))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn conv2d_bias_backward(
|
||||
x: FloatTensor<Self>,
|
||||
bias: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = Conv2dBiasBackwardOpIr::create(
|
||||
x.into_ir(),
|
||||
bias.into_ir(),
|
||||
output_grad.into_ir(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::Conv2dBiasBackward(
|
||||
desc,
|
||||
)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn conv3d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvOptions<3>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = Conv3dOpIr::create(
|
||||
x.into_ir(),
|
||||
weight.into_ir(),
|
||||
bias.map(|bias| bias.into_ir()),
|
||||
options.into(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::Conv3d(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn conv3d_x_backward(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: ConvOptions<3>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = Conv3dXBackwardOpIr::create(
|
||||
x.into_ir(),
|
||||
weight.into_ir(),
|
||||
output_grad.into_ir(),
|
||||
options.into(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::Conv3dXBackward(
|
||||
desc,
|
||||
)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn conv3d_weight_backward(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: ConvOptions<3>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = Conv3dWeightBackwardOpIr::create(
|
||||
x.into_ir(),
|
||||
weight.into_ir(),
|
||||
output_grad.into_ir(),
|
||||
options.into(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(
|
||||
ModuleOperationIr::Conv3dWeightBackward(desc),
|
||||
))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn conv3d_bias_backward(
|
||||
x: FloatTensor<Self>,
|
||||
bias: FloatTensor<Self>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = Conv3dBiasBackwardOpIr::create(
|
||||
x.into_ir(),
|
||||
bias.into_ir(),
|
||||
output_grad.into_ir(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::Conv3dBiasBackward(
|
||||
desc,
|
||||
)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvTransposeOptions<1>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = ConvTranspose1dOpIr::create(
|
||||
x.into_ir(),
|
||||
weight.into_ir(),
|
||||
bias.map(|bias| bias.into_ir()),
|
||||
options.into(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::ConvTranspose1d(
|
||||
desc,
|
||||
)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = ConvTranspose2dOpIr::create(
|
||||
x.into_ir(),
|
||||
weight.into_ir(),
|
||||
bias.map(|bias| bias.into_ir()),
|
||||
options.into(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::ConvTranspose2d(
|
||||
desc,
|
||||
)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn conv_transpose3d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvTransposeOptions<3>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = ConvTranspose3dOpIr::create(
|
||||
x.into_ir(),
|
||||
weight.into_ir(),
|
||||
bias.map(|bias| bias.into_ir()),
|
||||
options.into(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::ConvTranspose3d(
|
||||
desc,
|
||||
)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn avg_pool1d(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
count_include_pad: bool,
|
||||
ceil_mode: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = AvgPool1dOpIr::create(
|
||||
x.into_ir(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad,
|
||||
ceil_mode,
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::AvgPool1d(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn avg_pool2d(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
ceil_mode: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = AvgPool2dOpIr::create(
|
||||
x.into_ir(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad,
|
||||
ceil_mode,
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::AvgPool2d(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn avg_pool1d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
count_include_pad: bool,
|
||||
ceil_mode: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = AvgPool1dBackwardOpIr::create(
|
||||
x.into_ir(),
|
||||
grad.into_ir(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad,
|
||||
ceil_mode,
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::AvgPool1dBackward(
|
||||
desc,
|
||||
)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn avg_pool2d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
ceil_mode: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = AvgPool2dBackwardOpIr::create(
|
||||
x.into_ir(),
|
||||
grad.into_ir(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad,
|
||||
ceil_mode,
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::AvgPool2dBackward(
|
||||
desc,
|
||||
)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn max_pool1d(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
ceil_mode: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = MaxPool1dOpIr::create(
|
||||
x.into_ir(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::MaxPool1d(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn max_pool2d(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = MaxPool2dOpIr::create(
|
||||
x.into_ir(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::MaxPool2d(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn max_pool1d_with_indices(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
ceil_mode: bool,
|
||||
) -> MaxPool1dWithIndices<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = MaxPool1dWithIndicesOpIr::create(
|
||||
x.into_ir(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
IntElem::<Self>::dtype(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
let [out, out_indices] = client
|
||||
.register(OperationIr::Module(
|
||||
ModuleOperationIr::MaxPool1dWithIndices(desc),
|
||||
))
|
||||
.outputs();
|
||||
|
||||
MaxPool1dWithIndices::new(out, out_indices)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indices(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
) -> MaxPool2dWithIndices<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = MaxPool2dWithIndicesOpIr::create(
|
||||
x.into_ir(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
IntElem::<Self>::dtype(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
let [out, out_indices] = client
|
||||
.register(OperationIr::Module(
|
||||
ModuleOperationIr::MaxPool2dWithIndices(desc),
|
||||
))
|
||||
.outputs();
|
||||
|
||||
MaxPool2dWithIndices::new(out, out_indices)
|
||||
}
|
||||
|
||||
fn max_pool1d_with_indices_backward(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
ceil_mode: bool,
|
||||
output_grad: FloatTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> MaxPool1dBackward<Self> {
|
||||
let client = x.client.clone();
|
||||
|
||||
let desc = MaxPool1dWithIndicesBackwardOpIr::create(
|
||||
x.into_ir(),
|
||||
output_grad.into_ir(),
|
||||
indices.into_ir(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
let out = client
|
||||
.register(OperationIr::Module(
|
||||
ModuleOperationIr::MaxPool1dWithIndicesBackward(desc),
|
||||
))
|
||||
.output();
|
||||
|
||||
MaxPool1dBackward::new(out)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indices_backward(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
output_grad: FloatTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> MaxPool2dBackward<Self> {
|
||||
let client = x.client.clone();
|
||||
|
||||
let desc = MaxPool2dWithIndicesBackwardOpIr::create(
|
||||
x.into_ir(),
|
||||
output_grad.into_ir(),
|
||||
indices.into_ir(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
let out = client
|
||||
.register(OperationIr::Module(
|
||||
ModuleOperationIr::MaxPool2dWithIndicesBackward(desc),
|
||||
))
|
||||
.output();
|
||||
|
||||
MaxPool2dBackward::new(out)
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
|
||||
let desc = AdaptiveAvgPool1dOpIr::create(x.into_ir(), output_size, || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1d(
|
||||
desc,
|
||||
)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
|
||||
let desc = AdaptiveAvgPool2dOpIr::create(x.into_ir(), output_size, || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2d(
|
||||
desc,
|
||||
)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool1d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
|
||||
let desc = AdaptiveAvgPool1dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(
|
||||
ModuleOperationIr::AdaptiveAvgPool1dBackward(desc),
|
||||
))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
|
||||
let desc = AdaptiveAvgPool2dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(
|
||||
ModuleOperationIr::AdaptiveAvgPool2dBackward(desc),
|
||||
))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn interpolate(
|
||||
x: FloatTensor<Self>,
|
||||
output_size: [usize; 2],
|
||||
options: InterpolateOptions,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = InterpolateOpIr::create(x.into_ir(), output_size, options.into(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::Interpolate(desc)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn interpolate_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
output_size: [usize; 2],
|
||||
options: InterpolateOptions,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = InterpolateBackwardOpIr::create(
|
||||
x.into_ir(),
|
||||
grad.into_ir(),
|
||||
output_size,
|
||||
options.into(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::InterpolateBackward(
|
||||
desc,
|
||||
)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn deform_conv2d(
|
||||
x: FloatTensor<Self>,
|
||||
offset: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
mask: Option<FloatTensor<Self>>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let desc = DeformConv2dOpIr::create(
|
||||
x.into_ir(),
|
||||
offset.into_ir(),
|
||||
weight.into_ir(),
|
||||
mask.map(|mask| mask.into_ir()),
|
||||
bias.map(|bias| bias.into_ir()),
|
||||
options.into(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::DeformableConv2d(
|
||||
Box::new(desc),
|
||||
)))
|
||||
.output()
|
||||
}
|
||||
|
||||
fn deform_conv2d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
offset: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
mask: Option<FloatTensor<Self>>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> DeformConv2dBackward<Self> {
|
||||
let client = x.client.clone();
|
||||
let has_bias = bias.is_some();
|
||||
let has_mask = mask.is_some();
|
||||
|
||||
let desc = DeformConv2dBackwardOpIr::create(
|
||||
x.into_ir(),
|
||||
offset.into_ir(),
|
||||
weight.into_ir(),
|
||||
mask.map(|mask| mask.into_ir()),
|
||||
bias.map(|bias| bias.into_ir()),
|
||||
output_grad.into_ir(),
|
||||
options.into(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
let mut outputs = client
|
||||
.register(OperationIr::Module(
|
||||
ModuleOperationIr::DeformableConv2dBackward(Box::new(desc)),
|
||||
))
|
||||
.into_iter();
|
||||
|
||||
// When the number of outputs is variable, the order is important
|
||||
let input_grad = outputs.next().unwrap();
|
||||
let offset_grad = outputs.next().unwrap();
|
||||
let weight_grad = outputs.next().unwrap();
|
||||
let mask_grad = has_mask.then(|| outputs.next().unwrap());
|
||||
let bias_grad = has_bias.then(|| outputs.next().unwrap());
|
||||
|
||||
DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad)
|
||||
}
|
||||
|
||||
fn attention(
|
||||
query: FloatTensor<Self>,
|
||||
key: FloatTensor<Self>,
|
||||
value: FloatTensor<Self>,
|
||||
mask: Option<BoolTensor<Self>>,
|
||||
attn_bias: Option<FloatTensor<Self>>,
|
||||
options: AttentionModuleOptions,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = query.client.clone();
|
||||
let desc = AttentionOpIr::create(
|
||||
query.into_ir(),
|
||||
key.into_ir(),
|
||||
value.into_ir(),
|
||||
mask.map(|m: BoolTensor<Self>| m.into_ir()),
|
||||
attn_bias.map(|ab| ab.into_ir()),
|
||||
options.into(),
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(OperationIr::Module(ModuleOperationIr::Attention(desc)))
|
||||
.output()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
use burn_backend::{
|
||||
ExecutionError, Shape, Slice, TensorData,
|
||||
ops::QTensorOps,
|
||||
quantization::{QuantScheme, QuantizationParametersPrimitive},
|
||||
tensor::{Device, FloatTensor, IntTensor, QuantizedTensor},
|
||||
};
|
||||
|
||||
use crate::{BackendRouter, RunnerChannel};
|
||||
|
||||
impl<R: RunnerChannel> QTensorOps<Self> for BackendRouter<R> {
|
||||
fn q_from_data(_data: TensorData, _device: &Device<Self>) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn quantize(
|
||||
_tensor: FloatTensor<Self>,
|
||||
_scheme: &QuantScheme,
|
||||
_qparams: QuantizationParametersPrimitive<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn quantize_dynamic(
|
||||
_tensor: FloatTensor<Self>,
|
||||
_scheme: &QuantScheme,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_device(_tensor: &QuantizedTensor<Self>) -> Device<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_to_device(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_device: &Device<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_reshape(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn q_into_data(_tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_swap_dims(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_dim1: usize,
|
||||
_dim2: usize,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_gather(
|
||||
_dim: usize,
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_indices: IntTensor<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_select(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_dim: usize,
|
||||
_indices: IntTensor<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_slice(_tensor: QuantizedTensor<Self>, _slices: &[Slice]) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,5 @@
|
||||
use burn_backend::ops::TransactionOps;
|
||||
|
||||
use crate::{BackendRouter, RunnerChannel};
|
||||
|
||||
impl<R: RunnerChannel> TransactionOps<Self> for BackendRouter<R> {}
|
||||
@@ -0,0 +1,155 @@
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_float_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
|
||||
let output = $ops(lhs, $desc.rhs.into());
|
||||
|
||||
$handles.register_float_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_float_dim_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
|
||||
let output = $ops(lhs, $desc.rhs);
|
||||
|
||||
$handles.register_float_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! reduce_float_dim_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let input = $handles.get_float_tensor::<B>(&$desc.input);
|
||||
let output = $ops(input, $desc.axis);
|
||||
|
||||
$handles.register_float_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! reduce_float2int_dim_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let input = $handles.get_float_tensor::<B>(&$desc.input);
|
||||
let output = $ops(input, $desc.axis);
|
||||
|
||||
$handles.register_int_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! reduce_int_dim_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let input = $handles.get_int_tensor::<B>(&$desc.input);
|
||||
let output = $ops(input, $desc.axis);
|
||||
|
||||
$handles.register_int_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_float2int_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
|
||||
let output = $ops(lhs, $desc.rhs);
|
||||
|
||||
$handles.register_int_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_float_cmp_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
|
||||
let output = $ops(lhs, $desc.rhs.into());
|
||||
|
||||
$handles.register_bool_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! unary_float_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_float_tensor::<B>(&$desc.input);
|
||||
let output = $ops(lhs);
|
||||
|
||||
$handles.register_float_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_int_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
|
||||
let output = $ops(lhs, $desc.rhs.into());
|
||||
|
||||
$handles.register_int_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_int_dim_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
|
||||
let output = $ops(lhs, $desc.rhs);
|
||||
|
||||
$handles.register_int_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_int_cmp_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
|
||||
let output = $ops(lhs, $desc.rhs.into());
|
||||
|
||||
$handles.register_bool_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! unary_int_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_int_tensor::<B>(&$desc.input);
|
||||
let output = $ops(lhs);
|
||||
|
||||
$handles.register_int_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
1575
crates/stable-diffusion-burn/burn-crates/burn-router/src/runner.rs
Normal file
1575
crates/stable-diffusion-burn/burn-crates/burn-router/src/runner.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,142 @@
|
||||
use core::sync::atomic::{AtomicU32, Ordering};
|
||||
|
||||
use alloc::format;
|
||||
use alloc::{sync::Arc, vec::Vec};
|
||||
|
||||
use super::RunnerClient;
|
||||
use burn_backend::{DType, Shape, TensorData, TensorMetadata, backend::ExecutionError};
|
||||
use burn_ir::{TensorId, TensorIr, TensorStatus};
|
||||
|
||||
/// Tensor primitive for the [router backend](crate::BackendRouter).
|
||||
pub struct RouterTensor<C: RunnerClient> {
|
||||
pub(crate) id: TensorId,
|
||||
pub(crate) shape: Shape,
|
||||
pub(crate) dtype: DType,
|
||||
/// The client that has this tensor
|
||||
pub client: C,
|
||||
pub(crate) count: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
impl<C: RunnerClient> TensorMetadata for RouterTensor<C> {
|
||||
fn dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn shape(&self) -> Shape {
|
||||
self.shape.clone()
|
||||
}
|
||||
|
||||
fn rank(&self) -> usize {
|
||||
self.shape.num_dims()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: RunnerClient> RouterTensor<C> {
|
||||
/// Create a new router tensor.
|
||||
pub fn new(id: TensorId, shape: Shape, dtype: DType, client: C) -> Self {
|
||||
Self {
|
||||
id,
|
||||
shape,
|
||||
dtype,
|
||||
client,
|
||||
count: Arc::new(AtomicU32::new(1)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn into_data(self) -> Result<TensorData, ExecutionError> {
|
||||
self.client.clone().read_tensor_async(self.into_ir()).await
|
||||
}
|
||||
|
||||
/// Get the ir for this tensor
|
||||
pub fn into_ir(mut self) -> TensorIr {
|
||||
let count = self.count.load(Ordering::Relaxed);
|
||||
let status = self.status(count);
|
||||
let mut shape_out = Shape::from(Vec::<usize>::new());
|
||||
core::mem::swap(&mut self.shape, &mut shape_out);
|
||||
|
||||
if let TensorStatus::ReadWrite = status {
|
||||
// Avoids an unwanted drop on the same thread.
|
||||
//
|
||||
// Since `drop` is called after `into_ir`, we must not register a drop if the tensor
|
||||
// was consumed with a `ReadWrite` status.
|
||||
self.count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
TensorIr {
|
||||
status,
|
||||
shape: shape_out,
|
||||
id: self.id,
|
||||
dtype: self.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn to_ir_out(&self) -> TensorIr {
|
||||
TensorIr {
|
||||
status: TensorStatus::NotInit,
|
||||
shape: self.shape.clone(),
|
||||
id: self.id,
|
||||
dtype: self.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn status(&self, count: u32) -> TensorStatus {
|
||||
if count <= 1 {
|
||||
TensorStatus::ReadWrite
|
||||
} else {
|
||||
TensorStatus::ReadOnly
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: RunnerClient> core::fmt::Debug for RouterTensor<C> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(
|
||||
format!(
|
||||
"{{ id: {:?}, shape: {:?}, dtype: {:?}, device: {:?} }}",
|
||||
self.id,
|
||||
self.shape,
|
||||
self.dtype,
|
||||
self.client.device().clone(),
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: RunnerClient> Clone for RouterTensor<C> {
|
||||
fn clone(&self) -> Self {
|
||||
self.count.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
Self {
|
||||
id: self.id,
|
||||
shape: self.shape.clone(),
|
||||
client: self.client.clone(),
|
||||
dtype: self.dtype,
|
||||
count: self.count.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: RunnerClient> Drop for RouterTensor<C> {
|
||||
fn drop(&mut self) {
|
||||
let count = self.count.fetch_sub(1, Ordering::Relaxed);
|
||||
|
||||
match self.status(count) {
|
||||
TensorStatus::ReadWrite => {
|
||||
let id = self.id;
|
||||
let mut shape = Shape::from(Vec::<usize>::new());
|
||||
core::mem::swap(&mut shape, &mut self.shape);
|
||||
|
||||
let ir = TensorIr {
|
||||
id,
|
||||
shape,
|
||||
status: TensorStatus::ReadWrite,
|
||||
dtype: self.dtype,
|
||||
};
|
||||
self.client.register_op(burn_ir::OperationIr::Drop(ir));
|
||||
}
|
||||
TensorStatus::ReadOnly => {}
|
||||
TensorStatus::NotInit => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,386 @@
|
||||
use alloc::format;
|
||||
use alloc::string::String;
|
||||
use burn_backend::{
|
||||
DType, Shape, TensorData,
|
||||
backend::{Backend, DeviceId, DeviceOps, ExecutionError},
|
||||
try_read_sync,
|
||||
};
|
||||
use burn_ir::{BackendIr, OperationIr, TensorHandle, TensorId, TensorIr};
|
||||
use burn_std::future::DynFut;
|
||||
|
||||
use crate::{
|
||||
ByteBridge, DirectChannel, MultiBackendBridge, RouterTensor, Runner, RunnerChannel,
|
||||
RunnerClient,
|
||||
};
|
||||
|
||||
/// Implement multi backend types, with enums having one variant per backend.
|
||||
macro_rules! impl_multi_backend_types {
|
||||
// Match the default backend and at least one other backend, with rest being optional
|
||||
($module_name:ident, $DefaultBackend:ident, $($OtherBackend:ident),+) => {
|
||||
/// Module containing the essential types for multi-backend operations.
|
||||
///
|
||||
/// - `Handle`: the type used to point to a tensor (defined for all backends).
|
||||
/// - `MultiRunnerClient`: a client for multiple runners (each responsible to execute tensor operations on a given backend).
|
||||
/// - `DirectChannel`: a local channel with direct connection to the backend runner clients.
|
||||
/// - `ByteBridge`: a simple multi-backend bridge that transfers tensors via the underlying [tensor data](burn_backend::TensorData).
|
||||
///
|
||||
/// Each enum type is defined with backend identifiers as variant names (e.g., `B1` and `B2` for dual backends).
|
||||
pub mod $module_name {
|
||||
use super::*;
|
||||
|
||||
/// The type that can be used to point to a tensor of any kind.
|
||||
/// Each backend has its own variant.
|
||||
pub enum Handle<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> {
|
||||
#[allow(missing_docs)]
|
||||
$DefaultBackend($DefaultBackend::Handle),
|
||||
$(
|
||||
#[allow(missing_docs)]
|
||||
$OtherBackend($OtherBackend::Handle),
|
||||
)+
|
||||
}
|
||||
|
||||
/// The device type used by a backend.
|
||||
/// Each backend has its own variant.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum MultiDevice<$DefaultBackend: Backend, $($OtherBackend: Backend),+> {
|
||||
#[allow(missing_docs)]
|
||||
$DefaultBackend($DefaultBackend::Device),
|
||||
$(
|
||||
#[allow(missing_docs)]
|
||||
$OtherBackend($OtherBackend::Device),
|
||||
)+
|
||||
}
|
||||
impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> PartialEq for MultiDevice<$DefaultBackend, $($OtherBackend),+> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(Self::$DefaultBackend(lhs), Self::$DefaultBackend(rhs)) => lhs == rhs,
|
||||
$(
|
||||
(Self::$OtherBackend(lhs), Self::$OtherBackend(rhs)) => lhs == rhs,
|
||||
)+
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default implementation always returns the first backend's device
|
||||
impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> Default for MultiDevice<$DefaultBackend, $($OtherBackend),+> {
|
||||
fn default() -> Self {
|
||||
Self::$DefaultBackend($DefaultBackend::Device::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> burn_std::device::Device for MultiDevice<$DefaultBackend, $($OtherBackend),+> {
|
||||
fn from_id(_device_id: DeviceId) -> Self {
|
||||
// TODO: Should be fix with the new router backend.
|
||||
Default::default()
|
||||
}
|
||||
|
||||
fn to_id(&self) -> DeviceId {
|
||||
match self {
|
||||
Self::$DefaultBackend(device) => device.id(),
|
||||
$(
|
||||
Self::$OtherBackend(device) => device.id(),
|
||||
)+
|
||||
}
|
||||
}
|
||||
|
||||
fn device_count(_type_id: u16) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> DeviceOps for MultiDevice<$DefaultBackend, $($OtherBackend),+> {}
|
||||
|
||||
/// A local client with multiple runners (each responsible to execute tensor operations on a given backend).
|
||||
#[derive(Clone)]
|
||||
pub enum MultiRunnerClient<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> {
|
||||
#[allow(missing_docs)]
|
||||
$DefaultBackend(Runner<$DefaultBackend>),
|
||||
$(
|
||||
#[allow(missing_docs)]
|
||||
$OtherBackend(Runner<$OtherBackend>),
|
||||
)+
|
||||
}
|
||||
|
||||
impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> RunnerClient for MultiRunnerClient<$DefaultBackend, $($OtherBackend),+>
|
||||
{
|
||||
type Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>;
|
||||
|
||||
fn register_op(&self, op: OperationIr) {
|
||||
match self {
|
||||
Self::$DefaultBackend(runner) => runner.register_op(op),
|
||||
$(
|
||||
Self::$OtherBackend(runner) => runner.register_op(op),
|
||||
)+
|
||||
}
|
||||
}
|
||||
|
||||
fn read_tensor_async(&self, tensor: TensorIr) -> DynFut<Result<TensorData, ExecutionError>> {
|
||||
match self {
|
||||
Self::$DefaultBackend(runner) => runner.read_tensor_async(tensor),
|
||||
$(
|
||||
Self::$OtherBackend(runner) => runner.read_tensor_async(tensor),
|
||||
)+
|
||||
}
|
||||
}
|
||||
|
||||
fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
|
||||
match self {
|
||||
Self::$DefaultBackend(runner) => {
|
||||
let desc = runner.register_tensor_data_desc(data);
|
||||
RouterTensor::new(desc.id, desc.shape, desc.dtype, self.clone())
|
||||
}
|
||||
$(
|
||||
Self::$OtherBackend(runner) => {
|
||||
let desc = runner.register_tensor_data_desc(data);
|
||||
RouterTensor::new(desc.id, desc.shape, desc.dtype, self.clone())
|
||||
}
|
||||
)+
|
||||
}
|
||||
}
|
||||
|
||||
fn device(&self) -> Self::Device {
|
||||
match self {
|
||||
Self::$DefaultBackend(runner) => MultiDevice::$DefaultBackend(runner.device()),
|
||||
$(
|
||||
Self::$OtherBackend(runner) => MultiDevice::$OtherBackend(runner.device()),
|
||||
)+
|
||||
}
|
||||
}
|
||||
|
||||
fn sync(&self) -> Result<(), ExecutionError> {
|
||||
match self {
|
||||
Self::$DefaultBackend(runner) => runner.sync(),
|
||||
$(
|
||||
Self::$OtherBackend(runner) => runner.sync(),
|
||||
)+
|
||||
}
|
||||
}
|
||||
|
||||
fn seed(&self, seed: u64) {
|
||||
match self {
|
||||
Self::$DefaultBackend(runner) => runner.seed(seed),
|
||||
$(
|
||||
Self::$OtherBackend(runner) => runner.seed(seed),
|
||||
)+
|
||||
}
|
||||
}
|
||||
|
||||
fn create_empty_handle(&self) -> TensorId {
|
||||
match self {
|
||||
Self::$DefaultBackend(runner) => runner.create_empty_handle(),
|
||||
$(
|
||||
Self::$OtherBackend(runner) => runner.create_empty_handle(),
|
||||
)+
|
||||
}
|
||||
}
|
||||
|
||||
fn dtype_usage(&self, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {
|
||||
match self {
|
||||
Self::$DefaultBackend(runner) => runner.dtype_usage(dtype),
|
||||
$(
|
||||
Self::$OtherBackend(runner) => runner.dtype_usage(dtype),
|
||||
)+
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+, Br> RunnerChannel for DirectChannel<($DefaultBackend, $($OtherBackend),+), Br>
|
||||
where
|
||||
Br: MultiBackendBridge<TensorHandle = Handle<$DefaultBackend, $($OtherBackend),+>, Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>>,
|
||||
{
|
||||
type Device = Br::Device;
|
||||
|
||||
type Bridge = Br;
|
||||
|
||||
type FloatElem = $DefaultBackend::FloatElem;
|
||||
type IntElem = $DefaultBackend::IntElem;
|
||||
type BoolElem = $DefaultBackend::BoolElem;
|
||||
|
||||
type Client = MultiRunnerClient<$DefaultBackend, $($OtherBackend),+>;
|
||||
|
||||
fn init_client(device: &Self::Device) -> Self::Client {
|
||||
match device {
|
||||
MultiDevice::$DefaultBackend(device) => MultiRunnerClient::$DefaultBackend(Runner::new(device.clone())),
|
||||
$(
|
||||
MultiDevice::$OtherBackend(device) => MultiRunnerClient::$OtherBackend(Runner::new(device.clone())),
|
||||
)+
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tensor_handle(
|
||||
tensor: &TensorIr,
|
||||
client: &Self::Client,
|
||||
) -> <Self::Bridge as MultiBackendBridge>::TensorHandle {
|
||||
match client {
|
||||
MultiRunnerClient::$DefaultBackend(runner) => Handle::$DefaultBackend(runner.get_tensor_handle(tensor)),
|
||||
$(
|
||||
MultiRunnerClient::$OtherBackend(runner) => Handle::$OtherBackend(runner.get_tensor_handle(tensor)),
|
||||
)+
|
||||
}
|
||||
}
|
||||
|
||||
fn register_tensor(
|
||||
client: &Self::Client,
|
||||
handle: <Self::Bridge as MultiBackendBridge>::TensorHandle,
|
||||
shape: Shape,
|
||||
dtype: DType,
|
||||
) -> RouterTensor<Self::Client> {
|
||||
match client {
|
||||
MultiRunnerClient::$DefaultBackend(runner) => match handle {
|
||||
Handle::$DefaultBackend(handle) => runner.register_tensor(handle, shape, dtype, client.clone()),
|
||||
_ => unreachable!("Can't register tensor handle for another backend."),
|
||||
},
|
||||
$(
|
||||
MultiRunnerClient::$OtherBackend(runner) => match handle {
|
||||
Handle::$OtherBackend(handle) => runner.register_tensor(handle, shape, dtype, client.clone()),
|
||||
_ => unreachable!("Can't register tensor handle for another backend."),
|
||||
},
|
||||
)+
|
||||
}
|
||||
}
|
||||
|
||||
fn name(_device: &Self::Device) -> String {
|
||||
let mut name = format!("{}", $DefaultBackend::name(&<$DefaultBackend::Device as Default>::default()));
|
||||
$(
|
||||
name.push_str(&format!(", {}", $OtherBackend::name(&<$OtherBackend::Device as Default>::default())));
|
||||
)+
|
||||
format!("direct<({})>", name)
|
||||
}
|
||||
}
|
||||
|
||||
impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> MultiBackendBridge for ByteBridge<($DefaultBackend, $($OtherBackend),+)> {
|
||||
type TensorHandle = Handle<$DefaultBackend, $($OtherBackend),+>;
|
||||
type Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>;
|
||||
|
||||
fn change_backend_float(
|
||||
tensor: Self::TensorHandle,
|
||||
shape: Shape,
|
||||
target_device: &Self::Device,
|
||||
) -> Self::TensorHandle {
|
||||
multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+)
|
||||
}
|
||||
|
||||
fn change_backend_int(
|
||||
tensor: Self::TensorHandle,
|
||||
shape: Shape,
|
||||
target_device: &Self::Device,
|
||||
) -> Self::TensorHandle {
|
||||
multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+)
|
||||
}
|
||||
|
||||
fn change_backend_bool(
|
||||
tensor: Self::TensorHandle,
|
||||
shape: Shape,
|
||||
target_device: &Self::Device,
|
||||
) -> Self::TensorHandle {
|
||||
multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! bridge {
|
||||
($Backend:ident, $handle:expr, $device:expr, $shape:expr) => {{
|
||||
// Bridge for the same backend
|
||||
let tensor = $Backend::float_tensor(TensorHandle {
|
||||
handle: $handle,
|
||||
shape: $shape,
|
||||
});
|
||||
let tensor = $Backend::float_to_device(tensor, $device);
|
||||
let handle = $Backend::float_tensor_handle(tensor);
|
||||
Handle::$Backend(handle)
|
||||
}};
|
||||
($BackendA:ident, $BackendB:ident, $handle:expr, $device:expr, $shape:expr) => {{
|
||||
// Byte bridge between two backends
|
||||
let tensor = $BackendA::float_tensor(TensorHandle { handle: $handle, shape: $shape });
|
||||
let data = try_read_sync($BackendA::float_into_data(tensor)).unwrap().expect(
|
||||
"Failed to read tensor data synchronously. This can happen on platforms that don't support blocking futures like WASM."
|
||||
);
|
||||
let tensor = $BackendB::float_from_data(data, $device);
|
||||
let handle = $BackendB::float_tensor_handle(tensor);
|
||||
Handle::$BackendB(handle)
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! multi_backend_match {
|
||||
($shape:expr, ($handle:expr, $device:expr) : $DefaultBackend:ident, $($OtherBackend:ident),+) => {
|
||||
multi_backend_match! (
|
||||
@step
|
||||
$shape,
|
||||
($handle, $device);
|
||||
{
|
||||
(Handle::$DefaultBackend(handle), MultiDevice::$DefaultBackend(device)) => bridge!($DefaultBackend, handle, device, $shape),
|
||||
$(
|
||||
(Handle::$DefaultBackend(handle), MultiDevice::$OtherBackend(device)) => bridge!($DefaultBackend, $OtherBackend, handle, device, $shape),
|
||||
(Handle::$OtherBackend(handle), MultiDevice::$DefaultBackend(device)) => bridge!($OtherBackend, $DefaultBackend, handle, device, $shape),
|
||||
(Handle::$OtherBackend(handle), MultiDevice::$OtherBackend(device)) => bridge!($OtherBackend, handle, device, $shape),
|
||||
)+
|
||||
};
|
||||
$($OtherBackend),+
|
||||
)
|
||||
};
|
||||
|
||||
(@step
|
||||
$shape:expr,
|
||||
$pats:tt;
|
||||
{ $($arms:tt)* };
|
||||
$BackendA:ident,
|
||||
$($OtherBackend:ident),+
|
||||
) => {
|
||||
multi_backend_match! (
|
||||
@step
|
||||
$shape,
|
||||
$pats;
|
||||
{
|
||||
$($arms)*
|
||||
$(
|
||||
(Handle::$BackendA(handle), MultiDevice::$OtherBackend(device)) => bridge!($BackendA, $OtherBackend, handle, device, $shape),
|
||||
(Handle::$OtherBackend(handle), MultiDevice::$BackendA(device)) => bridge!($OtherBackend, $BackendA, handle, device, $shape),
|
||||
)*
|
||||
};
|
||||
$($OtherBackend),*
|
||||
)
|
||||
};
|
||||
|
||||
(@step
|
||||
$shape:expr,
|
||||
($handle:expr, $device:expr);
|
||||
{ $($arms:tt)* };
|
||||
$($BackendA:ident)?
|
||||
) => {
|
||||
match ($handle, $device) {
|
||||
$($arms)*
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Implement multi-backend types and byte bridge for up to 4 backends
|
||||
impl_multi_backend_types!(duo, B1, B2);
|
||||
impl_multi_backend_types!(trio, B1, B2, B3);
|
||||
impl_multi_backend_types!(quad, B1, B2, B3, B4);
|
||||
|
||||
#[cfg(not(target_os = "windows"))] // cannot find a wgpu adapter on windows CI
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn_tensor::{Tensor, backend::Backend};
|
||||
|
||||
use super::*;
|
||||
use crate::tests::{TestBackend, TestBackend1, TestBackend2};
|
||||
|
||||
#[test]
|
||||
fn should_support_dual_byte_bridge() {
|
||||
let device1 = duo::MultiDevice::B1(<TestBackend1 as Backend>::Device::default());
|
||||
let device2 = duo::MultiDevice::B2(<TestBackend2 as Backend>::Device::default());
|
||||
let tensor1 = Tensor::<TestBackend, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device1);
|
||||
let tensor2 = Tensor::<TestBackend, 1>::from_floats([5.0, 6.0, 7.0, 8.0], &device2);
|
||||
|
||||
let tensor1_2 = tensor1.clone().to_device(&device2);
|
||||
tensor1.into_data().assert_eq(&tensor1_2.into_data(), true);
|
||||
|
||||
let tensor2_1 = tensor2.clone().to_device(&device1);
|
||||
tensor2.into_data().assert_eq(&tensor2_1.into_data(), true);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user