feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,240 @@
use crate::{
FusionTensor,
client::GlobalFusionClient,
stream::{Context, OrderedExecution},
};
use burn_backend::{
Backend, DType, DeviceOps, Element, ExecutionError,
tensor::{BoolTensor, Device, FloatTensor, IntTensor, QuantizedTensor},
};
use burn_ir::{BackendIr, OperationIr, TensorHandle};
use serde::{Serialize, de::DeserializeOwned};
use std::marker::PhantomData;
/// Get the client for the given device.
pub fn get_client<B: FusionBackend>(device: &Device<B>) -> Client<B::FusionRuntime> {
GlobalFusionClient::load(device)
}
/// Enable dynamic operation fusion on a backend that implements [fusion backend](crate::FusionBackend).
#[derive(Clone, Debug, Default)]
pub struct Fusion<B: FusionBackend> {
_backend: PhantomData<B>,
}
impl<B: FusionBackend> Backend for Fusion<B> {
type Device = B::Device;
type FloatTensorPrimitive = FusionTensor<B::FusionRuntime>;
type FloatElem = B::FloatElem;
type IntTensorPrimitive = FusionTensor<B::FusionRuntime>;
type IntElem = B::IntElem;
type BoolTensorPrimitive = FusionTensor<B::FusionRuntime>;
type BoolElem = B::BoolElem;
type QuantizedTensorPrimitive = FusionTensor<B::FusionRuntime>;
fn name(device: &Self::Device) -> String {
format!("fusion<{}>", B::name(device))
}
fn seed(device: &B::Device, seed: u64) {
let client = GlobalFusionClient::<B::FusionRuntime>::load(device);
client.drain();
B::seed(device, seed);
}
fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
let client = GlobalFusionClient::<B::FusionRuntime>::load(device);
client.drain();
B::sync(device)
}
fn ad_enabled(_device: &Self::Device) -> bool {
false
}
fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
device: &Self::Device,
input: Input,
func: Func,
) -> Output {
B::memory_persistent_allocations(device, input, func)
}
fn memory_cleanup(device: &Self::Device) {
B::memory_cleanup(device)
}
fn staging<'a, Iter>(data: Iter, device: &Self::Device)
where
Iter: Iterator<Item = &'a mut burn_backend::TensorData>,
{
B::staging(data, device);
}
fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
B::supports_dtype(device, dtype)
}
fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
B::dtype_usage(device, dtype)
}
}
/// The status of a [fuser](OperationFuser).
#[derive(Clone, Debug, Copy, PartialEq, Eq)]
pub enum FuserStatus {
/// No more operations can be fused.
Closed,
/// More operations can be fused.
Open,
}
/// The properties of a [fuser](OperationFuser).
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct FuserProperties {
/// The score of the optimization, higher is better.
pub score: u64,
/// If the operation is ready to be executed.
pub ready: bool,
}
/// The fusion operation abstraction allows implementations to fuse many
/// [tensor operations](OperationIr) into one, improving the performance of the backend.
///
///
/// # Notes
///
/// The implementations are free to execute the registered operations the way they want to improve
/// the speed and efficiency of the computational graph. It doesn't mean that all registered
/// operations should be fused, but that another way of executing them is more efficient.
///
/// Also, it is important to return (FuserStatus::Closed) when no more registered operation can
/// improve the performance.
pub trait OperationFuser<O>: Send {
/// Register a new [tensor operation](OperationIr).
fn fuse(&mut self, operation: &OperationIr);
/// Finish the optimization and create a fusion operation.
fn finish(&mut self) -> O;
/// Reset the state.
fn reset(&mut self);
/// Return the builder [status](FuserStatus).
fn status(&self) -> FuserStatus;
/// Return the builder [properties](FuserProperties).
fn properties(&self) -> FuserProperties;
/// The number of operation fused.
fn len(&self) -> usize;
/// If no operations are fused.
fn is_empty(&self) -> bool {
self.len() == 0
}
/// Clone the optimization builder.
fn clone_dyn(&self) -> Box<dyn OperationFuser<O>>;
}
/// The number of operations contained in the data structure.
pub trait NumOperations: core::fmt::Debug {
/// The number of registered operations.
fn len(&self) -> usize;
/// If the current optimization is empty.
fn is_empty(&self) -> bool {
self.len() == 0
}
}
/// The optimization created from a [fuser](OperationFuser).
pub trait Optimization<R: FusionRuntime>: Send + NumOperations {
/// Execute the optimization.
fn execute(
&mut self,
context: &mut Context<'_, R::FusionHandle>,
execution: &OrderedExecution<R>,
);
/// Returns the state that can be serialized.
fn to_state(&self) -> R::OptimizationState;
/// Create the optimization from the state.
fn from_state(device: &R::FusionDevice, state: R::OptimizationState) -> Self;
}
/// Type alias for `<R as FusionRuntime>::FusionDevice`.
pub type FusionDevice<R> = <R as FusionRuntime>::FusionDevice;
/// Type alias for `<R as FusionRuntime>::FusionHandle`.
pub type FusionHandle<R> = <R as FusionRuntime>::FusionHandle;
/// Client alias.
pub type Client<R> = GlobalFusionClient<R>;
/// Trait that defines a runtime that will benefits from fused operations.
pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug + 'static {
/// The state that can be serialized for an optimization.
type OptimizationState: Serialize + DeserializeOwned;
/// Optimization type for the backend.
type Optimization: Optimization<Self>;
/// Handle used to store tensor dynamically.
type FusionHandle: Clone + Send;
/// Device used by the runtime.
type FusionDevice: DeviceOps;
/// The type that represents booleans on the backend.
type BoolRepr: Element;
/// The list of fusers that will be used to optimize the computational graph.
fn fusers(device: Self::FusionDevice) -> Vec<Box<dyn OperationFuser<Self::Optimization>>>;
}
/// Trait that allows an existing [backend](Backend) to specify graph optimizations using
/// [operation fuser](crate::OperationFuser).
pub trait FusionBackend:
BackendIr<Handle = FusionHandle<Self::FusionRuntime>, Device = FusionDevice<Self::FusionRuntime>>
{
/// The runtime used for this backend.
type FusionRuntime: FusionRuntime;
/// Cast a float tensor and returns the resulting handle.
fn cast_float(tensor: FloatTensor<Self>, dtype: DType) -> Self::Handle;
/// Pointer to the full precision fusion backend.
type FullPrecisionBackend: FusionBackend<FusionRuntime = Self::FusionRuntime>;
}
// Fusion implements `BackendIr` to enable router backend usage.
impl<B: FusionBackend> BackendIr for Fusion<B> {
type Handle = FusionTensor<B::FusionRuntime>;
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
handle.handle
}
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
handle.handle
}
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
handle.handle
}
fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
handle.handle
}
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
tensor
}
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
tensor
}
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
tensor
}
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
tensor
}
}

View File

@@ -0,0 +1,307 @@
use crate::{
FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionServer, FusionTensor,
stream::{OperationStreams, StreamId, execution::Operation},
};
use burn_backend::{Device, DeviceContext, DeviceId, DeviceState};
use burn_backend::{TensorData, backend::ExecutionError};
use burn_ir::{OperationIr, TensorId, TensorIr};
use std::sync::Arc;
/// Use a mutex to communicate with the fusion server.
pub struct GlobalFusionClient<R: FusionRuntime> {
server: DeviceContext<FusionServer<R>>,
device: FusionDevice<R>,
}
impl<R: FusionRuntime> DeviceState for FusionServer<R> {
fn init(device_id: DeviceId) -> Self {
let device = FusionDevice::<R>::from_id(device_id);
FusionServer::new(device)
}
}
impl<R> Clone for GlobalFusionClient<R>
where
R: FusionRuntime,
{
fn clone(&self) -> Self {
Self {
server: self.server.clone(),
device: self.device.clone(),
}
}
}
impl<R> GlobalFusionClient<R>
where
R: FusionRuntime + 'static,
{
/// Loads the client from the given device.
pub fn load(device: &FusionDevice<R>) -> Self {
Self {
device: device.clone(),
server: DeviceContext::locate(device),
}
}
}
impl<R> GlobalFusionClient<R>
where
R: FusionRuntime + 'static,
{
/// Create a new client for the given [device](FusionRuntime::FusionDevice).
pub fn new(device: FusionDevice<R>) -> Self {
Self {
device: device.clone(),
server: DeviceContext::locate(&device),
}
}
/// Register a new [tensor operation intermediate representation](OperationIr).
///
/// Returns the new (uninitialized) output tensor(s) generated by the registered operation.
pub fn register<O>(
&self,
streams: OperationStreams,
repr: OperationIr,
operation: O,
) -> Vec<FusionTensor<R>>
where
O: Operation<R> + 'static,
{
// Create output tensors returned by this operation
let outputs = repr
.outputs()
.map(|output| {
FusionTensor::new(
output.id,
output.shape.clone(),
output.dtype,
self.clone(),
StreamId::current(),
)
})
.collect();
self.server
.lock()
.register(streams, repr, Arc::new(operation));
outputs
}
/// Register all lazy computation.
pub fn drain(&self) {
let id = StreamId::current();
self.server.lock().drain_stream(id);
}
/// Create a new (uninitialized) empty tensor handle and returns its corresponding [tensor id](TensorId).
pub fn create_empty_handle(&self) -> TensorId {
self.server.lock().create_empty_handle()
}
/// Get the current device used by all operations handled by this client.
pub fn device(&self) -> &FusionDevice<R> {
&self.device
}
/// Create a tensor with the given handle and returns its corresponding [tensor id](TensorId).
pub fn register_tensor_handle(&self, handle: FusionHandle<R>) -> TensorId {
let mut server = self.server.lock();
let id = server.create_empty_handle();
server.handles.register_handle(id, handle);
core::mem::drop(server);
id
}
/// Read the values contained by a float tensor.
pub fn read_tensor_float<B>(
self,
tensor: TensorIr,
stream: StreamId,
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send
where
B: FusionBackend<FusionRuntime = R>,
{
self.server.lock().read_float::<B>(tensor, stream)
}
/// Read the values contained by an int tensor.
pub fn read_tensor_int<B>(
self,
tensor: TensorIr,
id: StreamId,
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send
where
B: FusionBackend<FusionRuntime = R>,
{
self.server.lock().read_int::<B>(tensor, id)
}
/// Read the values contained by a bool tensor.
pub fn read_tensor_bool<B>(
self,
tensor: TensorIr,
stream: StreamId,
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send
where
B: FusionBackend<FusionRuntime = R>,
{
self.server.lock().read_bool::<B>(tensor, stream)
}
/// Read the values contained by a quantized tensor.
pub fn read_tensor_quantized<B>(
self,
tensor: TensorIr,
stream: StreamId,
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send
where
B: FusionBackend<FusionRuntime = R>,
{
self.server.lock().read_quantized::<B>(tensor, stream)
}
/// Change the client of the given float tensor.
pub fn change_client_float<B>(
&self,
tensor: TensorIr,
client: Self,
stream: StreamId,
) -> FusionTensor<R>
where
B: FusionBackend<FusionRuntime = R>,
{
let guard = self.server.lock_device_kind();
let mut server_current = self.server.lock();
server_current.drain_stream(stream);
let mut server_other = client.server.lock();
let id = server_current.change_server_float::<B>(
&tensor,
stream,
&client.device,
&mut server_other,
);
core::mem::drop(server_current);
core::mem::drop(server_other);
core::mem::drop(guard);
FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current())
}
/// Change the client of the given int tensor.
pub fn change_client_int<B>(
&self,
tensor: TensorIr,
client: Self,
stream: StreamId,
) -> FusionTensor<R>
where
B: FusionBackend<FusionRuntime = R>,
{
let guard = self.server.lock_device_kind();
let mut server_current = self.server.lock();
server_current.drain_stream(stream);
let mut server_other = client.server.lock();
let id = server_current.change_server_int::<B>(
&tensor,
stream,
&client.device,
&mut server_other,
);
core::mem::drop(server_other);
core::mem::drop(server_current);
core::mem::drop(guard);
FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current())
}
/// Change the client of the given bool tensor.
pub fn change_client_bool<B>(
&self,
tensor: TensorIr,
client: Self,
stream: StreamId,
) -> FusionTensor<R>
where
B: FusionBackend<FusionRuntime = R>,
{
let guard = self.server.lock_device_kind();
let mut server_current = self.server.lock();
server_current.drain_stream(stream);
let mut server_other = client.server.lock();
let id = server_current.change_server_bool::<B>(
&tensor,
stream,
&client.device,
&mut server_other,
);
core::mem::drop(server_other);
core::mem::drop(server_current);
core::mem::drop(guard);
FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current())
}
/// Change the client of the given quantized tensor.
pub fn change_client_quantized<B>(
&self,
tensor: TensorIr,
client: Self,
stream: StreamId,
) -> FusionTensor<R>
where
B: FusionBackend<FusionRuntime = R>,
{
let guard = self.server.lock_device_kind();
let mut server_current = self.server.lock();
server_current.drain_stream(stream);
let mut server_other = client.server.lock();
let id =
server_current.change_server_quantized::<B>(&tensor, &client.device, &mut server_other);
core::mem::drop(server_other);
core::mem::drop(server_current);
core::mem::drop(guard);
FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current())
}
/// Resolve the given float tensor to a primitive tensor.
pub fn resolve_tensor_float<B>(&self, tensor: FusionTensor<R>) -> B::FloatTensorPrimitive
where
B: FusionBackend<FusionRuntime = R>,
{
let mut server = self.server.lock();
server.drain_stream(tensor.stream);
server.resolve_server_float::<B>(&tensor.into_ir())
}
/// Resolve the given int tensor to a primitive tensor.
pub fn resolve_tensor_int<B>(&self, tensor: FusionTensor<R>) -> B::IntTensorPrimitive
where
B: FusionBackend<FusionRuntime = R>,
{
let mut server = self.server.lock();
server.drain_stream(tensor.stream);
server.resolve_server_int::<B>(&tensor.into_ir())
}
/// Resolve the given bool tensor to a primitive tensor.
pub fn resolve_tensor_bool<B>(&self, tensor: FusionTensor<R>) -> B::BoolTensorPrimitive
where
B: FusionBackend<FusionRuntime = R>,
{
let mut server = self.server.lock();
server.drain_stream(tensor.stream);
server.resolve_server_bool::<B>(&tensor.into_ir())
}
}

View File

@@ -0,0 +1,29 @@
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
//! # Burn Fusion
//!
//! This library is a part of the Burn project. It is a standalone crate that
//! can be used to perform automatic operation fusion on backends that support it.
#[macro_use]
extern crate derive_new;
/// Client module exposing types to communicate with the fusion server.
pub mod client;
/// Stream module exposing all tensor operations that can be optimized.
pub mod stream;
/// Search module for stream optimizations.
pub(crate) mod search;
mod backend;
mod ops;
mod server;
mod tensor;
pub(crate) use server::*;
pub use backend::*;
pub use ops::NoOp;
pub use tensor::*;

View File

@@ -0,0 +1,4 @@
use crate::{Fusion, FusionBackend};
use burn_backend::ops::ActivationOps;
impl<B: FusionBackend> ActivationOps<Self> for Fusion<B> {}

View File

@@ -0,0 +1,15 @@
use crate::{FusionBackend, stream::Operation};
use burn_ir::HandleContainer;
use std::marker::PhantomData;
/// A no-operation placeholder for the fusion backend.
///
/// `NoOp` is an implementation of [`Operation`] that doesn't execute anything.
#[derive(new, Clone, Debug)]
pub struct NoOp<B: FusionBackend> {
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for NoOp<B> {
fn execute(&self, _handles: &mut HandleContainer<B::Handle>) {}
}

View File

@@ -0,0 +1,117 @@
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! binary_float_ops {
(
$name:ident,
$ops:expr
) => {
#[derive(Debug)]
struct $name<B: FusionBackend> {
desc: BinaryOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> $name<B> {
fn new(desc: BinaryOpIr) -> Self {
Self {
desc,
_b: PhantomData,
}
}
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);
let output = $ops(lhs, rhs);
handles.register_float_tensor::<B>(&self.desc.out.id, output);
}
}
};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! binary_float_cmp_ops {
(
$name:ident,
$ops:expr
) => {
#[derive(new, Debug)]
struct $name<B: FusionBackend> {
desc: BinaryOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);
let output = $ops(lhs, rhs);
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! binary_int_cmp_ops {
(
$name:ident,
$ops:expr
) => {
#[derive(Debug)]
struct $name<B: FusionBackend> {
desc: BinaryOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> $name<B> {
fn new(desc: BinaryOpIr) -> Self {
Self {
desc,
_b: PhantomData,
}
}
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);
let rhs = handles.get_int_tensor::<B>(&self.desc.rhs);
let output = $ops(lhs, rhs);
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! binary_int_ops {
(
$name:ident,
$ops:expr
) => {
#[derive(new, Debug)]
struct $name<B: FusionBackend> {
desc: BinaryOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);
let rhs = handles.get_int_tensor::<B>(&self.desc.rhs);
let output = $ops(lhs, rhs);
handles.register_int_tensor::<B>(&self.desc.out.id, output);
}
}
};
}

View File

@@ -0,0 +1,856 @@
use crate::{
Fusion, FusionBackend, get_client,
stream::{OperationStreams, execution::Operation},
};
use burn_backend::{
Element, ExecutionError, Scalar, Shape, Slice, TensorData,
ops::BoolTensorOps,
tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntTensor},
};
use burn_ir::{
BaseOperationIr, BinaryOpIr, BoolOperationIr, CastOpIr, CatOpIr, CreationOpIr, FlipOpIr,
GatherOpIr, HandleContainer, InitOperationIr, MaskFillOpIr, MaskWhereOpIr, OperationIr,
OperationOutput, PermuteOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr, ShapeOpIr,
SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, TensorIr, UnaryOpIr, UnfoldOpIr,
};
use std::marker::PhantomData;
use super::NoOp;
impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct EmptyOps<B: FusionBackend> {
desc: TensorIr,
device: Device<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let output = B::bool_empty(self.desc.shape.clone(), &self.device);
handles.register_bool_tensor::<B>(&self.desc.id, output);
}
}
let client = get_client::<B>(device);
let desc =
CreationOpIr::create(shape, B::BoolElem::dtype(), || client.create_empty_handle());
client
.register(
OperationStreams::default(),
OperationIr::BaseBool(BaseOperationIr::Empty(desc.clone())),
EmptyOps::<B>::new(desc.out, device.clone()),
)
.output()
}
fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct ZerosOps<B: FusionBackend> {
desc: TensorIr,
device: Device<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for ZerosOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let output = B::bool_zeros(self.desc.shape.clone(), &self.device);
handles.register_bool_tensor::<B>(&self.desc.id, output);
}
}
let client = get_client::<B>(device);
let desc =
CreationOpIr::create(shape, B::BoolElem::dtype(), || client.create_empty_handle());
client
.register(
OperationStreams::default(),
OperationIr::BaseBool(BaseOperationIr::Zeros(desc.clone())),
ZerosOps::<B>::new(desc.out, device.clone()),
)
.output()
}
fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct OnesOps<B: FusionBackend> {
desc: TensorIr,
device: Device<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for OnesOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let output = B::bool_ones(self.desc.shape.clone(), &self.device);
handles.register_bool_tensor::<B>(&self.desc.id, output);
}
}
let client = get_client::<B>(device);
let desc =
CreationOpIr::create(shape, B::BoolElem::dtype(), || client.create_empty_handle());
client
.register(
OperationStreams::default(),
OperationIr::BaseBool(BaseOperationIr::Ones(desc.clone())),
OnesOps::<B>::new(desc.out, device.clone()),
)
.output()
}
async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
tensor.bool_into_data::<B>().await
}
fn bool_from_data(data: burn_backend::TensorData, device: &Device<Self>) -> BoolTensor<Self> {
let client = get_client::<B>(device);
let tensor = B::bool_from_data(data, device);
let shape = burn_backend::TensorMetadata::shape(&tensor);
let handle = B::bool_tensor_handle(tensor);
let desc = InitOperationIr::create(shape, B::BoolElem::dtype(), || {
client.register_tensor_handle(handle)
});
client
.register(
OperationStreams::default(),
OperationIr::Init(desc),
NoOp::<B>::new(),
)
.output()
}
fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
#[derive(new, Debug)]
struct IntoIntOps<B: FusionBackend> {
desc: CastOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoIntOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_bool_tensor::<B>(&self.desc.input);
let output = B::bool_into_int(input);
handles.register_int_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = CastOpIr::create(tensor.into_ir(), B::IntElem::dtype(), || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::Bool(BoolOperationIr::IntoInt(desc.clone())),
IntoIntOps::<B>::new(desc),
)
.output()
}
fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
#[derive(new, Debug)]
struct IntoFloatOps<B: FusionBackend> {
desc: CastOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoFloatOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_bool_tensor::<B>(&self.desc.input);
let output = B::bool_into_float(input);
handles.register_float_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = CastOpIr::create(tensor.into_ir(), B::FloatElem::dtype(), || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::Bool(BoolOperationIr::IntoFloat(desc.clone())),
IntoFloatOps::<B>::new(desc),
)
.output()
}
fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
tensor.client.device().clone()
}
fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
let device_original: &B::Device = tensor.client.device();
if device_original == device {
return tensor;
}
let id = tensor.stream;
let client_target = get_client::<B>(device);
let client_original = tensor.client.clone();
client_original
.clone()
.change_client_bool::<B>(tensor.into_ir(), client_target, id)
}
fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
if tensor.shape == shape {
return tensor;
}
#[derive(new, Debug)]
struct ReshapeDimsOps<B: FusionBackend> {
desc: ShapeOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_bool_tensor::<B>(&self.desc.input);
let output = B::bool_reshape(input, self.desc.out.shape.clone());
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
client
.register(
streams,
OperationIr::BaseBool(BaseOperationIr::Reshape(desc.clone())),
ReshapeDimsOps::<B>::new(desc),
)
.output()
}
fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct SliceOps<B: FusionBackend> {
desc: SliceOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
let output = B::bool_slice(tensor, self.desc.ranges.as_slice());
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::BaseBool(BaseOperationIr::Slice(desc.clone())),
SliceOps::<B>::new(desc),
)
.output()
}
fn bool_slice_assign(
tensor: BoolTensor<Self>,
slices: &[Slice],
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct SliceAssignOps<B: FusionBackend> {
desc: SliceAssignOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
let value = handles.get_bool_tensor::<B>(&self.desc.value);
let output = B::bool_slice_assign(tensor, self.desc.ranges.as_slice(), value);
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor, &value]);
let client = tensor.client.clone();
let desc =
SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc.clone())),
SliceAssignOps::<B>::new(desc),
)
.output()
}
fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct CatOps<B: FusionBackend> {
desc: CatOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let tensors = self
.desc
.tensors
.iter()
.map(|tensor| handles.get_bool_tensor::<B>(tensor))
.collect();
let output = B::bool_cat(tensors, self.desc.dim);
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs(&tensors);
let client = tensors.first().unwrap().client.clone();
let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
client
.register(
streams,
OperationIr::BaseBool(BaseOperationIr::Cat(desc.clone())),
CatOps::<B>::new(desc),
)
.output()
}
fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct EqualOps<B: FusionBackend> {
desc: BinaryOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for EqualOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
let output = B::bool_equal(lhs, rhs);
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&lhs, &rhs]);
let client = lhs.client.clone();
let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::BaseBool(BaseOperationIr::Equal(desc.clone())),
EqualOps::<B>::new(desc),
)
.output()
}
fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct NotOps<B: FusionBackend> {
desc: UnaryOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for NotOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_bool_tensor::<B>(&self.desc.input);
let output = B::bool_not(input);
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
client
.register(
streams,
OperationIr::Bool(BoolOperationIr::Not(desc.clone())),
NotOps::<B>::new(desc),
)
.output()
}
fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct AndOps<B: FusionBackend> {
desc: BinaryOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for AndOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
let output = B::bool_and(lhs, rhs);
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&lhs, &rhs]);
let client = lhs.client.clone();
let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::Bool(BoolOperationIr::And(desc.clone())),
AndOps::<B>::new(desc),
)
.output()
}
fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct OrOps<B: FusionBackend> {
desc: BinaryOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for OrOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
let output = B::bool_or(lhs, rhs);
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&lhs, &rhs]);
let client = lhs.client.clone();
let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::Bool(BoolOperationIr::Or(desc.clone())),
OrOps::<B>::new(desc),
)
.output()
}
fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct SwapDimsOps<B: FusionBackend> {
desc: SwapDimsOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_bool_tensor::<B>(&self.desc.input);
let output = B::bool_swap_dims(input, self.desc.dim1, self.desc.dim2);
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::BaseBool(BaseOperationIr::SwapDims(desc.clone())),
SwapDimsOps::<B>::new(desc),
)
.output()
}
fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct PermuteDimsOps<B: FusionBackend> {
desc: PermuteOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_bool_tensor::<B>(&self.desc.input);
let output = B::bool_permute(input, self.desc.axes.as_slice());
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),
PermuteDimsOps::<B>::new(desc),
)
.output()
}
fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct ExpandOps<B: FusionBackend> {
desc: ShapeOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_bool_tensor::<B>(&self.desc.input);
let output = B::bool_expand(input, self.desc.out.shape.clone());
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
client
.register(
streams,
OperationIr::BaseBool(BaseOperationIr::Expand(desc.clone())),
ExpandOps::<B>::new(desc),
)
.output()
}
fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct FlipOps<B: FusionBackend> {
desc: FlipOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_bool_tensor::<B>(&self.desc.input);
let output = B::bool_flip(input, self.desc.axes.as_slice());
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::BaseBool(BaseOperationIr::Flip(desc.clone())),
FlipOps::<B>::new(desc),
)
.output()
}
fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct RepeatDimOps<B: FusionBackend> {
desc: RepeatDimOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
let output = B::bool_repeat_dim(tensor, self.desc.dim, self.desc.times);
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc.clone())),
RepeatDimOps::<B>::new(desc),
)
.output()
}
fn bool_unfold(
tensor: BoolTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct UnfoldOps<B: FusionBackend> {
desc: UnfoldOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_bool_tensor::<B>(&self.desc.input);
let output = B::bool_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::BaseBool(BaseOperationIr::Unfold(desc.clone())),
UnfoldOps::<B>::new(desc),
)
.output()
}
fn bool_mask_where(
tensor: BoolTensor<Self>,
mask: BoolTensor<Self>,
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct MaskWhereOps<B: FusionBackend> {
desc: MaskWhereOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskWhereOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
let value = handles.get_bool_tensor::<B>(&self.desc.value);
let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
let output = B::bool_mask_where(tensor, mask, value);
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor, &mask, &value]);
let client = tensor.client.clone();
let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::BaseBool(BaseOperationIr::MaskWhere(desc.clone())),
MaskWhereOps::<B>::new(desc),
)
.output()
}
fn bool_mask_fill(
tensor: BoolTensor<Self>,
mask: BoolTensor<Self>,
value: Scalar,
) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct MaskFillOps<B: FusionBackend> {
desc: MaskFillOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskFillOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
let output = B::bool_mask_fill(tensor, mask, self.desc.value.into());
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor, &mask]);
let client = tensor.client.clone();
let value = value.into();
let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::BaseBool(BaseOperationIr::MaskFill(desc.clone())),
MaskFillOps::<B>::new(desc),
)
.output()
}
fn bool_gather(
dim: usize,
tensor: BoolTensor<Self>,
indices: IntTensor<Self>,
) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct GatherOps<B: FusionBackend> {
desc: GatherOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
let indices = handles.get_int_tensor::<B>(&self.desc.indices);
let output = B::bool_gather(self.desc.dim, tensor, indices);
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor, &indices]);
let client = tensor.client.clone();
let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::BaseBool(BaseOperationIr::Gather(desc.clone())),
GatherOps::<B>::new(desc),
)
.output()
}
fn bool_scatter_or(
dim: usize,
tensor: BoolTensor<Self>,
indices: IntTensor<Self>,
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct ScatterOps<B: FusionBackend> {
desc: ScatterOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
let indices = handles.get_int_tensor::<B>(&self.desc.indices);
let value = handles.get_bool_tensor::<B>(&self.desc.value);
let output = B::bool_scatter_or(self.desc.dim, tensor, indices, value);
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);
let client = tensor.client.clone();
let desc = ScatterOpIr::create(
tensor.into_ir(),
dim,
indices.into_ir(),
value.into_ir(),
IndexingUpdateOp::Add,
|| client.create_empty_handle(),
);
client
.register(
streams,
OperationIr::BaseBool(BaseOperationIr::Scatter(desc.clone())),
ScatterOps::<B>::new(desc),
)
.output()
}
fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
#[derive(new, Debug)]
struct EqualElemOps<B: FusionBackend> {
desc: ScalarOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for EqualElemOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
let output = B::bool_equal_elem(lhs, self.desc.rhs.into());
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&lhs]);
let client = lhs.client.clone();
let rhs = rhs.into();
let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, B::BoolElem::dtype(), || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::BaseBool(BaseOperationIr::EqualElem(desc.clone())),
EqualElemOps::<B>::new(desc),
)
.output()
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,12 @@
mod activation;
mod binary;
mod bool_tensor;
mod int_tensor;
mod module;
mod qtensor;
mod tensor;
mod transaction;
mod unary;
mod base;
pub use base::NoOp;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,501 @@
use std::marker::PhantomData;
use burn_backend::{
DType, Element, ExecutionError, QTensorPrimitive, Shape, Slice, TensorData, TensorPrimitive,
ops::QTensorOps,
quantization::{QuantPropagation, QuantScheme, QuantizationParametersPrimitive},
tensor::{Device, FloatTensor, IntTensor, QuantizedTensor},
};
use burn_ir::{
BaseOperationIr, DequantizeOpIr, FlipOpIr, FloatOperationIr, GatherOpIr, HandleContainer,
InitOperationIr, MatmulOpIr, OperationIr, OperationOutput, PermuteOpIr,
QuantizationParametersIr, QuantizeOpIr, SelectOpIr, ShapeOpIr, SliceOpIr, SwapDimsOpIr,
};
use crate::{
Fusion, FusionBackend, get_client,
stream::{OperationStreams, execution::Operation},
};
use super::NoOp;
impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
let client = get_client::<B>(device);
let dtype = data.dtype;
let tensor = B::q_from_data(data, device);
let shape = burn_backend::TensorMetadata::shape(&tensor);
let handle = B::quantized_tensor_handle(tensor);
let desc = InitOperationIr::create(shape, dtype, || client.register_tensor_handle(handle));
client
.register(
OperationStreams::default(),
OperationIr::Init(desc),
NoOp::<B>::new(),
)
.output()
}
fn quantize(
tensor: FloatTensor<Self>,
scheme: &QuantScheme,
qparams: QuantizationParametersPrimitive<Self>,
) -> QuantizedTensor<Self> {
#[derive(new, Debug)]
struct QuantizeOp<B: FusionBackend> {
desc: QuantizeOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for QuantizeOp<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
let scales = handles.get_float_tensor::<B>(&self.desc.qparams.scales);
let qparams = QuantizationParametersPrimitive { scales };
let output = B::quantize(tensor, &self.desc.scheme, qparams);
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor, &qparams.scales]);
let client = tensor.client.clone();
let qparams = QuantizationParametersIr {
scales: qparams.scales.into_ir(),
};
let desc = QuantizeOpIr::create(tensor.into_ir(), qparams, *scheme, || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::Float(desc.tensor.dtype, FloatOperationIr::Quantize(desc.clone())),
QuantizeOp::<B>::new(desc),
)
.output()
}
fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
#[derive(new, Debug)]
struct DequantizeOp<B: FusionBackend> {
desc: DequantizeOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for DequantizeOp<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_quantized_tensor::<B>(&self.desc.input);
let output = B::dequantize(tensor);
handles.register_float_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let dtype = B::FloatElem::dtype();
let desc = DequantizeOpIr::create(tensor.into_ir(), dtype, || client.create_empty_handle());
client
.register(
streams,
OperationIr::Float(dtype, FloatOperationIr::Dequantize(desc.clone())),
DequantizeOp::<B>::new(desc),
)
.output()
}
fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
tensor.client.device().clone()
}
fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {
let device_original: &B::Device = tensor.client.device();
let device_target: B::Device = device.clone();
if device_original == &device_target {
return tensor;
}
let id = tensor.stream;
let client_target = get_client::<B>(&device_target);
let client_original = tensor.client.clone();
client_original.change_client_quantized::<B>(tensor.into_ir(), client_target, id)
}
fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
if tensor.shape == shape {
return tensor;
}
#[derive(new, Debug)]
struct ReshapeDimsOps<B: FusionBackend> {
desc: ShapeOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_quantized_tensor::<B>(&self.desc.input);
let output = B::q_reshape(input, self.desc.out.shape.clone());
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
client
.register(
streams,
OperationIr::BaseFloat(BaseOperationIr::Reshape(desc.clone())),
ReshapeDimsOps::<B>::new(desc),
)
.output()
}
async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
tensor.q_into_data::<B>().await
}
fn q_swap_dims(
tensor: QuantizedTensor<Self>,
dim1: usize,
dim2: usize,
) -> QuantizedTensor<Self> {
#[derive(new, Debug)]
struct SwapDimsOps<B: FusionBackend> {
desc: SwapDimsOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_quantized_tensor::<B>(&self.desc.input);
let output = B::q_swap_dims(input, self.desc.dim1, self.desc.dim2);
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc.clone())),
SwapDimsOps::<B>::new(desc),
)
.output()
}
fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
#[derive(new, Debug)]
struct PermuteDimsOps<B: FusionBackend> {
desc: PermuteOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_quantized_tensor::<B>(&self.desc.input);
let output = B::q_permute(input, self.desc.axes.as_slice());
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::BaseFloat(BaseOperationIr::Permute(desc.clone())),
PermuteDimsOps::<B>::new(desc),
)
.output()
}
fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
#[derive(new, Debug)]
struct FlipOps<B: FusionBackend> {
desc: FlipOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_quantized_tensor::<B>(&self.desc.input);
let output = B::q_flip(input, &self.desc.axes);
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::BaseFloat(BaseOperationIr::Flip(desc.clone())),
FlipOps::<B>::new(desc),
)
.output()
}
fn q_gather(
dim: usize,
tensor: QuantizedTensor<Self>,
indices: IntTensor<Self>,
) -> QuantizedTensor<Self> {
#[derive(new, Debug)]
struct GatherOps<B: FusionBackend> {
desc: GatherOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);
let indices = handles.get_int_tensor::<B>(&self.desc.indices);
let output = B::q_gather(self.desc.dim, tensor, indices);
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::BaseFloat(BaseOperationIr::Gather(desc.clone())),
GatherOps::<B>::new(desc),
)
.output()
}
fn q_select(
tensor: QuantizedTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
) -> QuantizedTensor<Self> {
#[derive(new, Debug)]
struct SelectOps<B: FusionBackend> {
desc: SelectOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);
let indices = handles.get_int_tensor::<B>(&self.desc.indices);
let output = B::q_select(tensor, self.desc.dim, indices);
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::BaseFloat(BaseOperationIr::Select(desc.clone())),
SelectOps::<B>::new(desc),
)
.output()
}
fn q_slice(tensor: QuantizedTensor<Self>, slices: &[Slice]) -> QuantizedTensor<Self> {
#[derive(new, Debug)]
struct SliceOps<B: FusionBackend> {
desc: SliceOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);
let output = B::q_slice(tensor, self.desc.ranges.as_slice());
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
client.create_empty_handle()
});
client
.register(
streams,
OperationIr::BaseFloat(BaseOperationIr::Slice(desc.clone())),
SliceOps::<B>::new(desc),
)
.output()
}
fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
#[derive(new, Debug)]
struct ExpandOps<B: FusionBackend> {
desc: ShapeOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_quantized_tensor::<B>(&self.desc.input);
let output = B::q_expand(input, self.desc.out.shape.clone());
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
}
}
let streams = OperationStreams::with_inputs([&tensor]);
let client = tensor.client.clone();
let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
client
.register(
streams,
OperationIr::BaseFloat(BaseOperationIr::Expand(desc.clone())),
ExpandOps::<B>::new(desc),
)
.output()
}
fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {
#[derive(new, Debug)]
struct MatmulOps<B: FusionBackend> {
desc: MatmulOpIr,
lhs_quantized: bool,
rhs_quantized: bool,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for MatmulOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let lhs = match self.lhs_quantized {
true => {
TensorPrimitive::QFloat(handles.get_quantized_tensor::<B>(&self.desc.lhs))
}
false => TensorPrimitive::Float(handles.get_float_tensor::<B>(&self.desc.lhs)),
};
let rhs = match self.rhs_quantized {
true => {
TensorPrimitive::QFloat(handles.get_quantized_tensor::<B>(&self.desc.rhs))
}
false => TensorPrimitive::Float(handles.get_float_tensor::<B>(&self.desc.rhs)),
};
let output = B::q_matmul(lhs, rhs);
match output {
TensorPrimitive::Float(output) => {
handles.register_float_tensor::<B>(&self.desc.out.id, output);
}
TensorPrimitive::QFloat(output) => {
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
}
}
}
}
let mut propagation = QuantPropagation::Inhibit;
let mut scheme = QuantScheme::default();
let mut streams = OperationStreams::default();
let mut lhs_quantized = false;
let mut rhs_quantized = false;
match &lhs {
TensorPrimitive::QFloat(lhs) => {
propagation = lhs.propagation();
scheme = *lhs.scheme();
lhs_quantized = true;
streams.tensor(lhs);
}
TensorPrimitive::Float(lhs) => {
streams.tensor(lhs);
}
}
match &rhs {
TensorPrimitive::QFloat(rhs) => {
propagation = rhs.propagation();
scheme = *rhs.scheme();
rhs_quantized = true;
streams.tensor(rhs);
}
TensorPrimitive::Float(rhs) => {
streams.tensor(rhs);
}
}
let dtype = match propagation {
QuantPropagation::Propagate => DType::QFloat(scheme),
QuantPropagation::Inhibit => B::FloatElem::dtype(),
};
let client = match &lhs {
TensorPrimitive::Float(lhs) => lhs.client.clone(),
TensorPrimitive::QFloat(lhs) => lhs.client.clone(),
};
let lhs = match lhs {
TensorPrimitive::Float(lhs) => lhs.into_ir(),
TensorPrimitive::QFloat(lhs) => lhs.into_ir(),
};
let rhs = match rhs {
TensorPrimitive::Float(rhs) => rhs.into_ir(),
TensorPrimitive::QFloat(rhs) => rhs.into_ir(),
};
let desc = MatmulOpIr::create_mixed(lhs, rhs, dtype, || client.create_empty_handle());
let out = client
.register(
streams,
OperationIr::Float(dtype, FloatOperationIr::Matmul(desc.clone())),
MatmulOps::<B>::new(desc, lhs_quantized, rhs_quantized),
)
.output();
match propagation {
QuantPropagation::Propagate => TensorPrimitive::QFloat(out),
QuantPropagation::Inhibit => TensorPrimitive::Float(out),
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,36 @@
use burn_backend::{
backend::ExecutionError,
ops::{TransactionOps, TransactionPrimitive},
};
use crate::{Fusion, FusionBackend};
impl<B: FusionBackend> TransactionOps<Fusion<B>> for Fusion<B> {
async fn tr_execute(
transaction: TransactionPrimitive<Self>,
) -> Result<burn_backend::ops::TransactionPrimitiveData, ExecutionError> {
B::tr_execute(TransactionPrimitive::new(
transaction
.read_floats
.into_iter()
.map(|t| t.client.clone().resolve_tensor_float::<B>(t))
.collect(),
transaction
.read_qfloats
.into_iter()
.map(|_t| todo!("Quantization not supported yet"))
.collect(),
transaction
.read_ints
.into_iter()
.map(|t| t.client.clone().resolve_tensor_int::<B>(t))
.collect(),
transaction
.read_bools
.into_iter()
.map(|t| t.client.clone().resolve_tensor_bool::<B>(t))
.collect(),
))
.await
}
}

View File

@@ -0,0 +1,319 @@
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_float_ops {
(
$name:ident,
$ops:expr
) => {
#[derive(new, Debug)]
struct $name<B: FusionBackend> {
desc: ScalarOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
let output = $ops(lhs, self.desc.rhs.into());
handles.register_float_tensor::<B>(&self.desc.out.id, output);
}
}
};
(
$name:ident,
$ops:expr,
noconvert
) => {
#[derive(new, Debug)]
struct $name<B: FusionBackend> {
desc: ScalarOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
let output = $ops(lhs, self.desc.rhs);
handles.register_float_tensor::<B>(&self.desc.out.id, output);
}
}
};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! reduce_float_ops {
(
$name:ident,
$ops:expr
) => {
#[derive(new, Debug)]
struct $name<B: FusionBackend> {
desc: ReduceDimOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_float_tensor::<B>(&self.desc.input);
let output = $ops(input, self.desc.axis);
handles.register_float_tensor::<B>(&self.desc.out.id, output);
}
}
};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! reduce_float2int_ops {
(
$name:ident,
$ops:expr
) => {
#[derive(new, Debug)]
struct $name<B: FusionBackend> {
desc: ReduceDimOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_float_tensor::<B>(&self.desc.input);
let output = $ops(input, self.desc.axis);
handles.register_int_tensor::<B>(&self.desc.out.id, output);
}
}
};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! reduce_int_ops {
(
$name:ident,
$ops:expr
) => {
#[derive(new, Debug)]
struct $name<B: FusionBackend> {
desc: ReduceDimOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_int_tensor::<B>(&self.desc.input);
let output = $ops(input, self.desc.axis);
handles.register_int_tensor::<B>(&self.desc.out.id, output);
}
}
};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_float2int_ops {
(
$name:ident,
$ops:expr,
) => {
#[derive(new, Debug)]
struct $name<B: FusionBackend> {
desc: ScalarOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
let output = $ops(lhs, self.desc.rhs.clone());
handles.register_int_tensor::<B>(&self.desc.out.id, output);
}
}
};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! unary_float_ops {
(
$name:ident,
$ops:expr
) => {
#[derive(new, Debug)]
struct $name<B: FusionBackend> {
desc: UnaryOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_float_tensor::<B>(&self.desc.input);
let output = $ops(input);
handles.register_float_tensor::<B>(&self.desc.out.id, output);
}
}
};
(
$name:ident,
$ops:expr,
reduce
) => {
#[derive(new, Debug)]
struct $name<B: FusionBackend> {
desc: UnaryOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_float_tensor::<B>(&self.desc.input);
let output = $ops(input);
handles.register_float_tensor::<B>(&self.desc.out.id, output);
}
}
};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! unary_int_ops {
(
$name:ident,
$ops:expr
) => {
#[derive(new, Debug)]
struct $name<B: FusionBackend> {
desc: UnaryOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_int_tensor::<B>(&self.desc.input);
let output = $ops(input);
handles.register_int_tensor::<B>(&self.desc.out.id, output);
}
}
};
(
$name:ident,
$ops:expr,
reduce
) => {
#[derive(new, Debug)]
struct $name<B: FusionBackend> {
desc: UnaryOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_int_tensor::<B>(&self.desc.input);
let output = $ops(input);
handles.register_int_tensor::<B>(&self.desc.out.id, output);
}
}
};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_float_cmp_ops {
(
$name:ident,
$ops:expr
) => {
#[derive(new, Debug)]
struct $name<B: FusionBackend> {
desc: ScalarOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
let output = $ops(lhs, self.desc.rhs.into());
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_int_cmp_ops {
(
$name:ident,
$ops:expr
) => {
#[derive(new, Debug)]
struct $name<B: FusionBackend> {
desc: ScalarOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);
let output = $ops(lhs, self.desc.rhs.into());
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}
};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_int_ops {
(
$name:ident,
$ops:expr
) => {
#[derive(new, Debug)]
struct $name<B: FusionBackend> {
desc: ScalarOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);
let output = $ops(lhs, self.desc.rhs.into());
handles.register_int_tensor::<B>(&self.desc.out.id, output);
}
}
};
(
$name:ident,
$ops:expr,
noconvert
) => {
#[derive(new, Debug)]
struct $name<B: FusionBackend> {
desc: ScalarOpIr,
_b: PhantomData<B>,
}
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);
let output = $ops(lhs, self.desc.rhs);
handles.register_int_tensor::<B>(&self.desc.out.id, output);
}
}
};
}

View File

@@ -0,0 +1,271 @@
use crate::{FuserStatus, NumOperations, OperationFuser, stream::store::ExecutionStrategy};
use burn_ir::{OperationIr, TensorId, TensorIr};
use std::{collections::HashSet, sync::Arc};
/// A block represents a list of operations, not necessarily in the same order as the execution
/// stream.
///
/// The start and end position of the relative execution stream are tracked in the block alongside
/// the ordering.
pub struct Block<O> {
builders: Vec<Box<dyn OperationFuser<O>>>,
operations: Vec<OperationIr>,
ids: HashSet<TensorId>,
ordering: Vec<usize>,
/// The start position in the relative execution stream.
pub start_pos: usize,
/// The end position in the relative execution stream.
pub end_pos: usize,
}
/// The result of [registering](Block::register) an [operation](OperationIr).
pub enum RegistrationResult {
/// If the [operation](OperationIr) is correctly registered.
Accepted,
/// If the [operation](OperationIr) isn't part of the graph.
///
/// In this case the operation isn't registered.
NotPartOfTheGraph,
}
/// The optimization found for a [block](Block).
#[derive(Debug, new)]
pub struct BlockOptimization<O> {
/// The [execution strategy](ExecutionStrategy) to be used to execute the [block](Block).
pub strategy: ExecutionStrategy<O>,
/// The ordering of each operation in the relative execution stream.
pub ordering: Vec<usize>,
}
impl<O: NumOperations> Block<O> {
/// Create a new block that will be optimized with the provided [optimization builders](OptimizationBuilder).
pub fn new(builders: &[Box<dyn OperationFuser<O>>]) -> Self {
Self {
builders: builders.iter().map(|o| o.clone_dyn()).collect(),
operations: Vec::new(),
ids: HashSet::new(),
ordering: Vec::new(),
start_pos: usize::MAX,
end_pos: usize::MIN,
}
}
/// Sort the [blocks](Block) based on the start position.
pub fn sort(blocks: &mut [Self]) {
blocks.sort_by(|a, b| a.start_pos.cmp(&b.start_pos));
}
/// Optimize the block.
pub fn optimize(mut self) -> BlockOptimization<O> {
match find_best_optimization_index(&mut self.builders) {
Some(index) => {
let opt = self.builders[index].finish();
let opt_len = opt.len();
if opt_len < self.operations.len() {
self.ordering.drain(opt_len..);
}
let strategy = ExecutionStrategy::Optimization {
ordering: Arc::new(self.ordering.clone()),
opt,
};
BlockOptimization::new(strategy, self.ordering)
}
None => {
let strategy = ExecutionStrategy::Operations {
ordering: Arc::new(self.ordering.clone()),
};
BlockOptimization::new(strategy, self.ordering)
}
}
}
/// Returns if the block contains any of the provided [tensors](TensorIr).
pub fn contains_tensors(&self, tensors: &[&TensorIr]) -> bool {
for node in tensors {
if self.ids.contains(&node.id) {
return true;
}
}
false
}
/// Merge the current block with the other one and returns if the operation is successful.
///
/// # Warning
///
/// This will modify the current block even if the other block isn't correctly merged.
pub fn merge(&mut self, other: &Block<O>) -> bool {
for (op, pos) in other.operations.iter().zip(&other.ordering) {
self.register(op, *pos, true);
}
// The operation is successful if the current block can still be optimized.
self.still_optimizing()
}
/// Register an [operation](OperationIr) in the current block.
///
/// You need to provide the order of the operation as well as a force flag.
///
/// When the force flag is true, the builder will always accept the operation, otherwise it
/// might refuse it if the operation [isn't part of the graph](RegistrationResult::NotPartOfTheGraph).
///
/// Forcing is useful to fuse operations that are part of different graphs, but included
/// in the same optimization.
pub fn register(
&mut self,
operation: &OperationIr,
order: usize,
force: bool,
) -> RegistrationResult {
if self.ids.is_empty() {
self.register_op(operation, order);
return RegistrationResult::Accepted;
}
let mut contains = false;
for node in operation.nodes() {
contains = self.ids.contains(&node.id);
if contains {
break;
}
}
if !contains && !force {
return RegistrationResult::NotPartOfTheGraph;
}
self.register_op(operation, order);
RegistrationResult::Accepted
}
/// If the block can still be optimized further.
pub fn still_optimizing(&self) -> bool {
let mut num_stopped = 0;
for optimization in self.builders.iter() {
if let FuserStatus::Closed = optimization.status() {
num_stopped += 1
}
}
num_stopped < self.builders.len()
}
fn register_op(&mut self, operation: &OperationIr, pos: usize) {
self.operations.push(operation.clone());
self.ordering.push(pos);
if pos < self.start_pos {
self.start_pos = pos;
}
if pos + 1 > self.end_pos {
self.end_pos = pos + 1;
}
for builder in self.builders.iter_mut() {
builder.fuse(operation);
}
for node in operation.nodes() {
self.ids.insert(node.id);
}
}
}
impl<O> BlockOptimization<O> {
/// Maps the ordering of the current block optimization using the given mapping.
pub fn map_ordering(&mut self, mapping: &[usize]) {
for i in self.ordering.iter_mut() {
*i = mapping[*i];
}
self.strategy.map_ordering(mapping);
}
}
impl<O> ExecutionStrategy<O> {
/// Maps the ordering of the current execution strategy using the given mapping.
pub fn map_ordering(&mut self, mapping: &[usize]) {
match self {
ExecutionStrategy::Optimization { ordering, .. } => {
let mut ordering_mapped = ordering.to_vec();
for o in ordering_mapped.iter_mut() {
*o = mapping[*o];
}
*ordering = Arc::new(ordering_mapped);
}
ExecutionStrategy::Operations { ordering } => {
let mut ordering_mapped = ordering.to_vec();
for o in ordering_mapped.iter_mut() {
*o = mapping[*o];
}
*ordering = Arc::new(ordering_mapped);
}
ExecutionStrategy::Composed(items) => {
for item in items.iter_mut() {
item.map_ordering(mapping);
}
}
}
}
}
fn find_best_optimization_index<O>(
optimizations: &mut [Box<dyn OperationFuser<O>>],
) -> Option<usize> {
let mut best_index = None;
let mut best_score = 0;
for (i, optimization) in optimizations.iter().enumerate() {
let properties = optimization.properties();
if properties.ready && properties.score >= best_score {
best_index = Some(i);
best_score = properties.score;
}
}
best_index
}
impl<O> PartialEq for Block<O> {
fn eq(&self, other: &Self) -> bool {
// Since the ordering can be seen as operation ids, we can use it to compare
// blocks.
let mut sorted_a = self.ordering.clone();
let mut sorted_b = other.ordering.clone();
sorted_a.sort();
sorted_b.sort();
sorted_a == sorted_b
}
}
impl<O> core::fmt::Debug for Block<O> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!(
"Block {{ pos: [{:?}, {:?}; {:?}] }}",
self.start_pos,
self.end_pos,
self.ordering.len(),
))
}
}
impl<O> Clone for Block<O> {
fn clone(&self) -> Self {
Self {
builders: self.builders.iter().map(|b| b.clone_dyn()).collect(),
operations: self.operations.clone(),
ids: self.ids.clone(),
ordering: self.ordering.clone(),
start_pos: self.start_pos,
end_pos: self.end_pos,
}
}
}

View File

@@ -0,0 +1,441 @@
use super::Block;
use crate::NumOperations;
#[derive(Debug, PartialEq)]
/// The result of [merging](merge_blocks) [blocks](Block).
pub enum MergeBlocksResult<O> {
/// All [blocks](Block) merged into one.
Full(Block<O>),
/// Some [blocks](Block) merged and some failed.
Partial {
merged: Vec<Block<O>>,
failed: Vec<Block<O>>,
},
/// All [blocks](Block) failed to merge.
Fail,
}
/// Merge multiple [block](Block) together.
///
/// The resulting [blocks](Block) might be sorted if the flag is true, otherwise the order isn't
/// guarantee. This is mostly useful for testing.
///
/// # Strategy
///
/// The merging strategy is in two steps:
///
/// 1. The first step is to recursively try to merge adjacent blocks. This has the advantage of
/// trying multiple blocks ordering, therefore trying multiple permutation of the blocks.
/// However, it has the downside of not trying to merge blocks that are further away in the list
/// of blocks. Since trying all combinations possible is exponential, therefore not possible, we
/// fallback on the second strategy.
/// 2. The second step is to reduce blocks by setting an accumulator block, then sequentially
/// trying to merge the remaining blocks. We try some permutations based on the result from
/// step1.
pub fn merge_blocks<O: NumOperations>(blocks: &[&Block<O>], sorted: bool) -> MergeBlocksResult<O> {
if blocks.is_empty() {
return MergeBlocksResult::Fail;
}
if blocks.len() == 1 {
return MergeBlocksResult::Full(blocks[0].clone());
}
if blocks.len() == 2 {
let block0 = blocks[0];
let block1 = blocks[1];
return match merge_two(block0, block1) {
Some(result) => MergeBlocksResult::Full(result),
None => MergeBlocksResult::Fail,
};
}
let mut step1 = merge_blocks_step1(blocks);
if step1.full.len() == 1 && step1.failed.is_empty() && step1.partial.is_empty() {
MergeBlocksResult::Full(step1.full.remove(0))
} else if step1.partial.len() == 1 && step1.failed.is_empty() && step1.full.is_empty() {
MergeBlocksResult::Full(step1.partial.remove(0))
} else {
let result = merge_blocks_step2(step1);
if !sorted {
return result;
}
match result {
MergeBlocksResult::Full(block) => MergeBlocksResult::Full(block),
MergeBlocksResult::Partial {
mut merged,
mut failed,
} => {
Block::sort(&mut merged);
Block::sort(&mut failed);
MergeBlocksResult::Partial { merged, failed }
}
MergeBlocksResult::Fail => MergeBlocksResult::Fail,
}
}
}
struct MergeBlockStep1<O> {
full: Vec<Block<O>>,
partial: Vec<Block<O>>,
failed: Vec<Block<O>>,
}
impl<O> Default for MergeBlockStep1<O> {
fn default() -> Self {
Self {
full: Default::default(),
partial: Default::default(),
failed: Default::default(),
}
}
}
fn merge_blocks_step1<O: NumOperations>(blocks: &[&Block<O>]) -> MergeBlockStep1<O> {
let step_size = blocks.len() / 2;
let num_steps = f32::ceil(blocks.len() as f32 / step_size as f32) as usize;
let mut result = MergeBlockStep1::default();
for i in 0..num_steps {
let start = i * step_size;
let end = usize::min(start + step_size, blocks.len());
match merge_blocks(&blocks[start..end], false) {
MergeBlocksResult::Full(block) => {
result.full.push(block);
}
MergeBlocksResult::Partial {
mut merged,
mut failed,
} => {
result.partial.append(&mut merged);
result.failed.append(&mut failed);
}
MergeBlocksResult::Fail => {
for b in &blocks[start..end] {
result.failed.push((*b).clone());
}
}
}
}
result
}
fn merge_blocks_step2<O: NumOperations>(mut step1: MergeBlockStep1<O>) -> MergeBlocksResult<O> {
// First let's try to merge partial graphs.
if step1.partial.len() > 1 {
match merge_accumulator(&step1.partial[0], &step1.partial[1..]) {
MergeBlocksResult::Full(block) => {
step1.partial = vec![block];
}
MergeBlocksResult::Partial { merged, mut failed } => {
step1.partial = merged;
step1.failed.append(&mut failed);
}
MergeBlocksResult::Fail => {}
}
}
// Then let's try to merge partial graphs with failed merges.
if !step1.failed.is_empty() {
step1.partial.append(&mut step1.failed);
match merge_accumulator(&step1.partial[0], &step1.partial[1..]) {
MergeBlocksResult::Full(block) => {
step1.partial = vec![block];
}
MergeBlocksResult::Partial { merged, mut failed } => {
step1.partial = merged;
step1.failed.append(&mut failed);
}
MergeBlocksResult::Fail => {}
}
}
// Then let's try to merge full graphs.
if step1.full.len() > 1 {
match merge_accumulator(&step1.full[0], &step1.full[1..]) {
MergeBlocksResult::Full(block) => {
step1.full = vec![block];
}
MergeBlocksResult::Partial { merged, mut failed } => {
step1.full = merged;
step1.failed.append(&mut failed);
}
MergeBlocksResult::Fail => {}
}
}
// Then let's try to merge full graphs with failed graphs.
if !step1.full.is_empty() {
step1.full.append(&mut step1.failed);
match merge_accumulator(&step1.full[0], &step1.full[1..]) {
MergeBlocksResult::Full(block) => {
step1.full = vec![block];
}
MergeBlocksResult::Partial { merged, mut failed } => {
step1.full = merged;
step1.failed.append(&mut failed);
}
MergeBlocksResult::Fail => {}
}
}
// Then let's try to merge full graphs with partial graphs.
if !step1.full.is_empty() || !step1.partial.is_empty() {
step1.full.append(&mut step1.partial);
match merge_accumulator(&step1.full[0], &step1.full[1..]) {
MergeBlocksResult::Full(block) => {
step1.full = vec![block];
}
MergeBlocksResult::Partial { merged, mut failed } => {
step1.full = merged;
step1.failed.append(&mut failed);
}
MergeBlocksResult::Fail => {
// We do nothing.
}
}
}
if step1.full.is_empty() {
MergeBlocksResult::Fail
} else if step1.failed.is_empty() {
if step1.full.len() == 1 {
MergeBlocksResult::Full(step1.full.remove(0))
} else {
MergeBlocksResult::Partial {
merged: step1.full,
failed: vec![],
}
}
} else {
MergeBlocksResult::Partial {
merged: step1.full,
failed: step1.failed,
}
}
}
fn merge_accumulator<O: NumOperations>(
base: &Block<O>,
blocks: &[Block<O>],
) -> MergeBlocksResult<O> {
let mut base = base.clone();
let mut merged_failed = Vec::<Block<O>>::new();
let mut merged_success = false;
for block in blocks {
let mut base_current = base.clone();
match base_current.merge(block) {
false => {
merged_failed.push((*block).clone());
}
true => {
merged_success = true;
base = base_current;
}
}
}
if merged_success {
if merged_failed.is_empty() {
MergeBlocksResult::Full(base)
} else {
MergeBlocksResult::Partial {
merged: vec![base],
failed: merged_failed,
}
}
} else {
MergeBlocksResult::Fail
}
}
fn merge_two<O: NumOperations>(a: &Block<O>, b: &Block<O>) -> Option<Block<O>> {
let mut base = a.clone();
if base.merge(b) {
return Some(base);
}
let mut base = b.clone();
match base.merge(a) {
true => Some(base),
false => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
pub use crate::stream::execution::tests::{TestOptimization, TestOptimizationBuilder};
use crate::{
OperationFuser,
stream::tests::{operation_1, operation_2, operation_3},
};
#[test]
fn test_merge_blocks_no_block() {
let actual = merge_blocks::<TestOptimization>(&[], true);
assert_eq!(actual, MergeBlocksResult::Fail);
}
#[test]
fn test_merge_blocks_single() {
let builders = builders();
let block = Block::new(&builders);
let actual = merge_blocks::<TestOptimization>(&[&block], true);
assert_eq!(actual, MergeBlocksResult::Full(block));
}
#[test]
fn test_merge_blocks_two_blocks() {
let builders = builders();
let mut block1 = Block::new(&builders);
let mut block2 = Block::new(&builders);
block1.register(&operation_1(), 0, false);
block1.register(&operation_1(), 1, false);
block2.register(&operation_1(), 2, false);
block2.register(&operation_1(), 3, false);
let actual = merge_blocks::<TestOptimization>(&[&block1, &block2], true);
let mut expected = Block::new(&builders);
expected.register(&operation_1(), 0, false);
expected.register(&operation_1(), 1, false);
expected.register(&operation_1(), 2, false);
expected.register(&operation_1(), 3, false);
assert_eq!(actual, MergeBlocksResult::Full(expected));
}
#[test]
fn test_merge_blocks_three_blocks() {
let builders = builders();
let mut block1 = Block::new(&builders);
let mut block2 = Block::new(&builders);
let mut block3 = Block::new(&builders);
block1.register(&operation_1(), 0, false);
block2.register(&operation_1(), 1, false);
block3.register(&operation_1(), 2, false);
let actual = merge_blocks::<TestOptimization>(&[&block1, &block2, &block3], true);
let mut expected = Block::new(&builders);
expected.register(&operation_1(), 0, false);
expected.register(&operation_1(), 1, false);
expected.register(&operation_1(), 2, false);
assert_eq!(actual, MergeBlocksResult::Full(expected));
}
#[test]
fn test_merge_blocks_three_blocks_partial() {
let builders = builders();
let mut block1 = Block::new(&builders);
let mut block2 = Block::new(&builders);
let mut block3 = Block::new(&builders);
block1.register(&operation_1(), 0, false);
block2.register(&operation_2(), 1, false);
block3.register(&operation_1(), 2, false);
let actual = merge_blocks::<TestOptimization>(&[&block1, &block2, &block3], true);
let mut expected1 = Block::new(&builders);
let mut expected2 = Block::new(&builders);
expected1.register(&operation_1(), 0, false);
expected1.register(&operation_1(), 2, false);
expected2.register(&operation_2(), 1, false);
assert_eq!(
actual,
MergeBlocksResult::Partial {
merged: vec![expected1, expected2],
failed: vec![]
}
);
}
#[test]
fn test_merge_blocks_four_blocks_partial_with_failure() {
let builders = builders();
let mut block1 = Block::new(&builders);
let mut block2 = Block::new(&builders);
let mut block3 = Block::new(&builders);
let mut block4 = Block::new(&builders);
block1.register(&operation_1(), 0, false);
block2.register(&operation_2(), 1, false);
block3.register(&operation_1(), 2, false);
block4.register(&operation_3(), 3, false);
let actual = merge_blocks::<TestOptimization>(&[&block1, &block2, &block3, &block4], true);
let mut expected1 = Block::new(&builders);
let mut expected2 = Block::new(&builders);
let mut failed = Block::new(&builders);
expected1.register(&operation_1(), 0, false);
expected1.register(&operation_1(), 2, false);
expected2.register(&operation_2(), 1, false);
failed.register(&operation_3(), 3, false);
assert_eq!(
actual,
MergeBlocksResult::Partial {
merged: vec![expected1],
failed: vec![expected2, failed]
}
);
}
#[test]
fn test_merge_blocks_five_blocks_partial_with_failure() {
let builders = builders();
let mut block1 = Block::new(&builders);
let mut block2 = Block::new(&builders);
let mut block3 = Block::new(&builders);
let mut block4 = Block::new(&builders);
let mut block5 = Block::new(&builders);
block1.register(&operation_1(), 0, false);
block2.register(&operation_2(), 1, false);
block3.register(&operation_1(), 2, false);
block4.register(&operation_3(), 3, false);
block5.register(&operation_2(), 4, false);
let actual =
merge_blocks::<TestOptimization>(&[&block1, &block2, &block3, &block4, &block5], true);
let mut expected1 = Block::new(&builders);
let mut expected2 = Block::new(&builders);
let mut failed = Block::new(&builders);
expected1.register(&operation_1(), 0, false);
expected1.register(&operation_1(), 2, false);
expected2.register(&operation_2(), 1, false);
expected2.register(&operation_2(), 4, false);
failed.register(&operation_3(), 3, false);
assert_eq!(
actual,
MergeBlocksResult::Partial {
merged: vec![expected1, expected2],
failed: vec![failed]
}
);
}
fn builders() -> Vec<Box<dyn OperationFuser<TestOptimization>>> {
let builder_1 = TestOptimizationBuilder::new(0, vec![operation_1(); 10]);
let builder_2 = TestOptimizationBuilder::new(1, vec![operation_2(); 10]);
vec![Box::new(builder_1), Box::new(builder_2)]
}
}

View File

@@ -0,0 +1,7 @@
mod block;
mod optimization;
pub(super) mod merging;
pub(super) use block::*;
pub use optimization::*;

View File

@@ -0,0 +1,232 @@
use std::sync::Arc;
use crate::{
NumOperations,
search::{
Block, BlockOptimization,
merging::{MergeBlocksResult, merge_blocks},
},
stream::store::ExecutionStrategy,
};
/// Try to optimize a list of [blocks](Block) into a [block optimization](BlockOptimization).
///
/// # Notes
///
/// What we know here is that every block is independent at that time and can be executed
/// in any order.
///
/// The contract is that the length of operations executed must include all operations. If we don't
/// find an optimization that can be executed with that constraint, we return a
/// [BlocksOptimizerResult::WithHoles].
pub struct BlocksOptimizer<O> {
blocks: Vec<Block<O>>,
resolved: Vec<bool>,
last_checked: usize,
}
/// When we can't find a proper optimization for the provided list of [blocks](Block).
pub enum BlocksOptimizerResult<O> {
/// When an optimization fill the hole stream.
Full(BlockOptimization<O>),
/// The optimization found with the holes indices.
WithHoles {
strategies: Vec<Box<ExecutionStrategy<O>>>,
ordering: Vec<usize>,
holes: Vec<usize>,
},
}
enum BlockOptimizationStep<O> {
Contiguous {
strategy: ExecutionStrategy<O>,
},
/// Only happen when we fallback on executing a single operation.
Operation {
strategy: ExecutionStrategy<O>,
},
WithHoles {
strategy: ExecutionStrategy<O>,
holes: Vec<usize>,
},
Stop,
}
impl<O: NumOperations> BlocksOptimizer<O> {
/// Create a new optimizer with the given blocks.
pub fn new(blocks: Vec<Block<O>>) -> Self {
let num_ops: usize = blocks.iter().map(|g| g.end_pos).max().unwrap();
Self {
blocks,
resolved: vec![false; num_ops],
last_checked: 0,
}
}
/// Optimizes the blocks.
///
/// The strategy is quite simple. We try to merge as much [blocks](Block) together as we can,
/// then we iterate over them in order composing optimizations with the remaining blocks, all
/// while minimizing fallbacks operations to avoid having holes in the optimization stream.
pub fn optimize(mut self) -> BlocksOptimizerResult<O> {
self = self.merging_pass();
let mut strategies = Vec::with_capacity(self.blocks.len());
let mut ordering = Vec::new();
let mut blocks = Vec::new();
core::mem::swap(&mut blocks, &mut self.blocks);
for block in blocks {
match self.optimize_block(block, &mut ordering) {
BlockOptimizationStep::Contiguous { strategy } => {
strategies.push(Box::new(strategy));
}
BlockOptimizationStep::Operation { strategy } => {
strategies.push(Box::new(strategy));
break;
}
BlockOptimizationStep::WithHoles { strategy, holes } => {
strategies.push(Box::new(strategy));
return BlocksOptimizerResult::WithHoles {
strategies,
ordering,
holes,
};
}
BlockOptimizationStep::Stop => {
break;
}
}
}
let optimization = match strategies.len() > 1 {
true => BlockOptimization {
strategy: ExecutionStrategy::Composed(strategies),
ordering,
},
false => BlockOptimization {
strategy: *strategies.remove(0),
ordering,
},
};
BlocksOptimizerResult::Full(optimization)
}
/// Optimize a single block.
fn optimize_block(
&mut self,
block: Block<O>,
ordering: &mut Vec<usize>,
) -> BlockOptimizationStep<O> {
let last_index = block.end_pos;
let mut block_optimization = block.optimize();
let opt_size = block_optimization.ordering.len();
for pos in block_optimization.ordering.iter() {
self.update_check(*pos);
}
if self.last_checked != ordering.len() + opt_size {
if !ordering.is_empty() {
// Don't include that block and need further exploring.
return BlockOptimizationStep::Stop;
}
return self.optimize_holes(block_optimization, last_index, ordering);
}
ordering.append(&mut block_optimization.ordering);
BlockOptimizationStep::Contiguous {
strategy: block_optimization.strategy,
}
}
/// The provided optimization has holes.
fn optimize_holes(
&mut self,
mut optimization: BlockOptimization<O>,
last_index: usize,
ordering_global: &mut Vec<usize>,
) -> BlockOptimizationStep<O> {
match optimization.strategy {
ExecutionStrategy::Optimization { opt, ordering } => {
ordering_global.append(&mut optimization.ordering);
let holes = self.find_holes(last_index);
if holes.is_empty() {
let strategy = ExecutionStrategy::Optimization { opt, ordering };
BlockOptimizationStep::Contiguous { strategy }
} else {
let strategy = ExecutionStrategy::Optimization { opt, ordering };
BlockOptimizationStep::WithHoles { strategy, holes }
}
}
ExecutionStrategy::Operations { ordering } => {
let min = ordering.iter().min().unwrap();
ordering_global.push(*min);
let strategy = ExecutionStrategy::Operations {
ordering: Arc::new(vec![*min]),
};
BlockOptimizationStep::Operation { strategy }
}
_ => unreachable!(),
}
}
fn update_check(&mut self, pos: usize) {
self.resolved[pos] = true;
for i in self.last_checked..self.resolved.len() {
if self.resolved[i] {
self.last_checked += 1;
} else {
break;
}
}
}
fn find_holes(&mut self, last: usize) -> Vec<usize> {
let mut fallbacks = Vec::new();
for i in self.last_checked..last {
if !self.resolved[i] {
fallbacks.push(i);
self.resolved[i] = true;
}
self.last_checked += 1;
}
fallbacks
}
/// Try to merge blocks together.
fn merging_pass(mut self) -> Self {
if self.blocks.len() == 1 {
return self;
}
Block::sort(&mut self.blocks);
let blocks = self.blocks.iter().collect::<Vec<_>>();
match merge_blocks(&blocks, false) {
MergeBlocksResult::Full(block) => {
self.blocks = vec![block];
}
MergeBlocksResult::Partial {
mut merged,
mut failed,
} => {
merged.append(&mut failed);
self.blocks = merged;
Block::sort(&mut self.blocks);
}
MergeBlocksResult::Fail => {}
}
self
}
}

View File

@@ -0,0 +1,4 @@
mod blocks;
mod stream;
pub use stream::*;

View File

@@ -0,0 +1,277 @@
use super::blocks::BlocksOptimizer;
use crate::{
NumOperations, OperationFuser,
search::{
Block, BlockOptimization, RegistrationResult,
merging::{MergeBlocksResult, merge_blocks},
optimization::blocks::BlocksOptimizerResult,
},
stream::store::ExecutionStrategy,
};
use burn_ir::OperationIr;
/// Optimize a stream of [operations](OperationIr) using a list of [builders](OptimizationBuilder).
pub struct StreamOptimizer<O> {
builders: Vec<Box<dyn OperationFuser<O>>>,
blocks: Vec<Block<O>>,
length: usize,
stopped: bool,
max_blocks: Option<usize>,
}
impl<O: NumOperations> StreamOptimizer<O> {
/// Create a new stream optimizer.
pub fn new(builders: Vec<Box<dyn OperationFuser<O>>>) -> Self {
Self {
builders,
blocks: Vec::new(),
length: 0,
stopped: false,
// Too high and it may breaks the fusion cache always retriggering explorations.
max_blocks: Some(5),
}
}
/// Register a new [operation](OperationIr) in the optimizer.
///
/// You can use the function [Self::still_optimizing] to know if the operations are actually
/// being registered.
pub fn register(&mut self, operation: &OperationIr) {
if self.stopped {
return;
}
if self.blocks.is_empty() {
self.on_new_block(operation);
self.length += 1;
return;
}
match self.merge_blocks(operation, false) {
MergeBlockStep::Full | MergeBlockStep::NoNeed => {}
MergeBlockStep::Fail | MergeBlockStep::Partial => {
// With the given operation, blocks are no longer independent.
self.stopped = true;
return;
}
}
if let Some(max_blocks) = self.max_blocks {
if self.register_max_block(operation, max_blocks) {
self.length += 1;
} else {
self.stopped = true;
}
return;
}
let added_count = self.register_inner(operation, false);
if added_count == 0 {
self.on_new_block(operation);
}
self.length += 1;
}
/// Optimize the current stream on the given [operations](OperationIr).
///
/// # Notes
///
/// The operations provided are the same as the ones used in the [register](Self::register)
/// method, this simply remove the need for the current type to also keep track of the list of
/// operations.
pub fn optimize(&self, operations: &[OperationIr]) -> BlockOptimization<O> {
let result = BlocksOptimizer::new(self.blocks.clone()).optimize();
match result {
BlocksOptimizerResult::Full(block_optimization) => block_optimization,
BlocksOptimizerResult::WithHoles {
mut strategies,
mut ordering,
mut holes,
} => {
loop {
let mut search = self.new_empty_search();
let mut operations_holes = Vec::with_capacity(holes.len());
for index in holes.iter() {
let op = &operations[*index];
operations_holes.push(op.clone());
search.register(op);
}
let mut optimization_of_holes = search.optimize(&operations_holes);
optimization_of_holes.map_ordering(&holes);
strategies.push(Box::new(optimization_of_holes.strategy));
holes.drain(0..optimization_of_holes.ordering.len());
ordering.append(&mut optimization_of_holes.ordering);
if holes.is_empty() {
break;
}
}
BlockOptimization::new(ExecutionStrategy::Composed(strategies), ordering)
}
}
}
/// Reset the state of the optimizer.
pub fn reset(&mut self) {
self.builders.iter_mut().for_each(|b| b.reset());
self.length = 0;
self.blocks.clear();
self.stopped = false;
}
/// Returns if some optimizations are still possible within the stream.
pub fn still_optimizing(&self) -> bool {
if self.stopped {
return false;
}
if self.blocks.is_empty() {
return true;
}
let mut num_stopped = 0;
for block in self.blocks.iter() {
if !block.still_optimizing() {
num_stopped += 1
}
}
num_stopped < self.blocks.len()
}
fn register_max_block(&mut self, operation: &OperationIr, max_blocks: usize) -> bool {
if max_blocks == 1 {
// Register in the single block with a force.
self.register_inner(operation, true);
return true;
}
let added_count = self.register_inner(operation, false);
if added_count > 0 {
return true;
}
if added_count == 0 && self.blocks.len() < max_blocks {
self.on_new_block(operation);
return true;
}
self.merge_blocks(operation, true);
if self.blocks.len() >= max_blocks {
self.stopped = true;
return false;
}
let added_count = self.register_inner(operation, false);
if added_count == 0 {
self.on_new_block(operation);
}
true
}
fn register_inner(&mut self, operation: &OperationIr, force: bool) -> usize {
let mut added_count = 0;
for block in self.blocks.iter_mut() {
match block.register(operation, self.length, force) {
RegistrationResult::Accepted => {
added_count += 1;
}
RegistrationResult::NotPartOfTheGraph => {}
}
}
added_count
}
fn new_empty_search(&self) -> Self {
Self::new(
self.builders
.iter()
.map(|b| {
let mut b = b.clone_dyn();
b.reset();
b
})
.collect(),
)
}
fn merge_blocks(&mut self, operation: &OperationIr, all: bool) -> MergeBlockStep {
let nodes = operation.nodes();
let mut block_merges = Vec::new();
for (i, block) in self.blocks.iter().enumerate() {
if all || block.contains_tensors(&nodes) {
block_merges.push(i);
}
}
if block_merges.len() <= 1 {
return MergeBlockStep::NoNeed;
}
let blocks_to_merge = self
.blocks
.iter()
.enumerate()
.filter_map(|(i, g)| match block_merges.contains(&i) {
true => Some(g),
false => None,
})
.collect::<Vec<_>>();
let merged = merge_blocks(&blocks_to_merge, false);
let mut clear_blocks = || {
let mut indices = block_merges.to_vec();
indices.sort();
for g in indices.into_iter().rev() {
self.blocks.remove(g);
}
};
match merged {
MergeBlocksResult::Full(block) => {
clear_blocks();
self.blocks.push(block);
Block::sort(&mut self.blocks);
MergeBlockStep::Full
}
MergeBlocksResult::Partial {
mut merged,
mut failed,
} => {
clear_blocks();
self.blocks.append(&mut merged);
self.blocks.append(&mut failed);
Block::sort(&mut self.blocks);
MergeBlockStep::Partial
}
MergeBlocksResult::Fail => MergeBlockStep::Fail,
}
}
fn on_new_block(&mut self, operation: &OperationIr) {
let mut block = Block::new(&self.builders);
block.register(operation, self.length, true);
self.blocks.push(block);
}
}
enum MergeBlockStep {
Full,
Partial,
Fail,
NoNeed,
}

View File

@@ -0,0 +1,215 @@
use std::sync::Arc;
use crate::{
FusionBackend, FusionRuntime,
stream::{MultiStream, OperationStreams, StreamId, execution::Operation},
};
use burn_backend::{TensorData, backend::ExecutionError};
use burn_ir::{HandleContainer, OperationIr, TensorId, TensorIr};
pub struct FusionServer<R: FusionRuntime> {
streams: MultiStream<R>,
pub(crate) handles: HandleContainer<R::FusionHandle>,
}
impl<R> FusionServer<R>
where
R: FusionRuntime,
{
pub fn new(device: R::FusionDevice) -> Self {
Self {
streams: MultiStream::new(device.clone()),
handles: HandleContainer::new(),
}
}
pub fn register(
&mut self,
streams: OperationStreams,
repr: OperationIr,
operation: Arc<dyn Operation<R>>,
) {
self.streams
.register(streams, repr, operation, &mut self.handles)
}
pub fn drain_stream(&mut self, id: StreamId) {
self.streams.drain(&mut self.handles, id)
}
pub fn create_empty_handle(&mut self) -> TensorId {
self.handles.create_tensor_uninit()
}
pub fn read_float<B>(
&mut self,
tensor: TensorIr,
id: StreamId,
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send + use<R, B>
where
B: FusionBackend<FusionRuntime = R>,
{
// Make sure all registered operations are executed.
// The underlying backend can still be async.
self.drain_stream(id);
let tensor_float = self.handles.get_float_tensor::<B>(&tensor);
self.streams.mark_read(id, &tensor, &self.handles);
B::float_into_data(tensor_float)
}
pub fn read_int<B>(
&mut self,
tensor: TensorIr,
id: StreamId,
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send + use<R, B>
where
B: FusionBackend<FusionRuntime = R>,
{
// Make sure all registered operations are executed.
// The underlying backend can still be async.
self.drain_stream(id);
let tensor_int = self.handles.get_int_tensor::<B>(&tensor);
self.streams.mark_read(id, &tensor, &self.handles);
B::int_into_data(tensor_int)
}
pub fn read_bool<B>(
&mut self,
tensor: TensorIr,
id: StreamId,
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send + use<R, B>
where
B: FusionBackend<FusionRuntime = R>,
{
// Make sure all registered operations are executed.
// The underlying backend can still be async.
self.drain_stream(id);
let tensor_bool = self.handles.get_bool_tensor::<B>(&tensor);
self.streams.mark_read(id, &tensor, &self.handles);
B::bool_into_data(tensor_bool)
}
pub fn read_quantized<B>(
&mut self,
tensor: TensorIr,
id: StreamId,
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send + use<R, B>
where
B: FusionBackend<FusionRuntime = R>,
{
// Make sure all registered operations are executed.
// The underlying backend can still be async.
self.drain_stream(id);
let tensor_q = self.handles.get_quantized_tensor::<B>(&tensor);
self.streams.mark_read(id, &tensor, &self.handles);
B::q_into_data(tensor_q)
}
pub fn change_server_float<B>(
&mut self,
tensor: &TensorIr,
stream_tensor: StreamId,
device: &R::FusionDevice,
server_device: &mut Self,
) -> TensorId
where
B: FusionBackend<FusionRuntime = R>,
{
let tensor_float = self.handles.get_float_tensor::<B>(tensor);
self.streams.mark_read(stream_tensor, tensor, &self.handles);
let tensor = B::float_to_device(tensor_float, device);
let id = server_device.create_empty_handle();
server_device
.handles
.register_float_tensor::<B>(&id, tensor.clone());
id
}
pub fn resolve_server_float<B>(&mut self, tensor: &TensorIr) -> B::FloatTensorPrimitive
where
B: FusionBackend<FusionRuntime = R>,
{
self.handles.get_float_tensor::<B>(tensor)
}
pub fn resolve_server_int<B>(&mut self, tensor: &TensorIr) -> B::IntTensorPrimitive
where
B: FusionBackend<FusionRuntime = R>,
{
self.handles.get_int_tensor::<B>(tensor)
}
pub fn resolve_server_bool<B>(&mut self, tensor: &TensorIr) -> B::BoolTensorPrimitive
where
B: FusionBackend<FusionRuntime = R>,
{
self.handles.get_bool_tensor::<B>(tensor)
}
pub fn change_server_int<B>(
&mut self,
tensor: &TensorIr,
stream_tensor: StreamId,
device: &R::FusionDevice,
server_device: &mut Self,
) -> TensorId
where
B: FusionBackend<FusionRuntime = R>,
{
let tensor_int = self.handles.get_int_tensor::<B>(tensor);
self.streams.mark_read(stream_tensor, tensor, &self.handles);
let tensor = B::int_to_device(tensor_int, device);
let id = server_device.create_empty_handle();
server_device
.handles
.register_int_tensor::<B>(&id, tensor.clone());
id
}
pub fn change_server_bool<B>(
&mut self,
tensor: &TensorIr,
stream_tensor: StreamId,
device: &R::FusionDevice,
server_device: &mut Self,
) -> TensorId
where
B: FusionBackend<FusionRuntime = R>,
{
let tensor_bool = self.handles.get_bool_tensor::<B>(tensor);
self.streams.mark_read(stream_tensor, tensor, &self.handles);
let tensor = B::bool_to_device(tensor_bool, device);
let id = server_device.create_empty_handle();
server_device
.handles
.register_bool_tensor::<B>(&id, tensor.clone());
id
}
pub fn change_server_quantized<B>(
&mut self,
tensor: &TensorIr,
device: &R::FusionDevice,
server_device: &mut Self,
) -> TensorId
where
B: FusionBackend<FusionRuntime = R>,
{
let tensor = self.handles.get_quantized_tensor::<B>(tensor);
let tensor = B::q_to_device(tensor, device);
let id = server_device.create_empty_handle();
server_device
.handles
.register_quantized_tensor::<B>(&id, tensor);
id
}
}

View File

@@ -0,0 +1 @@
pub use burn_backend::StreamId;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,16 @@
use burn_ir::HandleContainer;
use crate::FusionRuntime;
/// The mode in which the execution is done.
#[derive(Clone, Copy, Debug)]
pub(crate) enum ExecutionMode {
Lazy,
Sync,
}
/// General trait to abstract how a single operation is executed.
pub trait Operation<R: FusionRuntime>: Send + Sync + core::fmt::Debug {
/// Execute the operation.
fn execute(&self, handles: &mut HandleContainer<R::FusionHandle>);
}

View File

@@ -0,0 +1,91 @@
use burn_ir::OperationIr;
use super::ExecutionMode;
use crate::{
NumOperations, OperationFuser,
search::{BlockOptimization, StreamOptimizer},
};
/// Explore and create new optimization.
pub struct Explorer<O> {
optimizer: StreamOptimizer<O>,
num_deferred: usize,
num_explored: usize,
is_still_optimizing: bool,
}
/// The result of an exploration done by the [explorer](Explorer).
pub enum ExplorationAction<O> {
/// Found a new optimization.
Completed(BlockOptimization<O>),
/// We should continue exploring before arriving at a conclusion.
Continue,
}
impl<O: NumOperations> Explorer<O> {
/// Create a new explorer.
pub(crate) fn new(optimizations: Vec<Box<dyn OperationFuser<O>>>) -> Self {
Self {
optimizer: StreamOptimizer::new(optimizations),
num_deferred: 0,
num_explored: 0,
is_still_optimizing: true,
}
}
/// Indicate that a new operation is added.
pub(crate) fn on_new_operation(&mut self) {
self.num_deferred += 1;
}
/// If the explorer is up to date.
pub(crate) fn is_up_to_date(&self) -> bool {
self.num_deferred == 0
}
/// Explore the provided operations.
pub(crate) fn explore(
&mut self,
operations: &[OperationIr],
mode: ExecutionMode,
) -> ExplorationAction<O> {
self.update(operations);
// Can only continue exploration when not sync.
if let ExecutionMode::Lazy = mode
&& self.is_still_optimizing
{
return ExplorationAction::Continue;
}
let optimization = self.optimizer.optimize(operations);
ExplorationAction::Completed(optimization)
}
/// Reset the state of the explorer to the provided list of operations.
pub(crate) fn reset(&mut self, operations: &[OperationIr]) {
self.optimizer.reset();
self.num_explored = 0;
self.num_deferred = operations.len();
self.is_still_optimizing = true;
}
/// Register any operations that we had deferred
fn update(&mut self, operations: &[OperationIr]) {
for i in (0..self.num_deferred).rev() {
if !self.is_still_optimizing {
break;
}
let index = operations.len() - 1 - i;
let relative = &operations[index];
self.optimizer.register(relative);
self.num_explored += 1;
self.is_still_optimizing = self.optimizer.still_optimizing();
}
self.num_deferred = 0;
}
}

View File

@@ -0,0 +1,17 @@
pub(crate) mod validator;
mod base;
mod explorer;
mod ordering;
mod policy;
mod processor;
pub use base::*;
pub use ordering::*;
pub(crate) use explorer::*;
pub(crate) use policy::*;
pub(crate) use processor::*;
#[cfg(test)]
pub(crate) mod tests;

View File

@@ -0,0 +1,71 @@
use std::sync::Arc;
use burn_ir::HandleContainer;
use crate::{FusionRuntime, NumOperations, Optimization, stream::Context};
use super::Operation;
/// Manage the execution of potentially multiple optimizations and operations out of order.
pub struct OrderedExecution<R: FusionRuntime> {
operations: Vec<Arc<dyn Operation<R>>>,
num_executed: usize,
ordering: Option<Arc<Vec<usize>>>,
}
impl<R: FusionRuntime> OrderedExecution<R> {
/// Returns the operation that can be executed without impacting the state of the execution.
///
/// This is useful to implement fallback for optimizations.
#[allow(clippy::borrowed_box)]
pub fn operation_within_optimization(&self, index: usize) -> Arc<dyn Operation<R>> {
match &self.ordering {
Some(val) => {
let index = val[index];
self.operations[index].clone()
}
None => panic!("No ordering provided"),
}
}
pub(crate) fn new(operations: Vec<Arc<dyn Operation<R>>>) -> Self {
Self {
operations,
num_executed: 0,
ordering: None,
}
}
pub(crate) fn finish(mut self) -> (Vec<Arc<dyn Operation<R>>>, usize) {
self.operations.drain(0..self.num_executed);
(self.operations, self.num_executed)
}
pub(crate) fn execute_optimization(
&mut self,
optimization: &mut R::Optimization,
context: &mut Context<'_, R::FusionHandle>,
ordering: Arc<Vec<usize>>,
) {
if ordering.len() > self.operations.len() {
panic!("Ordering is bigger than operations");
}
self.ordering = Some(ordering);
let num_drained = optimization.len();
optimization.execute(context, self);
self.num_executed += num_drained;
}
pub(crate) fn execute_operations(
&mut self,
handles: &mut HandleContainer<R::FusionHandle>,
ordering: &[usize],
) {
self.num_executed += ordering.len();
for id in ordering {
let op = &self.operations[*id];
op.execute(handles);
}
}
}

View File

@@ -0,0 +1,572 @@
use burn_ir::OperationIr;
use super::ExecutionMode;
use super::validator::{
ExecutionPlanOperationsStore, TriggerOperationsStore, TriggerProgress, TriggerValidator,
ValidatorState,
};
use crate::stream::execution::validator::OperationsValidator;
use crate::stream::store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger, SearchQuery};
use std::marker::PhantomData;
/// The policy keeps track of all possible execution plans for the current operations.
///
/// # Details
///
/// We keep track of each new operation added and invalidate potential execution plans
/// when we see a different operation is added.
///
/// Therefore, the overhead is very minimal, since the time-complexity of checking for existing
/// execution plans scales with the number of concurrent potential plans for the current operations,
/// which isn't supposed to be big at any time.
pub(crate) struct Policy<O> {
/// List of potential execution plans that are compatible with current stream segment
candidates: Vec<OperationsValidator<ExecutionPlanId>>,
/// List of candidate execution plans that have been found; we can still keep searching
/// to potentially find a better one.
availables: Vec<AvailableItem>,
/// The found execution plan that should be executed, along with the number of operations
/// in the plan.
found: Option<(ExecutionPlanId, usize)>,
/// The number of operations that have been analyzed
num_operations: usize,
_item_type: PhantomData<O>,
}
#[derive(new)]
struct AvailableItem {
id: ExecutionPlanId,
size: usize,
triggers: Vec<TriggerValidator>,
}
/// Action to be made depending on the stream.
#[derive(PartialEq, Eq, Debug)]
pub enum Action {
/// Continue exploring using the [builder](crate::OptimizationBuilder).
Explore,
/// The current policy indicates that an exploration may be possible in the future, so the
/// best action is to defer any execution.
///
/// Sometimes, it can be a false positive and a new exploration should be built from scratch.
/// Therefore it's important to keep the previous operations to rebuild the state if it
/// happens.
Defer,
/// An exploration has been found, and the best action is to execute it!
Execute(ExecutionPlanId),
}
impl<O: core::fmt::Debug> Policy<O> {
/// Create a new policy.
pub(crate) fn new() -> Self {
Self {
candidates: Vec::new(),
availables: Vec::new(),
found: None,
num_operations: 0,
_item_type: PhantomData,
}
}
/// Returns the [action](Action) that should be taken given the state of the policy.
pub fn action(
&self,
store: &ExecutionPlanStore<O>,
operations: &[OperationIr],
mode: ExecutionMode,
) -> Action {
if self.num_operations < operations.len() {
panic!(
"Internal Error: Can't retrieve the policy action on a list of operations bigger than what is analyzed."
);
}
if let Some((id, _length)) = self.found {
return Action::Execute(id);
}
match mode {
ExecutionMode::Lazy => self.action_lazy(operations),
ExecutionMode::Sync => self.action_sync(operations, store),
}
}
/// Update the policy state.
pub fn update(&mut self, store: &ExecutionPlanStore<O>, operation: &OperationIr) {
// reset the candidates to contain all execution plans starting with the operation.
if self.num_operations == 0 {
self.candidates = store
.find(SearchQuery::PlansStartingWith(operation))
.into_iter()
.map(OperationsValidator::new)
.collect();
}
self.update_candidates(store, operation);
self.check_candidates(store);
self.update_availables(store, operation);
self.check_availables();
self.num_operations += 1;
}
// Reset the state of the policy.
pub fn reset(&mut self) {
self.candidates.clear();
self.availables.clear();
self.num_operations = 0;
self.found = None;
}
/// Check which candidates can be removed, and which one can go from
/// 'candidate' to 'available'
fn check_candidates(&mut self, store: &ExecutionPlanStore<O>) {
let mut candidates_to_remove = Vec::new();
for candidate in self.candidates.iter() {
match candidate.state {
ValidatorState::Found { size } => {
let item = store.get_unchecked(candidate.id);
let mut triggers = Vec::with_capacity(item.triggers.len());
for (index, trigger) in item.triggers.iter().enumerate() {
triggers.push(match trigger {
ExecutionTrigger::OnOperations(_) => TriggerValidator::OnOperations {
matching: OperationsValidator::new(index),
progress: TriggerProgress::NotInit,
},
ExecutionTrigger::OnSync => TriggerValidator::OnSync,
ExecutionTrigger::Always => TriggerValidator::Always,
});
}
self.availables
.push(AvailableItem::new(candidate.id, size, triggers));
candidates_to_remove.push(candidate.id);
}
ValidatorState::Invalidated => {
candidates_to_remove.push(candidate.id);
}
ValidatorState::Validating => {}
};
}
let mut updated_candidates = Vec::new();
core::mem::swap(&mut updated_candidates, &mut self.candidates);
self.candidates = updated_candidates
.into_iter()
.filter(|candidate| !candidates_to_remove.iter().any(|id| id == &candidate.id))
.collect();
}
fn check_availables(&mut self) {
for available in self.availables.iter() {
for trigger in available.triggers.iter() {
match trigger {
TriggerValidator::OnOperations {
matching,
progress: _,
} => {
if let ValidatorState::Found {
size: _size_of_trigger,
} = matching.state
{
self.found = Some((available.id, available.size));
return;
}
}
TriggerValidator::Always => {
self.found = Some((available.id, available.size));
return;
}
TriggerValidator::OnSync => {
// Does nothing during an update.
}
}
}
}
}
fn update_candidates(&mut self, store: &ExecutionPlanStore<O>, operation: &OperationIr) {
let main_store = ExecutionPlanOperationsStore::new(store);
self.candidates
.iter_mut()
.for_each(|candidate| candidate.update(operation, self.num_operations, &main_store));
}
fn update_availables(&mut self, store: &ExecutionPlanStore<O>, operation: &OperationIr) {
self.availables.iter_mut().for_each(|available| {
let store_trigger = TriggerOperationsStore::new(available.id, store);
available.triggers.iter_mut().for_each(|trigger| {
if let TriggerValidator::OnOperations { matching, progress } = trigger {
match progress {
TriggerProgress::NotInit => {
*progress = TriggerProgress::NumChecked(0);
}
TriggerProgress::NumChecked(num_check) => {
matching.update(operation, *num_check, &store_trigger);
*num_check += 1;
}
}
}
});
});
}
fn action_lazy(&self, operations: &[OperationIr]) -> Action {
if !self.candidates.is_empty() {
return Action::Defer;
}
for available in self.availables.iter() {
if available.size == operations.len() {
return Action::Defer;
}
for trigger in available.triggers.iter() {
if let TriggerValidator::OnOperations {
matching,
progress: _,
} = trigger
&& let ValidatorState::Validating = matching.state
{
return Action::Defer;
}
}
}
Action::Explore
}
fn action_sync(&self, operations: &[OperationIr], store: &ExecutionPlanStore<O>) -> Action {
for available in self.availables.iter() {
if available.size == operations.len() {
return Action::Execute(available.id);
}
}
for candidate in self.candidates.iter() {
let item = store.get_unchecked(candidate.id);
if item.operations.len() == operations.len() {
return Action::Execute(candidate.id);
}
}
Action::Explore
}
}
#[cfg(test)]
mod tests {
use burn_backend::{DType, Shape};
use burn_ir::{FloatOperationIr, TensorId, TensorIr, TensorStatus, UnaryOpIr};
use super::*;
use crate::{
search::BlockOptimization,
stream::store::{ExecutionPlan, ExecutionStrategy, ExecutionTrigger},
};
use std::ops::Range;
#[test]
fn given_no_optimization_should_explore() {
let store = ExecutionPlanStore::default();
let mut policy = Policy::new();
let stream = TestStream::new(3);
stream.assert_updates(
&store,
&mut policy,
AssertUpdatesOptions::OperationsIndex(0..3),
Action::Explore,
);
}
#[test]
fn given_existing_optimizations_when_sync_should_execute_one_when_available() {
let mut store = ExecutionPlanStore::default();
let mut policy = Policy::new();
let stream = TestStream::new(3);
let id_1 = store.add(ExecutionPlan {
operations: stream.operations[0..2].to_vec(),
triggers: Vec::new(),
optimization: BlockOptimization::new(ExecutionStrategy::operations(2), Vec::new()),
});
let _id_2 = store.add(ExecutionPlan {
operations: stream.operations[0..3].to_vec(),
triggers: Vec::new(),
optimization: BlockOptimization::new(ExecutionStrategy::operations(3), Vec::new()),
});
stream.assert_updates(
&store,
&mut policy,
AssertUpdatesOptions::OperationsIndex(0..2),
Action::Defer,
);
let action = policy.action(&store, &stream.operations[0..2], ExecutionMode::Sync);
assert_eq!(action, Action::Execute(id_1));
}
#[test]
fn given_existing_plan_when_found_trigger_should_execute_plan() {
let mut store = ExecutionPlanStore::default();
let mut policy = Policy::new();
let stream = TestStream::new(3);
let id = store.add(ExecutionPlan {
operations: stream.operations[0..2].to_vec(),
triggers: stream.operations[2..3]
.iter()
.map(|desc| ExecutionTrigger::OnOperations(vec![desc.clone()]))
.collect(),
optimization: BlockOptimization::new(ExecutionStrategy::operations(2), Vec::new()),
});
stream.assert_updates(
&store,
&mut policy,
AssertUpdatesOptions::OperationsIndex(0..2),
Action::Defer,
);
stream.assert_updates(
&store,
&mut policy,
AssertUpdatesOptions::OperationsIndex(2..3),
Action::Execute(id),
);
}
#[test]
fn should_support_multiple_triggers() {
let mut store = ExecutionPlanStore::default();
let mut policy_1 = Policy::new();
let mut policy_2 = Policy::new();
let mut stream_1 = TestStream::new(2);
let mut stream_2 = TestStream::new(2);
// Create different end operation for each stream.
let trigger_id_1 = 5;
let trigger_id_2 = 6;
stream_1.new_ops(trigger_id_1);
stream_2.new_ops(trigger_id_2);
let id = store.add(ExecutionPlan {
operations: stream_1.operations[0..2].to_vec(),
triggers: vec![
ExecutionTrigger::OnOperations(vec![stream_1.operations[2].clone()]),
ExecutionTrigger::OnOperations(vec![stream_2.operations[2].clone()]),
],
optimization: BlockOptimization::new(ExecutionStrategy::operations(2), Vec::new()),
});
stream_1.assert_updates(
&store,
&mut policy_1,
AssertUpdatesOptions::OperationsIndex(0..2),
Action::Defer,
);
stream_2.assert_updates(
&store,
&mut policy_2,
AssertUpdatesOptions::OperationsIndex(0..2),
Action::Defer,
);
stream_1.assert_updates(
&store,
&mut policy_1,
AssertUpdatesOptions::OperationsIndex(2..3), // First trigger.
Action::Execute(id),
);
stream_2.assert_updates(
&store,
&mut policy_2,
AssertUpdatesOptions::OperationsIndex(2..3), // Second trigger.
Action::Execute(id),
);
}
#[test]
fn should_select_right_optimization() {
let mut store = ExecutionPlanStore::default();
let mut policy_1 = Policy::new();
let mut policy_2 = Policy::new();
let mut stream_1 = TestStream::new(2);
let mut stream_2 = TestStream::new(2);
// Create different streams after op 2.
stream_1.new_ops(4);
stream_1.new_ops(5);
stream_2.new_ops(5);
stream_2.new_ops(6);
let optimization_stream_1 = store.add(ExecutionPlan {
operations: stream_1.operations[0..3].to_vec(),
triggers: stream_1.operations[3..4]
.iter()
.map(|desc| ExecutionTrigger::OnOperations(vec![desc.clone()]))
.collect(),
optimization: BlockOptimization::new(ExecutionStrategy::operations(3), Vec::new()),
});
let optimization_stream_2 = store.add(ExecutionPlan {
operations: stream_2.operations[0..3].to_vec(),
triggers: stream_2.operations[3..4]
.iter()
.map(|desc| ExecutionTrigger::OnOperations(vec![desc.clone()]))
.collect(),
optimization: BlockOptimization::new(ExecutionStrategy::operations(3), Vec::new()),
});
assert_ne!(optimization_stream_1, optimization_stream_2);
stream_1.assert_updates(
&store,
&mut policy_1,
AssertUpdatesOptions::OperationsIndex(0..3),
Action::Defer,
);
stream_2.assert_updates(
&store,
&mut policy_2,
AssertUpdatesOptions::OperationsIndex(0..3),
Action::Defer,
);
stream_1.assert_updates(
&store,
&mut policy_1,
AssertUpdatesOptions::OperationsIndex(3..4),
Action::Execute(optimization_stream_1),
);
stream_2.assert_updates(
&store,
&mut policy_2,
AssertUpdatesOptions::OperationsIndex(3..4),
Action::Execute(optimization_stream_2),
);
}
#[test]
fn should_invalidate_wrong_optimizations() {
let mut store = ExecutionPlanStore::default();
let stream_1 = TestStream::new(4);
let mut stream_2 = TestStream::new(2);
stream_2.new_ops(6);
stream_2.new_ops(7);
store.add(ExecutionPlan {
operations: stream_1.operations[0..3].to_vec(),
triggers: stream_1.operations[3..4]
.iter()
.map(|desc| ExecutionTrigger::OnOperations(vec![desc.clone()]))
.collect(),
optimization: BlockOptimization::new(ExecutionStrategy::operations(3), Vec::new()),
});
let mut policy = Policy::new();
// Same path as stream 1
stream_2.assert_updates(
&store,
&mut policy,
AssertUpdatesOptions::OperationsIndex(0..2),
Action::Defer,
);
// But is different.
stream_2.assert_updates(
&store,
&mut policy,
AssertUpdatesOptions::OperationsIndex(2..4),
Action::Explore,
);
}
#[derive(Default, Debug)]
struct TestStream {
tensors: Vec<TensorIr>,
operations: Vec<OperationIr>,
}
#[derive(Debug)]
enum AssertUpdatesOptions {
OperationsIndex(Range<usize>),
}
impl TestStream {
/// Create a new test stream with `num_ops` operations registered.
pub fn new(num_ops: usize) -> Self {
let mut stream = Self::default();
for id in 0..num_ops {
stream.new_ops(id as u64 + 1);
}
stream
}
/// The first follow should only be cache miss.
pub fn assert_updates(
&self,
optimizations: &ExecutionPlanStore<()>,
policy: &mut Policy<()>,
options: AssertUpdatesOptions,
action: Action,
) {
match options {
AssertUpdatesOptions::OperationsIndex(range) => {
for i in range {
let stream = &self.operations[0..i];
let next_ops = &self.operations[i];
policy.update(optimizations, next_ops);
let result = policy.action(optimizations, stream, ExecutionMode::Lazy);
assert_eq!(result, action);
}
}
}
}
/// Add a simple operation to the stream.
pub fn new_ops(&mut self, out_id: u64) {
if self.tensors.is_empty() {
// Root node.
self.new_empty_node(0);
}
// Out node.
self.new_empty_node(out_id);
self.operations.push(OperationIr::Float(
DType::F32,
FloatOperationIr::Log(self.unary_description()),
));
}
fn new_empty_node(&mut self, id: u64) {
self.tensors.push(TensorIr {
id: TensorId::new(id),
shape: Shape::new([32, 32, 1]),
status: TensorStatus::NotInit,
dtype: DType::F32,
});
}
fn unary_description(&self) -> UnaryOpIr {
let size = self.tensors.len();
UnaryOpIr {
input: self.tensors[size - 2].clone(),
out: self.tensors[size - 1].clone(),
}
}
}
}

View File

@@ -0,0 +1,184 @@
use burn_ir::OperationIr;
use super::{ExecutionMode, ExplorationAction, Explorer};
use crate::search::BlockOptimization;
use crate::stream::execution::{Action, Policy};
use crate::stream::store::{ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger};
use crate::{NumOperations, OperationFuser};
/// Process a [stream segment](StreamSegment) following a [policy](Policy).
pub(crate) struct Processor<O> {
policy: Policy<O>,
explorer: Explorer<O>,
}
/// A part of a stream that can be executed partially using [execution plan](ExecutionPlan).
pub(crate) trait StreamSegment<O> {
/// The operations in the segment.
fn operations(&self) -> &[OperationIr];
/// Execute part of the segment using the given plan id.
fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore<O>);
}
impl<O: NumOperations> Processor<O> {
/// Create a new stream processor.
pub fn new(optimizations: Vec<Box<dyn OperationFuser<O>>>) -> Self {
Self {
policy: Policy::new(),
explorer: Explorer::new(optimizations),
}
}
/// Process the [stream segment](StreamSegment) with the provided [mode](ExecutionMode).
pub fn process<Segment>(
&mut self,
mut segment: Segment,
store: &mut ExecutionPlanStore<O>,
mode: ExecutionMode,
) where
Segment: StreamSegment<O>,
{
// We assume that we always register a new operation in lazy mode.
if let ExecutionMode::Lazy = mode {
self.on_new_operation(&segment, store);
}
loop {
if segment.operations().is_empty() {
break;
}
let action = self.policy.action(store, segment.operations(), mode);
match action {
Action::Explore => {
self.explore(&mut segment, store, mode);
if self.explorer.is_up_to_date() {
break;
}
}
Action::Defer => {
match mode {
ExecutionMode::Lazy => break,
ExecutionMode::Sync => panic!("Can't defer while sync"),
};
}
Action::Execute(id) => {
if let ExecutionMode::Sync = mode {
store.add_trigger(id, ExecutionTrigger::OnSync);
}
segment.execute(id, store);
self.reset(store, segment.operations());
}
};
}
}
fn on_new_operation<Segment>(&mut self, segment: &Segment, store: &mut ExecutionPlanStore<O>)
where
Segment: StreamSegment<O>,
{
self.policy.update(
store,
segment
.operations()
.last()
.expect("At least one operation in the operation list."),
);
self.explorer.on_new_operation();
}
fn explore<Item: StreamSegment<O>>(
&mut self,
item: &mut Item,
store: &mut ExecutionPlanStore<O>,
mode: ExecutionMode,
) {
match self.explorer.explore(item.operations(), mode) {
ExplorationAction::Completed(optim) => {
let id = Self::on_exploration_completed(
&self.policy,
item.operations(),
store,
optim,
mode,
);
item.execute(id, store);
self.reset(store, item.operations());
}
ExplorationAction::Continue => {
if let ExecutionMode::Sync = mode {
panic!("Can't continue exploring when sync.")
}
}
}
}
fn reset(&mut self, store: &mut ExecutionPlanStore<O>, operations: &[OperationIr]) {
self.explorer.reset(operations);
self.policy.reset();
// Reset the policy state with the remaining operations
for operation in operations.iter() {
self.policy.update(store, operation);
}
}
/// We found an optimization (i.e. a new execution plan).
/// Cache it in the store.
fn on_exploration_completed(
policy: &Policy<O>,
operations: &[OperationIr],
store: &mut ExecutionPlanStore<O>,
optimization: BlockOptimization<O>,
mode: ExecutionMode,
) -> ExecutionPlanId {
let num_optimized = optimization.ordering.len();
let relative = &operations[0..num_optimized];
match mode {
ExecutionMode::Lazy => {
let next_ops = &operations[num_optimized..operations.len()];
let trigger = if next_ops.is_empty() {
// Happens if the next ops is included in the fused operation, and there is no
// way the builder can still continue fusing.
ExecutionTrigger::Always
} else {
ExecutionTrigger::OnOperations(next_ops.to_vec())
};
match policy.action(store, relative, ExecutionMode::Sync) {
Action::Execute(id) => {
store.add_trigger(id, trigger);
id
}
_ => {
let plan = ExecutionPlan {
operations: relative.to_vec(),
triggers: vec![trigger],
optimization,
};
store.add(plan)
}
}
}
ExecutionMode::Sync => match policy.action(store, relative, ExecutionMode::Sync) {
Action::Execute(id) => {
store.add_trigger(id, ExecutionTrigger::OnSync);
id
}
_ => {
let plan = ExecutionPlan {
operations: relative.to_vec(),
triggers: vec![ExecutionTrigger::OnSync],
optimization,
};
store.add(plan)
}
},
}
}
}

View File

@@ -0,0 +1,671 @@
//! A testing module that ensures the correctness of the explorer, policy, and processor.
//!
//! The primary focus is on validating the seamless interaction between these three components to
//! execute and optimize a stream of operations accurately.
//!
//! To test these components effectively, we create mock types for the stream, optimization,
//! optimization builder, and stream segment. These mock types aid in comprehensively
//! understanding the process of optimizing streams.
use std::sync::Arc;
use burn_backend::{DType, Shape};
use burn_ir::{
BinaryOpIr, FloatOperationIr, NumericOperationIr, OperationIr, ScalarIr, ScalarOpIr, TensorId,
TensorIr, TensorStatus, UnaryOpIr,
};
use crate::{
FuserProperties, FuserStatus, NumOperations, OperationFuser,
search::BlockOptimization,
stream::store::{
ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy, ExecutionTrigger,
},
};
use super::*;
/// A fake stream of operations for testing purpose.
pub struct TestStream {
processor: Processor<TestOptimization>,
store: ExecutionPlanStore<TestOptimization>,
executed: Vec<ExecutionPlanId>,
operations: Vec<OperationIr>,
}
/// A fake [optimization builder](OptimizationBuilder) for testing purpose.
///
/// The optimizer tries to fuse only the `expected_operations` if they appear
/// in the operations queue
#[derive(Clone)]
pub struct TestOptimizationBuilder {
builder_id: usize,
expected_operations: Vec<OperationIr>,
actual: Vec<OperationIr>,
}
/// A fake optimization for testing purpose.
#[derive(new, Debug, PartialEq)]
pub struct TestOptimization {
builder_id: usize,
size: usize,
}
impl NumOperations for TestOptimization {
fn len(&self) -> usize {
self.size
}
}
/// A fake [stream segment](StreamSegment) for testing purpose.
#[derive(new)]
pub struct TestSegment<'i> {
operations: &'i mut Vec<OperationIr>,
executed: &'i mut Vec<ExecutionPlanId>,
}
impl<O> ExecutionStrategy<O> {
/// Create an ordered execution strategy with the given size.
pub fn operations(size: usize) -> Self {
Self::Operations {
ordering: Arc::new((0..size).collect()),
}
}
}
impl ExecutionStrategy<TestOptimization> {
/// Only use it for testing, to easily create ordered strategies.
pub fn optimization(opt: TestOptimization) -> Self {
let ordering = Arc::new((0..opt.size).collect());
Self::Optimization { opt, ordering }
}
}
/// This is a substantial test case that examines a lengthy scenario with a diverse set of conditions.
///
/// While it's usually preferable to split tests into multiple independent scenarios, in this case, it is
/// crucial to verify that the stream's state is correctly updated when various cases occur consecutively.
#[test]
fn should_support_complex_stream() {
// We have 2 different optimization builders in this test case.
let builder_id_1 = 0;
let builder_id_2 = 1;
// We will have a total of 3 execution plans to execute.
let plan_id_1 = 0;
let plan_id_2 = 1;
let plan_id_3 = 2;
let builder_1 = TestOptimizationBuilder::new(builder_id_1, vec![operation_1(), operation_2()]);
let builder_2 = TestOptimizationBuilder::new(builder_id_2, vec![operation_2(), operation_2()]);
let mut stream = TestStream::new(vec![Box::new(builder_1), Box::new(builder_2)]);
// builder_1 is still waiting to see next op is operation_2
// builder_2 is closed because it's not the right operation
stream.add(operation_1());
stream.assert_number_of_operations(1);
stream.assert_number_of_executions(0);
// No optimization found for the first two operations.
stream.add(operation_1());
stream.assert_number_of_operations(0);
stream.assert_number_of_executions(1);
stream.assert_last_executed(plan_id_1);
stream.assert_plan(
plan_id_1,
ExecutionPlan {
operations: vec![operation_1(), operation_1()],
triggers: vec![ExecutionTrigger::Always],
optimization: BlockOptimization::new(ExecutionStrategy::operations(2), Vec::new()),
},
);
// Nothing to execute.
stream.add(operation_1());
stream.assert_number_of_operations(1);
stream.assert_number_of_executions(1);
// Now we should trigger the first optimization builder.
stream.add(operation_2());
stream.assert_number_of_operations(0);
stream.assert_number_of_executions(2);
stream.assert_last_executed(plan_id_2);
stream.assert_plan(
plan_id_2,
ExecutionPlan {
operations: vec![operation_1(), operation_2()],
triggers: vec![ExecutionTrigger::Always],
optimization: BlockOptimization::new(
ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)),
vec![0, 1],
),
},
);
// Nothing to execute.
stream.add(operation_2());
stream.assert_number_of_operations(1);
stream.assert_number_of_executions(2);
// Now we should trigger the second optimization builder.
stream.add(operation_2());
stream.assert_number_of_operations(0);
stream.assert_number_of_executions(3);
stream.assert_last_executed(plan_id_3);
stream.assert_plan(
plan_id_3,
ExecutionPlan {
operations: vec![operation_2(), operation_2()],
triggers: vec![ExecutionTrigger::Always],
optimization: BlockOptimization {
strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_2, 2)),
ordering: vec![0, 1],
},
},
);
// Nothing to execute.
stream.add(operation_1());
stream.assert_number_of_operations(1);
stream.assert_number_of_executions(3);
// Now we should trigger the first optimization builder (second plan).
stream.add(operation_2());
stream.assert_number_of_operations(0);
stream.assert_number_of_executions(4);
stream.assert_last_executed(plan_id_2);
stream.assert_plan(
plan_id_2,
ExecutionPlan {
operations: vec![operation_1(), operation_2()],
triggers: vec![ExecutionTrigger::Always],
optimization: BlockOptimization {
strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)),
ordering: vec![0, 1],
},
},
);
// Nothing to execute.
stream.add(operation_2());
stream.assert_number_of_operations(1);
stream.assert_number_of_executions(4);
// Now we should trigger the first optimization builder (third plan).
stream.add(operation_2());
stream.assert_number_of_operations(0);
stream.assert_number_of_executions(5);
stream.assert_last_executed(plan_id_3);
}
/// In this scenario we will never use an optimization, but we check that we reuse the execution plan stored.
#[test]
fn should_reuse_basic_operations() {
let builder_id_1 = 0;
let plan_id_1 = 0;
let plan_id_2 = 1;
let builder_1 = TestOptimizationBuilder::new(builder_id_1, vec![operation_1(), operation_2()]);
let mut stream = TestStream::new(vec![Box::new(builder_1)]);
stream.add(operation_3());
stream.assert_last_executed(plan_id_1);
stream.assert_number_of_operations(0);
stream.assert_plan(
plan_id_1,
ExecutionPlan {
operations: vec![operation_3()],
triggers: vec![ExecutionTrigger::Always],
optimization: BlockOptimization {
strategy: ExecutionStrategy::operations(1),
ordering: vec![0],
},
},
);
stream.add(operation_3());
stream.assert_last_executed(plan_id_1);
stream.assert_number_of_operations(0);
stream.assert_plan(
plan_id_1,
ExecutionPlan {
operations: vec![operation_3()],
triggers: vec![ExecutionTrigger::Always],
optimization: BlockOptimization {
strategy: ExecutionStrategy::operations(1),
ordering: vec![0],
},
},
);
// Lazy try to build optimization 1.
stream.add(operation_1());
// But not possible.
stream.add(operation_3());
// Creates a new plan with both operations.
stream.assert_plan(
plan_id_2,
ExecutionPlan {
operations: vec![operation_1(), operation_3()],
triggers: vec![ExecutionTrigger::Always],
optimization: BlockOptimization {
strategy: ExecutionStrategy::operations(2),
ordering: vec![0],
},
},
);
stream.assert_number_of_operations(0);
stream.assert_last_executed(plan_id_2);
}
// In this scenario we validate that we support multiple optimization builders with overlapping
// operations.
//
// This is a very long scenario that validates a lot of things.
#[test]
fn should_support_overlapping_optimizations() {
// We have 2 different optimization builders in this test case.
let builder_id_1 = 0;
let builder_id_2 = 0;
// We will have a total of 5 execution plans to execute.
let plan_id_1 = 0;
let plan_id_2 = 1;
let plan_id_3 = 2;
let plan_id_4 = 3;
let plan_id_5 = 4;
let builder_1 = TestOptimizationBuilder::new(builder_id_1, vec![operation_1(), operation_2()]);
let builder_2 = TestOptimizationBuilder::new(
builder_id_2,
vec![operation_1(), operation_2(), operation_1(), operation_1()],
);
let mut stream = TestStream::new(vec![Box::new(builder_1), Box::new(builder_2)]);
stream.add(operation_1());
stream.assert_number_of_operations(1);
stream.assert_number_of_executions(0);
stream.add(operation_2());
stream.assert_number_of_operations(2);
stream.assert_number_of_executions(0);
stream.add(operation_1());
stream.assert_number_of_operations(3);
stream.assert_number_of_executions(0);
stream.add(operation_2());
stream.assert_number_of_operations(2);
stream.assert_number_of_executions(1);
stream.assert_last_executed(plan_id_1);
stream.assert_plan(
plan_id_1,
ExecutionPlan {
operations: vec![operation_1(), operation_2()],
triggers: vec![ExecutionTrigger::OnOperations(vec![
operation_1(),
operation_2(),
])],
optimization: BlockOptimization {
strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)),
ordering: vec![0, 1],
},
},
);
stream.add(operation_2());
stream.assert_number_of_operations(0);
stream.assert_number_of_executions(3);
stream.assert_plan(
plan_id_1,
ExecutionPlan {
operations: vec![operation_1(), operation_2()],
triggers: vec![
ExecutionTrigger::OnOperations(vec![operation_1(), operation_2()]),
ExecutionTrigger::OnOperations(vec![operation_2()]),
],
optimization: BlockOptimization {
strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)),
ordering: vec![0, 1],
},
},
);
stream.assert_plan(
plan_id_2,
ExecutionPlan {
operations: vec![operation_2()],
triggers: vec![ExecutionTrigger::Always],
optimization: BlockOptimization {
strategy: ExecutionStrategy::operations(1),
ordering: vec![0],
},
},
);
stream.add(operation_1());
stream.assert_number_of_operations(1);
stream.assert_number_of_executions(3);
stream.add(operation_2());
stream.assert_number_of_operations(2);
stream.assert_number_of_executions(3);
stream.add(operation_1());
stream.assert_number_of_operations(3);
stream.assert_number_of_executions(3);
stream.add(operation_1());
stream.assert_number_of_operations(0);
stream.assert_number_of_executions(4);
stream.assert_plan(
plan_id_3,
ExecutionPlan {
operations: vec![operation_1(), operation_2(), operation_1(), operation_1()],
triggers: vec![ExecutionTrigger::Always],
optimization: BlockOptimization {
strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 4)),
ordering: vec![0],
},
},
);
stream.add(operation_1());
stream.assert_number_of_operations(1);
stream.assert_number_of_executions(4);
stream.add(operation_2());
stream.assert_number_of_operations(2);
stream.assert_number_of_executions(4);
stream.add(operation_1());
stream.assert_number_of_operations(3);
stream.assert_number_of_executions(4);
stream.sync();
stream.assert_number_of_operations(0);
stream.assert_number_of_executions(6);
stream.assert_plan(
plan_id_1,
ExecutionPlan {
operations: vec![operation_1(), operation_2()],
triggers: vec![
ExecutionTrigger::OnOperations(vec![operation_1(), operation_2()]),
ExecutionTrigger::OnOperations(vec![operation_2()]),
ExecutionTrigger::OnSync,
],
optimization: BlockOptimization {
strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)),
ordering: vec![0, 1],
},
},
);
stream.assert_plan(
plan_id_4,
ExecutionPlan {
operations: vec![operation_1()],
triggers: vec![ExecutionTrigger::OnSync],
optimization: BlockOptimization {
strategy: ExecutionStrategy::operations(1),
ordering: vec![0],
},
},
);
stream.add(operation_3());
stream.assert_last_executed(plan_id_5);
stream.assert_plan(
plan_id_5,
ExecutionPlan {
operations: vec![operation_3()],
triggers: vec![ExecutionTrigger::Always],
optimization: BlockOptimization {
strategy: ExecutionStrategy::operations(1),
ordering: vec![0],
},
},
);
stream.add(operation_3());
stream.assert_last_executed(plan_id_5);
}
impl TestStream {
/// Create a new stream with the given optimization builders.
fn new(optimizations: Vec<Box<dyn OperationFuser<TestOptimization>>>) -> Self {
Self {
processor: Processor::<TestOptimization>::new(optimizations),
store: ExecutionPlanStore::<TestOptimization>::new(),
executed: Vec::new(),
operations: Vec::new(),
}
}
/// Add an operation to the stream.
fn add(&mut self, operation: OperationIr) {
self.operations.push(operation);
self.processor.process(
TestSegment::new(&mut self.operations, &mut self.executed),
&mut self.store,
ExecutionMode::Lazy,
);
}
/// Sync the stream.
fn sync(&mut self) {
self.processor.process(
TestSegment::new(&mut self.operations, &mut self.executed),
&mut self.store,
ExecutionMode::Sync,
);
}
/// Assert that the plan has been executed as provided.
fn assert_plan(&self, id: ExecutionPlanId, expected: ExecutionPlan<TestOptimization>) {
let actual = self.store.get_unchecked(id);
assert_eq!(actual.operations, expected.operations, "Same operations");
assert_eq!(actual.triggers, expected.triggers, "Same triggers");
}
/// Assert that the given plan id has been the last executed.
fn assert_last_executed(&self, id: ExecutionPlanId) {
match self.executed.last() {
Some(last_id) => assert_eq!(*last_id, id),
None => panic!("No plan has been executed"),
}
}
/// Assert the number of executions since the start of the stream.
fn assert_number_of_executions(&self, number: usize) {
assert_eq!(self.executed.len(), number);
}
/// Assert the number of operations queued.
fn assert_number_of_operations(&self, number: usize) {
assert_eq!(self.operations.len(), number);
}
}
impl TestOptimizationBuilder {
/// Create a new optimization builder that follows a pattern with a trigger.
pub fn new(builder_id: usize, operations: Vec<OperationIr>) -> Self {
Self {
builder_id,
expected_operations: operations,
actual: Vec::new(),
}
}
}
impl OperationFuser<TestOptimization> for TestOptimizationBuilder {
/// Register a new operation.
fn fuse(&mut self, operation: &OperationIr) {
self.actual.push(operation.clone());
}
/// Build the optimization.
fn finish(&mut self) -> TestOptimization {
TestOptimization::new(self.builder_id, self.len())
}
/// Reset the state.
fn reset(&mut self) {
self.actual.clear();
}
/// Return the optimization status.
fn status(&self) -> FuserStatus {
if self.actual.len() < self.expected_operations.len() {
let operations = &self.expected_operations[0..self.actual.len()];
return match self.actual == operations {
// Still optimizing.
true => FuserStatus::Open,
// Never gonna be possible on that stream.
false => FuserStatus::Closed,
};
}
FuserStatus::Closed
}
/// Return the properties of this optimization.
fn properties(&self) -> FuserProperties {
if self.actual.len() < self.expected_operations.len() {
// Optimization not possible.
return FuserProperties {
score: 0,
ready: false,
};
}
let stream_is_ok =
self.actual[0..self.expected_operations.len()] == self.expected_operations;
if !stream_is_ok {
// Optimization not possible.
return FuserProperties {
score: 0,
ready: false,
};
}
// Optimization possible.
FuserProperties {
score: 1,
ready: true,
}
}
// The number of operations that should be handle by the optimization.
fn len(&self) -> usize {
self.expected_operations.len()
}
fn clone_dyn(&self) -> Box<dyn OperationFuser<TestOptimization>> {
Box::new(self.clone())
}
}
impl StreamSegment<TestOptimization> for TestSegment<'_> {
// The operations in the process.
fn operations(&self) -> &[OperationIr] {
self.operations
}
// Execute the process.
fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore<TestOptimization>) {
let execution_plan = store.get_unchecked(id);
self.execute_strategy(&execution_plan.optimization.strategy);
self.executed.push(id);
}
}
impl TestSegment<'_> {
fn execute_strategy(&mut self, strategy: &ExecutionStrategy<TestOptimization>) {
match strategy {
ExecutionStrategy::Optimization { opt, .. } => {
self.operations.drain(0..opt.size);
}
ExecutionStrategy::Operations { ordering } => {
self.operations.drain(0..ordering.len());
}
ExecutionStrategy::Composed(strategies) => {
for strategy in strategies {
self.execute_strategy(strategy);
}
}
}
}
}
/// Just a simple operation.
pub fn operation_1() -> OperationIr {
OperationIr::NumericFloat(
DType::F32,
NumericOperationIr::Add(BinaryOpIr {
lhs: TensorIr {
id: TensorId::new(0),
shape: Shape::new([32, 32]),
status: TensorStatus::ReadOnly,
dtype: DType::F32,
},
rhs: TensorIr {
id: TensorId::new(1),
shape: Shape::new([32, 32]),
status: TensorStatus::ReadOnly,
dtype: DType::F32,
},
out: TensorIr {
id: TensorId::new(2),
shape: Shape::new([32, 32]),
status: TensorStatus::NotInit,
dtype: DType::F32,
},
}),
)
}
/// Just a simple operation.
pub fn operation_2() -> OperationIr {
OperationIr::NumericFloat(
DType::F32,
NumericOperationIr::AddScalar(ScalarOpIr {
lhs: TensorIr {
id: TensorId::new(0),
shape: Shape::new([32, 32]),
status: TensorStatus::ReadOnly,
dtype: DType::F32,
},
rhs: ScalarIr::Float(5.0),
out: TensorIr {
id: TensorId::new(2),
shape: Shape::new([32, 32]),
status: TensorStatus::NotInit,
dtype: DType::F32,
},
}),
)
}
/// Just a simple operation.
pub fn operation_3() -> OperationIr {
OperationIr::Float(
DType::F32,
FloatOperationIr::Log(UnaryOpIr {
input: TensorIr {
id: TensorId::new(0),
shape: Shape::new([32, 32]),
status: TensorStatus::ReadOnly,
dtype: DType::F32,
},
out: TensorIr {
id: TensorId::new(0),
shape: Shape::new([32, 32]),
status: TensorStatus::NotInit,
dtype: DType::F32,
},
}),
)
}

View File

@@ -0,0 +1,136 @@
use burn_ir::OperationIr;
use crate::stream::store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger};
/// Compare each operation in the list of operations provided by the [store](OperationsStore)
/// to verify if the newly added operations match the original list.
///
/// It is used by the [policy](crate::stream::execution::Policy) to check each candidate as well
/// as to verify if a list of operations is optimal to execute based on their triggers.
#[derive(Debug)]
pub(crate) struct OperationsValidator<ID> {
/// The ID used to retrieve the operation list.
pub(crate) id: ID,
/// The current [state](MatchingState).
pub(crate) state: ValidatorState,
}
/// The state of the validator.
#[derive(Debug)]
pub(crate) enum ValidatorState {
/// A matching operation list has been found.
Found { size: usize },
/// No matching operation list has been found.
Invalidated,
/// Potentially going to find a matching operation list when more operations are added.
Validating,
}
/// Provides a list of operations based on an Id.
pub(crate) trait OperationsStore {
/// The type used for the identifier.
type Id: Copy;
/// retrieve the list of operations corresponding on the provided id.
fn get(&self, id: Self::Id) -> &[OperationIr];
}
impl<ID> OperationsValidator<ID> {
/// Create a new validator.
pub(crate) fn new(id: ID) -> Self {
Self {
id,
state: ValidatorState::Validating,
}
}
/// Update the state of the validator based on the newly added operation.
pub(crate) fn update<S>(&mut self, added: &OperationIr, added_position: usize, store: &S)
where
S: OperationsStore<Id = ID>,
ID: PartialEq + Copy,
{
match &self.state {
ValidatorState::Found { size: _ } => return,
ValidatorState::Invalidated => return,
ValidatorState::Validating => {}
};
let item = store.get(self.id);
let operation_candidate = match item.get(added_position) {
Some(val) => val,
None => {
self.state = ValidatorState::Invalidated;
return;
}
};
if operation_candidate != added {
self.state = ValidatorState::Invalidated;
return;
}
// Finished
if item.len() == added_position + 1 {
self.state = ValidatorState::Found { size: item.len() };
}
}
}
/// [Operations store](OperationsStore) used to retrieve the list of operations for a trigger.
#[derive(new)]
pub(crate) struct TriggerOperationsStore<'a, O> {
id: ExecutionPlanId,
store: &'a ExecutionPlanStore<O>,
}
/// Validates when operations match a trigger.
#[derive(Debug)]
pub(crate) enum TriggerValidator {
OnOperations {
matching: OperationsValidator<TriggerId>,
progress: TriggerProgress,
},
Always,
OnSync,
}
/// The progress made into the trigger validation process.
#[derive(Debug)]
pub(crate) enum TriggerProgress {
/// When the validation hasn't started.
NotInit,
/// The number of operations that have been checked.
NumChecked(usize),
}
/// An execution plan can have many triggers, so we use the position in the list to identify a
/// trigger.
pub(crate) type TriggerId = usize;
impl<O: core::fmt::Debug> OperationsStore for TriggerOperationsStore<'_, O> {
type Id = TriggerId;
fn get(&self, id: Self::Id) -> &[OperationIr] {
match &self.store.get_unchecked(self.id).triggers[id] {
ExecutionTrigger::OnOperations(operations) => operations,
ExecutionTrigger::OnSync => &[],
ExecutionTrigger::Always => &[],
}
}
}
/// [Operations store](OperationsStore) used to retrieve the list of operations for an
/// [execution plan](crate::stream::store::ExecutionPlan).
#[derive(new)]
pub(crate) struct ExecutionPlanOperationsStore<'a, O> {
store: &'a ExecutionPlanStore<O>,
}
impl<O: core::fmt::Debug> OperationsStore for ExecutionPlanOperationsStore<'_, O> {
type Id = ExecutionPlanId;
fn get(&self, id: Self::Id) -> &[OperationIr] {
&self.store.get_unchecked(id).operations
}
}

View File

@@ -0,0 +1,249 @@
use hashbrown::HashMap;
use std::{
fmt::Display,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
mpsc::SyncSender,
},
thread::JoinHandle,
time::Duration,
};
use burn_ir::{HandleContainer, TensorId, TensorStatus};
use burn_std::id::StreamId;
use crate::FusionRuntime;
use super::Stream;
/// Memory checks struct to validate there is no memory leak with the fusion runtime.
#[derive(Clone)]
pub(crate) struct MemoryChecks {
sender: SyncSender<Message>,
num_queued: Arc<AtomicU64>,
// Keeps track of its thread.
_handle: Arc<JoinHandle<()>>,
}
enum Message {
Register(StreamAnalyses),
Check(SyncSender<MemoryReport>),
}
enum MemoryReport {
Success,
NotReady,
NotStarted,
Fail(String),
}
#[derive(Default)]
struct StreamAnalyses {
streams: HashMap<StreamId, Analysis>,
num_handles: usize,
}
impl Display for StreamAnalyses {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("\n==== Fusion Memory Report ====\n")?;
f.write_fmt(format_args!(" - Handles: {}\n", self.num_handles))?;
f.write_fmt(format_args!(" - Streams: {}\n", self.streams.len()))?;
for (id, analysis) in self.streams.iter() {
f.write_fmt(format_args!(
" - {} => operations: {} cursor: {}\n",
id, analysis.num_operations, analysis.cursor
))?;
for (tid, (origin, status)) in analysis.variables.iter() {
f.write_fmt(format_args!(
" - {tid} => origin: {origin} status: {status:?}\n",
))?;
}
}
f.write_str("==============================\n")
}
}
#[derive(Default, Debug)]
struct Analysis {
variables: HashMap<TensorId, (StreamId, TensorStatus)>,
num_operations: usize,
cursor: u64,
}
#[macro_export]
/// Export memory checks tests.
macro_rules! memory_checks {
() => {
#[cfg(test)]
mod memory_checks {
#[test]
fn test_memory_leaks() {
burn_fusion::stream::memory_checks::check_memory_leaks();
}
}
};
}
static INSTANCE: spin::Mutex<Option<MemoryChecks>> = spin::Mutex::new(None);
/// Performs memory checks and panics if a leak is discovered.
pub fn check_memory_leaks() {
let mut num_try_uninit = 0;
let max_try = 25;
loop {
let report = fetch_memory_report();
match report {
MemoryReport::Success => return,
MemoryReport::NotReady => {
num_try_uninit = 0;
std::thread::sleep(Duration::from_millis(100))
}
MemoryReport::NotStarted => {
if num_try_uninit >= max_try {
// Nothing is running on the fusion runtime.
return;
}
num_try_uninit += 1;
std::thread::sleep(Duration::from_millis(100))
}
MemoryReport::Fail(msg) => panic!("{msg}"),
}
}
}
fn fetch_memory_report() -> MemoryReport {
let report = INSTANCE.lock();
let report = match report.as_ref() {
Some(client) => client,
None => return MemoryReport::NotStarted,
};
let (sender, rec) = std::sync::mpsc::sync_channel(1);
match report.sender.send(Message::Check(sender)) {
Ok(_) => {}
Err(err) => {
panic!("Channel closed can't send the check call: {err:?}")
}
};
match rec.recv() {
Ok(report) => report,
Err(err) => panic!("Received an error from fetching check results: {err}"),
}
}
impl Default for MemoryChecks {
fn default() -> Self {
let mut instance = INSTANCE.lock();
let result = match instance.as_mut() {
Some(client) => client.clone(),
None => {
let this = Self::spawn_new();
*instance = Some(this.clone());
this
}
};
core::mem::drop(instance);
result
}
}
impl MemoryChecks {
pub(crate) fn check<R: FusionRuntime>(
&mut self,
streams: &HashMap<StreamId, Stream<R>>,
handles: &HandleContainer<R::FusionHandle>,
) {
let mut analyses = StreamAnalyses {
num_handles: handles.num_handles(),
streams: Default::default(),
};
for (id, s) in streams.iter() {
let analysis = Analysis {
variables: s.queue.variables.clone(),
num_operations: s.queue.global.len(),
cursor: s.cursor,
};
analyses.streams.insert(*id, analysis);
}
self.num_queued.fetch_add(1, Ordering::Relaxed);
match self.sender.send(Message::Register(analyses)) {
Ok(..) => {}
Err(err) => {
panic!("Can't register memory checks analysis: {err:?}")
}
}
}
fn spawn_new() -> Self {
let (sender, rec) = std::sync::mpsc::sync_channel(100);
let num_queued = Arc::new(AtomicU64::new(0));
let num_queued_moved = num_queued.clone();
let handle = std::thread::spawn(move || {
let mut last_analyses = None;
loop {
let payload = match rec.recv() {
Err(_err) => {
// A client has panic, safe to skip as it may be normal.
continue;
}
Ok(payload) => payload,
};
match payload {
Message::Register(payload) => {
last_analyses = Some(payload);
num_queued_moved.fetch_sub(1, Ordering::Relaxed);
}
Message::Check(callback) => {
if num_queued_moved.load(Ordering::Relaxed) > 1 {
callback.send(MemoryReport::NotReady).unwrap();
continue;
}
// We assume that if nothing has been registered in the last second
// while being at a count of 1, it's the end.
std::thread::sleep(Duration::from_secs(5));
if num_queued_moved.load(Ordering::Relaxed) <= 1 {
match last_analyses.take() {
Some(val) => {
callback.send(Self::final_check(val)).unwrap();
}
None => {
callback
.send(MemoryReport::Fail("No analyses".into()))
.unwrap();
}
}
} else {
callback.send(MemoryReport::NotReady).unwrap();
}
}
}
}
});
Self {
sender,
num_queued,
_handle: Arc::new(handle),
}
}
fn final_check(analyses: StreamAnalyses) -> MemoryReport {
if !analyses.streams.is_empty() || analyses.num_handles > 0 {
return MemoryReport::Fail(format!("{analyses}"));
}
MemoryReport::Success
}
}

View File

@@ -0,0 +1,33 @@
pub(crate) mod execution;
pub(crate) mod queue;
pub(crate) mod shared_tensors;
pub(crate) mod store;
#[cfg(feature = "memory-checks")]
/// Memory checks module.
pub mod memory_checks;
#[cfg(not(feature = "memory-checks"))]
#[macro_export]
/// Export memory checks tests.
macro_rules! memory_checks {
() => {
#[cfg(test)]
mod memory_checks {
#[ignore = "'memory-checks' disabled"]
#[test]
fn test_memory_leaks() {
//
}
}
};
}
mod base;
mod context;
mod multi;
pub use base::*;
pub use context::*;
pub use execution::*;
pub use multi::*;

View File

@@ -0,0 +1,472 @@
use std::sync::Arc;
use burn_ir::{HandleContainer, OperationIr, TensorId, TensorIr, TensorStatus};
use hashbrown::{HashMap, HashSet};
use super::{
StreamId,
execution::{ExecutionMode, Operation, Processor, StreamSegment},
queue::OperationQueue,
shared_tensors::SharedTensors,
store::{ExecutionPlanId, ExecutionPlanStore},
};
use crate::{
DropOp, FusionRuntime,
stream::shared_tensors::{SharedTensorAnalysis, SharedTensorDropAction},
};
/// Keep track of multiple concurrent lazy streams of operations.
pub struct MultiStream<R: FusionRuntime> {
streams: HashMap<StreamId, Stream<R>>,
optimizations: ExecutionPlanStore<R::Optimization>,
shared_tensors: SharedTensors,
device: R::FusionDevice,
#[cfg(feature = "memory-checks")]
memory_checks: super::memory_checks::MemoryChecks,
}
#[derive(Debug)]
enum DropAction {
SkipSharedTensor,
ForceSharedTensor(Vec<StreamId>, TensorId),
ContinueDrop,
}
impl<R: FusionRuntime> MultiStream<R> {
pub(crate) fn new(device: R::FusionDevice) -> Self {
Self {
streams: HashMap::new(),
optimizations: ExecutionPlanStore::new(),
shared_tensors: SharedTensors::default(),
device,
#[cfg(feature = "memory-checks")]
memory_checks: super::memory_checks::MemoryChecks::default(),
}
}
/// Register a new tensor operation.
pub(crate) fn register(
&mut self,
streams: OperationStreams,
mut repr: OperationIr,
operation: Arc<dyn Operation<R>>,
handles: &mut HandleContainer<R::FusionHandle>,
) {
let id = self.resolve_streams(&streams, handles, &mut repr);
let drop_action = match &mut repr {
OperationIr::Drop(tensor_ir) => Some(self.handle_drop_op(id, tensor_ir)),
_ => None,
};
let sync = match drop_action {
Some(DropAction::SkipSharedTensor) => return,
Some(DropAction::ContinueDrop) => true,
Some(DropAction::ForceSharedTensor(stream_ids, tid)) => {
for stream_id in stream_ids {
if let Some(stream) = self.streams.get_mut(&stream_id) {
stream.queue.variables.remove(&tid);
if stream.queue.variables.is_empty() {
self.streams.remove(&stream_id);
}
}
}
true
}
None => false,
};
let num_executed = self.enqueue_operation(id, repr, &streams, operation, handles);
if num_executed > 0
&& let Some(stream) = self.streams.get_mut(&id)
{
let cleared = self.shared_tensors.on_executed_ops(id, stream);
self.clear_shared_tensors(&cleared, id);
let to_drop = self.shared_tensors.clear_tensors(cleared);
self.drop_shared_tensors(to_drop, handles, id);
}
let stream = match self.streams.get(&id) {
Some(val) => val,
None => {
#[cfg(feature = "memory-checks")]
self.memory_checks.check(&self.streams, handles);
return;
}
};
if !stream.queue.variables.is_empty() && sync {
// Not draining the queue can cause a memory leak when a stream is closing.
self.drain(handles, id);
}
#[cfg(feature = "memory-checks")]
self.memory_checks.check(&self.streams, handles);
}
/// Checks if the current operation is a drop.
///
/// When a tensor is shared across multiple concurrent streams, dropping a tensor might cause a
/// problem when the same tensor is registered lazily on another stream, but not yet executed.
fn handle_drop_op(&mut self, id: StreamId, tensor_ir: &mut TensorIr) -> DropAction {
match !matches!(tensor_ir.status, TensorStatus::ReadWrite) {
true => {
let stream = self.streams.get(&id);
let on_drop = self
.shared_tensors
.on_drop(id, tensor_ir.id, stream.is_none());
match on_drop {
SharedTensorDropAction::ForceDrop(streams) => {
tensor_ir.status = TensorStatus::ReadWrite;
DropAction::ForceSharedTensor(streams, tensor_ir.id)
}
SharedTensorDropAction::Skip => DropAction::SkipSharedTensor,
}
}
false => DropAction::ContinueDrop,
}
}
/// Enqueue an operation on the queue.
fn enqueue_operation(
&mut self,
id: StreamId,
repr: OperationIr,
streams: &OperationStreams,
operation: Arc<dyn Operation<R>>,
handles: &mut HandleContainer<R::FusionHandle>,
) -> usize {
let stream = match self.streams.get_mut(&id) {
Some(stream) => stream,
None => {
let stream = Stream::new(self.device.clone());
self.streams.insert(id, stream);
self.streams
.get_mut(&id)
.expect("Just added, so should be included in the hashmap.")
}
};
stream.queue.add(repr, operation, streams, id);
let len_before = stream.queue.global.len();
stream.processor.process(
Segment::new(&mut stream.queue, handles),
&mut self.optimizations,
ExecutionMode::Lazy,
);
let len_after = stream.queue.global.len();
let num_executed = len_before - len_after;
stream.cursor += num_executed as u64;
num_executed
}
/// Mark a tensor as read.
#[allow(unused_variables)]
pub fn mark_read(
&mut self,
id: StreamId,
ir: &TensorIr,
handles: &HandleContainer<R::FusionHandle>,
) {
if !matches!(ir.status, TensorStatus::ReadWrite) {
return;
};
let stream = match self.streams.get_mut(&id) {
Some(val) => val,
None => return,
};
stream.queue.variables.remove(&ir.id);
if stream.queue.variables.is_empty() {
self.streams.remove(&id);
}
#[cfg(feature = "memory-checks")]
self.memory_checks.check(&self.streams, handles);
}
/// Drain a stream
pub fn drain(&mut self, handles: &mut HandleContainer<R::FusionHandle>, id: StreamId) {
if let Some(stream) = self.streams.get_mut(&id) {
let old = unsafe { StreamId::swap(id) };
let num_executed = stream.queue.global.len();
stream.processor.process(
Segment::new(&mut stream.queue, handles),
&mut self.optimizations,
ExecutionMode::Sync,
);
stream.cursor += num_executed as u64;
let cleared = self.shared_tensors.on_executed_ops(id, stream);
self.clear_shared_tensors(&cleared, id);
let to_drop = self.shared_tensors.clear_tensors(cleared);
self.drop_shared_tensors(to_drop, handles, id);
unsafe {
StreamId::swap(old);
};
}
}
/// When one of the provided streams is different from the current stream, we drain them.
///
/// Returns the selected stream id.
fn resolve_streams(
&mut self,
streams: &OperationStreams,
handles: &mut HandleContainer<R::FusionHandle>,
op: &mut OperationIr,
) -> StreamId {
let current = streams.current;
let nodes = op.nodes();
let analysis = self.analyse_shared_tensors(&nodes, streams, current);
self.merge_streams_timelines(handles, &analysis, current, &nodes);
self.register_shared_tensors_drop(&analysis, op);
current
}
/// Drain the stream only if one of the tensor in the given nodes is also included in the
/// stream queue.
fn resolve_stream(
&mut self,
handles: &mut HandleContainer<R::FusionHandle>,
id: StreamId,
nodes: &[&TensorIr],
) {
if let Some(stream) = self.streams.get(&id) {
for node in nodes {
if stream.queue.variables.contains_key(&node.id) {
self.drain(handles, id);
return;
}
}
}
}
fn analyse_shared_tensors(
&mut self,
nodes: &[&TensorIr],
streams: &OperationStreams,
current: StreamId,
) -> MultiSharedTensorAnalysis {
let mut shared_analysis = MultiSharedTensorAnalysis::default();
for node in nodes.iter() {
let analysis = self
.shared_tensors
.analyse(current, node, streams, &self.streams);
match analysis {
SharedTensorAnalysis::SharedFromCurrentStream => {
shared_analysis.current.push(node.id);
}
SharedTensorAnalysis::NotShared => {}
SharedTensorAnalysis::SharedFromExistingStream {
stream_id,
original_cursor,
} => {
shared_analysis
.existing
.push((node.id, stream_id, original_cursor));
}
SharedTensorAnalysis::SharedFromNewStream { stream_id } => {
shared_analysis.new.push((node.id, stream_id));
}
}
}
shared_analysis
}
fn merge_streams_timelines(
&mut self,
handles: &mut HandleContainer<R::FusionHandle>,
analysis: &MultiSharedTensorAnalysis,
current: StreamId,
nodes: &[&TensorIr],
) {
// If we only have current tensors that are shared, we're safe to not sync the timelines.
if analysis.new.is_empty() && analysis.existing.is_empty() {
return;
}
let mut streams_to_sync = HashSet::new();
for (_tensor_id, stream_id) in analysis.new.iter() {
streams_to_sync.insert(*stream_id);
}
for (_tensor_id, stream_id, original_cursor) in analysis.existing.iter() {
if let Some(stream) = self.streams.get(stream_id) {
// We only have to sync a stream when the stream isn't up to date with
// the original cursor of the current operation.
if stream.cursor <= *original_cursor && *stream_id != current {
streams_to_sync.insert(*stream_id);
}
}
}
for id in streams_to_sync.drain() {
log::trace!("Drain stream {id} for use in current {current}");
self.resolve_stream(handles, id, nodes);
}
}
fn register_shared_tensors_drop(
&mut self,
analysis: &MultiSharedTensorAnalysis,
op: &mut OperationIr,
) {
let mut readonly_tensors = Vec::new();
for (tensor_id, _stream_id) in analysis.new.iter() {
readonly_tensors.push(*tensor_id);
}
for (tensor_id, _stream_id, _cursor) in analysis.existing.iter() {
readonly_tensors.push(*tensor_id);
}
for tensor_id in analysis.current.iter() {
readonly_tensors.push(*tensor_id);
}
self.shared_tensors
.tag_manual_drop(op.mark_read_only(&readonly_tensors));
}
fn drop_shared_tensors(
&mut self,
tensors: Vec<TensorIr>,
handles: &mut HandleContainer<R::FusionHandle>,
current: StreamId,
) {
for (stream_id, s) in self.streams.iter_mut() {
for tensor in tensors.iter() {
if let Some((original, _status)) = s.queue.variables.get(&tensor.id)
&& original != stream_id
{
s.queue.variables.remove(&tensor.id);
}
}
}
for tensor in tensors {
let streams = OperationStreams {
streams: HashMap::new(),
current,
};
let op = Arc::new(DropOp { id: tensor.id });
self.register(streams, OperationIr::Drop(tensor), op, handles);
}
}
fn clear_shared_tensors(&mut self, tensors: &[TensorId], current: StreamId) {
let mut to_remove = Vec::new();
for (stream_id, s) in self.streams.iter_mut() {
for tensor in tensors.iter() {
s.queue.variables.remove(tensor);
}
if s.queue.variables.is_empty() && current != *stream_id {
to_remove.push(*stream_id);
}
}
for s in to_remove {
self.streams.remove(&s);
}
}
}
pub(crate) struct Stream<R: FusionRuntime> {
pub(crate) queue: OperationQueue<R>,
processor: Processor<R::Optimization>,
pub(crate) cursor: u64,
}
#[derive(new)]
struct Segment<'a, R: FusionRuntime> {
queue: &'a mut OperationQueue<R>,
handles: &'a mut HandleContainer<R::FusionHandle>,
}
impl<R: FusionRuntime> StreamSegment<R::Optimization> for Segment<'_, R> {
fn operations(&self) -> &[OperationIr] {
&self.queue.relative
}
fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore<R::Optimization>) {
self.queue.execute(id, self.handles, store)
}
}
impl<R: FusionRuntime> Stream<R> {
fn new(device: R::FusionDevice) -> Self {
Self {
processor: Processor::new(R::fusers(device)),
queue: OperationQueue::new(),
cursor: 0,
}
}
}
#[derive(Debug)]
/// Manage the streams used for the current [operation](OperationIr).
pub struct OperationStreams {
pub(crate) streams: HashMap<TensorId, StreamId>,
pub(crate) current: StreamId,
}
impl Default for OperationStreams {
fn default() -> Self {
Self {
streams: HashMap::new(),
current: StreamId::current(),
}
}
}
impl OperationStreams {
/// Register a tensor in the list of streams used for the current [operation](OperationIr).
///
/// You only need to register input tensors, not the outputs.
/// So init tensor operations should have no streams registered.
pub fn tensor<R: FusionRuntime>(&mut self, tensor: &crate::FusionTensor<R>) {
self.streams.insert(tensor.id, tensor.stream);
}
pub(crate) fn get(&self, id: TensorId) -> Option<StreamId> {
self.streams.get(&id).cloned()
}
/// Create new operation streams with the given inputs.
///
/// The inputs are automatically registered.
pub fn with_inputs<'a, R: FusionRuntime + 'a, I>(tensors: I) -> Self
where
I: IntoIterator<Item = &'a crate::FusionTensor<R>>,
{
let mut streams = OperationStreams::default();
for tensor in tensors.into_iter() {
streams.tensor(tensor)
}
streams
}
}
#[derive(Default, Debug)]
struct MultiSharedTensorAnalysis {
/// Tensors that are shared with other streams, but we're currently executing on the same stream
/// the tensor was originally created.
current: Vec<TensorId>,
/// Tensors that are shared with new streams.
new: Vec<(TensorId, StreamId)>,
/// Tensors that are shared with existing streams.
existing: Vec<(TensorId, StreamId, u64)>,
}

View File

@@ -0,0 +1,95 @@
use std::sync::Arc;
use crate::FusionRuntime;
use crate::stream::{OperationConverter, OperationStreams, RelativeOps, execution::Operation};
use burn_backend::StreamId;
use burn_ir::{OperationIr, TensorId, TensorStatus};
use hashbrown::HashMap;
/// A growing list of [tensor operation descriptions](OperationIr).
pub struct OperationQueue<R: FusionRuntime> {
/// List of operation descriptions. These contain the exact tensor IDs
/// and shapes so that kernels can be run correctly.
///
/// The length of this list is the same as the length of the `operations` list.
pub(crate) global: Vec<OperationIr>,
/// List of operation descriptions. The tensor IDs and shapes are relative
/// because we don't need to know the exact values, but they are sufficient to
/// determine which operations can be fused.
pub(crate) relative: Vec<OperationIr>,
pub(crate) converter: OperationConverter,
pub(crate) operations: Vec<Arc<dyn Operation<R>>>,
pub(crate) variables: HashMap<TensorId, (StreamId, TensorStatus)>,
}
impl<R: FusionRuntime> Default for OperationQueue<R> {
fn default() -> Self {
Self::new()
}
}
impl<R: FusionRuntime> OperationQueue<R> {
/// Create a new empty queue.
pub fn new() -> Self {
Self {
global: Vec::new(),
relative: Vec::new(),
converter: OperationConverter::default(),
operations: Vec::new(),
variables: HashMap::new(),
}
}
/// Add a new tensor operation to the queue.
///
/// The new [operation intermediate representation](OperationIr) will be converted to a local
/// representation that can be reused when the same pattern emerge in different but similar
/// scenario, so that the same optimization can be used.
pub fn add(
&mut self,
global: OperationIr,
operation: Arc<dyn Operation<R>>,
streams: &OperationStreams,
current: StreamId,
) {
for node in global.nodes() {
if let Some(stream_id) = streams.get(node.id) {
self.variables.insert(node.id, (stream_id, node.status));
} else {
self.variables.insert(node.id, (current, node.status));
}
}
let relative = global.to_relative(&mut self.converter);
self.relative.push(relative);
self.global.push(global);
self.operations.push(operation);
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
#[test]
fn stream_id_from_different_threads() {
let current = StreamId::current();
let thread1 = std::thread::spawn(|| (StreamId::current(), StreamId::current()));
let thread2 = std::thread::spawn(StreamId::current);
let (stream_1, stream_11) = thread1.join().unwrap();
let stream_2 = thread2.join().unwrap();
assert_ne!(current, stream_1, "Should be different from thread 1");
assert_ne!(current, stream_2, "Should be different from thread 2");
assert_ne!(
stream_1, stream_2,
"Should be different from different threads"
);
assert_eq!(
stream_1, stream_11,
"Should be the same, since same thread."
);
}
}

View File

@@ -0,0 +1,153 @@
use std::sync::Arc;
use burn_ir::{HandleContainer, TensorStatus};
use crate::{
FusionRuntime,
search::BlockOptimization,
stream::{
Context, Operation, OperationConverter, OrderedExecution, RelativeOps,
store::{ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy},
},
};
use super::OperationQueue;
impl<R: FusionRuntime> OperationQueue<R> {
/// Execute the queue partially following the execution strategy from the plan.
pub(crate) fn execute(
&mut self,
id: ExecutionPlanId,
handles: &mut HandleContainer<R::FusionHandle>,
store: &mut ExecutionPlanStore<R::Optimization>,
) {
let plan = store.get_mut_unchecked(id);
self.execute_block_optimization(&mut plan.optimization, handles);
}
fn execute_block_optimization(
&mut self,
step: &mut BlockOptimization<R::Optimization>,
handles: &mut HandleContainer<R::FusionHandle>,
) {
let mut operations = Vec::new();
core::mem::swap(&mut operations, &mut self.operations);
let (operations, num_drained) =
QueueExecution::run(step, &mut self.converter, handles, operations);
self.operations = operations;
self.drain_queue(num_drained, handles);
}
/// Bookkeeping after executing `num_drained` operations from the queue.
fn drain_queue(&mut self, num_drained: usize, handles: &mut HandleContainer<R::FusionHandle>) {
self.global[0..num_drained]
.iter()
.flat_map(|desc| desc.nodes())
.for_each(|tensor| {
if tensor.status == TensorStatus::ReadWrite {
self.variables.remove(&tensor.id);
};
handles.free(tensor)
});
self.global.drain(0..num_drained);
self.reset_relative();
}
fn reset_relative(&mut self) {
self.relative.clear();
self.converter.clear();
for node in self.global.iter() {
let relative = node.to_relative(&mut self.converter);
self.relative.push(relative);
}
}
}
/// A queue execution has the responsibility to run the provided
/// [optimization](FusionRuntime::Optimization) without holes.
enum QueueExecution<'a, R: FusionRuntime> {
Single {
handles: &'a mut HandleContainer<R::FusionHandle>,
converter: &'a mut OperationConverter,
execution: OrderedExecution<R>,
},
Multiple {
context: &'a mut Context<'a, R::FusionHandle>,
execution: OrderedExecution<R>,
},
}
impl<'a, R: FusionRuntime> QueueExecution<'a, R> {
fn run(
optimization: &mut BlockOptimization<R::Optimization>,
converter: &'a mut OperationConverter,
handles: &'a mut HandleContainer<R::FusionHandle>,
operations: Vec<Arc<dyn Operation<R>>>,
) -> (Vec<Arc<dyn Operation<R>>>, usize) {
let execution = OrderedExecution::new(operations);
if matches!(&optimization.strategy, ExecutionStrategy::Composed(..)) {
let mut context = converter.context(handles);
let mut this = QueueExecution::Multiple {
context: &mut context,
execution,
};
this = this.execute_strategy(&mut optimization.strategy);
match this {
QueueExecution::Multiple { execution, .. } => execution.finish(),
_ => unreachable!(),
}
} else {
let mut this = QueueExecution::Single {
handles,
converter,
execution,
};
this = this.execute_strategy(&mut optimization.strategy);
match this {
QueueExecution::Single { execution, .. } => execution.finish(),
_ => unreachable!(),
}
}
}
fn execute_strategy(mut self, strategy: &mut ExecutionStrategy<R::Optimization>) -> Self {
match &mut self {
QueueExecution::Single {
handles,
converter,
execution,
} => match strategy {
ExecutionStrategy::Optimization { ordering, opt } => {
let mut context = converter.context(handles);
execution.execute_optimization(opt, &mut context, ordering.clone())
}
ExecutionStrategy::Operations { ordering } => {
execution.execute_operations(handles, ordering)
}
ExecutionStrategy::Composed(_) => unreachable!(),
},
QueueExecution::Multiple { context, execution } => match strategy {
ExecutionStrategy::Optimization { opt, ordering } => {
execution.execute_optimization(opt, context, ordering.clone());
}
ExecutionStrategy::Operations { ordering } => {
execution.execute_operations(context.handles, ordering);
}
ExecutionStrategy::Composed(items) => {
for item in items.iter_mut() {
self = self.execute_strategy(item);
}
}
},
};
self
}
}

View File

@@ -0,0 +1,4 @@
mod base;
mod execution;
pub use base::*;

View File

@@ -0,0 +1,306 @@
use burn_backend::StreamId;
use burn_ir::{TensorId, TensorIr};
use hashbrown::HashMap;
use super::{OperationStreams, Stream};
use crate::FusionRuntime;
#[derive(Default)]
/// Manages tensors that are shared between multiple streams.
pub struct SharedTensors {
shared_tensors: HashMap<TensorId, SharedTensor>,
shared_tensors_manual_drop: HashMap<TensorId, TensorIr>,
}
#[derive(Default, Debug)]
/// A tensor that is shared between multiple streams.
struct SharedTensor {
streams: HashMap<StreamId, SharedTensorState>,
}
#[derive(Debug)]
struct SharedTensorState {
cursor_current: u64,
cursor_origin: u64,
}
#[derive(Debug)]
/// What do to when a tensor is dropped.
pub enum SharedTensorDropAction {
/// Performs the drop and removes the shared tensor from the provided list of
/// stream ids.
ForceDrop(Vec<StreamId>),
/// Skip the drop.
Skip,
}
#[derive(Debug)]
/// Information about a shared tensor.
pub enum SharedTensorAnalysis {
/// The tensor is not shared.
NotShared,
/// The tensor is shared, but its original stream is the current one.
SharedFromCurrentStream,
/// The tensor is shared, and its original stream is an existing stream.
SharedFromExistingStream {
/// The stream id of the existing stream.
stream_id: StreamId,
/// The position of execution in the existing stream where the tensor was created.
original_cursor: u64,
},
/// The tensor is shared, and its original stream is a new one without any operation
/// executed.
SharedFromNewStream {
/// The stream id of the new stream.
stream_id: StreamId,
},
}
impl SharedTensors {
/// Function to call when a drop operation is registered on the given stream and tensor.
pub fn on_drop(
&mut self,
stream_id: StreamId,
tensor_id: TensorId,
stream_completed: bool,
) -> SharedTensorDropAction {
let mut execute_still = false;
if let Some(shared) = self.shared_tensors.get_mut(&tensor_id) {
if stream_completed {
shared.drop(stream_id);
execute_still = shared.streams.is_empty();
}
} else {
execute_still = true;
}
if execute_still {
let state = self.shared_tensors.remove(&tensor_id);
self.shared_tensors_manual_drop.remove(&tensor_id);
return match state {
Some(val) => {
let streams = val.streams.keys().copied().collect();
SharedTensorDropAction::ForceDrop(streams)
}
None => SharedTensorDropAction::ForceDrop(Vec::new()),
};
}
SharedTensorDropAction::Skip
}
/// Function to call when one or many operations were executed on the stream.
///
/// Returns the tensor id that can be cleared with [Self::clear_tensors]
pub fn on_executed_ops<R: FusionRuntime>(
&mut self,
id: StreamId,
stream: &mut Stream<R>,
) -> Vec<TensorId> {
let mut cleared = Vec::new();
for (tensor_id, state) in self.shared_tensors.iter_mut() {
match state.update(id, stream) {
SharedTensorUpdate::RemovedFromStream(no_more_stream) => {
stream.queue.variables.remove(tensor_id);
if no_more_stream {
cleared.push(*tensor_id);
}
}
SharedTensorUpdate::ReadyForCleanup => {
cleared.push(*tensor_id);
}
SharedTensorUpdate::NoChange => {}
}
}
cleared
}
/// Clear the provided tensors and returns the list of tensors that can be manually dropped.
pub fn clear_tensors(&mut self, tensors: Vec<TensorId>) -> Vec<TensorIr> {
let mut to_drop = Vec::new();
for id in tensors {
self.shared_tensors.remove(&id);
if let Some(tensor) = self.shared_tensors_manual_drop.remove(&id) {
to_drop.push(tensor);
}
}
self.register_manual_drop(to_drop)
}
/// Analyses the current tensor and updates its state.
pub fn analyse<R: FusionRuntime>(
&mut self,
id: StreamId,
node: &TensorIr,
streams_op: &OperationStreams,
streams: &HashMap<StreamId, Stream<R>>,
) -> SharedTensorAnalysis {
let stream_id = match streams_op.streams.get(&node.id) {
Some(val) => val,
None => {
return match self.shared_tensors.contains_key(&node.id) {
true => SharedTensorAnalysis::SharedFromCurrentStream,
false => SharedTensorAnalysis::NotShared,
};
}
};
if stream_id == &id {
return match self.shared_tensors.contains_key(&node.id) {
true => SharedTensorAnalysis::SharedFromCurrentStream,
false => SharedTensorAnalysis::NotShared,
};
}
// Here the node is tagged as newly shared.
let stream_current = streams.get(&id);
let stream = streams.get(stream_id);
let state = match self.shared_tensors.get_mut(&node.id) {
Some(state) => state,
None => {
self.shared_tensors.insert(node.id, SharedTensor::default());
self.shared_tensors.get_mut(&node.id).unwrap()
}
};
state.register_new_stream(id, stream_current);
match state.register_new_stream(*stream_id, stream) {
Some(origin) => SharedTensorAnalysis::SharedFromExistingStream {
stream_id: *stream_id,
original_cursor: origin,
},
None => SharedTensorAnalysis::SharedFromNewStream {
stream_id: *stream_id,
},
}
}
/// Tag the provided tensors as manually dropped.
pub fn tag_manual_drop(&mut self, dropped: Vec<TensorIr>) {
for tensor in dropped {
self.shared_tensors_manual_drop.insert(tensor.id, tensor);
}
}
fn register_manual_drop(&mut self, mut tensors: Vec<TensorIr>) -> Vec<TensorIr> {
if self.shared_tensors_manual_drop.is_empty() {
return tensors;
}
let mut to_drop = Vec::new();
for id in self.shared_tensors_manual_drop.keys() {
if !self.shared_tensors.contains_key(id) {
to_drop.push(*id);
}
}
for id in to_drop {
let entry = self.shared_tensors_manual_drop.remove(&id).unwrap();
tensors.push(entry);
}
tensors
}
}
/// The result from a [SharedTensor::update].
pub enum SharedTensorUpdate {
/// The tensor is removed from the current stream.
///
/// Also contains if the current stream is empty.
RemovedFromStream(bool),
/// If the tensor is shared across zero streams.
ReadyForCleanup,
/// If nothing has been done from the update.
NoChange,
}
impl SharedTensor {
/// Register the tensor as also part of the given stream.
///
/// The stream might not exist yet when the current tensor is part of the first operation in
/// the newly created stream.
fn register_new_stream<R: FusionRuntime>(
&mut self,
id: StreamId,
stream: Option<&Stream<R>>,
) -> Option<u64> {
let cursor_current = match stream {
Some(stream) => stream.cursor + stream.queue.global.len() as u64,
None => 1,
};
match self.streams.get_mut(&id) {
Some(s) => {
s.cursor_current = cursor_current;
Some(s.cursor_origin)
}
None => {
let state = SharedTensorState {
cursor_current,
cursor_origin: cursor_current,
};
self.streams.insert(id, state);
None
}
}
}
/// Update the current shared tensor state on the given stream.
///
/// If the shared tensor is no longer needed on the stream, we will remove it from the list of
/// shared streams.
fn update<R: FusionRuntime>(&mut self, id: StreamId, stream: &Stream<R>) -> SharedTensorUpdate {
let entry = match self.streams.remove(&id) {
Some(val) => val,
None => {
return if self.streams.is_empty() {
SharedTensorUpdate::ReadyForCleanup
} else {
SharedTensorUpdate::NoChange
};
}
};
// We can only free the shared tensor if the latest cursor is executed.
if entry.cursor_current <= stream.cursor {
SharedTensorUpdate::RemovedFromStream(self.streams.is_empty())
} else {
self.streams.insert(id, entry);
SharedTensorUpdate::NoChange
}
}
fn drop(&mut self, id: StreamId) {
self.streams.remove(&id);
}
}
impl core::fmt::Debug for SharedTensors {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("\n==== Shared Tensors ====\n")?;
for sh in self.shared_tensors.iter() {
f.write_fmt(format_args!(" - Shared {}", sh.0))?;
for (id, state) in sh.1.streams.iter() {
f.write_fmt(format_args!(
" [{}, cursor={}..{}] ",
id, state.cursor_origin, state.cursor_current
))?;
}
f.write_str("\n")?;
}
for sh in self.shared_tensors_manual_drop.iter() {
f.write_fmt(format_args!(" - Manual Drop {}", sh.0))?;
f.write_str("\n")?;
}
f.write_str("========================\n")
}
}

View File

@@ -0,0 +1,94 @@
use std::sync::Arc;
use crate::search::BlockOptimization;
use super::{ExecutionPlanIndex, InsertQuery, SearchQuery};
use burn_ir::OperationIr;
use serde::{Deserialize, Serialize};
/// The store that contains all explorations done on a device.
#[derive(Default)]
pub(crate) struct ExecutionPlanStore<O> {
plans: Vec<ExecutionPlan<O>>,
index: ExecutionPlanIndex,
}
/// How a list of operations should be executed.
#[derive(PartialEq, Debug, Clone)]
pub(crate) enum ExecutionStrategy<O> {
/// An optimization was found, and therefore should be executed.
Optimization { opt: O, ordering: Arc<Vec<usize>> },
/// No optimization was found, each operation should be executed individually.
Operations { ordering: Arc<Vec<usize>> },
/// A composition of multiple execution strategies.
Composed(Vec<Box<Self>>),
}
/// The trigger that indicates when to stop exploring.
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub(crate) enum ExecutionTrigger {
OnOperations(Vec<OperationIr>),
OnSync,
Always,
}
/// The unique identifier for an exploration that was executed.
pub(crate) type ExecutionPlanId = usize;
/// The outcome of an exploration that can be stored.
#[derive(Debug)]
pub(crate) struct ExecutionPlan<O> {
/// The operations on which the exploration is related to.
pub(crate) operations: Vec<OperationIr>,
/// The criteria that signal when this plan should be executed. Only one trigger is necessary.
pub(crate) triggers: Vec<ExecutionTrigger>,
/// The optimization that should be used when executing this plan.
pub(crate) optimization: BlockOptimization<O>,
}
impl<O: core::fmt::Debug> ExecutionPlanStore<O> {
pub fn new() -> Self {
Self {
plans: Vec::new(),
index: ExecutionPlanIndex::default(),
}
}
pub fn find(&self, query: SearchQuery<'_>) -> Vec<ExecutionPlanId> {
self.index.find(query)
}
pub fn add(&mut self, exploration: ExecutionPlan<O>) -> ExecutionPlanId {
if exploration.operations.is_empty() {
panic!("Can't add an empty optimization.");
}
let id = self.plans.len();
self.index.insert(InsertQuery::NewPlan {
operations: &exploration.operations,
id,
});
self.plans.push(exploration);
id
}
pub fn get_mut_unchecked(&mut self, id: ExecutionPlanId) -> &mut ExecutionPlan<O> {
&mut self.plans[id]
}
pub fn get_unchecked(&self, id: ExecutionPlanId) -> &ExecutionPlan<O> {
&self.plans[id]
}
/// Add a new end condition for an optimization.
pub fn add_trigger(&mut self, id: ExecutionPlanId, trigger: ExecutionTrigger) {
let criteria = &mut self.plans[id].triggers;
if !criteria.contains(&trigger) {
criteria.push(trigger);
}
}
}

View File

@@ -0,0 +1,293 @@
use crate::stream::store::ExecutionPlanId;
use burn_ir::OperationIr;
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, hash_map::DefaultHasher},
hash::{Hash, Hasher},
};
/// Index used to search optimizations.
#[derive(Default, Serialize, Deserialize, Clone)]
pub struct ExecutionPlanIndex {
/// We can't use `HashMap<OperationIr, Vec<ExecutionPlanId>>` since `OperationIr`
/// doesn't implement [`Eq`](core::cmp::Eq).
///
/// `OperationIr` can't implement `Eq` since float types don't implement it.
///
/// We rely instead on [`PartialEq`](core::cmp::PartialEq) to manually handle hash collisions.
/// This is OK because we use `relative` operations where any scalar values are set to zeros,
/// see [`RelativeStreamConverter`](crate::stream::RelativeStreamConverter).
///
/// Map from the hash of the `OperationIr` to a list of `(OperationIr, index)` pairs,
/// where `index` is the index of all the execution plans that start with the `OperationIr`
/// in the `starters` list.
mapping: HashMap<u64, Vec<(OperationIr, usize)>>,
starters: Vec<Vec<ExecutionPlanId>>,
}
pub enum SearchQuery<'a> {
PlansStartingWith(&'a OperationIr),
}
pub enum InsertQuery<'a> {
NewPlan {
operations: &'a [OperationIr],
id: ExecutionPlanId,
},
}
impl ExecutionPlanIndex {
/// Search optimizations with the given [query](SearchQuery).
pub fn find(&self, query: SearchQuery<'_>) -> Vec<ExecutionPlanId> {
match query {
SearchQuery::PlansStartingWith(ops) => self.find_starting_with(ops),
}
}
/// Register a new optimization with the given [query](InsertQuery).
pub fn insert(&mut self, query: InsertQuery<'_>) {
match query {
InsertQuery::NewPlan { operations, id } => {
if let Some(operation) = operations.first() {
self.insert_new_operation(operation, id)
}
}
}
}
/// Find execution plans starting with the `OperationIr`
fn find_starting_with(&self, operation: &OperationIr) -> Vec<ExecutionPlanId> {
let key = self.operation_key(operation);
let values = match self.mapping.get(&key) {
Some(val) => val,
None => return Vec::new(),
};
if values.is_empty() {
return Vec::new();
}
let (_, index) = match values.iter().find(|value| &value.0 == operation) {
Some(val) => val,
None => return Vec::new(),
};
match self.starters.get(*index) {
Some(value) => value.clone(),
None => Vec::new(),
}
}
/// Update the index for an execution plan starting with operation `ops`
fn insert_new_operation(&mut self, ops: &OperationIr, new_id: ExecutionPlanId) {
let key = self.operation_key(ops);
let values = match self.mapping.get_mut(&key) {
Some(val) => val,
None => {
// New starter ops.
let index = self.starters.len();
self.starters.push(vec![new_id]);
self.mapping.insert(key, vec![(ops.clone(), index)]);
return;
}
};
let (_, index) = match values.iter_mut().find(|value| &value.0 == ops) {
Some(val) => val,
None => {
// New with hash collision.
let index = self.starters.len();
self.starters.push(vec![new_id]);
values.push((ops.clone(), index));
return;
}
};
// New optimization for an existing starter.
self.starters
.get_mut(*index)
.expect("Should exist")
.push(new_id);
}
// Hash the value of the first operation in a list.
fn operation_key(&self, ops: &OperationIr) -> u64 {
let mut hasher = DefaultHasher::new();
ops.hash(&mut hasher);
hasher.finish()
}
}
#[cfg(test)]
mod tests {
use burn_backend::{DType, Shape};
use burn_ir::{
BinaryOpIr, NumericOperationIr, ScalarIr, ScalarOpIr, TensorId, TensorIr, TensorStatus,
};
use super::*;
#[test]
fn should_find_optimization_id_based_on_tensor_ops() {
let mut index = ExecutionPlanIndex::default();
let stream_1 = [ops_1()];
let optimization_id_1 = 0;
index.insert(InsertQuery::NewPlan {
operations: &stream_1,
id: optimization_id_1,
});
let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));
assert_eq!(found, vec![optimization_id_1]);
}
#[test]
fn should_support_multiple_optimization_ids_with_same_starting_ops() {
let mut index = ExecutionPlanIndex::default();
let stream_1 = [ops_1(), ops_2(), ops_1()];
let stream_2 = [ops_1(), ops_1(), ops_2()];
let optimization_id_1 = 0;
let optimization_id_2 = 1;
index.insert(InsertQuery::NewPlan {
operations: &stream_1,
id: optimization_id_1,
});
index.insert(InsertQuery::NewPlan {
operations: &stream_2,
id: optimization_id_2,
});
let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));
assert_eq!(found, vec![optimization_id_1, optimization_id_2]);
}
#[test]
fn should_only_find_optimization_with_correct_starting_ops() {
let mut index = ExecutionPlanIndex::default();
let stream_1 = [ops_1(), ops_1()];
let stream_2 = [ops_2(), ops_1()];
let optimization_id_1 = 0;
let optimization_id_2 = 1;
index.insert(InsertQuery::NewPlan {
operations: &stream_1,
id: optimization_id_1,
});
index.insert(InsertQuery::NewPlan {
operations: &stream_2,
id: optimization_id_2,
});
let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));
assert_eq!(found, vec![optimization_id_1]);
}
#[test]
fn should_handle_hash_collisions() {
let mut index = ExecutionPlanIndex::default();
let stream_1 = [ops_1(), ops_1()];
let stream_2 = [ops_3(), ops_1()];
let optimization_id_1 = 0;
let optimization_id_2 = 1;
let stream_1_key = index.operation_key(&stream_1[0]);
let stream_2_key = index.operation_key(&stream_2[0]);
assert_ne!(
stream_1_key, stream_2_key,
"Ops 1 and Ops 3 should not have the same hash"
); // ops 1 and 3 have different variants, so the hash differs
assert_ne!(stream_1[0], stream_2[0], "Ops 1 and Ops 3 are different.");
index.insert(InsertQuery::NewPlan {
operations: &stream_1,
id: optimization_id_1,
});
index.insert(InsertQuery::NewPlan {
operations: &stream_2,
id: optimization_id_2,
});
let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));
assert_eq!(found, vec![optimization_id_1]);
}
fn ops_1() -> OperationIr {
OperationIr::NumericFloat(
DType::F32,
NumericOperationIr::Add(BinaryOpIr {
lhs: TensorIr {
id: TensorId::new(0),
shape: Shape::new([32, 32]),
status: TensorStatus::ReadOnly,
dtype: DType::F32,
},
rhs: TensorIr {
id: TensorId::new(1),
shape: Shape::new([32, 32]),
status: TensorStatus::ReadOnly,
dtype: DType::F32,
},
out: TensorIr {
id: TensorId::new(2),
shape: Shape::new([32, 32]),
status: TensorStatus::NotInit,
dtype: DType::F32,
},
}),
)
}
fn ops_2() -> OperationIr {
OperationIr::NumericFloat(
DType::F32,
NumericOperationIr::AddScalar(ScalarOpIr {
lhs: TensorIr {
id: TensorId::new(0),
shape: Shape::new([32, 32]),
status: TensorStatus::ReadOnly,
dtype: DType::F32,
},
rhs: ScalarIr::Float(5.0),
out: TensorIr {
id: TensorId::new(2),
shape: Shape::new([32, 32]),
status: TensorStatus::NotInit,
dtype: DType::F32,
},
}),
)
}
fn ops_3() -> OperationIr {
OperationIr::NumericFloat(
DType::F32,
NumericOperationIr::Sub(BinaryOpIr {
lhs: TensorIr {
id: TensorId::new(0),
shape: Shape::new([32, 32]),
status: TensorStatus::ReadOnly,
dtype: DType::F32,
},
rhs: TensorIr {
id: TensorId::new(1),
shape: Shape::new([32, 32]),
status: TensorStatus::ReadOnly,
dtype: DType::F32,
},
out: TensorIr {
id: TensorId::new(2),
shape: Shape::new([32, 32]),
status: TensorStatus::NotInit,
dtype: DType::F32,
},
}),
)
}
}

View File

@@ -0,0 +1,5 @@
mod base;
mod index;
pub(crate) use base::*;
pub(super) use index::*;

View File

@@ -0,0 +1,232 @@
use crate::{
Client, FusionBackend, FusionRuntime,
stream::{Operation, OperationStreams, StreamId},
};
use burn_backend::{
DType, ExecutionError, QTensorPrimitive, Shape, TensorData, TensorMetadata,
quantization::QuantScheme,
};
use burn_ir::{OperationIr, TensorId, TensorIr, TensorStatus};
use std::sync::{
Arc,
atomic::{AtomicU32, Ordering},
};
/// Tensor primitive for the [fusion backend](crate::FusionBackend) for all kind.
pub struct FusionTensor<R: FusionRuntime> {
/// Tensor id.
pub id: TensorId,
/// The shape of the tensor.
pub shape: Shape,
/// The fusion client.
pub client: Client<R>,
/// The datatype of the tensor.
pub dtype: DType,
/// The current stream id this tensor is on.
pub stream: StreamId,
pub(crate) count: Arc<AtomicU32>,
}
impl<R: FusionRuntime> Clone for FusionTensor<R> {
fn clone(&self) -> Self {
self.count.fetch_add(1, Ordering::Acquire);
Self {
id: self.id,
shape: self.shape.clone(),
client: self.client.clone(),
dtype: self.dtype,
stream: self.stream,
count: self.count.clone(),
}
}
}
impl<R: FusionRuntime> core::fmt::Debug for FusionTensor<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(
format!(
"{{ id: {:?}, shape: {:?}, device: {:?} }}",
self.id,
self.shape,
self.client.device().clone(),
)
.as_str(),
)
}
}
impl<R: FusionRuntime> TensorMetadata for FusionTensor<R> {
fn dtype(&self) -> DType {
self.dtype
}
fn shape(&self) -> Shape {
self.shape.clone()
}
fn rank(&self) -> usize {
self.shape.num_dims()
}
}
impl<R: FusionRuntime> FusionTensor<R> {
pub(crate) fn new(
id: TensorId,
shape: Shape,
dtype: DType,
client: Client<R>,
stream: StreamId,
) -> Self {
Self {
id,
shape,
client,
dtype,
stream,
count: Arc::new(AtomicU32::new(1)),
}
}
fn status(&self, count: u32) -> TensorStatus {
if count <= 1 {
TensorStatus::ReadWrite
} else {
TensorStatus::ReadOnly
}
}
/// Intermediate representation to be used when using an uninitialized tensor as output.
pub fn to_ir_out(&self) -> TensorIr {
TensorIr {
status: TensorStatus::NotInit,
shape: self.shape.clone(),
id: self.id,
dtype: self.dtype,
}
}
/// Intermediate representation to be used when using an initialized tensor used as input.
pub fn into_ir(mut self) -> TensorIr {
let count = self.count.load(Ordering::Acquire);
let status = self.status(count);
let mut shape_out = Shape::from(Vec::<usize>::new());
core::mem::swap(&mut self.shape, &mut shape_out);
if let TensorStatus::ReadWrite = status {
// Avoids an unwanted drop on the same thread.
//
// Since `drop` is called after `into_ir`, we must not register a drop if the tensor
// was consumed with a `ReadWrite` status.
self.count.fetch_add(1, Ordering::Acquire);
}
TensorIr {
status,
shape: shape_out,
id: self.id,
dtype: self.dtype,
}
}
pub(crate) async fn into_data<B>(self) -> Result<TensorData, ExecutionError>
where
B: FusionBackend<FusionRuntime = R>,
{
let id = self.stream;
let client = self.client.clone();
let desc = self.into_ir();
client.read_tensor_float::<B>(desc, id).await
}
pub(crate) async fn q_into_data<B>(self) -> Result<TensorData, ExecutionError>
where
B: FusionBackend<FusionRuntime = R>,
{
if let DType::QFloat(_scheme) = self.dtype {
let id = self.stream;
let client = self.client.clone();
let desc = self.into_ir();
client.read_tensor_quantized::<B>(desc, id).await
} else {
panic!("Expected quantized float dtype, got {:?}", self.dtype)
}
}
pub(crate) async fn int_into_data<B>(self) -> Result<TensorData, ExecutionError>
where
B: FusionBackend<FusionRuntime = R>,
{
let id = self.stream;
let client = self.client.clone();
let desc = self.into_ir();
client.read_tensor_int::<B>(desc, id).await
}
pub(crate) async fn bool_into_data<B>(self) -> Result<TensorData, ExecutionError>
where
B: FusionBackend<FusionRuntime = R>,
{
let id = self.stream;
let client = self.client.clone();
let desc = self.into_ir();
client.read_tensor_bool::<B>(desc, id).await
}
}
#[derive(new, Debug)]
pub(crate) struct DropOp {
pub(crate) id: TensorId,
}
impl<RO: FusionRuntime> Operation<RO> for DropOp {
fn execute(&self, handles: &mut burn_ir::HandleContainer<RO::FusionHandle>) {
handles.remove_handle(self.id);
}
}
impl<R: FusionRuntime> Drop for FusionTensor<R> {
fn drop(&mut self) {
let count = self.count.fetch_sub(1, Ordering::Acquire);
// Workaround to prevent segfaults when an operation panics
if std::thread::panicking() {
return;
}
match self.status(count) {
TensorStatus::ReadWrite => {
let mut shape = Shape::from(Vec::<usize>::new());
core::mem::swap(&mut shape, &mut self.shape);
let ir = TensorIr {
id: self.id,
shape,
status: TensorStatus::ReadWrite,
dtype: self.dtype,
};
let mut streams = OperationStreams::default();
streams.tensor(self);
self.client
.register(streams, OperationIr::Drop(ir), DropOp { id: self.id });
}
TensorStatus::ReadOnly => {}
TensorStatus::NotInit => {}
}
}
}
impl<R: FusionRuntime> QTensorPrimitive for FusionTensor<R> {
fn scheme(&self) -> &QuantScheme {
if let DType::QFloat(scheme) = &self.dtype {
scheme
} else {
panic!(
"Quantization scheme is not valid for dtype {:?}",
self.dtype,
)
}
}
}