feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
@@ -0,0 +1,42 @@
|
||||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science"]
|
||||
description = "Kernel fusion backend decorator for the Burn framework"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "data"]
|
||||
license.workspace = true
|
||||
name = "burn-fusion"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-fusion"
|
||||
documentation = "https://docs.rs/burn-fusion"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = ["std", "tracing"]
|
||||
std = ["serde/std", "tracing?/std"]
|
||||
doc = ["default"]
|
||||
memory-checks = ["std"]
|
||||
|
||||
tracing = [
|
||||
"dep:tracing",
|
||||
"burn-backend/tracing",
|
||||
"burn-ir/tracing",
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2" }
|
||||
burn-ir = { path = "../burn-ir", version = "=0.21.0-pre.2" }
|
||||
tracing = { workspace = true, optional = true, features = ["attributes"] }
|
||||
|
||||
hashbrown = { workspace = true }
|
||||
derive-new = { workspace = true }
|
||||
spin = { workspace = true }
|
||||
log = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
@@ -0,0 +1 @@
|
||||
../../LICENSE-APACHE
|
||||
1
crates/stable-diffusion-burn/burn-crates/burn-fusion/LICENSE-MIT
Symbolic link
1
crates/stable-diffusion-burn/burn-crates/burn-fusion/LICENSE-MIT
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-MIT
|
||||
@@ -0,0 +1,3 @@
|
||||
# Burn Fusion
|
||||
|
||||
A kernel fusion backend decorator for Burn.
|
||||
@@ -0,0 +1,240 @@
|
||||
use crate::{
|
||||
FusionTensor,
|
||||
client::GlobalFusionClient,
|
||||
stream::{Context, OrderedExecution},
|
||||
};
|
||||
use burn_backend::{
|
||||
Backend, DType, DeviceOps, Element, ExecutionError,
|
||||
tensor::{BoolTensor, Device, FloatTensor, IntTensor, QuantizedTensor},
|
||||
};
|
||||
use burn_ir::{BackendIr, OperationIr, TensorHandle};
|
||||
use serde::{Serialize, de::DeserializeOwned};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Get the client for the given device.
|
||||
pub fn get_client<B: FusionBackend>(device: &Device<B>) -> Client<B::FusionRuntime> {
|
||||
GlobalFusionClient::load(device)
|
||||
}
|
||||
|
||||
/// Enable dynamic operation fusion on a backend that implements [fusion backend](crate::FusionBackend).
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct Fusion<B: FusionBackend> {
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Backend for Fusion<B> {
|
||||
type Device = B::Device;
|
||||
|
||||
type FloatTensorPrimitive = FusionTensor<B::FusionRuntime>;
|
||||
|
||||
type FloatElem = B::FloatElem;
|
||||
|
||||
type IntTensorPrimitive = FusionTensor<B::FusionRuntime>;
|
||||
|
||||
type IntElem = B::IntElem;
|
||||
|
||||
type BoolTensorPrimitive = FusionTensor<B::FusionRuntime>;
|
||||
|
||||
type BoolElem = B::BoolElem;
|
||||
|
||||
type QuantizedTensorPrimitive = FusionTensor<B::FusionRuntime>;
|
||||
|
||||
fn name(device: &Self::Device) -> String {
|
||||
format!("fusion<{}>", B::name(device))
|
||||
}
|
||||
|
||||
fn seed(device: &B::Device, seed: u64) {
|
||||
let client = GlobalFusionClient::<B::FusionRuntime>::load(device);
|
||||
client.drain();
|
||||
B::seed(device, seed);
|
||||
}
|
||||
|
||||
fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
|
||||
let client = GlobalFusionClient::<B::FusionRuntime>::load(device);
|
||||
client.drain();
|
||||
B::sync(device)
|
||||
}
|
||||
|
||||
fn ad_enabled(_device: &Self::Device) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
|
||||
device: &Self::Device,
|
||||
input: Input,
|
||||
func: Func,
|
||||
) -> Output {
|
||||
B::memory_persistent_allocations(device, input, func)
|
||||
}
|
||||
|
||||
fn memory_cleanup(device: &Self::Device) {
|
||||
B::memory_cleanup(device)
|
||||
}
|
||||
|
||||
fn staging<'a, Iter>(data: Iter, device: &Self::Device)
|
||||
where
|
||||
Iter: Iterator<Item = &'a mut burn_backend::TensorData>,
|
||||
{
|
||||
B::staging(data, device);
|
||||
}
|
||||
|
||||
fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
|
||||
B::supports_dtype(device, dtype)
|
||||
}
|
||||
|
||||
fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
|
||||
B::dtype_usage(device, dtype)
|
||||
}
|
||||
}
|
||||
|
||||
/// The status of a [fuser](OperationFuser).
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq)]
|
||||
pub enum FuserStatus {
|
||||
/// No more operations can be fused.
|
||||
Closed,
|
||||
/// More operations can be fused.
|
||||
Open,
|
||||
}
|
||||
|
||||
/// The properties of a [fuser](OperationFuser).
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||
pub struct FuserProperties {
|
||||
/// The score of the optimization, higher is better.
|
||||
pub score: u64,
|
||||
/// If the operation is ready to be executed.
|
||||
pub ready: bool,
|
||||
}
|
||||
|
||||
/// The fusion operation abstraction allows implementations to fuse many
|
||||
/// [tensor operations](OperationIr) into one, improving the performance of the backend.
|
||||
///
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The implementations are free to execute the registered operations the way they want to improve
|
||||
/// the speed and efficiency of the computational graph. It doesn't mean that all registered
|
||||
/// operations should be fused, but that another way of executing them is more efficient.
|
||||
///
|
||||
/// Also, it is important to return (FuserStatus::Closed) when no more registered operation can
|
||||
/// improve the performance.
|
||||
pub trait OperationFuser<O>: Send {
|
||||
/// Register a new [tensor operation](OperationIr).
|
||||
fn fuse(&mut self, operation: &OperationIr);
|
||||
/// Finish the optimization and create a fusion operation.
|
||||
fn finish(&mut self) -> O;
|
||||
/// Reset the state.
|
||||
fn reset(&mut self);
|
||||
/// Return the builder [status](FuserStatus).
|
||||
fn status(&self) -> FuserStatus;
|
||||
/// Return the builder [properties](FuserProperties).
|
||||
fn properties(&self) -> FuserProperties;
|
||||
/// The number of operation fused.
|
||||
fn len(&self) -> usize;
|
||||
/// If no operations are fused.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
/// Clone the optimization builder.
|
||||
fn clone_dyn(&self) -> Box<dyn OperationFuser<O>>;
|
||||
}
|
||||
|
||||
/// The number of operations contained in the data structure.
|
||||
pub trait NumOperations: core::fmt::Debug {
|
||||
/// The number of registered operations.
|
||||
fn len(&self) -> usize;
|
||||
/// If the current optimization is empty.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
}
|
||||
|
||||
/// The optimization created from a [fuser](OperationFuser).
|
||||
pub trait Optimization<R: FusionRuntime>: Send + NumOperations {
|
||||
/// Execute the optimization.
|
||||
fn execute(
|
||||
&mut self,
|
||||
context: &mut Context<'_, R::FusionHandle>,
|
||||
execution: &OrderedExecution<R>,
|
||||
);
|
||||
|
||||
/// Returns the state that can be serialized.
|
||||
fn to_state(&self) -> R::OptimizationState;
|
||||
/// Create the optimization from the state.
|
||||
fn from_state(device: &R::FusionDevice, state: R::OptimizationState) -> Self;
|
||||
}
|
||||
|
||||
/// Type alias for `<R as FusionRuntime>::FusionDevice`.
|
||||
pub type FusionDevice<R> = <R as FusionRuntime>::FusionDevice;
|
||||
/// Type alias for `<R as FusionRuntime>::FusionHandle`.
|
||||
pub type FusionHandle<R> = <R as FusionRuntime>::FusionHandle;
|
||||
/// Client alias.
|
||||
pub type Client<R> = GlobalFusionClient<R>;
|
||||
|
||||
/// Trait that defines a runtime that will benefits from fused operations.
|
||||
pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug + 'static {
|
||||
/// The state that can be serialized for an optimization.
|
||||
type OptimizationState: Serialize + DeserializeOwned;
|
||||
/// Optimization type for the backend.
|
||||
type Optimization: Optimization<Self>;
|
||||
/// Handle used to store tensor dynamically.
|
||||
type FusionHandle: Clone + Send;
|
||||
/// Device used by the runtime.
|
||||
type FusionDevice: DeviceOps;
|
||||
/// The type that represents booleans on the backend.
|
||||
type BoolRepr: Element;
|
||||
|
||||
/// The list of fusers that will be used to optimize the computational graph.
|
||||
fn fusers(device: Self::FusionDevice) -> Vec<Box<dyn OperationFuser<Self::Optimization>>>;
|
||||
}
|
||||
|
||||
/// Trait that allows an existing [backend](Backend) to specify graph optimizations using
|
||||
/// [operation fuser](crate::OperationFuser).
|
||||
pub trait FusionBackend:
|
||||
BackendIr<Handle = FusionHandle<Self::FusionRuntime>, Device = FusionDevice<Self::FusionRuntime>>
|
||||
{
|
||||
/// The runtime used for this backend.
|
||||
type FusionRuntime: FusionRuntime;
|
||||
|
||||
/// Cast a float tensor and returns the resulting handle.
|
||||
fn cast_float(tensor: FloatTensor<Self>, dtype: DType) -> Self::Handle;
|
||||
|
||||
/// Pointer to the full precision fusion backend.
|
||||
type FullPrecisionBackend: FusionBackend<FusionRuntime = Self::FusionRuntime>;
|
||||
}
|
||||
|
||||
// Fusion implements `BackendIr` to enable router backend usage.
|
||||
impl<B: FusionBackend> BackendIr for Fusion<B> {
|
||||
type Handle = FusionTensor<B::FusionRuntime>;
|
||||
|
||||
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
|
||||
handle.handle
|
||||
}
|
||||
|
||||
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
|
||||
handle.handle
|
||||
}
|
||||
|
||||
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
|
||||
handle.handle
|
||||
}
|
||||
|
||||
fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
|
||||
handle.handle
|
||||
}
|
||||
|
||||
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
|
||||
tensor
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,307 @@
|
||||
use crate::{
|
||||
FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionServer, FusionTensor,
|
||||
stream::{OperationStreams, StreamId, execution::Operation},
|
||||
};
|
||||
use burn_backend::{Device, DeviceContext, DeviceId, DeviceState};
|
||||
use burn_backend::{TensorData, backend::ExecutionError};
|
||||
use burn_ir::{OperationIr, TensorId, TensorIr};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Use a mutex to communicate with the fusion server.
|
||||
pub struct GlobalFusionClient<R: FusionRuntime> {
|
||||
server: DeviceContext<FusionServer<R>>,
|
||||
device: FusionDevice<R>,
|
||||
}
|
||||
|
||||
impl<R: FusionRuntime> DeviceState for FusionServer<R> {
|
||||
fn init(device_id: DeviceId) -> Self {
|
||||
let device = FusionDevice::<R>::from_id(device_id);
|
||||
FusionServer::new(device)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> Clone for GlobalFusionClient<R>
|
||||
where
|
||||
R: FusionRuntime,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
server: self.server.clone(),
|
||||
device: self.device.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<R> GlobalFusionClient<R>
|
||||
where
|
||||
R: FusionRuntime + 'static,
|
||||
{
|
||||
/// Loads the client from the given device.
|
||||
pub fn load(device: &FusionDevice<R>) -> Self {
|
||||
Self {
|
||||
device: device.clone(),
|
||||
server: DeviceContext::locate(device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> GlobalFusionClient<R>
|
||||
where
|
||||
R: FusionRuntime + 'static,
|
||||
{
|
||||
/// Create a new client for the given [device](FusionRuntime::FusionDevice).
|
||||
pub fn new(device: FusionDevice<R>) -> Self {
|
||||
Self {
|
||||
device: device.clone(),
|
||||
server: DeviceContext::locate(&device),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a new [tensor operation intermediate representation](OperationIr).
|
||||
///
|
||||
/// Returns the new (uninitialized) output tensor(s) generated by the registered operation.
|
||||
pub fn register<O>(
|
||||
&self,
|
||||
streams: OperationStreams,
|
||||
repr: OperationIr,
|
||||
operation: O,
|
||||
) -> Vec<FusionTensor<R>>
|
||||
where
|
||||
O: Operation<R> + 'static,
|
||||
{
|
||||
// Create output tensors returned by this operation
|
||||
let outputs = repr
|
||||
.outputs()
|
||||
.map(|output| {
|
||||
FusionTensor::new(
|
||||
output.id,
|
||||
output.shape.clone(),
|
||||
output.dtype,
|
||||
self.clone(),
|
||||
StreamId::current(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
self.server
|
||||
.lock()
|
||||
.register(streams, repr, Arc::new(operation));
|
||||
|
||||
outputs
|
||||
}
|
||||
|
||||
/// Register all lazy computation.
|
||||
pub fn drain(&self) {
|
||||
let id = StreamId::current();
|
||||
self.server.lock().drain_stream(id);
|
||||
}
|
||||
|
||||
/// Create a new (uninitialized) empty tensor handle and returns its corresponding [tensor id](TensorId).
|
||||
pub fn create_empty_handle(&self) -> TensorId {
|
||||
self.server.lock().create_empty_handle()
|
||||
}
|
||||
|
||||
/// Get the current device used by all operations handled by this client.
|
||||
pub fn device(&self) -> &FusionDevice<R> {
|
||||
&self.device
|
||||
}
|
||||
|
||||
/// Create a tensor with the given handle and returns its corresponding [tensor id](TensorId).
|
||||
pub fn register_tensor_handle(&self, handle: FusionHandle<R>) -> TensorId {
|
||||
let mut server = self.server.lock();
|
||||
let id = server.create_empty_handle();
|
||||
server.handles.register_handle(id, handle);
|
||||
core::mem::drop(server);
|
||||
|
||||
id
|
||||
}
|
||||
|
||||
/// Read the values contained by a float tensor.
|
||||
pub fn read_tensor_float<B>(
|
||||
self,
|
||||
tensor: TensorIr,
|
||||
stream: StreamId,
|
||||
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
self.server.lock().read_float::<B>(tensor, stream)
|
||||
}
|
||||
|
||||
/// Read the values contained by an int tensor.
|
||||
pub fn read_tensor_int<B>(
|
||||
self,
|
||||
tensor: TensorIr,
|
||||
id: StreamId,
|
||||
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
self.server.lock().read_int::<B>(tensor, id)
|
||||
}
|
||||
|
||||
/// Read the values contained by a bool tensor.
|
||||
pub fn read_tensor_bool<B>(
|
||||
self,
|
||||
tensor: TensorIr,
|
||||
stream: StreamId,
|
||||
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
self.server.lock().read_bool::<B>(tensor, stream)
|
||||
}
|
||||
|
||||
/// Read the values contained by a quantized tensor.
|
||||
pub fn read_tensor_quantized<B>(
|
||||
self,
|
||||
tensor: TensorIr,
|
||||
stream: StreamId,
|
||||
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
self.server.lock().read_quantized::<B>(tensor, stream)
|
||||
}
|
||||
|
||||
/// Change the client of the given float tensor.
|
||||
pub fn change_client_float<B>(
|
||||
&self,
|
||||
tensor: TensorIr,
|
||||
client: Self,
|
||||
stream: StreamId,
|
||||
) -> FusionTensor<R>
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
let guard = self.server.lock_device_kind();
|
||||
let mut server_current = self.server.lock();
|
||||
server_current.drain_stream(stream);
|
||||
|
||||
let mut server_other = client.server.lock();
|
||||
let id = server_current.change_server_float::<B>(
|
||||
&tensor,
|
||||
stream,
|
||||
&client.device,
|
||||
&mut server_other,
|
||||
);
|
||||
|
||||
core::mem::drop(server_current);
|
||||
core::mem::drop(server_other);
|
||||
core::mem::drop(guard);
|
||||
|
||||
FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current())
|
||||
}
|
||||
|
||||
/// Change the client of the given int tensor.
|
||||
pub fn change_client_int<B>(
|
||||
&self,
|
||||
tensor: TensorIr,
|
||||
client: Self,
|
||||
stream: StreamId,
|
||||
) -> FusionTensor<R>
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
let guard = self.server.lock_device_kind();
|
||||
let mut server_current = self.server.lock();
|
||||
server_current.drain_stream(stream);
|
||||
|
||||
let mut server_other = client.server.lock();
|
||||
let id = server_current.change_server_int::<B>(
|
||||
&tensor,
|
||||
stream,
|
||||
&client.device,
|
||||
&mut server_other,
|
||||
);
|
||||
|
||||
core::mem::drop(server_other);
|
||||
core::mem::drop(server_current);
|
||||
core::mem::drop(guard);
|
||||
|
||||
FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current())
|
||||
}
|
||||
|
||||
/// Change the client of the given bool tensor.
|
||||
pub fn change_client_bool<B>(
|
||||
&self,
|
||||
tensor: TensorIr,
|
||||
client: Self,
|
||||
stream: StreamId,
|
||||
) -> FusionTensor<R>
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
let guard = self.server.lock_device_kind();
|
||||
let mut server_current = self.server.lock();
|
||||
server_current.drain_stream(stream);
|
||||
|
||||
let mut server_other = client.server.lock();
|
||||
let id = server_current.change_server_bool::<B>(
|
||||
&tensor,
|
||||
stream,
|
||||
&client.device,
|
||||
&mut server_other,
|
||||
);
|
||||
|
||||
core::mem::drop(server_other);
|
||||
core::mem::drop(server_current);
|
||||
core::mem::drop(guard);
|
||||
|
||||
FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current())
|
||||
}
|
||||
|
||||
/// Change the client of the given quantized tensor.
|
||||
pub fn change_client_quantized<B>(
|
||||
&self,
|
||||
tensor: TensorIr,
|
||||
client: Self,
|
||||
stream: StreamId,
|
||||
) -> FusionTensor<R>
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
let guard = self.server.lock_device_kind();
|
||||
let mut server_current = self.server.lock();
|
||||
server_current.drain_stream(stream);
|
||||
|
||||
let mut server_other = client.server.lock();
|
||||
let id =
|
||||
server_current.change_server_quantized::<B>(&tensor, &client.device, &mut server_other);
|
||||
|
||||
core::mem::drop(server_other);
|
||||
core::mem::drop(server_current);
|
||||
core::mem::drop(guard);
|
||||
|
||||
FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current())
|
||||
}
|
||||
|
||||
/// Resolve the given float tensor to a primitive tensor.
|
||||
pub fn resolve_tensor_float<B>(&self, tensor: FusionTensor<R>) -> B::FloatTensorPrimitive
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
let mut server = self.server.lock();
|
||||
server.drain_stream(tensor.stream);
|
||||
server.resolve_server_float::<B>(&tensor.into_ir())
|
||||
}
|
||||
|
||||
/// Resolve the given int tensor to a primitive tensor.
|
||||
pub fn resolve_tensor_int<B>(&self, tensor: FusionTensor<R>) -> B::IntTensorPrimitive
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
let mut server = self.server.lock();
|
||||
server.drain_stream(tensor.stream);
|
||||
server.resolve_server_int::<B>(&tensor.into_ir())
|
||||
}
|
||||
|
||||
/// Resolve the given bool tensor to a primitive tensor.
|
||||
pub fn resolve_tensor_bool<B>(&self, tensor: FusionTensor<R>) -> B::BoolTensorPrimitive
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
let mut server = self.server.lock();
|
||||
server.drain_stream(tensor.stream);
|
||||
server.resolve_server_bool::<B>(&tensor.into_ir())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
//! # Burn Fusion
|
||||
//!
|
||||
//! This library is a part of the Burn project. It is a standalone crate that
|
||||
//! can be used to perform automatic operation fusion on backends that support it.
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
/// Client module exposing types to communicate with the fusion server.
|
||||
pub mod client;
|
||||
/// Stream module exposing all tensor operations that can be optimized.
|
||||
pub mod stream;
|
||||
|
||||
/// Search module for stream optimizations.
|
||||
pub(crate) mod search;
|
||||
|
||||
mod backend;
|
||||
mod ops;
|
||||
mod server;
|
||||
mod tensor;
|
||||
|
||||
pub(crate) use server::*;
|
||||
|
||||
pub use backend::*;
|
||||
pub use ops::NoOp;
|
||||
pub use tensor::*;
|
||||
@@ -0,0 +1,4 @@
|
||||
use crate::{Fusion, FusionBackend};
|
||||
use burn_backend::ops::ActivationOps;
|
||||
|
||||
impl<B: FusionBackend> ActivationOps<Self> for Fusion<B> {}
|
||||
@@ -0,0 +1,15 @@
|
||||
use crate::{FusionBackend, stream::Operation};
|
||||
use burn_ir::HandleContainer;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// A no-operation placeholder for the fusion backend.
|
||||
///
|
||||
/// `NoOp` is an implementation of [`Operation`] that doesn't execute anything.
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct NoOp<B: FusionBackend> {
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for NoOp<B> {
|
||||
fn execute(&self, _handles: &mut HandleContainer<B::Handle>) {}
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! binary_float_ops {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(Debug)]
|
||||
struct $name<B: FusionBackend> {
|
||||
desc: BinaryOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> $name<B> {
|
||||
fn new(desc: BinaryOpIr) -> Self {
|
||||
Self {
|
||||
desc,
|
||||
_b: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
|
||||
let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);
|
||||
let output = $ops(lhs, rhs);
|
||||
|
||||
handles.register_float_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! binary_float_cmp_ops {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
struct $name<B: FusionBackend> {
|
||||
desc: BinaryOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
|
||||
let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);
|
||||
let output = $ops(lhs, rhs);
|
||||
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! binary_int_cmp_ops {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(Debug)]
|
||||
struct $name<B: FusionBackend> {
|
||||
desc: BinaryOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> $name<B> {
|
||||
fn new(desc: BinaryOpIr) -> Self {
|
||||
Self {
|
||||
desc,
|
||||
_b: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);
|
||||
let rhs = handles.get_int_tensor::<B>(&self.desc.rhs);
|
||||
let output = $ops(lhs, rhs);
|
||||
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! binary_int_ops {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
struct $name<B: FusionBackend> {
|
||||
desc: BinaryOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);
|
||||
let rhs = handles.get_int_tensor::<B>(&self.desc.rhs);
|
||||
let output = $ops(lhs, rhs);
|
||||
|
||||
handles.register_int_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,856 @@
|
||||
use crate::{
|
||||
Fusion, FusionBackend, get_client,
|
||||
stream::{OperationStreams, execution::Operation},
|
||||
};
|
||||
use burn_backend::{
|
||||
Element, ExecutionError, Scalar, Shape, Slice, TensorData,
|
||||
ops::BoolTensorOps,
|
||||
tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntTensor},
|
||||
};
|
||||
use burn_ir::{
|
||||
BaseOperationIr, BinaryOpIr, BoolOperationIr, CastOpIr, CatOpIr, CreationOpIr, FlipOpIr,
|
||||
GatherOpIr, HandleContainer, InitOperationIr, MaskFillOpIr, MaskWhereOpIr, OperationIr,
|
||||
OperationOutput, PermuteOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr, ShapeOpIr,
|
||||
SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, TensorIr, UnaryOpIr, UnfoldOpIr,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use super::NoOp;
|
||||
|
||||
impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
||||
fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct EmptyOps<B: FusionBackend> {
|
||||
desc: TensorIr,
|
||||
device: Device<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let output = B::bool_empty(self.desc.shape.clone(), &self.device);
|
||||
handles.register_bool_tensor::<B>(&self.desc.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let client = get_client::<B>(device);
|
||||
let desc =
|
||||
CreationOpIr::create(shape, B::BoolElem::dtype(), || client.create_empty_handle());
|
||||
|
||||
client
|
||||
.register(
|
||||
OperationStreams::default(),
|
||||
OperationIr::BaseBool(BaseOperationIr::Empty(desc.clone())),
|
||||
EmptyOps::<B>::new(desc.out, device.clone()),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct ZerosOps<B: FusionBackend> {
|
||||
desc: TensorIr,
|
||||
device: Device<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for ZerosOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let output = B::bool_zeros(self.desc.shape.clone(), &self.device);
|
||||
handles.register_bool_tensor::<B>(&self.desc.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let client = get_client::<B>(device);
|
||||
let desc =
|
||||
CreationOpIr::create(shape, B::BoolElem::dtype(), || client.create_empty_handle());
|
||||
|
||||
client
|
||||
.register(
|
||||
OperationStreams::default(),
|
||||
OperationIr::BaseBool(BaseOperationIr::Zeros(desc.clone())),
|
||||
ZerosOps::<B>::new(desc.out, device.clone()),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct OnesOps<B: FusionBackend> {
|
||||
desc: TensorIr,
|
||||
device: Device<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for OnesOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let output = B::bool_ones(self.desc.shape.clone(), &self.device);
|
||||
handles.register_bool_tensor::<B>(&self.desc.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let client = get_client::<B>(device);
|
||||
let desc =
|
||||
CreationOpIr::create(shape, B::BoolElem::dtype(), || client.create_empty_handle());
|
||||
|
||||
client
|
||||
.register(
|
||||
OperationStreams::default(),
|
||||
OperationIr::BaseBool(BaseOperationIr::Ones(desc.clone())),
|
||||
OnesOps::<B>::new(desc.out, device.clone()),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
|
||||
tensor.bool_into_data::<B>().await
|
||||
}
|
||||
|
||||
fn bool_from_data(data: burn_backend::TensorData, device: &Device<Self>) -> BoolTensor<Self> {
|
||||
let client = get_client::<B>(device);
|
||||
let tensor = B::bool_from_data(data, device);
|
||||
let shape = burn_backend::TensorMetadata::shape(&tensor);
|
||||
|
||||
let handle = B::bool_tensor_handle(tensor);
|
||||
let desc = InitOperationIr::create(shape, B::BoolElem::dtype(), || {
|
||||
client.register_tensor_handle(handle)
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
OperationStreams::default(),
|
||||
OperationIr::Init(desc),
|
||||
NoOp::<B>::new(),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct IntoIntOps<B: FusionBackend> {
|
||||
desc: CastOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoIntOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_bool_tensor::<B>(&self.desc.input);
|
||||
let output = B::bool_into_int(input);
|
||||
handles.register_int_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = CastOpIr::create(tensor.into_ir(), B::IntElem::dtype(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::Bool(BoolOperationIr::IntoInt(desc.clone())),
|
||||
IntoIntOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct IntoFloatOps<B: FusionBackend> {
|
||||
desc: CastOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoFloatOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_bool_tensor::<B>(&self.desc.input);
|
||||
let output = B::bool_into_float(input);
|
||||
handles.register_float_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = CastOpIr::create(tensor.into_ir(), B::FloatElem::dtype(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::Bool(BoolOperationIr::IntoFloat(desc.clone())),
|
||||
IntoFloatOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
|
||||
tensor.client.device().clone()
|
||||
}
|
||||
|
||||
fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
|
||||
let device_original: &B::Device = tensor.client.device();
|
||||
|
||||
if device_original == device {
|
||||
return tensor;
|
||||
}
|
||||
|
||||
let id = tensor.stream;
|
||||
let client_target = get_client::<B>(device);
|
||||
let client_original = tensor.client.clone();
|
||||
|
||||
client_original
|
||||
.clone()
|
||||
.change_client_bool::<B>(tensor.into_ir(), client_target, id)
|
||||
}
|
||||
|
||||
fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
|
||||
if tensor.shape == shape {
|
||||
return tensor;
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct ReshapeDimsOps<B: FusionBackend> {
|
||||
desc: ShapeOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_bool_tensor::<B>(&self.desc.input);
|
||||
let output = B::bool_reshape(input, self.desc.out.shape.clone());
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseBool(BaseOperationIr::Reshape(desc.clone())),
|
||||
ReshapeDimsOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct SliceOps<B: FusionBackend> {
|
||||
desc: SliceOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
|
||||
|
||||
let output = B::bool_slice(tensor, self.desc.ranges.as_slice());
|
||||
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseBool(BaseOperationIr::Slice(desc.clone())),
|
||||
SliceOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_slice_assign(
|
||||
tensor: BoolTensor<Self>,
|
||||
slices: &[Slice],
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct SliceAssignOps<B: FusionBackend> {
|
||||
desc: SliceAssignOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
|
||||
let value = handles.get_bool_tensor::<B>(&self.desc.value);
|
||||
|
||||
let output = B::bool_slice_assign(tensor, self.desc.ranges.as_slice(), value);
|
||||
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor, &value]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc =
|
||||
SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc.clone())),
|
||||
SliceAssignOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct CatOps<B: FusionBackend> {
|
||||
desc: CatOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let tensors = self
|
||||
.desc
|
||||
.tensors
|
||||
.iter()
|
||||
.map(|tensor| handles.get_bool_tensor::<B>(tensor))
|
||||
.collect();
|
||||
|
||||
let output = B::bool_cat(tensors, self.desc.dim);
|
||||
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs(&tensors);
|
||||
|
||||
let client = tensors.first().unwrap().client.clone();
|
||||
let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
|
||||
let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseBool(BaseOperationIr::Cat(desc.clone())),
|
||||
CatOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct EqualOps<B: FusionBackend> {
|
||||
desc: BinaryOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for EqualOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
|
||||
let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
|
||||
let output = B::bool_equal(lhs, rhs);
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&lhs, &rhs]);
|
||||
|
||||
let client = lhs.client.clone();
|
||||
let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseBool(BaseOperationIr::Equal(desc.clone())),
|
||||
EqualOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct NotOps<B: FusionBackend> {
|
||||
desc: UnaryOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for NotOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_bool_tensor::<B>(&self.desc.input);
|
||||
let output = B::bool_not(input);
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::Bool(BoolOperationIr::Not(desc.clone())),
|
||||
NotOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct AndOps<B: FusionBackend> {
|
||||
desc: BinaryOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for AndOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
|
||||
let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
|
||||
let output = B::bool_and(lhs, rhs);
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&lhs, &rhs]);
|
||||
|
||||
let client = lhs.client.clone();
|
||||
let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::Bool(BoolOperationIr::And(desc.clone())),
|
||||
AndOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct OrOps<B: FusionBackend> {
|
||||
desc: BinaryOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for OrOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
|
||||
let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
|
||||
let output = B::bool_or(lhs, rhs);
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&lhs, &rhs]);
|
||||
|
||||
let client = lhs.client.clone();
|
||||
let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::Bool(BoolOperationIr::Or(desc.clone())),
|
||||
OrOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct SwapDimsOps<B: FusionBackend> {
|
||||
desc: SwapDimsOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_bool_tensor::<B>(&self.desc.input);
|
||||
let output = B::bool_swap_dims(input, self.desc.dim1, self.desc.dim2);
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseBool(BaseOperationIr::SwapDims(desc.clone())),
|
||||
SwapDimsOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct PermuteDimsOps<B: FusionBackend> {
|
||||
desc: PermuteOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_bool_tensor::<B>(&self.desc.input);
|
||||
let output = B::bool_permute(input, self.desc.axes.as_slice());
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),
|
||||
PermuteDimsOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct ExpandOps<B: FusionBackend> {
|
||||
desc: ShapeOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_bool_tensor::<B>(&self.desc.input);
|
||||
let output = B::bool_expand(input, self.desc.out.shape.clone());
|
||||
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseBool(BaseOperationIr::Expand(desc.clone())),
|
||||
ExpandOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct FlipOps<B: FusionBackend> {
|
||||
desc: FlipOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_bool_tensor::<B>(&self.desc.input);
|
||||
let output = B::bool_flip(input, self.desc.axes.as_slice());
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseBool(BaseOperationIr::Flip(desc.clone())),
|
||||
FlipOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct RepeatDimOps<B: FusionBackend> {
|
||||
desc: RepeatDimOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
|
||||
|
||||
let output = B::bool_repeat_dim(tensor, self.desc.dim, self.desc.times);
|
||||
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc.clone())),
|
||||
RepeatDimOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_unfold(
|
||||
tensor: BoolTensor<Self>,
|
||||
dim: usize,
|
||||
size: usize,
|
||||
step: usize,
|
||||
) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct UnfoldOps<B: FusionBackend> {
|
||||
desc: UnfoldOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_bool_tensor::<B>(&self.desc.input);
|
||||
let output = B::bool_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
|
||||
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseBool(BaseOperationIr::Unfold(desc.clone())),
|
||||
UnfoldOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_mask_where(
|
||||
tensor: BoolTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct MaskWhereOps<B: FusionBackend> {
|
||||
desc: MaskWhereOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskWhereOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
|
||||
let value = handles.get_bool_tensor::<B>(&self.desc.value);
|
||||
let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
|
||||
|
||||
let output = B::bool_mask_where(tensor, mask, value);
|
||||
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor, &mask, &value]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseBool(BaseOperationIr::MaskWhere(desc.clone())),
|
||||
MaskWhereOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_mask_fill(
|
||||
tensor: BoolTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: Scalar,
|
||||
) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct MaskFillOps<B: FusionBackend> {
|
||||
desc: MaskFillOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskFillOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
|
||||
let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
|
||||
|
||||
let output = B::bool_mask_fill(tensor, mask, self.desc.value.into());
|
||||
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor, &mask]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let value = value.into();
|
||||
let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseBool(BaseOperationIr::MaskFill(desc.clone())),
|
||||
MaskFillOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_gather(
|
||||
dim: usize,
|
||||
tensor: BoolTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct GatherOps<B: FusionBackend> {
|
||||
desc: GatherOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
|
||||
let indices = handles.get_int_tensor::<B>(&self.desc.indices);
|
||||
|
||||
let output = B::bool_gather(self.desc.dim, tensor, indices);
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor, &indices]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseBool(BaseOperationIr::Gather(desc.clone())),
|
||||
GatherOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_scatter_or(
|
||||
dim: usize,
|
||||
tensor: BoolTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct ScatterOps<B: FusionBackend> {
|
||||
desc: ScatterOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
|
||||
let indices = handles.get_int_tensor::<B>(&self.desc.indices);
|
||||
let value = handles.get_bool_tensor::<B>(&self.desc.value);
|
||||
|
||||
let output = B::bool_scatter_or(self.desc.dim, tensor, indices, value);
|
||||
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = ScatterOpIr::create(
|
||||
tensor.into_ir(),
|
||||
dim,
|
||||
indices.into_ir(),
|
||||
value.into_ir(),
|
||||
IndexingUpdateOp::Add,
|
||||
|| client.create_empty_handle(),
|
||||
);
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseBool(BaseOperationIr::Scatter(desc.clone())),
|
||||
ScatterOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct EqualElemOps<B: FusionBackend> {
|
||||
desc: ScalarOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for EqualElemOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
|
||||
let output = B::bool_equal_elem(lhs, self.desc.rhs.into());
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&lhs]);
|
||||
|
||||
let client = lhs.client.clone();
|
||||
let rhs = rhs.into();
|
||||
let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, B::BoolElem::dtype(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseBool(BaseOperationIr::EqualElem(desc.clone())),
|
||||
EqualElemOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,12 @@
|
||||
mod activation;
|
||||
mod binary;
|
||||
mod bool_tensor;
|
||||
mod int_tensor;
|
||||
mod module;
|
||||
mod qtensor;
|
||||
mod tensor;
|
||||
mod transaction;
|
||||
mod unary;
|
||||
|
||||
mod base;
|
||||
pub use base::NoOp;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,501 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use burn_backend::{
|
||||
DType, Element, ExecutionError, QTensorPrimitive, Shape, Slice, TensorData, TensorPrimitive,
|
||||
ops::QTensorOps,
|
||||
quantization::{QuantPropagation, QuantScheme, QuantizationParametersPrimitive},
|
||||
tensor::{Device, FloatTensor, IntTensor, QuantizedTensor},
|
||||
};
|
||||
use burn_ir::{
|
||||
BaseOperationIr, DequantizeOpIr, FlipOpIr, FloatOperationIr, GatherOpIr, HandleContainer,
|
||||
InitOperationIr, MatmulOpIr, OperationIr, OperationOutput, PermuteOpIr,
|
||||
QuantizationParametersIr, QuantizeOpIr, SelectOpIr, ShapeOpIr, SliceOpIr, SwapDimsOpIr,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
Fusion, FusionBackend, get_client,
|
||||
stream::{OperationStreams, execution::Operation},
|
||||
};
|
||||
|
||||
use super::NoOp;
|
||||
|
||||
impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
|
||||
fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
|
||||
let client = get_client::<B>(device);
|
||||
let dtype = data.dtype;
|
||||
let tensor = B::q_from_data(data, device);
|
||||
let shape = burn_backend::TensorMetadata::shape(&tensor);
|
||||
|
||||
let handle = B::quantized_tensor_handle(tensor);
|
||||
let desc = InitOperationIr::create(shape, dtype, || client.register_tensor_handle(handle));
|
||||
|
||||
client
|
||||
.register(
|
||||
OperationStreams::default(),
|
||||
OperationIr::Init(desc),
|
||||
NoOp::<B>::new(),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn quantize(
|
||||
tensor: FloatTensor<Self>,
|
||||
scheme: &QuantScheme,
|
||||
qparams: QuantizationParametersPrimitive<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct QuantizeOp<B: FusionBackend> {
|
||||
desc: QuantizeOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for QuantizeOp<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
|
||||
let scales = handles.get_float_tensor::<B>(&self.desc.qparams.scales);
|
||||
|
||||
let qparams = QuantizationParametersPrimitive { scales };
|
||||
let output = B::quantize(tensor, &self.desc.scheme, qparams);
|
||||
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor, &qparams.scales]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let qparams = QuantizationParametersIr {
|
||||
scales: qparams.scales.into_ir(),
|
||||
};
|
||||
let desc = QuantizeOpIr::create(tensor.into_ir(), qparams, *scheme, || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::Float(desc.tensor.dtype, FloatOperationIr::Quantize(desc.clone())),
|
||||
QuantizeOp::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct DequantizeOp<B: FusionBackend> {
|
||||
desc: DequantizeOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for DequantizeOp<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let tensor = handles.get_quantized_tensor::<B>(&self.desc.input);
|
||||
|
||||
let output = B::dequantize(tensor);
|
||||
handles.register_float_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let dtype = B::FloatElem::dtype();
|
||||
let desc = DequantizeOpIr::create(tensor.into_ir(), dtype, || client.create_empty_handle());
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::Float(dtype, FloatOperationIr::Dequantize(desc.clone())),
|
||||
DequantizeOp::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
|
||||
tensor.client.device().clone()
|
||||
}
|
||||
|
||||
fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {
|
||||
let device_original: &B::Device = tensor.client.device();
|
||||
let device_target: B::Device = device.clone();
|
||||
|
||||
if device_original == &device_target {
|
||||
return tensor;
|
||||
}
|
||||
|
||||
let id = tensor.stream;
|
||||
let client_target = get_client::<B>(&device_target);
|
||||
let client_original = tensor.client.clone();
|
||||
|
||||
client_original.change_client_quantized::<B>(tensor.into_ir(), client_target, id)
|
||||
}
|
||||
|
||||
fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
|
||||
if tensor.shape == shape {
|
||||
return tensor;
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct ReshapeDimsOps<B: FusionBackend> {
|
||||
desc: ShapeOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_quantized_tensor::<B>(&self.desc.input);
|
||||
let output = B::q_reshape(input, self.desc.out.shape.clone());
|
||||
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseFloat(BaseOperationIr::Reshape(desc.clone())),
|
||||
ReshapeDimsOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
|
||||
tensor.q_into_data::<B>().await
|
||||
}
|
||||
|
||||
fn q_swap_dims(
|
||||
tensor: QuantizedTensor<Self>,
|
||||
dim1: usize,
|
||||
dim2: usize,
|
||||
) -> QuantizedTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct SwapDimsOps<B: FusionBackend> {
|
||||
desc: SwapDimsOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_quantized_tensor::<B>(&self.desc.input);
|
||||
let output = B::q_swap_dims(input, self.desc.dim1, self.desc.dim2);
|
||||
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc.clone())),
|
||||
SwapDimsOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct PermuteDimsOps<B: FusionBackend> {
|
||||
desc: PermuteOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_quantized_tensor::<B>(&self.desc.input);
|
||||
let output = B::q_permute(input, self.desc.axes.as_slice());
|
||||
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseFloat(BaseOperationIr::Permute(desc.clone())),
|
||||
PermuteDimsOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct FlipOps<B: FusionBackend> {
|
||||
desc: FlipOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_quantized_tensor::<B>(&self.desc.input);
|
||||
let output = B::q_flip(input, &self.desc.axes);
|
||||
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseFloat(BaseOperationIr::Flip(desc.clone())),
|
||||
FlipOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn q_gather(
|
||||
dim: usize,
|
||||
tensor: QuantizedTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct GatherOps<B: FusionBackend> {
|
||||
desc: GatherOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);
|
||||
let indices = handles.get_int_tensor::<B>(&self.desc.indices);
|
||||
|
||||
let output = B::q_gather(self.desc.dim, tensor, indices);
|
||||
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseFloat(BaseOperationIr::Gather(desc.clone())),
|
||||
GatherOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn q_select(
|
||||
tensor: QuantizedTensor<Self>,
|
||||
dim: usize,
|
||||
indices: IntTensor<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct SelectOps<B: FusionBackend> {
|
||||
desc: SelectOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);
|
||||
let indices = handles.get_int_tensor::<B>(&self.desc.indices);
|
||||
|
||||
let output = B::q_select(tensor, self.desc.dim, indices);
|
||||
|
||||
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseFloat(BaseOperationIr::Select(desc.clone())),
|
||||
SelectOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn q_slice(tensor: QuantizedTensor<Self>, slices: &[Slice]) -> QuantizedTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct SliceOps<B: FusionBackend> {
|
||||
desc: SliceOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);
|
||||
|
||||
let output = B::q_slice(tensor, self.desc.ranges.as_slice());
|
||||
|
||||
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
|
||||
client.create_empty_handle()
|
||||
});
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseFloat(BaseOperationIr::Slice(desc.clone())),
|
||||
SliceOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct ExpandOps<B: FusionBackend> {
|
||||
desc: ShapeOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_quantized_tensor::<B>(&self.desc.input);
|
||||
let output = B::q_expand(input, self.desc.out.shape.clone());
|
||||
|
||||
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&tensor]);
|
||||
|
||||
let client = tensor.client.clone();
|
||||
let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
|
||||
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::BaseFloat(BaseOperationIr::Expand(desc.clone())),
|
||||
ExpandOps::<B>::new(desc),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {
|
||||
#[derive(new, Debug)]
|
||||
struct MatmulOps<B: FusionBackend> {
|
||||
desc: MatmulOpIr,
|
||||
lhs_quantized: bool,
|
||||
rhs_quantized: bool,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for MatmulOps<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let lhs = match self.lhs_quantized {
|
||||
true => {
|
||||
TensorPrimitive::QFloat(handles.get_quantized_tensor::<B>(&self.desc.lhs))
|
||||
}
|
||||
false => TensorPrimitive::Float(handles.get_float_tensor::<B>(&self.desc.lhs)),
|
||||
};
|
||||
let rhs = match self.rhs_quantized {
|
||||
true => {
|
||||
TensorPrimitive::QFloat(handles.get_quantized_tensor::<B>(&self.desc.rhs))
|
||||
}
|
||||
false => TensorPrimitive::Float(handles.get_float_tensor::<B>(&self.desc.rhs)),
|
||||
};
|
||||
let output = B::q_matmul(lhs, rhs);
|
||||
match output {
|
||||
TensorPrimitive::Float(output) => {
|
||||
handles.register_float_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
TensorPrimitive::QFloat(output) => {
|
||||
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut propagation = QuantPropagation::Inhibit;
|
||||
let mut scheme = QuantScheme::default();
|
||||
let mut streams = OperationStreams::default();
|
||||
let mut lhs_quantized = false;
|
||||
let mut rhs_quantized = false;
|
||||
match &lhs {
|
||||
TensorPrimitive::QFloat(lhs) => {
|
||||
propagation = lhs.propagation();
|
||||
scheme = *lhs.scheme();
|
||||
lhs_quantized = true;
|
||||
streams.tensor(lhs);
|
||||
}
|
||||
TensorPrimitive::Float(lhs) => {
|
||||
streams.tensor(lhs);
|
||||
}
|
||||
}
|
||||
match &rhs {
|
||||
TensorPrimitive::QFloat(rhs) => {
|
||||
propagation = rhs.propagation();
|
||||
scheme = *rhs.scheme();
|
||||
rhs_quantized = true;
|
||||
streams.tensor(rhs);
|
||||
}
|
||||
TensorPrimitive::Float(rhs) => {
|
||||
streams.tensor(rhs);
|
||||
}
|
||||
}
|
||||
|
||||
let dtype = match propagation {
|
||||
QuantPropagation::Propagate => DType::QFloat(scheme),
|
||||
QuantPropagation::Inhibit => B::FloatElem::dtype(),
|
||||
};
|
||||
|
||||
let client = match &lhs {
|
||||
TensorPrimitive::Float(lhs) => lhs.client.clone(),
|
||||
TensorPrimitive::QFloat(lhs) => lhs.client.clone(),
|
||||
};
|
||||
|
||||
let lhs = match lhs {
|
||||
TensorPrimitive::Float(lhs) => lhs.into_ir(),
|
||||
TensorPrimitive::QFloat(lhs) => lhs.into_ir(),
|
||||
};
|
||||
let rhs = match rhs {
|
||||
TensorPrimitive::Float(rhs) => rhs.into_ir(),
|
||||
TensorPrimitive::QFloat(rhs) => rhs.into_ir(),
|
||||
};
|
||||
|
||||
let desc = MatmulOpIr::create_mixed(lhs, rhs, dtype, || client.create_empty_handle());
|
||||
|
||||
let out = client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::Float(dtype, FloatOperationIr::Matmul(desc.clone())),
|
||||
MatmulOps::<B>::new(desc, lhs_quantized, rhs_quantized),
|
||||
)
|
||||
.output();
|
||||
|
||||
match propagation {
|
||||
QuantPropagation::Propagate => TensorPrimitive::QFloat(out),
|
||||
QuantPropagation::Inhibit => TensorPrimitive::Float(out),
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,36 @@
|
||||
use burn_backend::{
|
||||
backend::ExecutionError,
|
||||
ops::{TransactionOps, TransactionPrimitive},
|
||||
};
|
||||
|
||||
use crate::{Fusion, FusionBackend};
|
||||
|
||||
impl<B: FusionBackend> TransactionOps<Fusion<B>> for Fusion<B> {
|
||||
async fn tr_execute(
|
||||
transaction: TransactionPrimitive<Self>,
|
||||
) -> Result<burn_backend::ops::TransactionPrimitiveData, ExecutionError> {
|
||||
B::tr_execute(TransactionPrimitive::new(
|
||||
transaction
|
||||
.read_floats
|
||||
.into_iter()
|
||||
.map(|t| t.client.clone().resolve_tensor_float::<B>(t))
|
||||
.collect(),
|
||||
transaction
|
||||
.read_qfloats
|
||||
.into_iter()
|
||||
.map(|_t| todo!("Quantization not supported yet"))
|
||||
.collect(),
|
||||
transaction
|
||||
.read_ints
|
||||
.into_iter()
|
||||
.map(|t| t.client.clone().resolve_tensor_int::<B>(t))
|
||||
.collect(),
|
||||
transaction
|
||||
.read_bools
|
||||
.into_iter()
|
||||
.map(|t| t.client.clone().resolve_tensor_bool::<B>(t))
|
||||
.collect(),
|
||||
))
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,319 @@
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_float_ops {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
struct $name<B: FusionBackend> {
|
||||
desc: ScalarOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
|
||||
let output = $ops(lhs, self.desc.rhs.into());
|
||||
|
||||
handles.register_float_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr,
|
||||
noconvert
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
struct $name<B: FusionBackend> {
|
||||
desc: ScalarOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
|
||||
let output = $ops(lhs, self.desc.rhs);
|
||||
|
||||
handles.register_float_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! reduce_float_ops {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
struct $name<B: FusionBackend> {
|
||||
desc: ReduceDimOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_float_tensor::<B>(&self.desc.input);
|
||||
let output = $ops(input, self.desc.axis);
|
||||
|
||||
handles.register_float_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! reduce_float2int_ops {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
struct $name<B: FusionBackend> {
|
||||
desc: ReduceDimOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_float_tensor::<B>(&self.desc.input);
|
||||
let output = $ops(input, self.desc.axis);
|
||||
|
||||
handles.register_int_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! reduce_int_ops {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
struct $name<B: FusionBackend> {
|
||||
desc: ReduceDimOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_int_tensor::<B>(&self.desc.input);
|
||||
let output = $ops(input, self.desc.axis);
|
||||
|
||||
handles.register_int_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_float2int_ops {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr,
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
struct $name<B: FusionBackend> {
|
||||
desc: ScalarOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
|
||||
let output = $ops(lhs, self.desc.rhs.clone());
|
||||
|
||||
handles.register_int_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! unary_float_ops {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
struct $name<B: FusionBackend> {
|
||||
desc: UnaryOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_float_tensor::<B>(&self.desc.input);
|
||||
let output = $ops(input);
|
||||
|
||||
handles.register_float_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr,
|
||||
reduce
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
struct $name<B: FusionBackend> {
|
||||
desc: UnaryOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_float_tensor::<B>(&self.desc.input);
|
||||
let output = $ops(input);
|
||||
|
||||
handles.register_float_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! unary_int_ops {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
struct $name<B: FusionBackend> {
|
||||
desc: UnaryOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_int_tensor::<B>(&self.desc.input);
|
||||
let output = $ops(input);
|
||||
|
||||
handles.register_int_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr,
|
||||
reduce
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
struct $name<B: FusionBackend> {
|
||||
desc: UnaryOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let input = handles.get_int_tensor::<B>(&self.desc.input);
|
||||
let output = $ops(input);
|
||||
|
||||
handles.register_int_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_float_cmp_ops {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
struct $name<B: FusionBackend> {
|
||||
desc: ScalarOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
|
||||
let output = $ops(lhs, self.desc.rhs.into());
|
||||
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_int_cmp_ops {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
struct $name<B: FusionBackend> {
|
||||
desc: ScalarOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);
|
||||
let output = $ops(lhs, self.desc.rhs.into());
|
||||
|
||||
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_int_ops {
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
struct $name<B: FusionBackend> {
|
||||
desc: ScalarOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);
|
||||
let output = $ops(lhs, self.desc.rhs.into());
|
||||
|
||||
handles.register_int_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
(
|
||||
$name:ident,
|
||||
$ops:expr,
|
||||
noconvert
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
struct $name<B: FusionBackend> {
|
||||
desc: ScalarOpIr,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
|
||||
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
|
||||
let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);
|
||||
let output = $ops(lhs, self.desc.rhs);
|
||||
|
||||
handles.register_int_tensor::<B>(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
use crate::{FuserStatus, NumOperations, OperationFuser, stream::store::ExecutionStrategy};
|
||||
use burn_ir::{OperationIr, TensorId, TensorIr};
|
||||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
/// A block represents a list of operations, not necessarily in the same order as the execution
|
||||
/// stream.
|
||||
///
|
||||
/// The start and end position of the relative execution stream are tracked in the block alongside
|
||||
/// the ordering.
|
||||
pub struct Block<O> {
|
||||
builders: Vec<Box<dyn OperationFuser<O>>>,
|
||||
operations: Vec<OperationIr>,
|
||||
ids: HashSet<TensorId>,
|
||||
ordering: Vec<usize>,
|
||||
/// The start position in the relative execution stream.
|
||||
pub start_pos: usize,
|
||||
/// The end position in the relative execution stream.
|
||||
pub end_pos: usize,
|
||||
}
|
||||
|
||||
/// The result of [registering](Block::register) an [operation](OperationIr).
|
||||
pub enum RegistrationResult {
|
||||
/// If the [operation](OperationIr) is correctly registered.
|
||||
Accepted,
|
||||
/// If the [operation](OperationIr) isn't part of the graph.
|
||||
///
|
||||
/// In this case the operation isn't registered.
|
||||
NotPartOfTheGraph,
|
||||
}
|
||||
|
||||
/// The optimization found for a [block](Block).
|
||||
#[derive(Debug, new)]
|
||||
pub struct BlockOptimization<O> {
|
||||
/// The [execution strategy](ExecutionStrategy) to be used to execute the [block](Block).
|
||||
pub strategy: ExecutionStrategy<O>,
|
||||
/// The ordering of each operation in the relative execution stream.
|
||||
pub ordering: Vec<usize>,
|
||||
}
|
||||
|
||||
impl<O: NumOperations> Block<O> {
|
||||
/// Create a new block that will be optimized with the provided [optimization builders](OptimizationBuilder).
|
||||
pub fn new(builders: &[Box<dyn OperationFuser<O>>]) -> Self {
|
||||
Self {
|
||||
builders: builders.iter().map(|o| o.clone_dyn()).collect(),
|
||||
operations: Vec::new(),
|
||||
ids: HashSet::new(),
|
||||
ordering: Vec::new(),
|
||||
start_pos: usize::MAX,
|
||||
end_pos: usize::MIN,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sort the [blocks](Block) based on the start position.
|
||||
pub fn sort(blocks: &mut [Self]) {
|
||||
blocks.sort_by(|a, b| a.start_pos.cmp(&b.start_pos));
|
||||
}
|
||||
|
||||
/// Optimize the block.
|
||||
pub fn optimize(mut self) -> BlockOptimization<O> {
|
||||
match find_best_optimization_index(&mut self.builders) {
|
||||
Some(index) => {
|
||||
let opt = self.builders[index].finish();
|
||||
let opt_len = opt.len();
|
||||
if opt_len < self.operations.len() {
|
||||
self.ordering.drain(opt_len..);
|
||||
}
|
||||
|
||||
let strategy = ExecutionStrategy::Optimization {
|
||||
ordering: Arc::new(self.ordering.clone()),
|
||||
opt,
|
||||
};
|
||||
BlockOptimization::new(strategy, self.ordering)
|
||||
}
|
||||
None => {
|
||||
let strategy = ExecutionStrategy::Operations {
|
||||
ordering: Arc::new(self.ordering.clone()),
|
||||
};
|
||||
BlockOptimization::new(strategy, self.ordering)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns if the block contains any of the provided [tensors](TensorIr).
|
||||
pub fn contains_tensors(&self, tensors: &[&TensorIr]) -> bool {
|
||||
for node in tensors {
|
||||
if self.ids.contains(&node.id) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Merge the current block with the other one and returns if the operation is successful.
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// This will modify the current block even if the other block isn't correctly merged.
|
||||
pub fn merge(&mut self, other: &Block<O>) -> bool {
|
||||
for (op, pos) in other.operations.iter().zip(&other.ordering) {
|
||||
self.register(op, *pos, true);
|
||||
}
|
||||
|
||||
// The operation is successful if the current block can still be optimized.
|
||||
self.still_optimizing()
|
||||
}
|
||||
|
||||
/// Register an [operation](OperationIr) in the current block.
|
||||
///
|
||||
/// You need to provide the order of the operation as well as a force flag.
|
||||
///
|
||||
/// When the force flag is true, the builder will always accept the operation, otherwise it
|
||||
/// might refuse it if the operation [isn't part of the graph](RegistrationResult::NotPartOfTheGraph).
|
||||
///
|
||||
/// Forcing is useful to fuse operations that are part of different graphs, but included
|
||||
/// in the same optimization.
|
||||
pub fn register(
|
||||
&mut self,
|
||||
operation: &OperationIr,
|
||||
order: usize,
|
||||
force: bool,
|
||||
) -> RegistrationResult {
|
||||
if self.ids.is_empty() {
|
||||
self.register_op(operation, order);
|
||||
return RegistrationResult::Accepted;
|
||||
}
|
||||
let mut contains = false;
|
||||
for node in operation.nodes() {
|
||||
contains = self.ids.contains(&node.id);
|
||||
|
||||
if contains {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !contains && !force {
|
||||
return RegistrationResult::NotPartOfTheGraph;
|
||||
}
|
||||
|
||||
self.register_op(operation, order);
|
||||
RegistrationResult::Accepted
|
||||
}
|
||||
|
||||
/// If the block can still be optimized further.
|
||||
pub fn still_optimizing(&self) -> bool {
|
||||
let mut num_stopped = 0;
|
||||
|
||||
for optimization in self.builders.iter() {
|
||||
if let FuserStatus::Closed = optimization.status() {
|
||||
num_stopped += 1
|
||||
}
|
||||
}
|
||||
|
||||
num_stopped < self.builders.len()
|
||||
}
|
||||
|
||||
fn register_op(&mut self, operation: &OperationIr, pos: usize) {
|
||||
self.operations.push(operation.clone());
|
||||
self.ordering.push(pos);
|
||||
|
||||
if pos < self.start_pos {
|
||||
self.start_pos = pos;
|
||||
}
|
||||
if pos + 1 > self.end_pos {
|
||||
self.end_pos = pos + 1;
|
||||
}
|
||||
|
||||
for builder in self.builders.iter_mut() {
|
||||
builder.fuse(operation);
|
||||
}
|
||||
|
||||
for node in operation.nodes() {
|
||||
self.ids.insert(node.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<O> BlockOptimization<O> {
|
||||
/// Maps the ordering of the current block optimization using the given mapping.
|
||||
pub fn map_ordering(&mut self, mapping: &[usize]) {
|
||||
for i in self.ordering.iter_mut() {
|
||||
*i = mapping[*i];
|
||||
}
|
||||
self.strategy.map_ordering(mapping);
|
||||
}
|
||||
}
|
||||
|
||||
impl<O> ExecutionStrategy<O> {
|
||||
/// Maps the ordering of the current execution strategy using the given mapping.
|
||||
pub fn map_ordering(&mut self, mapping: &[usize]) {
|
||||
match self {
|
||||
ExecutionStrategy::Optimization { ordering, .. } => {
|
||||
let mut ordering_mapped = ordering.to_vec();
|
||||
|
||||
for o in ordering_mapped.iter_mut() {
|
||||
*o = mapping[*o];
|
||||
}
|
||||
*ordering = Arc::new(ordering_mapped);
|
||||
}
|
||||
ExecutionStrategy::Operations { ordering } => {
|
||||
let mut ordering_mapped = ordering.to_vec();
|
||||
|
||||
for o in ordering_mapped.iter_mut() {
|
||||
*o = mapping[*o];
|
||||
}
|
||||
|
||||
*ordering = Arc::new(ordering_mapped);
|
||||
}
|
||||
ExecutionStrategy::Composed(items) => {
|
||||
for item in items.iter_mut() {
|
||||
item.map_ordering(mapping);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn find_best_optimization_index<O>(
|
||||
optimizations: &mut [Box<dyn OperationFuser<O>>],
|
||||
) -> Option<usize> {
|
||||
let mut best_index = None;
|
||||
let mut best_score = 0;
|
||||
|
||||
for (i, optimization) in optimizations.iter().enumerate() {
|
||||
let properties = optimization.properties();
|
||||
|
||||
if properties.ready && properties.score >= best_score {
|
||||
best_index = Some(i);
|
||||
best_score = properties.score;
|
||||
}
|
||||
}
|
||||
|
||||
best_index
|
||||
}
|
||||
|
||||
impl<O> PartialEq for Block<O> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
// Since the ordering can be seen as operation ids, we can use it to compare
|
||||
// blocks.
|
||||
let mut sorted_a = self.ordering.clone();
|
||||
let mut sorted_b = other.ordering.clone();
|
||||
sorted_a.sort();
|
||||
sorted_b.sort();
|
||||
|
||||
sorted_a == sorted_b
|
||||
}
|
||||
}
|
||||
|
||||
impl<O> core::fmt::Debug for Block<O> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_fmt(format_args!(
|
||||
"Block {{ pos: [{:?}, {:?}; {:?}] }}",
|
||||
self.start_pos,
|
||||
self.end_pos,
|
||||
self.ordering.len(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl<O> Clone for Block<O> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
builders: self.builders.iter().map(|b| b.clone_dyn()).collect(),
|
||||
operations: self.operations.clone(),
|
||||
ids: self.ids.clone(),
|
||||
ordering: self.ordering.clone(),
|
||||
start_pos: self.start_pos,
|
||||
end_pos: self.end_pos,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,441 @@
|
||||
use super::Block;
|
||||
use crate::NumOperations;
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
/// The result of [merging](merge_blocks) [blocks](Block).
|
||||
pub enum MergeBlocksResult<O> {
|
||||
/// All [blocks](Block) merged into one.
|
||||
Full(Block<O>),
|
||||
/// Some [blocks](Block) merged and some failed.
|
||||
Partial {
|
||||
merged: Vec<Block<O>>,
|
||||
failed: Vec<Block<O>>,
|
||||
},
|
||||
/// All [blocks](Block) failed to merge.
|
||||
Fail,
|
||||
}
|
||||
|
||||
/// Merge multiple [block](Block) together.
|
||||
///
|
||||
/// The resulting [blocks](Block) might be sorted if the flag is true, otherwise the order isn't
|
||||
/// guarantee. This is mostly useful for testing.
|
||||
///
|
||||
/// # Strategy
|
||||
///
|
||||
/// The merging strategy is in two steps:
|
||||
///
|
||||
/// 1. The first step is to recursively try to merge adjacent blocks. This has the advantage of
|
||||
/// trying multiple blocks ordering, therefore trying multiple permutation of the blocks.
|
||||
/// However, it has the downside of not trying to merge blocks that are further away in the list
|
||||
/// of blocks. Since trying all combinations possible is exponential, therefore not possible, we
|
||||
/// fallback on the second strategy.
|
||||
/// 2. The second step is to reduce blocks by setting an accumulator block, then sequentially
|
||||
/// trying to merge the remaining blocks. We try some permutations based on the result from
|
||||
/// step1.
|
||||
pub fn merge_blocks<O: NumOperations>(blocks: &[&Block<O>], sorted: bool) -> MergeBlocksResult<O> {
|
||||
if blocks.is_empty() {
|
||||
return MergeBlocksResult::Fail;
|
||||
}
|
||||
|
||||
if blocks.len() == 1 {
|
||||
return MergeBlocksResult::Full(blocks[0].clone());
|
||||
}
|
||||
|
||||
if blocks.len() == 2 {
|
||||
let block0 = blocks[0];
|
||||
let block1 = blocks[1];
|
||||
|
||||
return match merge_two(block0, block1) {
|
||||
Some(result) => MergeBlocksResult::Full(result),
|
||||
None => MergeBlocksResult::Fail,
|
||||
};
|
||||
}
|
||||
|
||||
let mut step1 = merge_blocks_step1(blocks);
|
||||
|
||||
if step1.full.len() == 1 && step1.failed.is_empty() && step1.partial.is_empty() {
|
||||
MergeBlocksResult::Full(step1.full.remove(0))
|
||||
} else if step1.partial.len() == 1 && step1.failed.is_empty() && step1.full.is_empty() {
|
||||
MergeBlocksResult::Full(step1.partial.remove(0))
|
||||
} else {
|
||||
let result = merge_blocks_step2(step1);
|
||||
|
||||
if !sorted {
|
||||
return result;
|
||||
}
|
||||
|
||||
match result {
|
||||
MergeBlocksResult::Full(block) => MergeBlocksResult::Full(block),
|
||||
MergeBlocksResult::Partial {
|
||||
mut merged,
|
||||
mut failed,
|
||||
} => {
|
||||
Block::sort(&mut merged);
|
||||
Block::sort(&mut failed);
|
||||
|
||||
MergeBlocksResult::Partial { merged, failed }
|
||||
}
|
||||
MergeBlocksResult::Fail => MergeBlocksResult::Fail,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct MergeBlockStep1<O> {
|
||||
full: Vec<Block<O>>,
|
||||
partial: Vec<Block<O>>,
|
||||
failed: Vec<Block<O>>,
|
||||
}
|
||||
|
||||
impl<O> Default for MergeBlockStep1<O> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
full: Default::default(),
|
||||
partial: Default::default(),
|
||||
failed: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_blocks_step1<O: NumOperations>(blocks: &[&Block<O>]) -> MergeBlockStep1<O> {
|
||||
let step_size = blocks.len() / 2;
|
||||
let num_steps = f32::ceil(blocks.len() as f32 / step_size as f32) as usize;
|
||||
|
||||
let mut result = MergeBlockStep1::default();
|
||||
|
||||
for i in 0..num_steps {
|
||||
let start = i * step_size;
|
||||
let end = usize::min(start + step_size, blocks.len());
|
||||
|
||||
match merge_blocks(&blocks[start..end], false) {
|
||||
MergeBlocksResult::Full(block) => {
|
||||
result.full.push(block);
|
||||
}
|
||||
MergeBlocksResult::Partial {
|
||||
mut merged,
|
||||
mut failed,
|
||||
} => {
|
||||
result.partial.append(&mut merged);
|
||||
result.failed.append(&mut failed);
|
||||
}
|
||||
MergeBlocksResult::Fail => {
|
||||
for b in &blocks[start..end] {
|
||||
result.failed.push((*b).clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn merge_blocks_step2<O: NumOperations>(mut step1: MergeBlockStep1<O>) -> MergeBlocksResult<O> {
|
||||
// First let's try to merge partial graphs.
|
||||
if step1.partial.len() > 1 {
|
||||
match merge_accumulator(&step1.partial[0], &step1.partial[1..]) {
|
||||
MergeBlocksResult::Full(block) => {
|
||||
step1.partial = vec![block];
|
||||
}
|
||||
MergeBlocksResult::Partial { merged, mut failed } => {
|
||||
step1.partial = merged;
|
||||
step1.failed.append(&mut failed);
|
||||
}
|
||||
MergeBlocksResult::Fail => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Then let's try to merge partial graphs with failed merges.
|
||||
if !step1.failed.is_empty() {
|
||||
step1.partial.append(&mut step1.failed);
|
||||
match merge_accumulator(&step1.partial[0], &step1.partial[1..]) {
|
||||
MergeBlocksResult::Full(block) => {
|
||||
step1.partial = vec![block];
|
||||
}
|
||||
MergeBlocksResult::Partial { merged, mut failed } => {
|
||||
step1.partial = merged;
|
||||
step1.failed.append(&mut failed);
|
||||
}
|
||||
MergeBlocksResult::Fail => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Then let's try to merge full graphs.
|
||||
if step1.full.len() > 1 {
|
||||
match merge_accumulator(&step1.full[0], &step1.full[1..]) {
|
||||
MergeBlocksResult::Full(block) => {
|
||||
step1.full = vec![block];
|
||||
}
|
||||
MergeBlocksResult::Partial { merged, mut failed } => {
|
||||
step1.full = merged;
|
||||
step1.failed.append(&mut failed);
|
||||
}
|
||||
MergeBlocksResult::Fail => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Then let's try to merge full graphs with failed graphs.
|
||||
if !step1.full.is_empty() {
|
||||
step1.full.append(&mut step1.failed);
|
||||
match merge_accumulator(&step1.full[0], &step1.full[1..]) {
|
||||
MergeBlocksResult::Full(block) => {
|
||||
step1.full = vec![block];
|
||||
}
|
||||
MergeBlocksResult::Partial { merged, mut failed } => {
|
||||
step1.full = merged;
|
||||
step1.failed.append(&mut failed);
|
||||
}
|
||||
MergeBlocksResult::Fail => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Then let's try to merge full graphs with partial graphs.
|
||||
if !step1.full.is_empty() || !step1.partial.is_empty() {
|
||||
step1.full.append(&mut step1.partial);
|
||||
match merge_accumulator(&step1.full[0], &step1.full[1..]) {
|
||||
MergeBlocksResult::Full(block) => {
|
||||
step1.full = vec![block];
|
||||
}
|
||||
MergeBlocksResult::Partial { merged, mut failed } => {
|
||||
step1.full = merged;
|
||||
step1.failed.append(&mut failed);
|
||||
}
|
||||
MergeBlocksResult::Fail => {
|
||||
// We do nothing.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if step1.full.is_empty() {
|
||||
MergeBlocksResult::Fail
|
||||
} else if step1.failed.is_empty() {
|
||||
if step1.full.len() == 1 {
|
||||
MergeBlocksResult::Full(step1.full.remove(0))
|
||||
} else {
|
||||
MergeBlocksResult::Partial {
|
||||
merged: step1.full,
|
||||
failed: vec![],
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MergeBlocksResult::Partial {
|
||||
merged: step1.full,
|
||||
failed: step1.failed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_accumulator<O: NumOperations>(
|
||||
base: &Block<O>,
|
||||
blocks: &[Block<O>],
|
||||
) -> MergeBlocksResult<O> {
|
||||
let mut base = base.clone();
|
||||
let mut merged_failed = Vec::<Block<O>>::new();
|
||||
let mut merged_success = false;
|
||||
|
||||
for block in blocks {
|
||||
let mut base_current = base.clone();
|
||||
match base_current.merge(block) {
|
||||
false => {
|
||||
merged_failed.push((*block).clone());
|
||||
}
|
||||
true => {
|
||||
merged_success = true;
|
||||
base = base_current;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if merged_success {
|
||||
if merged_failed.is_empty() {
|
||||
MergeBlocksResult::Full(base)
|
||||
} else {
|
||||
MergeBlocksResult::Partial {
|
||||
merged: vec![base],
|
||||
failed: merged_failed,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MergeBlocksResult::Fail
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_two<O: NumOperations>(a: &Block<O>, b: &Block<O>) -> Option<Block<O>> {
|
||||
let mut base = a.clone();
|
||||
|
||||
if base.merge(b) {
|
||||
return Some(base);
|
||||
}
|
||||
|
||||
let mut base = b.clone();
|
||||
|
||||
match base.merge(a) {
|
||||
true => Some(base),
|
||||
false => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
pub use crate::stream::execution::tests::{TestOptimization, TestOptimizationBuilder};
|
||||
use crate::{
|
||||
OperationFuser,
|
||||
stream::tests::{operation_1, operation_2, operation_3},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_merge_blocks_no_block() {
|
||||
let actual = merge_blocks::<TestOptimization>(&[], true);
|
||||
|
||||
assert_eq!(actual, MergeBlocksResult::Fail);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_blocks_single() {
|
||||
let builders = builders();
|
||||
let block = Block::new(&builders);
|
||||
let actual = merge_blocks::<TestOptimization>(&[&block], true);
|
||||
|
||||
assert_eq!(actual, MergeBlocksResult::Full(block));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_blocks_two_blocks() {
|
||||
let builders = builders();
|
||||
let mut block1 = Block::new(&builders);
|
||||
let mut block2 = Block::new(&builders);
|
||||
block1.register(&operation_1(), 0, false);
|
||||
block1.register(&operation_1(), 1, false);
|
||||
block2.register(&operation_1(), 2, false);
|
||||
block2.register(&operation_1(), 3, false);
|
||||
|
||||
let actual = merge_blocks::<TestOptimization>(&[&block1, &block2], true);
|
||||
|
||||
let mut expected = Block::new(&builders);
|
||||
expected.register(&operation_1(), 0, false);
|
||||
expected.register(&operation_1(), 1, false);
|
||||
expected.register(&operation_1(), 2, false);
|
||||
expected.register(&operation_1(), 3, false);
|
||||
|
||||
assert_eq!(actual, MergeBlocksResult::Full(expected));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_blocks_three_blocks() {
|
||||
let builders = builders();
|
||||
let mut block1 = Block::new(&builders);
|
||||
let mut block2 = Block::new(&builders);
|
||||
let mut block3 = Block::new(&builders);
|
||||
block1.register(&operation_1(), 0, false);
|
||||
block2.register(&operation_1(), 1, false);
|
||||
block3.register(&operation_1(), 2, false);
|
||||
|
||||
let actual = merge_blocks::<TestOptimization>(&[&block1, &block2, &block3], true);
|
||||
|
||||
let mut expected = Block::new(&builders);
|
||||
expected.register(&operation_1(), 0, false);
|
||||
expected.register(&operation_1(), 1, false);
|
||||
expected.register(&operation_1(), 2, false);
|
||||
|
||||
assert_eq!(actual, MergeBlocksResult::Full(expected));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_blocks_three_blocks_partial() {
|
||||
let builders = builders();
|
||||
let mut block1 = Block::new(&builders);
|
||||
let mut block2 = Block::new(&builders);
|
||||
let mut block3 = Block::new(&builders);
|
||||
block1.register(&operation_1(), 0, false);
|
||||
block2.register(&operation_2(), 1, false);
|
||||
block3.register(&operation_1(), 2, false);
|
||||
|
||||
let actual = merge_blocks::<TestOptimization>(&[&block1, &block2, &block3], true);
|
||||
|
||||
let mut expected1 = Block::new(&builders);
|
||||
let mut expected2 = Block::new(&builders);
|
||||
expected1.register(&operation_1(), 0, false);
|
||||
expected1.register(&operation_1(), 2, false);
|
||||
expected2.register(&operation_2(), 1, false);
|
||||
|
||||
assert_eq!(
|
||||
actual,
|
||||
MergeBlocksResult::Partial {
|
||||
merged: vec![expected1, expected2],
|
||||
failed: vec![]
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_blocks_four_blocks_partial_with_failure() {
|
||||
let builders = builders();
|
||||
let mut block1 = Block::new(&builders);
|
||||
let mut block2 = Block::new(&builders);
|
||||
let mut block3 = Block::new(&builders);
|
||||
let mut block4 = Block::new(&builders);
|
||||
block1.register(&operation_1(), 0, false);
|
||||
block2.register(&operation_2(), 1, false);
|
||||
block3.register(&operation_1(), 2, false);
|
||||
block4.register(&operation_3(), 3, false);
|
||||
|
||||
let actual = merge_blocks::<TestOptimization>(&[&block1, &block2, &block3, &block4], true);
|
||||
|
||||
let mut expected1 = Block::new(&builders);
|
||||
let mut expected2 = Block::new(&builders);
|
||||
let mut failed = Block::new(&builders);
|
||||
expected1.register(&operation_1(), 0, false);
|
||||
expected1.register(&operation_1(), 2, false);
|
||||
expected2.register(&operation_2(), 1, false);
|
||||
failed.register(&operation_3(), 3, false);
|
||||
|
||||
assert_eq!(
|
||||
actual,
|
||||
MergeBlocksResult::Partial {
|
||||
merged: vec![expected1],
|
||||
failed: vec![expected2, failed]
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_blocks_five_blocks_partial_with_failure() {
|
||||
let builders = builders();
|
||||
let mut block1 = Block::new(&builders);
|
||||
let mut block2 = Block::new(&builders);
|
||||
let mut block3 = Block::new(&builders);
|
||||
let mut block4 = Block::new(&builders);
|
||||
let mut block5 = Block::new(&builders);
|
||||
block1.register(&operation_1(), 0, false);
|
||||
block2.register(&operation_2(), 1, false);
|
||||
block3.register(&operation_1(), 2, false);
|
||||
block4.register(&operation_3(), 3, false);
|
||||
block5.register(&operation_2(), 4, false);
|
||||
|
||||
let actual =
|
||||
merge_blocks::<TestOptimization>(&[&block1, &block2, &block3, &block4, &block5], true);
|
||||
|
||||
let mut expected1 = Block::new(&builders);
|
||||
let mut expected2 = Block::new(&builders);
|
||||
let mut failed = Block::new(&builders);
|
||||
expected1.register(&operation_1(), 0, false);
|
||||
expected1.register(&operation_1(), 2, false);
|
||||
expected2.register(&operation_2(), 1, false);
|
||||
expected2.register(&operation_2(), 4, false);
|
||||
failed.register(&operation_3(), 3, false);
|
||||
|
||||
assert_eq!(
|
||||
actual,
|
||||
MergeBlocksResult::Partial {
|
||||
merged: vec![expected1, expected2],
|
||||
failed: vec![failed]
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
fn builders() -> Vec<Box<dyn OperationFuser<TestOptimization>>> {
|
||||
let builder_1 = TestOptimizationBuilder::new(0, vec![operation_1(); 10]);
|
||||
let builder_2 = TestOptimizationBuilder::new(1, vec![operation_2(); 10]);
|
||||
|
||||
vec![Box::new(builder_1), Box::new(builder_2)]
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
mod block;
|
||||
mod optimization;
|
||||
|
||||
pub(super) mod merging;
|
||||
pub(super) use block::*;
|
||||
|
||||
pub use optimization::*;
|
||||
@@ -0,0 +1,232 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
NumOperations,
|
||||
search::{
|
||||
Block, BlockOptimization,
|
||||
merging::{MergeBlocksResult, merge_blocks},
|
||||
},
|
||||
stream::store::ExecutionStrategy,
|
||||
};
|
||||
|
||||
/// Try to optimize a list of [blocks](Block) into a [block optimization](BlockOptimization).
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// What we know here is that every block is independent at that time and can be executed
|
||||
/// in any order.
|
||||
///
|
||||
/// The contract is that the length of operations executed must include all operations. If we don't
|
||||
/// find an optimization that can be executed with that constraint, we return a
|
||||
/// [BlocksOptimizerResult::WithHoles].
|
||||
pub struct BlocksOptimizer<O> {
|
||||
blocks: Vec<Block<O>>,
|
||||
resolved: Vec<bool>,
|
||||
last_checked: usize,
|
||||
}
|
||||
|
||||
/// When we can't find a proper optimization for the provided list of [blocks](Block).
|
||||
pub enum BlocksOptimizerResult<O> {
|
||||
/// When an optimization fill the hole stream.
|
||||
Full(BlockOptimization<O>),
|
||||
/// The optimization found with the holes indices.
|
||||
WithHoles {
|
||||
strategies: Vec<Box<ExecutionStrategy<O>>>,
|
||||
ordering: Vec<usize>,
|
||||
holes: Vec<usize>,
|
||||
},
|
||||
}
|
||||
|
||||
enum BlockOptimizationStep<O> {
|
||||
Contiguous {
|
||||
strategy: ExecutionStrategy<O>,
|
||||
},
|
||||
/// Only happen when we fallback on executing a single operation.
|
||||
Operation {
|
||||
strategy: ExecutionStrategy<O>,
|
||||
},
|
||||
WithHoles {
|
||||
strategy: ExecutionStrategy<O>,
|
||||
holes: Vec<usize>,
|
||||
},
|
||||
Stop,
|
||||
}
|
||||
|
||||
impl<O: NumOperations> BlocksOptimizer<O> {
|
||||
/// Create a new optimizer with the given blocks.
|
||||
pub fn new(blocks: Vec<Block<O>>) -> Self {
|
||||
let num_ops: usize = blocks.iter().map(|g| g.end_pos).max().unwrap();
|
||||
|
||||
Self {
|
||||
blocks,
|
||||
resolved: vec![false; num_ops],
|
||||
last_checked: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Optimizes the blocks.
|
||||
///
|
||||
/// The strategy is quite simple. We try to merge as much [blocks](Block) together as we can,
|
||||
/// then we iterate over them in order composing optimizations with the remaining blocks, all
|
||||
/// while minimizing fallbacks operations to avoid having holes in the optimization stream.
|
||||
pub fn optimize(mut self) -> BlocksOptimizerResult<O> {
|
||||
self = self.merging_pass();
|
||||
|
||||
let mut strategies = Vec::with_capacity(self.blocks.len());
|
||||
let mut ordering = Vec::new();
|
||||
let mut blocks = Vec::new();
|
||||
core::mem::swap(&mut blocks, &mut self.blocks);
|
||||
|
||||
for block in blocks {
|
||||
match self.optimize_block(block, &mut ordering) {
|
||||
BlockOptimizationStep::Contiguous { strategy } => {
|
||||
strategies.push(Box::new(strategy));
|
||||
}
|
||||
BlockOptimizationStep::Operation { strategy } => {
|
||||
strategies.push(Box::new(strategy));
|
||||
break;
|
||||
}
|
||||
BlockOptimizationStep::WithHoles { strategy, holes } => {
|
||||
strategies.push(Box::new(strategy));
|
||||
|
||||
return BlocksOptimizerResult::WithHoles {
|
||||
strategies,
|
||||
ordering,
|
||||
holes,
|
||||
};
|
||||
}
|
||||
BlockOptimizationStep::Stop => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let optimization = match strategies.len() > 1 {
|
||||
true => BlockOptimization {
|
||||
strategy: ExecutionStrategy::Composed(strategies),
|
||||
ordering,
|
||||
},
|
||||
false => BlockOptimization {
|
||||
strategy: *strategies.remove(0),
|
||||
ordering,
|
||||
},
|
||||
};
|
||||
|
||||
BlocksOptimizerResult::Full(optimization)
|
||||
}
|
||||
|
||||
/// Optimize a single block.
|
||||
fn optimize_block(
|
||||
&mut self,
|
||||
block: Block<O>,
|
||||
ordering: &mut Vec<usize>,
|
||||
) -> BlockOptimizationStep<O> {
|
||||
let last_index = block.end_pos;
|
||||
let mut block_optimization = block.optimize();
|
||||
let opt_size = block_optimization.ordering.len();
|
||||
|
||||
for pos in block_optimization.ordering.iter() {
|
||||
self.update_check(*pos);
|
||||
}
|
||||
|
||||
if self.last_checked != ordering.len() + opt_size {
|
||||
if !ordering.is_empty() {
|
||||
// Don't include that block and need further exploring.
|
||||
return BlockOptimizationStep::Stop;
|
||||
}
|
||||
|
||||
return self.optimize_holes(block_optimization, last_index, ordering);
|
||||
}
|
||||
|
||||
ordering.append(&mut block_optimization.ordering);
|
||||
BlockOptimizationStep::Contiguous {
|
||||
strategy: block_optimization.strategy,
|
||||
}
|
||||
}
|
||||
|
||||
/// The provided optimization has holes.
|
||||
fn optimize_holes(
|
||||
&mut self,
|
||||
mut optimization: BlockOptimization<O>,
|
||||
last_index: usize,
|
||||
ordering_global: &mut Vec<usize>,
|
||||
) -> BlockOptimizationStep<O> {
|
||||
match optimization.strategy {
|
||||
ExecutionStrategy::Optimization { opt, ordering } => {
|
||||
ordering_global.append(&mut optimization.ordering);
|
||||
let holes = self.find_holes(last_index);
|
||||
|
||||
if holes.is_empty() {
|
||||
let strategy = ExecutionStrategy::Optimization { opt, ordering };
|
||||
BlockOptimizationStep::Contiguous { strategy }
|
||||
} else {
|
||||
let strategy = ExecutionStrategy::Optimization { opt, ordering };
|
||||
BlockOptimizationStep::WithHoles { strategy, holes }
|
||||
}
|
||||
}
|
||||
ExecutionStrategy::Operations { ordering } => {
|
||||
let min = ordering.iter().min().unwrap();
|
||||
ordering_global.push(*min);
|
||||
|
||||
let strategy = ExecutionStrategy::Operations {
|
||||
ordering: Arc::new(vec![*min]),
|
||||
};
|
||||
BlockOptimizationStep::Operation { strategy }
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn update_check(&mut self, pos: usize) {
|
||||
self.resolved[pos] = true;
|
||||
|
||||
for i in self.last_checked..self.resolved.len() {
|
||||
if self.resolved[i] {
|
||||
self.last_checked += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn find_holes(&mut self, last: usize) -> Vec<usize> {
|
||||
let mut fallbacks = Vec::new();
|
||||
|
||||
for i in self.last_checked..last {
|
||||
if !self.resolved[i] {
|
||||
fallbacks.push(i);
|
||||
self.resolved[i] = true;
|
||||
}
|
||||
self.last_checked += 1;
|
||||
}
|
||||
|
||||
fallbacks
|
||||
}
|
||||
|
||||
/// Try to merge blocks together.
|
||||
fn merging_pass(mut self) -> Self {
|
||||
if self.blocks.len() == 1 {
|
||||
return self;
|
||||
}
|
||||
|
||||
Block::sort(&mut self.blocks);
|
||||
let blocks = self.blocks.iter().collect::<Vec<_>>();
|
||||
|
||||
match merge_blocks(&blocks, false) {
|
||||
MergeBlocksResult::Full(block) => {
|
||||
self.blocks = vec![block];
|
||||
}
|
||||
MergeBlocksResult::Partial {
|
||||
mut merged,
|
||||
mut failed,
|
||||
} => {
|
||||
merged.append(&mut failed);
|
||||
self.blocks = merged;
|
||||
Block::sort(&mut self.blocks);
|
||||
}
|
||||
MergeBlocksResult::Fail => {}
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
mod blocks;
|
||||
mod stream;
|
||||
|
||||
pub use stream::*;
|
||||
@@ -0,0 +1,277 @@
|
||||
use super::blocks::BlocksOptimizer;
|
||||
use crate::{
|
||||
NumOperations, OperationFuser,
|
||||
search::{
|
||||
Block, BlockOptimization, RegistrationResult,
|
||||
merging::{MergeBlocksResult, merge_blocks},
|
||||
optimization::blocks::BlocksOptimizerResult,
|
||||
},
|
||||
stream::store::ExecutionStrategy,
|
||||
};
|
||||
use burn_ir::OperationIr;
|
||||
|
||||
/// Optimize a stream of [operations](OperationIr) using a list of [builders](OptimizationBuilder).
|
||||
pub struct StreamOptimizer<O> {
|
||||
builders: Vec<Box<dyn OperationFuser<O>>>,
|
||||
blocks: Vec<Block<O>>,
|
||||
length: usize,
|
||||
stopped: bool,
|
||||
max_blocks: Option<usize>,
|
||||
}
|
||||
|
||||
impl<O: NumOperations> StreamOptimizer<O> {
|
||||
/// Create a new stream optimizer.
|
||||
pub fn new(builders: Vec<Box<dyn OperationFuser<O>>>) -> Self {
|
||||
Self {
|
||||
builders,
|
||||
blocks: Vec::new(),
|
||||
length: 0,
|
||||
stopped: false,
|
||||
// Too high and it may breaks the fusion cache always retriggering explorations.
|
||||
max_blocks: Some(5),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a new [operation](OperationIr) in the optimizer.
|
||||
///
|
||||
/// You can use the function [Self::still_optimizing] to know if the operations are actually
|
||||
/// being registered.
|
||||
pub fn register(&mut self, operation: &OperationIr) {
|
||||
if self.stopped {
|
||||
return;
|
||||
}
|
||||
|
||||
if self.blocks.is_empty() {
|
||||
self.on_new_block(operation);
|
||||
self.length += 1;
|
||||
return;
|
||||
}
|
||||
|
||||
match self.merge_blocks(operation, false) {
|
||||
MergeBlockStep::Full | MergeBlockStep::NoNeed => {}
|
||||
MergeBlockStep::Fail | MergeBlockStep::Partial => {
|
||||
// With the given operation, blocks are no longer independent.
|
||||
self.stopped = true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(max_blocks) = self.max_blocks {
|
||||
if self.register_max_block(operation, max_blocks) {
|
||||
self.length += 1;
|
||||
} else {
|
||||
self.stopped = true;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let added_count = self.register_inner(operation, false);
|
||||
if added_count == 0 {
|
||||
self.on_new_block(operation);
|
||||
}
|
||||
|
||||
self.length += 1;
|
||||
}
|
||||
|
||||
/// Optimize the current stream on the given [operations](OperationIr).
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The operations provided are the same as the ones used in the [register](Self::register)
|
||||
/// method, this simply remove the need for the current type to also keep track of the list of
|
||||
/// operations.
|
||||
pub fn optimize(&self, operations: &[OperationIr]) -> BlockOptimization<O> {
|
||||
let result = BlocksOptimizer::new(self.blocks.clone()).optimize();
|
||||
|
||||
match result {
|
||||
BlocksOptimizerResult::Full(block_optimization) => block_optimization,
|
||||
BlocksOptimizerResult::WithHoles {
|
||||
mut strategies,
|
||||
mut ordering,
|
||||
mut holes,
|
||||
} => {
|
||||
loop {
|
||||
let mut search = self.new_empty_search();
|
||||
|
||||
let mut operations_holes = Vec::with_capacity(holes.len());
|
||||
|
||||
for index in holes.iter() {
|
||||
let op = &operations[*index];
|
||||
operations_holes.push(op.clone());
|
||||
search.register(op);
|
||||
}
|
||||
|
||||
let mut optimization_of_holes = search.optimize(&operations_holes);
|
||||
|
||||
optimization_of_holes.map_ordering(&holes);
|
||||
|
||||
strategies.push(Box::new(optimization_of_holes.strategy));
|
||||
holes.drain(0..optimization_of_holes.ordering.len());
|
||||
ordering.append(&mut optimization_of_holes.ordering);
|
||||
|
||||
if holes.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
BlockOptimization::new(ExecutionStrategy::Composed(strategies), ordering)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the state of the optimizer.
|
||||
pub fn reset(&mut self) {
|
||||
self.builders.iter_mut().for_each(|b| b.reset());
|
||||
self.length = 0;
|
||||
self.blocks.clear();
|
||||
self.stopped = false;
|
||||
}
|
||||
|
||||
/// Returns if some optimizations are still possible within the stream.
|
||||
pub fn still_optimizing(&self) -> bool {
|
||||
if self.stopped {
|
||||
return false;
|
||||
}
|
||||
if self.blocks.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
let mut num_stopped = 0;
|
||||
|
||||
for block in self.blocks.iter() {
|
||||
if !block.still_optimizing() {
|
||||
num_stopped += 1
|
||||
}
|
||||
}
|
||||
|
||||
num_stopped < self.blocks.len()
|
||||
}
|
||||
|
||||
fn register_max_block(&mut self, operation: &OperationIr, max_blocks: usize) -> bool {
|
||||
if max_blocks == 1 {
|
||||
// Register in the single block with a force.
|
||||
self.register_inner(operation, true);
|
||||
return true;
|
||||
}
|
||||
let added_count = self.register_inner(operation, false);
|
||||
|
||||
if added_count > 0 {
|
||||
return true;
|
||||
}
|
||||
|
||||
if added_count == 0 && self.blocks.len() < max_blocks {
|
||||
self.on_new_block(operation);
|
||||
return true;
|
||||
}
|
||||
|
||||
self.merge_blocks(operation, true);
|
||||
|
||||
if self.blocks.len() >= max_blocks {
|
||||
self.stopped = true;
|
||||
return false;
|
||||
}
|
||||
|
||||
let added_count = self.register_inner(operation, false);
|
||||
|
||||
if added_count == 0 {
|
||||
self.on_new_block(operation);
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn register_inner(&mut self, operation: &OperationIr, force: bool) -> usize {
|
||||
let mut added_count = 0;
|
||||
for block in self.blocks.iter_mut() {
|
||||
match block.register(operation, self.length, force) {
|
||||
RegistrationResult::Accepted => {
|
||||
added_count += 1;
|
||||
}
|
||||
RegistrationResult::NotPartOfTheGraph => {}
|
||||
}
|
||||
}
|
||||
added_count
|
||||
}
|
||||
|
||||
fn new_empty_search(&self) -> Self {
|
||||
Self::new(
|
||||
self.builders
|
||||
.iter()
|
||||
.map(|b| {
|
||||
let mut b = b.clone_dyn();
|
||||
b.reset();
|
||||
b
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn merge_blocks(&mut self, operation: &OperationIr, all: bool) -> MergeBlockStep {
|
||||
let nodes = operation.nodes();
|
||||
let mut block_merges = Vec::new();
|
||||
|
||||
for (i, block) in self.blocks.iter().enumerate() {
|
||||
if all || block.contains_tensors(&nodes) {
|
||||
block_merges.push(i);
|
||||
}
|
||||
}
|
||||
|
||||
if block_merges.len() <= 1 {
|
||||
return MergeBlockStep::NoNeed;
|
||||
}
|
||||
|
||||
let blocks_to_merge = self
|
||||
.blocks
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, g)| match block_merges.contains(&i) {
|
||||
true => Some(g),
|
||||
false => None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let merged = merge_blocks(&blocks_to_merge, false);
|
||||
|
||||
let mut clear_blocks = || {
|
||||
let mut indices = block_merges.to_vec();
|
||||
indices.sort();
|
||||
|
||||
for g in indices.into_iter().rev() {
|
||||
self.blocks.remove(g);
|
||||
}
|
||||
};
|
||||
|
||||
match merged {
|
||||
MergeBlocksResult::Full(block) => {
|
||||
clear_blocks();
|
||||
self.blocks.push(block);
|
||||
Block::sort(&mut self.blocks);
|
||||
MergeBlockStep::Full
|
||||
}
|
||||
MergeBlocksResult::Partial {
|
||||
mut merged,
|
||||
mut failed,
|
||||
} => {
|
||||
clear_blocks();
|
||||
self.blocks.append(&mut merged);
|
||||
self.blocks.append(&mut failed);
|
||||
Block::sort(&mut self.blocks);
|
||||
MergeBlockStep::Partial
|
||||
}
|
||||
MergeBlocksResult::Fail => MergeBlockStep::Fail,
|
||||
}
|
||||
}
|
||||
|
||||
fn on_new_block(&mut self, operation: &OperationIr) {
|
||||
let mut block = Block::new(&self.builders);
|
||||
block.register(operation, self.length, true);
|
||||
self.blocks.push(block);
|
||||
}
|
||||
}
|
||||
|
||||
enum MergeBlockStep {
|
||||
Full,
|
||||
Partial,
|
||||
Fail,
|
||||
NoNeed,
|
||||
}
|
||||
@@ -0,0 +1,215 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
FusionBackend, FusionRuntime,
|
||||
stream::{MultiStream, OperationStreams, StreamId, execution::Operation},
|
||||
};
|
||||
use burn_backend::{TensorData, backend::ExecutionError};
|
||||
use burn_ir::{HandleContainer, OperationIr, TensorId, TensorIr};
|
||||
|
||||
pub struct FusionServer<R: FusionRuntime> {
|
||||
streams: MultiStream<R>,
|
||||
pub(crate) handles: HandleContainer<R::FusionHandle>,
|
||||
}
|
||||
|
||||
impl<R> FusionServer<R>
|
||||
where
|
||||
R: FusionRuntime,
|
||||
{
|
||||
pub fn new(device: R::FusionDevice) -> Self {
|
||||
Self {
|
||||
streams: MultiStream::new(device.clone()),
|
||||
handles: HandleContainer::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register(
|
||||
&mut self,
|
||||
streams: OperationStreams,
|
||||
repr: OperationIr,
|
||||
operation: Arc<dyn Operation<R>>,
|
||||
) {
|
||||
self.streams
|
||||
.register(streams, repr, operation, &mut self.handles)
|
||||
}
|
||||
|
||||
pub fn drain_stream(&mut self, id: StreamId) {
|
||||
self.streams.drain(&mut self.handles, id)
|
||||
}
|
||||
|
||||
pub fn create_empty_handle(&mut self) -> TensorId {
|
||||
self.handles.create_tensor_uninit()
|
||||
}
|
||||
|
||||
pub fn read_float<B>(
|
||||
&mut self,
|
||||
tensor: TensorIr,
|
||||
id: StreamId,
|
||||
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send + use<R, B>
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
// Make sure all registered operations are executed.
|
||||
// The underlying backend can still be async.
|
||||
self.drain_stream(id);
|
||||
let tensor_float = self.handles.get_float_tensor::<B>(&tensor);
|
||||
self.streams.mark_read(id, &tensor, &self.handles);
|
||||
B::float_into_data(tensor_float)
|
||||
}
|
||||
|
||||
pub fn read_int<B>(
|
||||
&mut self,
|
||||
tensor: TensorIr,
|
||||
id: StreamId,
|
||||
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send + use<R, B>
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
// Make sure all registered operations are executed.
|
||||
// The underlying backend can still be async.
|
||||
self.drain_stream(id);
|
||||
let tensor_int = self.handles.get_int_tensor::<B>(&tensor);
|
||||
self.streams.mark_read(id, &tensor, &self.handles);
|
||||
B::int_into_data(tensor_int)
|
||||
}
|
||||
|
||||
pub fn read_bool<B>(
|
||||
&mut self,
|
||||
tensor: TensorIr,
|
||||
id: StreamId,
|
||||
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send + use<R, B>
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
// Make sure all registered operations are executed.
|
||||
// The underlying backend can still be async.
|
||||
self.drain_stream(id);
|
||||
let tensor_bool = self.handles.get_bool_tensor::<B>(&tensor);
|
||||
self.streams.mark_read(id, &tensor, &self.handles);
|
||||
B::bool_into_data(tensor_bool)
|
||||
}
|
||||
|
||||
pub fn read_quantized<B>(
|
||||
&mut self,
|
||||
tensor: TensorIr,
|
||||
id: StreamId,
|
||||
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send + use<R, B>
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
// Make sure all registered operations are executed.
|
||||
// The underlying backend can still be async.
|
||||
self.drain_stream(id);
|
||||
let tensor_q = self.handles.get_quantized_tensor::<B>(&tensor);
|
||||
self.streams.mark_read(id, &tensor, &self.handles);
|
||||
B::q_into_data(tensor_q)
|
||||
}
|
||||
|
||||
pub fn change_server_float<B>(
|
||||
&mut self,
|
||||
tensor: &TensorIr,
|
||||
stream_tensor: StreamId,
|
||||
device: &R::FusionDevice,
|
||||
server_device: &mut Self,
|
||||
) -> TensorId
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
let tensor_float = self.handles.get_float_tensor::<B>(tensor);
|
||||
self.streams.mark_read(stream_tensor, tensor, &self.handles);
|
||||
|
||||
let tensor = B::float_to_device(tensor_float, device);
|
||||
let id = server_device.create_empty_handle();
|
||||
|
||||
server_device
|
||||
.handles
|
||||
.register_float_tensor::<B>(&id, tensor.clone());
|
||||
|
||||
id
|
||||
}
|
||||
|
||||
pub fn resolve_server_float<B>(&mut self, tensor: &TensorIr) -> B::FloatTensorPrimitive
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
self.handles.get_float_tensor::<B>(tensor)
|
||||
}
|
||||
|
||||
pub fn resolve_server_int<B>(&mut self, tensor: &TensorIr) -> B::IntTensorPrimitive
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
self.handles.get_int_tensor::<B>(tensor)
|
||||
}
|
||||
|
||||
pub fn resolve_server_bool<B>(&mut self, tensor: &TensorIr) -> B::BoolTensorPrimitive
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
self.handles.get_bool_tensor::<B>(tensor)
|
||||
}
|
||||
|
||||
pub fn change_server_int<B>(
|
||||
&mut self,
|
||||
tensor: &TensorIr,
|
||||
stream_tensor: StreamId,
|
||||
device: &R::FusionDevice,
|
||||
server_device: &mut Self,
|
||||
) -> TensorId
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
let tensor_int = self.handles.get_int_tensor::<B>(tensor);
|
||||
self.streams.mark_read(stream_tensor, tensor, &self.handles);
|
||||
let tensor = B::int_to_device(tensor_int, device);
|
||||
let id = server_device.create_empty_handle();
|
||||
|
||||
server_device
|
||||
.handles
|
||||
.register_int_tensor::<B>(&id, tensor.clone());
|
||||
|
||||
id
|
||||
}
|
||||
|
||||
pub fn change_server_bool<B>(
|
||||
&mut self,
|
||||
tensor: &TensorIr,
|
||||
stream_tensor: StreamId,
|
||||
device: &R::FusionDevice,
|
||||
server_device: &mut Self,
|
||||
) -> TensorId
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
let tensor_bool = self.handles.get_bool_tensor::<B>(tensor);
|
||||
self.streams.mark_read(stream_tensor, tensor, &self.handles);
|
||||
let tensor = B::bool_to_device(tensor_bool, device);
|
||||
let id = server_device.create_empty_handle();
|
||||
|
||||
server_device
|
||||
.handles
|
||||
.register_bool_tensor::<B>(&id, tensor.clone());
|
||||
|
||||
id
|
||||
}
|
||||
|
||||
pub fn change_server_quantized<B>(
|
||||
&mut self,
|
||||
tensor: &TensorIr,
|
||||
device: &R::FusionDevice,
|
||||
server_device: &mut Self,
|
||||
) -> TensorId
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
let tensor = self.handles.get_quantized_tensor::<B>(tensor);
|
||||
let tensor = B::q_to_device(tensor, device);
|
||||
let id = server_device.create_empty_handle();
|
||||
|
||||
server_device
|
||||
.handles
|
||||
.register_quantized_tensor::<B>(&id, tensor);
|
||||
|
||||
id
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
pub use burn_backend::StreamId;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,16 @@
|
||||
use burn_ir::HandleContainer;
|
||||
|
||||
use crate::FusionRuntime;
|
||||
|
||||
/// The mode in which the execution is done.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub(crate) enum ExecutionMode {
|
||||
Lazy,
|
||||
Sync,
|
||||
}
|
||||
|
||||
/// General trait to abstract how a single operation is executed.
|
||||
pub trait Operation<R: FusionRuntime>: Send + Sync + core::fmt::Debug {
|
||||
/// Execute the operation.
|
||||
fn execute(&self, handles: &mut HandleContainer<R::FusionHandle>);
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
use burn_ir::OperationIr;
|
||||
|
||||
use super::ExecutionMode;
|
||||
use crate::{
|
||||
NumOperations, OperationFuser,
|
||||
search::{BlockOptimization, StreamOptimizer},
|
||||
};
|
||||
|
||||
/// Explore and create new optimization.
|
||||
pub struct Explorer<O> {
|
||||
optimizer: StreamOptimizer<O>,
|
||||
num_deferred: usize,
|
||||
num_explored: usize,
|
||||
is_still_optimizing: bool,
|
||||
}
|
||||
|
||||
/// The result of an exploration done by the [explorer](Explorer).
|
||||
pub enum ExplorationAction<O> {
|
||||
/// Found a new optimization.
|
||||
Completed(BlockOptimization<O>),
|
||||
/// We should continue exploring before arriving at a conclusion.
|
||||
Continue,
|
||||
}
|
||||
|
||||
impl<O: NumOperations> Explorer<O> {
|
||||
/// Create a new explorer.
|
||||
pub(crate) fn new(optimizations: Vec<Box<dyn OperationFuser<O>>>) -> Self {
|
||||
Self {
|
||||
optimizer: StreamOptimizer::new(optimizations),
|
||||
num_deferred: 0,
|
||||
num_explored: 0,
|
||||
is_still_optimizing: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Indicate that a new operation is added.
|
||||
pub(crate) fn on_new_operation(&mut self) {
|
||||
self.num_deferred += 1;
|
||||
}
|
||||
|
||||
/// If the explorer is up to date.
|
||||
pub(crate) fn is_up_to_date(&self) -> bool {
|
||||
self.num_deferred == 0
|
||||
}
|
||||
|
||||
/// Explore the provided operations.
|
||||
pub(crate) fn explore(
|
||||
&mut self,
|
||||
operations: &[OperationIr],
|
||||
mode: ExecutionMode,
|
||||
) -> ExplorationAction<O> {
|
||||
self.update(operations);
|
||||
|
||||
// Can only continue exploration when not sync.
|
||||
if let ExecutionMode::Lazy = mode
|
||||
&& self.is_still_optimizing
|
||||
{
|
||||
return ExplorationAction::Continue;
|
||||
}
|
||||
|
||||
let optimization = self.optimizer.optimize(operations);
|
||||
|
||||
ExplorationAction::Completed(optimization)
|
||||
}
|
||||
|
||||
/// Reset the state of the explorer to the provided list of operations.
|
||||
pub(crate) fn reset(&mut self, operations: &[OperationIr]) {
|
||||
self.optimizer.reset();
|
||||
self.num_explored = 0;
|
||||
self.num_deferred = operations.len();
|
||||
self.is_still_optimizing = true;
|
||||
}
|
||||
|
||||
/// Register any operations that we had deferred
|
||||
fn update(&mut self, operations: &[OperationIr]) {
|
||||
for i in (0..self.num_deferred).rev() {
|
||||
if !self.is_still_optimizing {
|
||||
break;
|
||||
}
|
||||
let index = operations.len() - 1 - i;
|
||||
let relative = &operations[index];
|
||||
|
||||
self.optimizer.register(relative);
|
||||
self.num_explored += 1;
|
||||
|
||||
self.is_still_optimizing = self.optimizer.still_optimizing();
|
||||
}
|
||||
|
||||
self.num_deferred = 0;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
pub(crate) mod validator;
|
||||
|
||||
mod base;
|
||||
mod explorer;
|
||||
mod ordering;
|
||||
mod policy;
|
||||
mod processor;
|
||||
|
||||
pub use base::*;
|
||||
pub use ordering::*;
|
||||
|
||||
pub(crate) use explorer::*;
|
||||
pub(crate) use policy::*;
|
||||
pub(crate) use processor::*;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests;
|
||||
@@ -0,0 +1,71 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use burn_ir::HandleContainer;
|
||||
|
||||
use crate::{FusionRuntime, NumOperations, Optimization, stream::Context};
|
||||
|
||||
use super::Operation;
|
||||
|
||||
/// Manage the execution of potentially multiple optimizations and operations out of order.
|
||||
pub struct OrderedExecution<R: FusionRuntime> {
|
||||
operations: Vec<Arc<dyn Operation<R>>>,
|
||||
num_executed: usize,
|
||||
ordering: Option<Arc<Vec<usize>>>,
|
||||
}
|
||||
|
||||
impl<R: FusionRuntime> OrderedExecution<R> {
|
||||
/// Returns the operation that can be executed without impacting the state of the execution.
|
||||
///
|
||||
/// This is useful to implement fallback for optimizations.
|
||||
#[allow(clippy::borrowed_box)]
|
||||
pub fn operation_within_optimization(&self, index: usize) -> Arc<dyn Operation<R>> {
|
||||
match &self.ordering {
|
||||
Some(val) => {
|
||||
let index = val[index];
|
||||
self.operations[index].clone()
|
||||
}
|
||||
None => panic!("No ordering provided"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new(operations: Vec<Arc<dyn Operation<R>>>) -> Self {
|
||||
Self {
|
||||
operations,
|
||||
num_executed: 0,
|
||||
ordering: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn finish(mut self) -> (Vec<Arc<dyn Operation<R>>>, usize) {
|
||||
self.operations.drain(0..self.num_executed);
|
||||
(self.operations, self.num_executed)
|
||||
}
|
||||
|
||||
pub(crate) fn execute_optimization(
|
||||
&mut self,
|
||||
optimization: &mut R::Optimization,
|
||||
context: &mut Context<'_, R::FusionHandle>,
|
||||
ordering: Arc<Vec<usize>>,
|
||||
) {
|
||||
if ordering.len() > self.operations.len() {
|
||||
panic!("Ordering is bigger than operations");
|
||||
}
|
||||
self.ordering = Some(ordering);
|
||||
let num_drained = optimization.len();
|
||||
optimization.execute(context, self);
|
||||
self.num_executed += num_drained;
|
||||
}
|
||||
|
||||
pub(crate) fn execute_operations(
|
||||
&mut self,
|
||||
handles: &mut HandleContainer<R::FusionHandle>,
|
||||
ordering: &[usize],
|
||||
) {
|
||||
self.num_executed += ordering.len();
|
||||
|
||||
for id in ordering {
|
||||
let op = &self.operations[*id];
|
||||
op.execute(handles);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,572 @@
|
||||
use burn_ir::OperationIr;
|
||||
|
||||
use super::ExecutionMode;
|
||||
use super::validator::{
|
||||
ExecutionPlanOperationsStore, TriggerOperationsStore, TriggerProgress, TriggerValidator,
|
||||
ValidatorState,
|
||||
};
|
||||
use crate::stream::execution::validator::OperationsValidator;
|
||||
use crate::stream::store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger, SearchQuery};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// The policy keeps track of all possible execution plans for the current operations.
|
||||
///
|
||||
/// # Details
|
||||
///
|
||||
/// We keep track of each new operation added and invalidate potential execution plans
|
||||
/// when we see a different operation is added.
|
||||
///
|
||||
/// Therefore, the overhead is very minimal, since the time-complexity of checking for existing
|
||||
/// execution plans scales with the number of concurrent potential plans for the current operations,
|
||||
/// which isn't supposed to be big at any time.
|
||||
pub(crate) struct Policy<O> {
|
||||
/// List of potential execution plans that are compatible with current stream segment
|
||||
candidates: Vec<OperationsValidator<ExecutionPlanId>>,
|
||||
/// List of candidate execution plans that have been found; we can still keep searching
|
||||
/// to potentially find a better one.
|
||||
availables: Vec<AvailableItem>,
|
||||
/// The found execution plan that should be executed, along with the number of operations
|
||||
/// in the plan.
|
||||
found: Option<(ExecutionPlanId, usize)>,
|
||||
/// The number of operations that have been analyzed
|
||||
num_operations: usize,
|
||||
_item_type: PhantomData<O>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct AvailableItem {
|
||||
id: ExecutionPlanId,
|
||||
size: usize,
|
||||
triggers: Vec<TriggerValidator>,
|
||||
}
|
||||
|
||||
/// Action to be made depending on the stream.
|
||||
#[derive(PartialEq, Eq, Debug)]
|
||||
pub enum Action {
|
||||
/// Continue exploring using the [builder](crate::OptimizationBuilder).
|
||||
Explore,
|
||||
/// The current policy indicates that an exploration may be possible in the future, so the
|
||||
/// best action is to defer any execution.
|
||||
///
|
||||
/// Sometimes, it can be a false positive and a new exploration should be built from scratch.
|
||||
/// Therefore it's important to keep the previous operations to rebuild the state if it
|
||||
/// happens.
|
||||
Defer,
|
||||
/// An exploration has been found, and the best action is to execute it!
|
||||
Execute(ExecutionPlanId),
|
||||
}
|
||||
|
||||
impl<O: core::fmt::Debug> Policy<O> {
|
||||
/// Create a new policy.
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
candidates: Vec::new(),
|
||||
availables: Vec::new(),
|
||||
found: None,
|
||||
num_operations: 0,
|
||||
_item_type: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the [action](Action) that should be taken given the state of the policy.
|
||||
pub fn action(
|
||||
&self,
|
||||
store: &ExecutionPlanStore<O>,
|
||||
operations: &[OperationIr],
|
||||
mode: ExecutionMode,
|
||||
) -> Action {
|
||||
if self.num_operations < operations.len() {
|
||||
panic!(
|
||||
"Internal Error: Can't retrieve the policy action on a list of operations bigger than what is analyzed."
|
||||
);
|
||||
}
|
||||
|
||||
if let Some((id, _length)) = self.found {
|
||||
return Action::Execute(id);
|
||||
}
|
||||
|
||||
match mode {
|
||||
ExecutionMode::Lazy => self.action_lazy(operations),
|
||||
ExecutionMode::Sync => self.action_sync(operations, store),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the policy state.
|
||||
pub fn update(&mut self, store: &ExecutionPlanStore<O>, operation: &OperationIr) {
|
||||
// reset the candidates to contain all execution plans starting with the operation.
|
||||
if self.num_operations == 0 {
|
||||
self.candidates = store
|
||||
.find(SearchQuery::PlansStartingWith(operation))
|
||||
.into_iter()
|
||||
.map(OperationsValidator::new)
|
||||
.collect();
|
||||
}
|
||||
|
||||
self.update_candidates(store, operation);
|
||||
self.check_candidates(store);
|
||||
|
||||
self.update_availables(store, operation);
|
||||
self.check_availables();
|
||||
self.num_operations += 1;
|
||||
}
|
||||
|
||||
// Reset the state of the policy.
|
||||
pub fn reset(&mut self) {
|
||||
self.candidates.clear();
|
||||
self.availables.clear();
|
||||
|
||||
self.num_operations = 0;
|
||||
self.found = None;
|
||||
}
|
||||
|
||||
/// Check which candidates can be removed, and which one can go from
|
||||
/// 'candidate' to 'available'
|
||||
fn check_candidates(&mut self, store: &ExecutionPlanStore<O>) {
|
||||
let mut candidates_to_remove = Vec::new();
|
||||
|
||||
for candidate in self.candidates.iter() {
|
||||
match candidate.state {
|
||||
ValidatorState::Found { size } => {
|
||||
let item = store.get_unchecked(candidate.id);
|
||||
let mut triggers = Vec::with_capacity(item.triggers.len());
|
||||
|
||||
for (index, trigger) in item.triggers.iter().enumerate() {
|
||||
triggers.push(match trigger {
|
||||
ExecutionTrigger::OnOperations(_) => TriggerValidator::OnOperations {
|
||||
matching: OperationsValidator::new(index),
|
||||
progress: TriggerProgress::NotInit,
|
||||
},
|
||||
ExecutionTrigger::OnSync => TriggerValidator::OnSync,
|
||||
ExecutionTrigger::Always => TriggerValidator::Always,
|
||||
});
|
||||
}
|
||||
|
||||
self.availables
|
||||
.push(AvailableItem::new(candidate.id, size, triggers));
|
||||
candidates_to_remove.push(candidate.id);
|
||||
}
|
||||
ValidatorState::Invalidated => {
|
||||
candidates_to_remove.push(candidate.id);
|
||||
}
|
||||
ValidatorState::Validating => {}
|
||||
};
|
||||
}
|
||||
|
||||
let mut updated_candidates = Vec::new();
|
||||
core::mem::swap(&mut updated_candidates, &mut self.candidates);
|
||||
|
||||
self.candidates = updated_candidates
|
||||
.into_iter()
|
||||
.filter(|candidate| !candidates_to_remove.iter().any(|id| id == &candidate.id))
|
||||
.collect();
|
||||
}
|
||||
|
||||
fn check_availables(&mut self) {
|
||||
for available in self.availables.iter() {
|
||||
for trigger in available.triggers.iter() {
|
||||
match trigger {
|
||||
TriggerValidator::OnOperations {
|
||||
matching,
|
||||
progress: _,
|
||||
} => {
|
||||
if let ValidatorState::Found {
|
||||
size: _size_of_trigger,
|
||||
} = matching.state
|
||||
{
|
||||
self.found = Some((available.id, available.size));
|
||||
return;
|
||||
}
|
||||
}
|
||||
TriggerValidator::Always => {
|
||||
self.found = Some((available.id, available.size));
|
||||
return;
|
||||
}
|
||||
TriggerValidator::OnSync => {
|
||||
// Does nothing during an update.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_candidates(&mut self, store: &ExecutionPlanStore<O>, operation: &OperationIr) {
|
||||
let main_store = ExecutionPlanOperationsStore::new(store);
|
||||
|
||||
self.candidates
|
||||
.iter_mut()
|
||||
.for_each(|candidate| candidate.update(operation, self.num_operations, &main_store));
|
||||
}
|
||||
|
||||
fn update_availables(&mut self, store: &ExecutionPlanStore<O>, operation: &OperationIr) {
|
||||
self.availables.iter_mut().for_each(|available| {
|
||||
let store_trigger = TriggerOperationsStore::new(available.id, store);
|
||||
|
||||
available.triggers.iter_mut().for_each(|trigger| {
|
||||
if let TriggerValidator::OnOperations { matching, progress } = trigger {
|
||||
match progress {
|
||||
TriggerProgress::NotInit => {
|
||||
*progress = TriggerProgress::NumChecked(0);
|
||||
}
|
||||
TriggerProgress::NumChecked(num_check) => {
|
||||
matching.update(operation, *num_check, &store_trigger);
|
||||
*num_check += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn action_lazy(&self, operations: &[OperationIr]) -> Action {
|
||||
if !self.candidates.is_empty() {
|
||||
return Action::Defer;
|
||||
}
|
||||
|
||||
for available in self.availables.iter() {
|
||||
if available.size == operations.len() {
|
||||
return Action::Defer;
|
||||
}
|
||||
|
||||
for trigger in available.triggers.iter() {
|
||||
if let TriggerValidator::OnOperations {
|
||||
matching,
|
||||
progress: _,
|
||||
} = trigger
|
||||
&& let ValidatorState::Validating = matching.state
|
||||
{
|
||||
return Action::Defer;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Action::Explore
|
||||
}
|
||||
|
||||
fn action_sync(&self, operations: &[OperationIr], store: &ExecutionPlanStore<O>) -> Action {
|
||||
for available in self.availables.iter() {
|
||||
if available.size == operations.len() {
|
||||
return Action::Execute(available.id);
|
||||
}
|
||||
}
|
||||
|
||||
for candidate in self.candidates.iter() {
|
||||
let item = store.get_unchecked(candidate.id);
|
||||
|
||||
if item.operations.len() == operations.len() {
|
||||
return Action::Execute(candidate.id);
|
||||
}
|
||||
}
|
||||
|
||||
Action::Explore
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn_backend::{DType, Shape};
|
||||
use burn_ir::{FloatOperationIr, TensorId, TensorIr, TensorStatus, UnaryOpIr};
|
||||
|
||||
use super::*;
|
||||
use crate::{
|
||||
search::BlockOptimization,
|
||||
stream::store::{ExecutionPlan, ExecutionStrategy, ExecutionTrigger},
|
||||
};
|
||||
use std::ops::Range;
|
||||
|
||||
#[test]
|
||||
fn given_no_optimization_should_explore() {
|
||||
let store = ExecutionPlanStore::default();
|
||||
let mut policy = Policy::new();
|
||||
let stream = TestStream::new(3);
|
||||
|
||||
stream.assert_updates(
|
||||
&store,
|
||||
&mut policy,
|
||||
AssertUpdatesOptions::OperationsIndex(0..3),
|
||||
Action::Explore,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_existing_optimizations_when_sync_should_execute_one_when_available() {
|
||||
let mut store = ExecutionPlanStore::default();
|
||||
let mut policy = Policy::new();
|
||||
let stream = TestStream::new(3);
|
||||
|
||||
let id_1 = store.add(ExecutionPlan {
|
||||
operations: stream.operations[0..2].to_vec(),
|
||||
triggers: Vec::new(),
|
||||
optimization: BlockOptimization::new(ExecutionStrategy::operations(2), Vec::new()),
|
||||
});
|
||||
let _id_2 = store.add(ExecutionPlan {
|
||||
operations: stream.operations[0..3].to_vec(),
|
||||
triggers: Vec::new(),
|
||||
optimization: BlockOptimization::new(ExecutionStrategy::operations(3), Vec::new()),
|
||||
});
|
||||
|
||||
stream.assert_updates(
|
||||
&store,
|
||||
&mut policy,
|
||||
AssertUpdatesOptions::OperationsIndex(0..2),
|
||||
Action::Defer,
|
||||
);
|
||||
|
||||
let action = policy.action(&store, &stream.operations[0..2], ExecutionMode::Sync);
|
||||
assert_eq!(action, Action::Execute(id_1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_existing_plan_when_found_trigger_should_execute_plan() {
|
||||
let mut store = ExecutionPlanStore::default();
|
||||
let mut policy = Policy::new();
|
||||
|
||||
let stream = TestStream::new(3);
|
||||
let id = store.add(ExecutionPlan {
|
||||
operations: stream.operations[0..2].to_vec(),
|
||||
triggers: stream.operations[2..3]
|
||||
.iter()
|
||||
.map(|desc| ExecutionTrigger::OnOperations(vec![desc.clone()]))
|
||||
.collect(),
|
||||
optimization: BlockOptimization::new(ExecutionStrategy::operations(2), Vec::new()),
|
||||
});
|
||||
|
||||
stream.assert_updates(
|
||||
&store,
|
||||
&mut policy,
|
||||
AssertUpdatesOptions::OperationsIndex(0..2),
|
||||
Action::Defer,
|
||||
);
|
||||
stream.assert_updates(
|
||||
&store,
|
||||
&mut policy,
|
||||
AssertUpdatesOptions::OperationsIndex(2..3),
|
||||
Action::Execute(id),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_multiple_triggers() {
|
||||
let mut store = ExecutionPlanStore::default();
|
||||
let mut policy_1 = Policy::new();
|
||||
let mut policy_2 = Policy::new();
|
||||
|
||||
let mut stream_1 = TestStream::new(2);
|
||||
let mut stream_2 = TestStream::new(2);
|
||||
|
||||
// Create different end operation for each stream.
|
||||
let trigger_id_1 = 5;
|
||||
let trigger_id_2 = 6;
|
||||
stream_1.new_ops(trigger_id_1);
|
||||
stream_2.new_ops(trigger_id_2);
|
||||
|
||||
let id = store.add(ExecutionPlan {
|
||||
operations: stream_1.operations[0..2].to_vec(),
|
||||
triggers: vec![
|
||||
ExecutionTrigger::OnOperations(vec![stream_1.operations[2].clone()]),
|
||||
ExecutionTrigger::OnOperations(vec![stream_2.operations[2].clone()]),
|
||||
],
|
||||
optimization: BlockOptimization::new(ExecutionStrategy::operations(2), Vec::new()),
|
||||
});
|
||||
|
||||
stream_1.assert_updates(
|
||||
&store,
|
||||
&mut policy_1,
|
||||
AssertUpdatesOptions::OperationsIndex(0..2),
|
||||
Action::Defer,
|
||||
);
|
||||
stream_2.assert_updates(
|
||||
&store,
|
||||
&mut policy_2,
|
||||
AssertUpdatesOptions::OperationsIndex(0..2),
|
||||
Action::Defer,
|
||||
);
|
||||
|
||||
stream_1.assert_updates(
|
||||
&store,
|
||||
&mut policy_1,
|
||||
AssertUpdatesOptions::OperationsIndex(2..3), // First trigger.
|
||||
Action::Execute(id),
|
||||
);
|
||||
stream_2.assert_updates(
|
||||
&store,
|
||||
&mut policy_2,
|
||||
AssertUpdatesOptions::OperationsIndex(2..3), // Second trigger.
|
||||
Action::Execute(id),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_right_optimization() {
|
||||
let mut store = ExecutionPlanStore::default();
|
||||
let mut policy_1 = Policy::new();
|
||||
let mut policy_2 = Policy::new();
|
||||
|
||||
let mut stream_1 = TestStream::new(2);
|
||||
let mut stream_2 = TestStream::new(2);
|
||||
|
||||
// Create different streams after op 2.
|
||||
stream_1.new_ops(4);
|
||||
stream_1.new_ops(5);
|
||||
|
||||
stream_2.new_ops(5);
|
||||
stream_2.new_ops(6);
|
||||
|
||||
let optimization_stream_1 = store.add(ExecutionPlan {
|
||||
operations: stream_1.operations[0..3].to_vec(),
|
||||
triggers: stream_1.operations[3..4]
|
||||
.iter()
|
||||
.map(|desc| ExecutionTrigger::OnOperations(vec![desc.clone()]))
|
||||
.collect(),
|
||||
optimization: BlockOptimization::new(ExecutionStrategy::operations(3), Vec::new()),
|
||||
});
|
||||
let optimization_stream_2 = store.add(ExecutionPlan {
|
||||
operations: stream_2.operations[0..3].to_vec(),
|
||||
triggers: stream_2.operations[3..4]
|
||||
.iter()
|
||||
.map(|desc| ExecutionTrigger::OnOperations(vec![desc.clone()]))
|
||||
.collect(),
|
||||
optimization: BlockOptimization::new(ExecutionStrategy::operations(3), Vec::new()),
|
||||
});
|
||||
assert_ne!(optimization_stream_1, optimization_stream_2);
|
||||
|
||||
stream_1.assert_updates(
|
||||
&store,
|
||||
&mut policy_1,
|
||||
AssertUpdatesOptions::OperationsIndex(0..3),
|
||||
Action::Defer,
|
||||
);
|
||||
stream_2.assert_updates(
|
||||
&store,
|
||||
&mut policy_2,
|
||||
AssertUpdatesOptions::OperationsIndex(0..3),
|
||||
Action::Defer,
|
||||
);
|
||||
|
||||
stream_1.assert_updates(
|
||||
&store,
|
||||
&mut policy_1,
|
||||
AssertUpdatesOptions::OperationsIndex(3..4),
|
||||
Action::Execute(optimization_stream_1),
|
||||
);
|
||||
stream_2.assert_updates(
|
||||
&store,
|
||||
&mut policy_2,
|
||||
AssertUpdatesOptions::OperationsIndex(3..4),
|
||||
Action::Execute(optimization_stream_2),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_invalidate_wrong_optimizations() {
|
||||
let mut store = ExecutionPlanStore::default();
|
||||
let stream_1 = TestStream::new(4);
|
||||
let mut stream_2 = TestStream::new(2);
|
||||
stream_2.new_ops(6);
|
||||
stream_2.new_ops(7);
|
||||
|
||||
store.add(ExecutionPlan {
|
||||
operations: stream_1.operations[0..3].to_vec(),
|
||||
triggers: stream_1.operations[3..4]
|
||||
.iter()
|
||||
.map(|desc| ExecutionTrigger::OnOperations(vec![desc.clone()]))
|
||||
.collect(),
|
||||
optimization: BlockOptimization::new(ExecutionStrategy::operations(3), Vec::new()),
|
||||
});
|
||||
|
||||
let mut policy = Policy::new();
|
||||
// Same path as stream 1
|
||||
stream_2.assert_updates(
|
||||
&store,
|
||||
&mut policy,
|
||||
AssertUpdatesOptions::OperationsIndex(0..2),
|
||||
Action::Defer,
|
||||
);
|
||||
|
||||
// But is different.
|
||||
stream_2.assert_updates(
|
||||
&store,
|
||||
&mut policy,
|
||||
AssertUpdatesOptions::OperationsIndex(2..4),
|
||||
Action::Explore,
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct TestStream {
|
||||
tensors: Vec<TensorIr>,
|
||||
operations: Vec<OperationIr>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum AssertUpdatesOptions {
|
||||
OperationsIndex(Range<usize>),
|
||||
}
|
||||
|
||||
impl TestStream {
|
||||
/// Create a new test stream with `num_ops` operations registered.
|
||||
pub fn new(num_ops: usize) -> Self {
|
||||
let mut stream = Self::default();
|
||||
for id in 0..num_ops {
|
||||
stream.new_ops(id as u64 + 1);
|
||||
}
|
||||
|
||||
stream
|
||||
}
|
||||
|
||||
/// The first follow should only be cache miss.
|
||||
pub fn assert_updates(
|
||||
&self,
|
||||
optimizations: &ExecutionPlanStore<()>,
|
||||
policy: &mut Policy<()>,
|
||||
options: AssertUpdatesOptions,
|
||||
action: Action,
|
||||
) {
|
||||
match options {
|
||||
AssertUpdatesOptions::OperationsIndex(range) => {
|
||||
for i in range {
|
||||
let stream = &self.operations[0..i];
|
||||
let next_ops = &self.operations[i];
|
||||
policy.update(optimizations, next_ops);
|
||||
let result = policy.action(optimizations, stream, ExecutionMode::Lazy);
|
||||
|
||||
assert_eq!(result, action);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a simple operation to the stream.
|
||||
pub fn new_ops(&mut self, out_id: u64) {
|
||||
if self.tensors.is_empty() {
|
||||
// Root node.
|
||||
self.new_empty_node(0);
|
||||
}
|
||||
|
||||
// Out node.
|
||||
self.new_empty_node(out_id);
|
||||
|
||||
self.operations.push(OperationIr::Float(
|
||||
DType::F32,
|
||||
FloatOperationIr::Log(self.unary_description()),
|
||||
));
|
||||
}
|
||||
|
||||
fn new_empty_node(&mut self, id: u64) {
|
||||
self.tensors.push(TensorIr {
|
||||
id: TensorId::new(id),
|
||||
shape: Shape::new([32, 32, 1]),
|
||||
status: TensorStatus::NotInit,
|
||||
dtype: DType::F32,
|
||||
});
|
||||
}
|
||||
|
||||
fn unary_description(&self) -> UnaryOpIr {
|
||||
let size = self.tensors.len();
|
||||
|
||||
UnaryOpIr {
|
||||
input: self.tensors[size - 2].clone(),
|
||||
out: self.tensors[size - 1].clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,184 @@
|
||||
use burn_ir::OperationIr;
|
||||
|
||||
use super::{ExecutionMode, ExplorationAction, Explorer};
|
||||
use crate::search::BlockOptimization;
|
||||
use crate::stream::execution::{Action, Policy};
|
||||
use crate::stream::store::{ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger};
|
||||
use crate::{NumOperations, OperationFuser};
|
||||
|
||||
/// Process a [stream segment](StreamSegment) following a [policy](Policy).
|
||||
pub(crate) struct Processor<O> {
|
||||
policy: Policy<O>,
|
||||
explorer: Explorer<O>,
|
||||
}
|
||||
|
||||
/// A part of a stream that can be executed partially using [execution plan](ExecutionPlan).
|
||||
pub(crate) trait StreamSegment<O> {
|
||||
/// The operations in the segment.
|
||||
fn operations(&self) -> &[OperationIr];
|
||||
/// Execute part of the segment using the given plan id.
|
||||
fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore<O>);
|
||||
}
|
||||
|
||||
impl<O: NumOperations> Processor<O> {
|
||||
/// Create a new stream processor.
|
||||
pub fn new(optimizations: Vec<Box<dyn OperationFuser<O>>>) -> Self {
|
||||
Self {
|
||||
policy: Policy::new(),
|
||||
explorer: Explorer::new(optimizations),
|
||||
}
|
||||
}
|
||||
|
||||
/// Process the [stream segment](StreamSegment) with the provided [mode](ExecutionMode).
|
||||
pub fn process<Segment>(
|
||||
&mut self,
|
||||
mut segment: Segment,
|
||||
store: &mut ExecutionPlanStore<O>,
|
||||
mode: ExecutionMode,
|
||||
) where
|
||||
Segment: StreamSegment<O>,
|
||||
{
|
||||
// We assume that we always register a new operation in lazy mode.
|
||||
if let ExecutionMode::Lazy = mode {
|
||||
self.on_new_operation(&segment, store);
|
||||
}
|
||||
|
||||
loop {
|
||||
if segment.operations().is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
let action = self.policy.action(store, segment.operations(), mode);
|
||||
|
||||
match action {
|
||||
Action::Explore => {
|
||||
self.explore(&mut segment, store, mode);
|
||||
|
||||
if self.explorer.is_up_to_date() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Action::Defer => {
|
||||
match mode {
|
||||
ExecutionMode::Lazy => break,
|
||||
ExecutionMode::Sync => panic!("Can't defer while sync"),
|
||||
};
|
||||
}
|
||||
Action::Execute(id) => {
|
||||
if let ExecutionMode::Sync = mode {
|
||||
store.add_trigger(id, ExecutionTrigger::OnSync);
|
||||
}
|
||||
|
||||
segment.execute(id, store);
|
||||
self.reset(store, segment.operations());
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn on_new_operation<Segment>(&mut self, segment: &Segment, store: &mut ExecutionPlanStore<O>)
|
||||
where
|
||||
Segment: StreamSegment<O>,
|
||||
{
|
||||
self.policy.update(
|
||||
store,
|
||||
segment
|
||||
.operations()
|
||||
.last()
|
||||
.expect("At least one operation in the operation list."),
|
||||
);
|
||||
self.explorer.on_new_operation();
|
||||
}
|
||||
|
||||
fn explore<Item: StreamSegment<O>>(
|
||||
&mut self,
|
||||
item: &mut Item,
|
||||
store: &mut ExecutionPlanStore<O>,
|
||||
mode: ExecutionMode,
|
||||
) {
|
||||
match self.explorer.explore(item.operations(), mode) {
|
||||
ExplorationAction::Completed(optim) => {
|
||||
let id = Self::on_exploration_completed(
|
||||
&self.policy,
|
||||
item.operations(),
|
||||
store,
|
||||
optim,
|
||||
mode,
|
||||
);
|
||||
item.execute(id, store);
|
||||
self.reset(store, item.operations());
|
||||
}
|
||||
ExplorationAction::Continue => {
|
||||
if let ExecutionMode::Sync = mode {
|
||||
panic!("Can't continue exploring when sync.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn reset(&mut self, store: &mut ExecutionPlanStore<O>, operations: &[OperationIr]) {
|
||||
self.explorer.reset(operations);
|
||||
self.policy.reset();
|
||||
|
||||
// Reset the policy state with the remaining operations
|
||||
for operation in operations.iter() {
|
||||
self.policy.update(store, operation);
|
||||
}
|
||||
}
|
||||
|
||||
/// We found an optimization (i.e. a new execution plan).
|
||||
/// Cache it in the store.
|
||||
fn on_exploration_completed(
|
||||
policy: &Policy<O>,
|
||||
operations: &[OperationIr],
|
||||
store: &mut ExecutionPlanStore<O>,
|
||||
optimization: BlockOptimization<O>,
|
||||
mode: ExecutionMode,
|
||||
) -> ExecutionPlanId {
|
||||
let num_optimized = optimization.ordering.len();
|
||||
let relative = &operations[0..num_optimized];
|
||||
|
||||
match mode {
|
||||
ExecutionMode::Lazy => {
|
||||
let next_ops = &operations[num_optimized..operations.len()];
|
||||
|
||||
let trigger = if next_ops.is_empty() {
|
||||
// Happens if the next ops is included in the fused operation, and there is no
|
||||
// way the builder can still continue fusing.
|
||||
ExecutionTrigger::Always
|
||||
} else {
|
||||
ExecutionTrigger::OnOperations(next_ops.to_vec())
|
||||
};
|
||||
|
||||
match policy.action(store, relative, ExecutionMode::Sync) {
|
||||
Action::Execute(id) => {
|
||||
store.add_trigger(id, trigger);
|
||||
id
|
||||
}
|
||||
_ => {
|
||||
let plan = ExecutionPlan {
|
||||
operations: relative.to_vec(),
|
||||
triggers: vec![trigger],
|
||||
optimization,
|
||||
};
|
||||
store.add(plan)
|
||||
}
|
||||
}
|
||||
}
|
||||
ExecutionMode::Sync => match policy.action(store, relative, ExecutionMode::Sync) {
|
||||
Action::Execute(id) => {
|
||||
store.add_trigger(id, ExecutionTrigger::OnSync);
|
||||
id
|
||||
}
|
||||
_ => {
|
||||
let plan = ExecutionPlan {
|
||||
operations: relative.to_vec(),
|
||||
triggers: vec![ExecutionTrigger::OnSync],
|
||||
optimization,
|
||||
};
|
||||
store.add(plan)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,671 @@
|
||||
//! A testing module that ensures the correctness of the explorer, policy, and processor.
|
||||
//!
|
||||
//! The primary focus is on validating the seamless interaction between these three components to
|
||||
//! execute and optimize a stream of operations accurately.
|
||||
//!
|
||||
//! To test these components effectively, we create mock types for the stream, optimization,
|
||||
//! optimization builder, and stream segment. These mock types aid in comprehensively
|
||||
//! understanding the process of optimizing streams.
|
||||
use std::sync::Arc;
|
||||
|
||||
use burn_backend::{DType, Shape};
|
||||
use burn_ir::{
|
||||
BinaryOpIr, FloatOperationIr, NumericOperationIr, OperationIr, ScalarIr, ScalarOpIr, TensorId,
|
||||
TensorIr, TensorStatus, UnaryOpIr,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
FuserProperties, FuserStatus, NumOperations, OperationFuser,
|
||||
search::BlockOptimization,
|
||||
stream::store::{
|
||||
ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy, ExecutionTrigger,
|
||||
},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
/// A fake stream of operations for testing purpose.
|
||||
pub struct TestStream {
|
||||
processor: Processor<TestOptimization>,
|
||||
store: ExecutionPlanStore<TestOptimization>,
|
||||
executed: Vec<ExecutionPlanId>,
|
||||
operations: Vec<OperationIr>,
|
||||
}
|
||||
|
||||
/// A fake [optimization builder](OptimizationBuilder) for testing purpose.
|
||||
///
|
||||
/// The optimizer tries to fuse only the `expected_operations` if they appear
|
||||
/// in the operations queue
|
||||
#[derive(Clone)]
|
||||
pub struct TestOptimizationBuilder {
|
||||
builder_id: usize,
|
||||
expected_operations: Vec<OperationIr>,
|
||||
actual: Vec<OperationIr>,
|
||||
}
|
||||
|
||||
/// A fake optimization for testing purpose.
|
||||
#[derive(new, Debug, PartialEq)]
|
||||
pub struct TestOptimization {
|
||||
builder_id: usize,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
impl NumOperations for TestOptimization {
|
||||
fn len(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
/// A fake [stream segment](StreamSegment) for testing purpose.
|
||||
#[derive(new)]
|
||||
pub struct TestSegment<'i> {
|
||||
operations: &'i mut Vec<OperationIr>,
|
||||
executed: &'i mut Vec<ExecutionPlanId>,
|
||||
}
|
||||
|
||||
impl<O> ExecutionStrategy<O> {
|
||||
/// Create an ordered execution strategy with the given size.
|
||||
pub fn operations(size: usize) -> Self {
|
||||
Self::Operations {
|
||||
ordering: Arc::new((0..size).collect()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionStrategy<TestOptimization> {
|
||||
/// Only use it for testing, to easily create ordered strategies.
|
||||
pub fn optimization(opt: TestOptimization) -> Self {
|
||||
let ordering = Arc::new((0..opt.size).collect());
|
||||
Self::Optimization { opt, ordering }
|
||||
}
|
||||
}
|
||||
|
||||
/// This is a substantial test case that examines a lengthy scenario with a diverse set of conditions.
|
||||
///
|
||||
/// While it's usually preferable to split tests into multiple independent scenarios, in this case, it is
|
||||
/// crucial to verify that the stream's state is correctly updated when various cases occur consecutively.
|
||||
#[test]
|
||||
fn should_support_complex_stream() {
|
||||
// We have 2 different optimization builders in this test case.
|
||||
let builder_id_1 = 0;
|
||||
let builder_id_2 = 1;
|
||||
|
||||
// We will have a total of 3 execution plans to execute.
|
||||
let plan_id_1 = 0;
|
||||
let plan_id_2 = 1;
|
||||
let plan_id_3 = 2;
|
||||
|
||||
let builder_1 = TestOptimizationBuilder::new(builder_id_1, vec![operation_1(), operation_2()]);
|
||||
let builder_2 = TestOptimizationBuilder::new(builder_id_2, vec![operation_2(), operation_2()]);
|
||||
let mut stream = TestStream::new(vec![Box::new(builder_1), Box::new(builder_2)]);
|
||||
|
||||
// builder_1 is still waiting to see next op is operation_2
|
||||
// builder_2 is closed because it's not the right operation
|
||||
stream.add(operation_1());
|
||||
stream.assert_number_of_operations(1);
|
||||
stream.assert_number_of_executions(0);
|
||||
|
||||
// No optimization found for the first two operations.
|
||||
stream.add(operation_1());
|
||||
stream.assert_number_of_operations(0);
|
||||
stream.assert_number_of_executions(1);
|
||||
stream.assert_last_executed(plan_id_1);
|
||||
stream.assert_plan(
|
||||
plan_id_1,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_1(), operation_1()],
|
||||
triggers: vec![ExecutionTrigger::Always],
|
||||
optimization: BlockOptimization::new(ExecutionStrategy::operations(2), Vec::new()),
|
||||
},
|
||||
);
|
||||
|
||||
// Nothing to execute.
|
||||
stream.add(operation_1());
|
||||
stream.assert_number_of_operations(1);
|
||||
stream.assert_number_of_executions(1);
|
||||
|
||||
// Now we should trigger the first optimization builder.
|
||||
stream.add(operation_2());
|
||||
stream.assert_number_of_operations(0);
|
||||
stream.assert_number_of_executions(2);
|
||||
stream.assert_last_executed(plan_id_2);
|
||||
stream.assert_plan(
|
||||
plan_id_2,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_1(), operation_2()],
|
||||
triggers: vec![ExecutionTrigger::Always],
|
||||
optimization: BlockOptimization::new(
|
||||
ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)),
|
||||
vec![0, 1],
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
// Nothing to execute.
|
||||
stream.add(operation_2());
|
||||
stream.assert_number_of_operations(1);
|
||||
stream.assert_number_of_executions(2);
|
||||
|
||||
// Now we should trigger the second optimization builder.
|
||||
stream.add(operation_2());
|
||||
stream.assert_number_of_operations(0);
|
||||
stream.assert_number_of_executions(3);
|
||||
stream.assert_last_executed(plan_id_3);
|
||||
stream.assert_plan(
|
||||
plan_id_3,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_2(), operation_2()],
|
||||
triggers: vec![ExecutionTrigger::Always],
|
||||
optimization: BlockOptimization {
|
||||
strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_2, 2)),
|
||||
ordering: vec![0, 1],
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
// Nothing to execute.
|
||||
stream.add(operation_1());
|
||||
stream.assert_number_of_operations(1);
|
||||
stream.assert_number_of_executions(3);
|
||||
|
||||
// Now we should trigger the first optimization builder (second plan).
|
||||
stream.add(operation_2());
|
||||
stream.assert_number_of_operations(0);
|
||||
stream.assert_number_of_executions(4);
|
||||
stream.assert_last_executed(plan_id_2);
|
||||
stream.assert_plan(
|
||||
plan_id_2,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_1(), operation_2()],
|
||||
triggers: vec![ExecutionTrigger::Always],
|
||||
optimization: BlockOptimization {
|
||||
strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)),
|
||||
ordering: vec![0, 1],
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
// Nothing to execute.
|
||||
stream.add(operation_2());
|
||||
stream.assert_number_of_operations(1);
|
||||
stream.assert_number_of_executions(4);
|
||||
|
||||
// Now we should trigger the first optimization builder (third plan).
|
||||
stream.add(operation_2());
|
||||
stream.assert_number_of_operations(0);
|
||||
stream.assert_number_of_executions(5);
|
||||
stream.assert_last_executed(plan_id_3);
|
||||
}
|
||||
|
||||
/// In this scenario we will never use an optimization, but we check that we reuse the execution plan stored.
|
||||
#[test]
|
||||
fn should_reuse_basic_operations() {
|
||||
let builder_id_1 = 0;
|
||||
let plan_id_1 = 0;
|
||||
let plan_id_2 = 1;
|
||||
|
||||
let builder_1 = TestOptimizationBuilder::new(builder_id_1, vec![operation_1(), operation_2()]);
|
||||
let mut stream = TestStream::new(vec![Box::new(builder_1)]);
|
||||
|
||||
stream.add(operation_3());
|
||||
stream.assert_last_executed(plan_id_1);
|
||||
stream.assert_number_of_operations(0);
|
||||
stream.assert_plan(
|
||||
plan_id_1,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_3()],
|
||||
triggers: vec![ExecutionTrigger::Always],
|
||||
optimization: BlockOptimization {
|
||||
strategy: ExecutionStrategy::operations(1),
|
||||
ordering: vec![0],
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
stream.add(operation_3());
|
||||
stream.assert_last_executed(plan_id_1);
|
||||
stream.assert_number_of_operations(0);
|
||||
stream.assert_plan(
|
||||
plan_id_1,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_3()],
|
||||
triggers: vec![ExecutionTrigger::Always],
|
||||
optimization: BlockOptimization {
|
||||
strategy: ExecutionStrategy::operations(1),
|
||||
ordering: vec![0],
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
// Lazy try to build optimization 1.
|
||||
stream.add(operation_1());
|
||||
// But not possible.
|
||||
stream.add(operation_3());
|
||||
|
||||
// Creates a new plan with both operations.
|
||||
stream.assert_plan(
|
||||
plan_id_2,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_1(), operation_3()],
|
||||
triggers: vec![ExecutionTrigger::Always],
|
||||
optimization: BlockOptimization {
|
||||
strategy: ExecutionStrategy::operations(2),
|
||||
ordering: vec![0],
|
||||
},
|
||||
},
|
||||
);
|
||||
stream.assert_number_of_operations(0);
|
||||
stream.assert_last_executed(plan_id_2);
|
||||
}
|
||||
|
||||
// In this scenario we validate that we support multiple optimization builders with overlapping
|
||||
// operations.
|
||||
//
|
||||
// This is a very long scenario that validates a lot of things.
|
||||
#[test]
|
||||
fn should_support_overlapping_optimizations() {
|
||||
// We have 2 different optimization builders in this test case.
|
||||
let builder_id_1 = 0;
|
||||
let builder_id_2 = 0;
|
||||
|
||||
// We will have a total of 5 execution plans to execute.
|
||||
let plan_id_1 = 0;
|
||||
let plan_id_2 = 1;
|
||||
let plan_id_3 = 2;
|
||||
let plan_id_4 = 3;
|
||||
let plan_id_5 = 4;
|
||||
|
||||
let builder_1 = TestOptimizationBuilder::new(builder_id_1, vec![operation_1(), operation_2()]);
|
||||
let builder_2 = TestOptimizationBuilder::new(
|
||||
builder_id_2,
|
||||
vec![operation_1(), operation_2(), operation_1(), operation_1()],
|
||||
);
|
||||
let mut stream = TestStream::new(vec![Box::new(builder_1), Box::new(builder_2)]);
|
||||
|
||||
stream.add(operation_1());
|
||||
stream.assert_number_of_operations(1);
|
||||
stream.assert_number_of_executions(0);
|
||||
|
||||
stream.add(operation_2());
|
||||
stream.assert_number_of_operations(2);
|
||||
stream.assert_number_of_executions(0);
|
||||
|
||||
stream.add(operation_1());
|
||||
stream.assert_number_of_operations(3);
|
||||
stream.assert_number_of_executions(0);
|
||||
|
||||
stream.add(operation_2());
|
||||
stream.assert_number_of_operations(2);
|
||||
stream.assert_number_of_executions(1);
|
||||
stream.assert_last_executed(plan_id_1);
|
||||
stream.assert_plan(
|
||||
plan_id_1,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_1(), operation_2()],
|
||||
triggers: vec![ExecutionTrigger::OnOperations(vec![
|
||||
operation_1(),
|
||||
operation_2(),
|
||||
])],
|
||||
optimization: BlockOptimization {
|
||||
strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)),
|
||||
ordering: vec![0, 1],
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
stream.add(operation_2());
|
||||
stream.assert_number_of_operations(0);
|
||||
stream.assert_number_of_executions(3);
|
||||
stream.assert_plan(
|
||||
plan_id_1,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_1(), operation_2()],
|
||||
triggers: vec![
|
||||
ExecutionTrigger::OnOperations(vec![operation_1(), operation_2()]),
|
||||
ExecutionTrigger::OnOperations(vec![operation_2()]),
|
||||
],
|
||||
optimization: BlockOptimization {
|
||||
strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)),
|
||||
ordering: vec![0, 1],
|
||||
},
|
||||
},
|
||||
);
|
||||
stream.assert_plan(
|
||||
plan_id_2,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_2()],
|
||||
triggers: vec![ExecutionTrigger::Always],
|
||||
optimization: BlockOptimization {
|
||||
strategy: ExecutionStrategy::operations(1),
|
||||
ordering: vec![0],
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
stream.add(operation_1());
|
||||
stream.assert_number_of_operations(1);
|
||||
stream.assert_number_of_executions(3);
|
||||
|
||||
stream.add(operation_2());
|
||||
stream.assert_number_of_operations(2);
|
||||
stream.assert_number_of_executions(3);
|
||||
|
||||
stream.add(operation_1());
|
||||
stream.assert_number_of_operations(3);
|
||||
stream.assert_number_of_executions(3);
|
||||
|
||||
stream.add(operation_1());
|
||||
stream.assert_number_of_operations(0);
|
||||
stream.assert_number_of_executions(4);
|
||||
|
||||
stream.assert_plan(
|
||||
plan_id_3,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_1(), operation_2(), operation_1(), operation_1()],
|
||||
triggers: vec![ExecutionTrigger::Always],
|
||||
optimization: BlockOptimization {
|
||||
strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 4)),
|
||||
ordering: vec![0],
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
stream.add(operation_1());
|
||||
stream.assert_number_of_operations(1);
|
||||
stream.assert_number_of_executions(4);
|
||||
|
||||
stream.add(operation_2());
|
||||
stream.assert_number_of_operations(2);
|
||||
stream.assert_number_of_executions(4);
|
||||
|
||||
stream.add(operation_1());
|
||||
stream.assert_number_of_operations(3);
|
||||
stream.assert_number_of_executions(4);
|
||||
|
||||
stream.sync();
|
||||
stream.assert_number_of_operations(0);
|
||||
stream.assert_number_of_executions(6);
|
||||
stream.assert_plan(
|
||||
plan_id_1,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_1(), operation_2()],
|
||||
triggers: vec![
|
||||
ExecutionTrigger::OnOperations(vec![operation_1(), operation_2()]),
|
||||
ExecutionTrigger::OnOperations(vec![operation_2()]),
|
||||
ExecutionTrigger::OnSync,
|
||||
],
|
||||
optimization: BlockOptimization {
|
||||
strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)),
|
||||
ordering: vec![0, 1],
|
||||
},
|
||||
},
|
||||
);
|
||||
stream.assert_plan(
|
||||
plan_id_4,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_1()],
|
||||
triggers: vec![ExecutionTrigger::OnSync],
|
||||
optimization: BlockOptimization {
|
||||
strategy: ExecutionStrategy::operations(1),
|
||||
ordering: vec![0],
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
stream.add(operation_3());
|
||||
stream.assert_last_executed(plan_id_5);
|
||||
stream.assert_plan(
|
||||
plan_id_5,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_3()],
|
||||
triggers: vec![ExecutionTrigger::Always],
|
||||
optimization: BlockOptimization {
|
||||
strategy: ExecutionStrategy::operations(1),
|
||||
ordering: vec![0],
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
stream.add(operation_3());
|
||||
stream.assert_last_executed(plan_id_5);
|
||||
}
|
||||
|
||||
impl TestStream {
|
||||
/// Create a new stream with the given optimization builders.
|
||||
fn new(optimizations: Vec<Box<dyn OperationFuser<TestOptimization>>>) -> Self {
|
||||
Self {
|
||||
processor: Processor::<TestOptimization>::new(optimizations),
|
||||
store: ExecutionPlanStore::<TestOptimization>::new(),
|
||||
executed: Vec::new(),
|
||||
operations: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an operation to the stream.
|
||||
fn add(&mut self, operation: OperationIr) {
|
||||
self.operations.push(operation);
|
||||
self.processor.process(
|
||||
TestSegment::new(&mut self.operations, &mut self.executed),
|
||||
&mut self.store,
|
||||
ExecutionMode::Lazy,
|
||||
);
|
||||
}
|
||||
|
||||
/// Sync the stream.
|
||||
fn sync(&mut self) {
|
||||
self.processor.process(
|
||||
TestSegment::new(&mut self.operations, &mut self.executed),
|
||||
&mut self.store,
|
||||
ExecutionMode::Sync,
|
||||
);
|
||||
}
|
||||
|
||||
/// Assert that the plan has been executed as provided.
|
||||
fn assert_plan(&self, id: ExecutionPlanId, expected: ExecutionPlan<TestOptimization>) {
|
||||
let actual = self.store.get_unchecked(id);
|
||||
assert_eq!(actual.operations, expected.operations, "Same operations");
|
||||
assert_eq!(actual.triggers, expected.triggers, "Same triggers");
|
||||
}
|
||||
|
||||
/// Assert that the given plan id has been the last executed.
|
||||
fn assert_last_executed(&self, id: ExecutionPlanId) {
|
||||
match self.executed.last() {
|
||||
Some(last_id) => assert_eq!(*last_id, id),
|
||||
None => panic!("No plan has been executed"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Assert the number of executions since the start of the stream.
|
||||
fn assert_number_of_executions(&self, number: usize) {
|
||||
assert_eq!(self.executed.len(), number);
|
||||
}
|
||||
|
||||
/// Assert the number of operations queued.
|
||||
fn assert_number_of_operations(&self, number: usize) {
|
||||
assert_eq!(self.operations.len(), number);
|
||||
}
|
||||
}
|
||||
|
||||
impl TestOptimizationBuilder {
|
||||
/// Create a new optimization builder that follows a pattern with a trigger.
|
||||
pub fn new(builder_id: usize, operations: Vec<OperationIr>) -> Self {
|
||||
Self {
|
||||
builder_id,
|
||||
expected_operations: operations,
|
||||
actual: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OperationFuser<TestOptimization> for TestOptimizationBuilder {
|
||||
/// Register a new operation.
|
||||
fn fuse(&mut self, operation: &OperationIr) {
|
||||
self.actual.push(operation.clone());
|
||||
}
|
||||
|
||||
/// Build the optimization.
|
||||
fn finish(&mut self) -> TestOptimization {
|
||||
TestOptimization::new(self.builder_id, self.len())
|
||||
}
|
||||
|
||||
/// Reset the state.
|
||||
fn reset(&mut self) {
|
||||
self.actual.clear();
|
||||
}
|
||||
|
||||
/// Return the optimization status.
|
||||
fn status(&self) -> FuserStatus {
|
||||
if self.actual.len() < self.expected_operations.len() {
|
||||
let operations = &self.expected_operations[0..self.actual.len()];
|
||||
|
||||
return match self.actual == operations {
|
||||
// Still optimizing.
|
||||
true => FuserStatus::Open,
|
||||
// Never gonna be possible on that stream.
|
||||
false => FuserStatus::Closed,
|
||||
};
|
||||
}
|
||||
|
||||
FuserStatus::Closed
|
||||
}
|
||||
|
||||
/// Return the properties of this optimization.
|
||||
fn properties(&self) -> FuserProperties {
|
||||
if self.actual.len() < self.expected_operations.len() {
|
||||
// Optimization not possible.
|
||||
return FuserProperties {
|
||||
score: 0,
|
||||
ready: false,
|
||||
};
|
||||
}
|
||||
|
||||
let stream_is_ok =
|
||||
self.actual[0..self.expected_operations.len()] == self.expected_operations;
|
||||
|
||||
if !stream_is_ok {
|
||||
// Optimization not possible.
|
||||
return FuserProperties {
|
||||
score: 0,
|
||||
ready: false,
|
||||
};
|
||||
}
|
||||
|
||||
// Optimization possible.
|
||||
FuserProperties {
|
||||
score: 1,
|
||||
ready: true,
|
||||
}
|
||||
}
|
||||
|
||||
// The number of operations that should be handle by the optimization.
|
||||
fn len(&self) -> usize {
|
||||
self.expected_operations.len()
|
||||
}
|
||||
fn clone_dyn(&self) -> Box<dyn OperationFuser<TestOptimization>> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamSegment<TestOptimization> for TestSegment<'_> {
|
||||
// The operations in the process.
|
||||
fn operations(&self) -> &[OperationIr] {
|
||||
self.operations
|
||||
}
|
||||
|
||||
// Execute the process.
|
||||
fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore<TestOptimization>) {
|
||||
let execution_plan = store.get_unchecked(id);
|
||||
|
||||
self.execute_strategy(&execution_plan.optimization.strategy);
|
||||
|
||||
self.executed.push(id);
|
||||
}
|
||||
}
|
||||
|
||||
impl TestSegment<'_> {
|
||||
fn execute_strategy(&mut self, strategy: &ExecutionStrategy<TestOptimization>) {
|
||||
match strategy {
|
||||
ExecutionStrategy::Optimization { opt, .. } => {
|
||||
self.operations.drain(0..opt.size);
|
||||
}
|
||||
ExecutionStrategy::Operations { ordering } => {
|
||||
self.operations.drain(0..ordering.len());
|
||||
}
|
||||
ExecutionStrategy::Composed(strategies) => {
|
||||
for strategy in strategies {
|
||||
self.execute_strategy(strategy);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Just a simple operation.
|
||||
pub fn operation_1() -> OperationIr {
|
||||
OperationIr::NumericFloat(
|
||||
DType::F32,
|
||||
NumericOperationIr::Add(BinaryOpIr {
|
||||
lhs: TensorIr {
|
||||
id: TensorId::new(0),
|
||||
shape: Shape::new([32, 32]),
|
||||
status: TensorStatus::ReadOnly,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
rhs: TensorIr {
|
||||
id: TensorId::new(1),
|
||||
shape: Shape::new([32, 32]),
|
||||
status: TensorStatus::ReadOnly,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
out: TensorIr {
|
||||
id: TensorId::new(2),
|
||||
shape: Shape::new([32, 32]),
|
||||
status: TensorStatus::NotInit,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Just a simple operation.
|
||||
pub fn operation_2() -> OperationIr {
|
||||
OperationIr::NumericFloat(
|
||||
DType::F32,
|
||||
NumericOperationIr::AddScalar(ScalarOpIr {
|
||||
lhs: TensorIr {
|
||||
id: TensorId::new(0),
|
||||
shape: Shape::new([32, 32]),
|
||||
status: TensorStatus::ReadOnly,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
rhs: ScalarIr::Float(5.0),
|
||||
out: TensorIr {
|
||||
id: TensorId::new(2),
|
||||
shape: Shape::new([32, 32]),
|
||||
status: TensorStatus::NotInit,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Just a simple operation.
|
||||
pub fn operation_3() -> OperationIr {
|
||||
OperationIr::Float(
|
||||
DType::F32,
|
||||
FloatOperationIr::Log(UnaryOpIr {
|
||||
input: TensorIr {
|
||||
id: TensorId::new(0),
|
||||
shape: Shape::new([32, 32]),
|
||||
status: TensorStatus::ReadOnly,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
out: TensorIr {
|
||||
id: TensorId::new(0),
|
||||
shape: Shape::new([32, 32]),
|
||||
status: TensorStatus::NotInit,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
}),
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
use burn_ir::OperationIr;
|
||||
|
||||
use crate::stream::store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger};
|
||||
|
||||
/// Compare each operation in the list of operations provided by the [store](OperationsStore)
|
||||
/// to verify if the newly added operations match the original list.
|
||||
///
|
||||
/// It is used by the [policy](crate::stream::execution::Policy) to check each candidate as well
|
||||
/// as to verify if a list of operations is optimal to execute based on their triggers.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct OperationsValidator<ID> {
|
||||
/// The ID used to retrieve the operation list.
|
||||
pub(crate) id: ID,
|
||||
/// The current [state](MatchingState).
|
||||
pub(crate) state: ValidatorState,
|
||||
}
|
||||
|
||||
/// The state of the validator.
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum ValidatorState {
|
||||
/// A matching operation list has been found.
|
||||
Found { size: usize },
|
||||
/// No matching operation list has been found.
|
||||
Invalidated,
|
||||
/// Potentially going to find a matching operation list when more operations are added.
|
||||
Validating,
|
||||
}
|
||||
|
||||
/// Provides a list of operations based on an Id.
|
||||
pub(crate) trait OperationsStore {
|
||||
/// The type used for the identifier.
|
||||
type Id: Copy;
|
||||
|
||||
/// retrieve the list of operations corresponding on the provided id.
|
||||
fn get(&self, id: Self::Id) -> &[OperationIr];
|
||||
}
|
||||
|
||||
impl<ID> OperationsValidator<ID> {
|
||||
/// Create a new validator.
|
||||
pub(crate) fn new(id: ID) -> Self {
|
||||
Self {
|
||||
id,
|
||||
state: ValidatorState::Validating,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the state of the validator based on the newly added operation.
|
||||
pub(crate) fn update<S>(&mut self, added: &OperationIr, added_position: usize, store: &S)
|
||||
where
|
||||
S: OperationsStore<Id = ID>,
|
||||
ID: PartialEq + Copy,
|
||||
{
|
||||
match &self.state {
|
||||
ValidatorState::Found { size: _ } => return,
|
||||
ValidatorState::Invalidated => return,
|
||||
ValidatorState::Validating => {}
|
||||
};
|
||||
|
||||
let item = store.get(self.id);
|
||||
let operation_candidate = match item.get(added_position) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
self.state = ValidatorState::Invalidated;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if operation_candidate != added {
|
||||
self.state = ValidatorState::Invalidated;
|
||||
return;
|
||||
}
|
||||
|
||||
// Finished
|
||||
if item.len() == added_position + 1 {
|
||||
self.state = ValidatorState::Found { size: item.len() };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// [Operations store](OperationsStore) used to retrieve the list of operations for a trigger.
|
||||
#[derive(new)]
|
||||
pub(crate) struct TriggerOperationsStore<'a, O> {
|
||||
id: ExecutionPlanId,
|
||||
store: &'a ExecutionPlanStore<O>,
|
||||
}
|
||||
|
||||
/// Validates when operations match a trigger.
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum TriggerValidator {
|
||||
OnOperations {
|
||||
matching: OperationsValidator<TriggerId>,
|
||||
progress: TriggerProgress,
|
||||
},
|
||||
Always,
|
||||
OnSync,
|
||||
}
|
||||
|
||||
/// The progress made into the trigger validation process.
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum TriggerProgress {
|
||||
/// When the validation hasn't started.
|
||||
NotInit,
|
||||
/// The number of operations that have been checked.
|
||||
NumChecked(usize),
|
||||
}
|
||||
|
||||
/// An execution plan can have many triggers, so we use the position in the list to identify a
|
||||
/// trigger.
|
||||
pub(crate) type TriggerId = usize;
|
||||
|
||||
impl<O: core::fmt::Debug> OperationsStore for TriggerOperationsStore<'_, O> {
|
||||
type Id = TriggerId;
|
||||
|
||||
fn get(&self, id: Self::Id) -> &[OperationIr] {
|
||||
match &self.store.get_unchecked(self.id).triggers[id] {
|
||||
ExecutionTrigger::OnOperations(operations) => operations,
|
||||
ExecutionTrigger::OnSync => &[],
|
||||
ExecutionTrigger::Always => &[],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// [Operations store](OperationsStore) used to retrieve the list of operations for an
|
||||
/// [execution plan](crate::stream::store::ExecutionPlan).
|
||||
#[derive(new)]
|
||||
pub(crate) struct ExecutionPlanOperationsStore<'a, O> {
|
||||
store: &'a ExecutionPlanStore<O>,
|
||||
}
|
||||
|
||||
impl<O: core::fmt::Debug> OperationsStore for ExecutionPlanOperationsStore<'_, O> {
|
||||
type Id = ExecutionPlanId;
|
||||
|
||||
fn get(&self, id: Self::Id) -> &[OperationIr] {
|
||||
&self.store.get_unchecked(id).operations
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,249 @@
|
||||
use hashbrown::HashMap;
|
||||
use std::{
|
||||
fmt::Display,
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicU64, Ordering},
|
||||
mpsc::SyncSender,
|
||||
},
|
||||
thread::JoinHandle,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use burn_ir::{HandleContainer, TensorId, TensorStatus};
|
||||
use burn_std::id::StreamId;
|
||||
|
||||
use crate::FusionRuntime;
|
||||
|
||||
use super::Stream;
|
||||
|
||||
/// Memory checks struct to validate there is no memory leak with the fusion runtime.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct MemoryChecks {
|
||||
sender: SyncSender<Message>,
|
||||
num_queued: Arc<AtomicU64>,
|
||||
// Keeps track of its thread.
|
||||
_handle: Arc<JoinHandle<()>>,
|
||||
}
|
||||
|
||||
enum Message {
|
||||
Register(StreamAnalyses),
|
||||
Check(SyncSender<MemoryReport>),
|
||||
}
|
||||
|
||||
enum MemoryReport {
|
||||
Success,
|
||||
NotReady,
|
||||
NotStarted,
|
||||
Fail(String),
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct StreamAnalyses {
|
||||
streams: HashMap<StreamId, Analysis>,
|
||||
num_handles: usize,
|
||||
}
|
||||
|
||||
impl Display for StreamAnalyses {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str("\n==== Fusion Memory Report ====\n")?;
|
||||
f.write_fmt(format_args!(" - Handles: {}\n", self.num_handles))?;
|
||||
f.write_fmt(format_args!(" - Streams: {}\n", self.streams.len()))?;
|
||||
|
||||
for (id, analysis) in self.streams.iter() {
|
||||
f.write_fmt(format_args!(
|
||||
" - {} => operations: {} cursor: {}\n",
|
||||
id, analysis.num_operations, analysis.cursor
|
||||
))?;
|
||||
for (tid, (origin, status)) in analysis.variables.iter() {
|
||||
f.write_fmt(format_args!(
|
||||
" - {tid} => origin: {origin} status: {status:?}\n",
|
||||
))?;
|
||||
}
|
||||
}
|
||||
|
||||
f.write_str("==============================\n")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct Analysis {
|
||||
variables: HashMap<TensorId, (StreamId, TensorStatus)>,
|
||||
num_operations: usize,
|
||||
cursor: u64,
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
/// Export memory checks tests.
|
||||
macro_rules! memory_checks {
|
||||
() => {
|
||||
#[cfg(test)]
|
||||
mod memory_checks {
|
||||
#[test]
|
||||
fn test_memory_leaks() {
|
||||
burn_fusion::stream::memory_checks::check_memory_leaks();
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
static INSTANCE: spin::Mutex<Option<MemoryChecks>> = spin::Mutex::new(None);
|
||||
|
||||
/// Performs memory checks and panics if a leak is discovered.
|
||||
pub fn check_memory_leaks() {
|
||||
let mut num_try_uninit = 0;
|
||||
let max_try = 25;
|
||||
|
||||
loop {
|
||||
let report = fetch_memory_report();
|
||||
match report {
|
||||
MemoryReport::Success => return,
|
||||
MemoryReport::NotReady => {
|
||||
num_try_uninit = 0;
|
||||
std::thread::sleep(Duration::from_millis(100))
|
||||
}
|
||||
MemoryReport::NotStarted => {
|
||||
if num_try_uninit >= max_try {
|
||||
// Nothing is running on the fusion runtime.
|
||||
return;
|
||||
}
|
||||
num_try_uninit += 1;
|
||||
std::thread::sleep(Duration::from_millis(100))
|
||||
}
|
||||
MemoryReport::Fail(msg) => panic!("{msg}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch_memory_report() -> MemoryReport {
|
||||
let report = INSTANCE.lock();
|
||||
|
||||
let report = match report.as_ref() {
|
||||
Some(client) => client,
|
||||
None => return MemoryReport::NotStarted,
|
||||
};
|
||||
|
||||
let (sender, rec) = std::sync::mpsc::sync_channel(1);
|
||||
match report.sender.send(Message::Check(sender)) {
|
||||
Ok(_) => {}
|
||||
Err(err) => {
|
||||
panic!("Channel closed can't send the check call: {err:?}")
|
||||
}
|
||||
};
|
||||
|
||||
match rec.recv() {
|
||||
Ok(report) => report,
|
||||
Err(err) => panic!("Received an error from fetching check results: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MemoryChecks {
|
||||
fn default() -> Self {
|
||||
let mut instance = INSTANCE.lock();
|
||||
let result = match instance.as_mut() {
|
||||
Some(client) => client.clone(),
|
||||
None => {
|
||||
let this = Self::spawn_new();
|
||||
*instance = Some(this.clone());
|
||||
this
|
||||
}
|
||||
};
|
||||
core::mem::drop(instance);
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryChecks {
|
||||
pub(crate) fn check<R: FusionRuntime>(
|
||||
&mut self,
|
||||
streams: &HashMap<StreamId, Stream<R>>,
|
||||
handles: &HandleContainer<R::FusionHandle>,
|
||||
) {
|
||||
let mut analyses = StreamAnalyses {
|
||||
num_handles: handles.num_handles(),
|
||||
streams: Default::default(),
|
||||
};
|
||||
|
||||
for (id, s) in streams.iter() {
|
||||
let analysis = Analysis {
|
||||
variables: s.queue.variables.clone(),
|
||||
num_operations: s.queue.global.len(),
|
||||
cursor: s.cursor,
|
||||
};
|
||||
analyses.streams.insert(*id, analysis);
|
||||
}
|
||||
|
||||
self.num_queued.fetch_add(1, Ordering::Relaxed);
|
||||
match self.sender.send(Message::Register(analyses)) {
|
||||
Ok(..) => {}
|
||||
Err(err) => {
|
||||
panic!("Can't register memory checks analysis: {err:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_new() -> Self {
|
||||
let (sender, rec) = std::sync::mpsc::sync_channel(100);
|
||||
let num_queued = Arc::new(AtomicU64::new(0));
|
||||
let num_queued_moved = num_queued.clone();
|
||||
|
||||
let handle = std::thread::spawn(move || {
|
||||
let mut last_analyses = None;
|
||||
|
||||
loop {
|
||||
let payload = match rec.recv() {
|
||||
Err(_err) => {
|
||||
// A client has panic, safe to skip as it may be normal.
|
||||
continue;
|
||||
}
|
||||
Ok(payload) => payload,
|
||||
};
|
||||
match payload {
|
||||
Message::Register(payload) => {
|
||||
last_analyses = Some(payload);
|
||||
num_queued_moved.fetch_sub(1, Ordering::Relaxed);
|
||||
}
|
||||
Message::Check(callback) => {
|
||||
if num_queued_moved.load(Ordering::Relaxed) > 1 {
|
||||
callback.send(MemoryReport::NotReady).unwrap();
|
||||
continue;
|
||||
}
|
||||
|
||||
// We assume that if nothing has been registered in the last second
|
||||
// while being at a count of 1, it's the end.
|
||||
std::thread::sleep(Duration::from_secs(5));
|
||||
|
||||
if num_queued_moved.load(Ordering::Relaxed) <= 1 {
|
||||
match last_analyses.take() {
|
||||
Some(val) => {
|
||||
callback.send(Self::final_check(val)).unwrap();
|
||||
}
|
||||
None => {
|
||||
callback
|
||||
.send(MemoryReport::Fail("No analyses".into()))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
callback.send(MemoryReport::NotReady).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
sender,
|
||||
num_queued,
|
||||
_handle: Arc::new(handle),
|
||||
}
|
||||
}
|
||||
|
||||
fn final_check(analyses: StreamAnalyses) -> MemoryReport {
|
||||
if !analyses.streams.is_empty() || analyses.num_handles > 0 {
|
||||
return MemoryReport::Fail(format!("{analyses}"));
|
||||
}
|
||||
|
||||
MemoryReport::Success
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
pub(crate) mod execution;
|
||||
pub(crate) mod queue;
|
||||
pub(crate) mod shared_tensors;
|
||||
pub(crate) mod store;
|
||||
|
||||
#[cfg(feature = "memory-checks")]
|
||||
/// Memory checks module.
|
||||
pub mod memory_checks;
|
||||
|
||||
#[cfg(not(feature = "memory-checks"))]
|
||||
#[macro_export]
|
||||
/// Export memory checks tests.
|
||||
macro_rules! memory_checks {
|
||||
() => {
|
||||
#[cfg(test)]
|
||||
mod memory_checks {
|
||||
#[ignore = "'memory-checks' disabled"]
|
||||
#[test]
|
||||
fn test_memory_leaks() {
|
||||
//
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
mod base;
|
||||
mod context;
|
||||
mod multi;
|
||||
|
||||
pub use base::*;
|
||||
pub use context::*;
|
||||
pub use execution::*;
|
||||
pub use multi::*;
|
||||
@@ -0,0 +1,472 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use burn_ir::{HandleContainer, OperationIr, TensorId, TensorIr, TensorStatus};
|
||||
use hashbrown::{HashMap, HashSet};
|
||||
|
||||
use super::{
|
||||
StreamId,
|
||||
execution::{ExecutionMode, Operation, Processor, StreamSegment},
|
||||
queue::OperationQueue,
|
||||
shared_tensors::SharedTensors,
|
||||
store::{ExecutionPlanId, ExecutionPlanStore},
|
||||
};
|
||||
use crate::{
|
||||
DropOp, FusionRuntime,
|
||||
stream::shared_tensors::{SharedTensorAnalysis, SharedTensorDropAction},
|
||||
};
|
||||
|
||||
/// Keep track of multiple concurrent lazy streams of operations.
|
||||
pub struct MultiStream<R: FusionRuntime> {
|
||||
streams: HashMap<StreamId, Stream<R>>,
|
||||
optimizations: ExecutionPlanStore<R::Optimization>,
|
||||
shared_tensors: SharedTensors,
|
||||
device: R::FusionDevice,
|
||||
#[cfg(feature = "memory-checks")]
|
||||
memory_checks: super::memory_checks::MemoryChecks,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum DropAction {
|
||||
SkipSharedTensor,
|
||||
ForceSharedTensor(Vec<StreamId>, TensorId),
|
||||
ContinueDrop,
|
||||
}
|
||||
|
||||
impl<R: FusionRuntime> MultiStream<R> {
|
||||
pub(crate) fn new(device: R::FusionDevice) -> Self {
|
||||
Self {
|
||||
streams: HashMap::new(),
|
||||
optimizations: ExecutionPlanStore::new(),
|
||||
shared_tensors: SharedTensors::default(),
|
||||
device,
|
||||
#[cfg(feature = "memory-checks")]
|
||||
memory_checks: super::memory_checks::MemoryChecks::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a new tensor operation.
|
||||
pub(crate) fn register(
|
||||
&mut self,
|
||||
streams: OperationStreams,
|
||||
mut repr: OperationIr,
|
||||
operation: Arc<dyn Operation<R>>,
|
||||
handles: &mut HandleContainer<R::FusionHandle>,
|
||||
) {
|
||||
let id = self.resolve_streams(&streams, handles, &mut repr);
|
||||
|
||||
let drop_action = match &mut repr {
|
||||
OperationIr::Drop(tensor_ir) => Some(self.handle_drop_op(id, tensor_ir)),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let sync = match drop_action {
|
||||
Some(DropAction::SkipSharedTensor) => return,
|
||||
Some(DropAction::ContinueDrop) => true,
|
||||
Some(DropAction::ForceSharedTensor(stream_ids, tid)) => {
|
||||
for stream_id in stream_ids {
|
||||
if let Some(stream) = self.streams.get_mut(&stream_id) {
|
||||
stream.queue.variables.remove(&tid);
|
||||
if stream.queue.variables.is_empty() {
|
||||
self.streams.remove(&stream_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
None => false,
|
||||
};
|
||||
|
||||
let num_executed = self.enqueue_operation(id, repr, &streams, operation, handles);
|
||||
|
||||
if num_executed > 0
|
||||
&& let Some(stream) = self.streams.get_mut(&id)
|
||||
{
|
||||
let cleared = self.shared_tensors.on_executed_ops(id, stream);
|
||||
self.clear_shared_tensors(&cleared, id);
|
||||
let to_drop = self.shared_tensors.clear_tensors(cleared);
|
||||
self.drop_shared_tensors(to_drop, handles, id);
|
||||
}
|
||||
|
||||
let stream = match self.streams.get(&id) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
#[cfg(feature = "memory-checks")]
|
||||
self.memory_checks.check(&self.streams, handles);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if !stream.queue.variables.is_empty() && sync {
|
||||
// Not draining the queue can cause a memory leak when a stream is closing.
|
||||
self.drain(handles, id);
|
||||
}
|
||||
|
||||
#[cfg(feature = "memory-checks")]
|
||||
self.memory_checks.check(&self.streams, handles);
|
||||
}
|
||||
|
||||
/// Checks if the current operation is a drop.
|
||||
///
|
||||
/// When a tensor is shared across multiple concurrent streams, dropping a tensor might cause a
|
||||
/// problem when the same tensor is registered lazily on another stream, but not yet executed.
|
||||
fn handle_drop_op(&mut self, id: StreamId, tensor_ir: &mut TensorIr) -> DropAction {
|
||||
match !matches!(tensor_ir.status, TensorStatus::ReadWrite) {
|
||||
true => {
|
||||
let stream = self.streams.get(&id);
|
||||
let on_drop = self
|
||||
.shared_tensors
|
||||
.on_drop(id, tensor_ir.id, stream.is_none());
|
||||
|
||||
match on_drop {
|
||||
SharedTensorDropAction::ForceDrop(streams) => {
|
||||
tensor_ir.status = TensorStatus::ReadWrite;
|
||||
DropAction::ForceSharedTensor(streams, tensor_ir.id)
|
||||
}
|
||||
SharedTensorDropAction::Skip => DropAction::SkipSharedTensor,
|
||||
}
|
||||
}
|
||||
false => DropAction::ContinueDrop,
|
||||
}
|
||||
}
|
||||
|
||||
/// Enqueue an operation on the queue.
|
||||
fn enqueue_operation(
|
||||
&mut self,
|
||||
id: StreamId,
|
||||
repr: OperationIr,
|
||||
streams: &OperationStreams,
|
||||
operation: Arc<dyn Operation<R>>,
|
||||
handles: &mut HandleContainer<R::FusionHandle>,
|
||||
) -> usize {
|
||||
let stream = match self.streams.get_mut(&id) {
|
||||
Some(stream) => stream,
|
||||
None => {
|
||||
let stream = Stream::new(self.device.clone());
|
||||
self.streams.insert(id, stream);
|
||||
self.streams
|
||||
.get_mut(&id)
|
||||
.expect("Just added, so should be included in the hashmap.")
|
||||
}
|
||||
};
|
||||
|
||||
stream.queue.add(repr, operation, streams, id);
|
||||
|
||||
let len_before = stream.queue.global.len();
|
||||
stream.processor.process(
|
||||
Segment::new(&mut stream.queue, handles),
|
||||
&mut self.optimizations,
|
||||
ExecutionMode::Lazy,
|
||||
);
|
||||
let len_after = stream.queue.global.len();
|
||||
let num_executed = len_before - len_after;
|
||||
|
||||
stream.cursor += num_executed as u64;
|
||||
|
||||
num_executed
|
||||
}
|
||||
|
||||
/// Mark a tensor as read.
|
||||
#[allow(unused_variables)]
|
||||
pub fn mark_read(
|
||||
&mut self,
|
||||
id: StreamId,
|
||||
ir: &TensorIr,
|
||||
handles: &HandleContainer<R::FusionHandle>,
|
||||
) {
|
||||
if !matches!(ir.status, TensorStatus::ReadWrite) {
|
||||
return;
|
||||
};
|
||||
|
||||
let stream = match self.streams.get_mut(&id) {
|
||||
Some(val) => val,
|
||||
None => return,
|
||||
};
|
||||
|
||||
stream.queue.variables.remove(&ir.id);
|
||||
|
||||
if stream.queue.variables.is_empty() {
|
||||
self.streams.remove(&id);
|
||||
}
|
||||
|
||||
#[cfg(feature = "memory-checks")]
|
||||
self.memory_checks.check(&self.streams, handles);
|
||||
}
|
||||
|
||||
/// Drain a stream
|
||||
pub fn drain(&mut self, handles: &mut HandleContainer<R::FusionHandle>, id: StreamId) {
|
||||
if let Some(stream) = self.streams.get_mut(&id) {
|
||||
let old = unsafe { StreamId::swap(id) };
|
||||
let num_executed = stream.queue.global.len();
|
||||
stream.processor.process(
|
||||
Segment::new(&mut stream.queue, handles),
|
||||
&mut self.optimizations,
|
||||
ExecutionMode::Sync,
|
||||
);
|
||||
stream.cursor += num_executed as u64;
|
||||
|
||||
let cleared = self.shared_tensors.on_executed_ops(id, stream);
|
||||
self.clear_shared_tensors(&cleared, id);
|
||||
let to_drop = self.shared_tensors.clear_tensors(cleared);
|
||||
|
||||
self.drop_shared_tensors(to_drop, handles, id);
|
||||
unsafe {
|
||||
StreamId::swap(old);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// When one of the provided streams is different from the current stream, we drain them.
|
||||
///
|
||||
/// Returns the selected stream id.
|
||||
fn resolve_streams(
|
||||
&mut self,
|
||||
streams: &OperationStreams,
|
||||
handles: &mut HandleContainer<R::FusionHandle>,
|
||||
op: &mut OperationIr,
|
||||
) -> StreamId {
|
||||
let current = streams.current;
|
||||
let nodes = op.nodes();
|
||||
|
||||
let analysis = self.analyse_shared_tensors(&nodes, streams, current);
|
||||
|
||||
self.merge_streams_timelines(handles, &analysis, current, &nodes);
|
||||
self.register_shared_tensors_drop(&analysis, op);
|
||||
|
||||
current
|
||||
}
|
||||
|
||||
/// Drain the stream only if one of the tensor in the given nodes is also included in the
|
||||
/// stream queue.
|
||||
fn resolve_stream(
|
||||
&mut self,
|
||||
handles: &mut HandleContainer<R::FusionHandle>,
|
||||
id: StreamId,
|
||||
nodes: &[&TensorIr],
|
||||
) {
|
||||
if let Some(stream) = self.streams.get(&id) {
|
||||
for node in nodes {
|
||||
if stream.queue.variables.contains_key(&node.id) {
|
||||
self.drain(handles, id);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn analyse_shared_tensors(
|
||||
&mut self,
|
||||
nodes: &[&TensorIr],
|
||||
streams: &OperationStreams,
|
||||
current: StreamId,
|
||||
) -> MultiSharedTensorAnalysis {
|
||||
let mut shared_analysis = MultiSharedTensorAnalysis::default();
|
||||
|
||||
for node in nodes.iter() {
|
||||
let analysis = self
|
||||
.shared_tensors
|
||||
.analyse(current, node, streams, &self.streams);
|
||||
match analysis {
|
||||
SharedTensorAnalysis::SharedFromCurrentStream => {
|
||||
shared_analysis.current.push(node.id);
|
||||
}
|
||||
SharedTensorAnalysis::NotShared => {}
|
||||
SharedTensorAnalysis::SharedFromExistingStream {
|
||||
stream_id,
|
||||
original_cursor,
|
||||
} => {
|
||||
shared_analysis
|
||||
.existing
|
||||
.push((node.id, stream_id, original_cursor));
|
||||
}
|
||||
SharedTensorAnalysis::SharedFromNewStream { stream_id } => {
|
||||
shared_analysis.new.push((node.id, stream_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
shared_analysis
|
||||
}
|
||||
|
||||
fn merge_streams_timelines(
|
||||
&mut self,
|
||||
handles: &mut HandleContainer<R::FusionHandle>,
|
||||
analysis: &MultiSharedTensorAnalysis,
|
||||
current: StreamId,
|
||||
nodes: &[&TensorIr],
|
||||
) {
|
||||
// If we only have current tensors that are shared, we're safe to not sync the timelines.
|
||||
if analysis.new.is_empty() && analysis.existing.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut streams_to_sync = HashSet::new();
|
||||
for (_tensor_id, stream_id) in analysis.new.iter() {
|
||||
streams_to_sync.insert(*stream_id);
|
||||
}
|
||||
|
||||
for (_tensor_id, stream_id, original_cursor) in analysis.existing.iter() {
|
||||
if let Some(stream) = self.streams.get(stream_id) {
|
||||
// We only have to sync a stream when the stream isn't up to date with
|
||||
// the original cursor of the current operation.
|
||||
if stream.cursor <= *original_cursor && *stream_id != current {
|
||||
streams_to_sync.insert(*stream_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for id in streams_to_sync.drain() {
|
||||
log::trace!("Drain stream {id} for use in current {current}");
|
||||
self.resolve_stream(handles, id, nodes);
|
||||
}
|
||||
}
|
||||
|
||||
fn register_shared_tensors_drop(
|
||||
&mut self,
|
||||
analysis: &MultiSharedTensorAnalysis,
|
||||
op: &mut OperationIr,
|
||||
) {
|
||||
let mut readonly_tensors = Vec::new();
|
||||
|
||||
for (tensor_id, _stream_id) in analysis.new.iter() {
|
||||
readonly_tensors.push(*tensor_id);
|
||||
}
|
||||
for (tensor_id, _stream_id, _cursor) in analysis.existing.iter() {
|
||||
readonly_tensors.push(*tensor_id);
|
||||
}
|
||||
for tensor_id in analysis.current.iter() {
|
||||
readonly_tensors.push(*tensor_id);
|
||||
}
|
||||
|
||||
self.shared_tensors
|
||||
.tag_manual_drop(op.mark_read_only(&readonly_tensors));
|
||||
}
|
||||
|
||||
fn drop_shared_tensors(
|
||||
&mut self,
|
||||
tensors: Vec<TensorIr>,
|
||||
handles: &mut HandleContainer<R::FusionHandle>,
|
||||
current: StreamId,
|
||||
) {
|
||||
for (stream_id, s) in self.streams.iter_mut() {
|
||||
for tensor in tensors.iter() {
|
||||
if let Some((original, _status)) = s.queue.variables.get(&tensor.id)
|
||||
&& original != stream_id
|
||||
{
|
||||
s.queue.variables.remove(&tensor.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
for tensor in tensors {
|
||||
let streams = OperationStreams {
|
||||
streams: HashMap::new(),
|
||||
current,
|
||||
};
|
||||
|
||||
let op = Arc::new(DropOp { id: tensor.id });
|
||||
self.register(streams, OperationIr::Drop(tensor), op, handles);
|
||||
}
|
||||
}
|
||||
fn clear_shared_tensors(&mut self, tensors: &[TensorId], current: StreamId) {
|
||||
let mut to_remove = Vec::new();
|
||||
for (stream_id, s) in self.streams.iter_mut() {
|
||||
for tensor in tensors.iter() {
|
||||
s.queue.variables.remove(tensor);
|
||||
}
|
||||
|
||||
if s.queue.variables.is_empty() && current != *stream_id {
|
||||
to_remove.push(*stream_id);
|
||||
}
|
||||
}
|
||||
|
||||
for s in to_remove {
|
||||
self.streams.remove(&s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct Stream<R: FusionRuntime> {
|
||||
pub(crate) queue: OperationQueue<R>,
|
||||
processor: Processor<R::Optimization>,
|
||||
pub(crate) cursor: u64,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct Segment<'a, R: FusionRuntime> {
|
||||
queue: &'a mut OperationQueue<R>,
|
||||
handles: &'a mut HandleContainer<R::FusionHandle>,
|
||||
}
|
||||
|
||||
impl<R: FusionRuntime> StreamSegment<R::Optimization> for Segment<'_, R> {
|
||||
fn operations(&self) -> &[OperationIr] {
|
||||
&self.queue.relative
|
||||
}
|
||||
|
||||
fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore<R::Optimization>) {
|
||||
self.queue.execute(id, self.handles, store)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: FusionRuntime> Stream<R> {
|
||||
fn new(device: R::FusionDevice) -> Self {
|
||||
Self {
|
||||
processor: Processor::new(R::fusers(device)),
|
||||
queue: OperationQueue::new(),
|
||||
cursor: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Manage the streams used for the current [operation](OperationIr).
|
||||
pub struct OperationStreams {
|
||||
pub(crate) streams: HashMap<TensorId, StreamId>,
|
||||
pub(crate) current: StreamId,
|
||||
}
|
||||
|
||||
impl Default for OperationStreams {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
streams: HashMap::new(),
|
||||
current: StreamId::current(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OperationStreams {
|
||||
/// Register a tensor in the list of streams used for the current [operation](OperationIr).
|
||||
///
|
||||
/// You only need to register input tensors, not the outputs.
|
||||
/// So init tensor operations should have no streams registered.
|
||||
pub fn tensor<R: FusionRuntime>(&mut self, tensor: &crate::FusionTensor<R>) {
|
||||
self.streams.insert(tensor.id, tensor.stream);
|
||||
}
|
||||
|
||||
pub(crate) fn get(&self, id: TensorId) -> Option<StreamId> {
|
||||
self.streams.get(&id).cloned()
|
||||
}
|
||||
|
||||
/// Create new operation streams with the given inputs.
|
||||
///
|
||||
/// The inputs are automatically registered.
|
||||
pub fn with_inputs<'a, R: FusionRuntime + 'a, I>(tensors: I) -> Self
|
||||
where
|
||||
I: IntoIterator<Item = &'a crate::FusionTensor<R>>,
|
||||
{
|
||||
let mut streams = OperationStreams::default();
|
||||
for tensor in tensors.into_iter() {
|
||||
streams.tensor(tensor)
|
||||
}
|
||||
streams
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct MultiSharedTensorAnalysis {
|
||||
/// Tensors that are shared with other streams, but we're currently executing on the same stream
|
||||
/// the tensor was originally created.
|
||||
current: Vec<TensorId>,
|
||||
/// Tensors that are shared with new streams.
|
||||
new: Vec<(TensorId, StreamId)>,
|
||||
/// Tensors that are shared with existing streams.
|
||||
existing: Vec<(TensorId, StreamId, u64)>,
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::FusionRuntime;
|
||||
use crate::stream::{OperationConverter, OperationStreams, RelativeOps, execution::Operation};
|
||||
use burn_backend::StreamId;
|
||||
use burn_ir::{OperationIr, TensorId, TensorStatus};
|
||||
|
||||
use hashbrown::HashMap;
|
||||
|
||||
/// A growing list of [tensor operation descriptions](OperationIr).
|
||||
pub struct OperationQueue<R: FusionRuntime> {
|
||||
/// List of operation descriptions. These contain the exact tensor IDs
|
||||
/// and shapes so that kernels can be run correctly.
|
||||
///
|
||||
/// The length of this list is the same as the length of the `operations` list.
|
||||
pub(crate) global: Vec<OperationIr>,
|
||||
/// List of operation descriptions. The tensor IDs and shapes are relative
|
||||
/// because we don't need to know the exact values, but they are sufficient to
|
||||
/// determine which operations can be fused.
|
||||
pub(crate) relative: Vec<OperationIr>,
|
||||
pub(crate) converter: OperationConverter,
|
||||
pub(crate) operations: Vec<Arc<dyn Operation<R>>>,
|
||||
pub(crate) variables: HashMap<TensorId, (StreamId, TensorStatus)>,
|
||||
}
|
||||
|
||||
impl<R: FusionRuntime> Default for OperationQueue<R> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: FusionRuntime> OperationQueue<R> {
|
||||
/// Create a new empty queue.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
global: Vec::new(),
|
||||
relative: Vec::new(),
|
||||
converter: OperationConverter::default(),
|
||||
operations: Vec::new(),
|
||||
variables: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a new tensor operation to the queue.
|
||||
///
|
||||
/// The new [operation intermediate representation](OperationIr) will be converted to a local
|
||||
/// representation that can be reused when the same pattern emerge in different but similar
|
||||
/// scenario, so that the same optimization can be used.
|
||||
pub fn add(
|
||||
&mut self,
|
||||
global: OperationIr,
|
||||
operation: Arc<dyn Operation<R>>,
|
||||
streams: &OperationStreams,
|
||||
current: StreamId,
|
||||
) {
|
||||
for node in global.nodes() {
|
||||
if let Some(stream_id) = streams.get(node.id) {
|
||||
self.variables.insert(node.id, (stream_id, node.status));
|
||||
} else {
|
||||
self.variables.insert(node.id, (current, node.status));
|
||||
}
|
||||
}
|
||||
let relative = global.to_relative(&mut self.converter);
|
||||
self.relative.push(relative);
|
||||
self.global.push(global);
|
||||
self.operations.push(operation);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn stream_id_from_different_threads() {
|
||||
let current = StreamId::current();
|
||||
|
||||
let thread1 = std::thread::spawn(|| (StreamId::current(), StreamId::current()));
|
||||
let thread2 = std::thread::spawn(StreamId::current);
|
||||
|
||||
let (stream_1, stream_11) = thread1.join().unwrap();
|
||||
let stream_2 = thread2.join().unwrap();
|
||||
|
||||
assert_ne!(current, stream_1, "Should be different from thread 1");
|
||||
assert_ne!(current, stream_2, "Should be different from thread 2");
|
||||
assert_ne!(
|
||||
stream_1, stream_2,
|
||||
"Should be different from different threads"
|
||||
);
|
||||
assert_eq!(
|
||||
stream_1, stream_11,
|
||||
"Should be the same, since same thread."
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use burn_ir::{HandleContainer, TensorStatus};
|
||||
|
||||
use crate::{
|
||||
FusionRuntime,
|
||||
search::BlockOptimization,
|
||||
stream::{
|
||||
Context, Operation, OperationConverter, OrderedExecution, RelativeOps,
|
||||
store::{ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy},
|
||||
},
|
||||
};
|
||||
|
||||
use super::OperationQueue;
|
||||
|
||||
impl<R: FusionRuntime> OperationQueue<R> {
|
||||
/// Execute the queue partially following the execution strategy from the plan.
|
||||
pub(crate) fn execute(
|
||||
&mut self,
|
||||
id: ExecutionPlanId,
|
||||
handles: &mut HandleContainer<R::FusionHandle>,
|
||||
store: &mut ExecutionPlanStore<R::Optimization>,
|
||||
) {
|
||||
let plan = store.get_mut_unchecked(id);
|
||||
self.execute_block_optimization(&mut plan.optimization, handles);
|
||||
}
|
||||
|
||||
fn execute_block_optimization(
|
||||
&mut self,
|
||||
step: &mut BlockOptimization<R::Optimization>,
|
||||
handles: &mut HandleContainer<R::FusionHandle>,
|
||||
) {
|
||||
let mut operations = Vec::new();
|
||||
core::mem::swap(&mut operations, &mut self.operations);
|
||||
let (operations, num_drained) =
|
||||
QueueExecution::run(step, &mut self.converter, handles, operations);
|
||||
|
||||
self.operations = operations;
|
||||
self.drain_queue(num_drained, handles);
|
||||
}
|
||||
|
||||
/// Bookkeeping after executing `num_drained` operations from the queue.
|
||||
fn drain_queue(&mut self, num_drained: usize, handles: &mut HandleContainer<R::FusionHandle>) {
|
||||
self.global[0..num_drained]
|
||||
.iter()
|
||||
.flat_map(|desc| desc.nodes())
|
||||
.for_each(|tensor| {
|
||||
if tensor.status == TensorStatus::ReadWrite {
|
||||
self.variables.remove(&tensor.id);
|
||||
};
|
||||
handles.free(tensor)
|
||||
});
|
||||
|
||||
self.global.drain(0..num_drained);
|
||||
|
||||
self.reset_relative();
|
||||
}
|
||||
|
||||
fn reset_relative(&mut self) {
|
||||
self.relative.clear();
|
||||
self.converter.clear();
|
||||
|
||||
for node in self.global.iter() {
|
||||
let relative = node.to_relative(&mut self.converter);
|
||||
self.relative.push(relative);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A queue execution has the responsibility to run the provided
|
||||
/// [optimization](FusionRuntime::Optimization) without holes.
|
||||
enum QueueExecution<'a, R: FusionRuntime> {
|
||||
Single {
|
||||
handles: &'a mut HandleContainer<R::FusionHandle>,
|
||||
converter: &'a mut OperationConverter,
|
||||
execution: OrderedExecution<R>,
|
||||
},
|
||||
Multiple {
|
||||
context: &'a mut Context<'a, R::FusionHandle>,
|
||||
execution: OrderedExecution<R>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<'a, R: FusionRuntime> QueueExecution<'a, R> {
|
||||
fn run(
|
||||
optimization: &mut BlockOptimization<R::Optimization>,
|
||||
converter: &'a mut OperationConverter,
|
||||
handles: &'a mut HandleContainer<R::FusionHandle>,
|
||||
operations: Vec<Arc<dyn Operation<R>>>,
|
||||
) -> (Vec<Arc<dyn Operation<R>>>, usize) {
|
||||
let execution = OrderedExecution::new(operations);
|
||||
|
||||
if matches!(&optimization.strategy, ExecutionStrategy::Composed(..)) {
|
||||
let mut context = converter.context(handles);
|
||||
let mut this = QueueExecution::Multiple {
|
||||
context: &mut context,
|
||||
execution,
|
||||
};
|
||||
|
||||
this = this.execute_strategy(&mut optimization.strategy);
|
||||
|
||||
match this {
|
||||
QueueExecution::Multiple { execution, .. } => execution.finish(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
let mut this = QueueExecution::Single {
|
||||
handles,
|
||||
converter,
|
||||
execution,
|
||||
};
|
||||
this = this.execute_strategy(&mut optimization.strategy);
|
||||
|
||||
match this {
|
||||
QueueExecution::Single { execution, .. } => execution.finish(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn execute_strategy(mut self, strategy: &mut ExecutionStrategy<R::Optimization>) -> Self {
|
||||
match &mut self {
|
||||
QueueExecution::Single {
|
||||
handles,
|
||||
converter,
|
||||
execution,
|
||||
} => match strategy {
|
||||
ExecutionStrategy::Optimization { ordering, opt } => {
|
||||
let mut context = converter.context(handles);
|
||||
execution.execute_optimization(opt, &mut context, ordering.clone())
|
||||
}
|
||||
ExecutionStrategy::Operations { ordering } => {
|
||||
execution.execute_operations(handles, ordering)
|
||||
}
|
||||
ExecutionStrategy::Composed(_) => unreachable!(),
|
||||
},
|
||||
QueueExecution::Multiple { context, execution } => match strategy {
|
||||
ExecutionStrategy::Optimization { opt, ordering } => {
|
||||
execution.execute_optimization(opt, context, ordering.clone());
|
||||
}
|
||||
ExecutionStrategy::Operations { ordering } => {
|
||||
execution.execute_operations(context.handles, ordering);
|
||||
}
|
||||
ExecutionStrategy::Composed(items) => {
|
||||
for item in items.iter_mut() {
|
||||
self = self.execute_strategy(item);
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
self
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
mod base;
|
||||
mod execution;
|
||||
|
||||
pub use base::*;
|
||||
@@ -0,0 +1,306 @@
|
||||
use burn_backend::StreamId;
|
||||
use burn_ir::{TensorId, TensorIr};
|
||||
use hashbrown::HashMap;
|
||||
|
||||
use super::{OperationStreams, Stream};
|
||||
use crate::FusionRuntime;
|
||||
|
||||
#[derive(Default)]
|
||||
/// Manages tensors that are shared between multiple streams.
|
||||
pub struct SharedTensors {
|
||||
shared_tensors: HashMap<TensorId, SharedTensor>,
|
||||
shared_tensors_manual_drop: HashMap<TensorId, TensorIr>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
/// A tensor that is shared between multiple streams.
|
||||
struct SharedTensor {
|
||||
streams: HashMap<StreamId, SharedTensorState>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SharedTensorState {
|
||||
cursor_current: u64,
|
||||
cursor_origin: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// What do to when a tensor is dropped.
|
||||
pub enum SharedTensorDropAction {
|
||||
/// Performs the drop and removes the shared tensor from the provided list of
|
||||
/// stream ids.
|
||||
ForceDrop(Vec<StreamId>),
|
||||
/// Skip the drop.
|
||||
Skip,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Information about a shared tensor.
|
||||
pub enum SharedTensorAnalysis {
|
||||
/// The tensor is not shared.
|
||||
NotShared,
|
||||
/// The tensor is shared, but its original stream is the current one.
|
||||
SharedFromCurrentStream,
|
||||
/// The tensor is shared, and its original stream is an existing stream.
|
||||
SharedFromExistingStream {
|
||||
/// The stream id of the existing stream.
|
||||
stream_id: StreamId,
|
||||
/// The position of execution in the existing stream where the tensor was created.
|
||||
original_cursor: u64,
|
||||
},
|
||||
/// The tensor is shared, and its original stream is a new one without any operation
|
||||
/// executed.
|
||||
SharedFromNewStream {
|
||||
/// The stream id of the new stream.
|
||||
stream_id: StreamId,
|
||||
},
|
||||
}
|
||||
|
||||
impl SharedTensors {
|
||||
/// Function to call when a drop operation is registered on the given stream and tensor.
|
||||
pub fn on_drop(
|
||||
&mut self,
|
||||
stream_id: StreamId,
|
||||
tensor_id: TensorId,
|
||||
stream_completed: bool,
|
||||
) -> SharedTensorDropAction {
|
||||
let mut execute_still = false;
|
||||
|
||||
if let Some(shared) = self.shared_tensors.get_mut(&tensor_id) {
|
||||
if stream_completed {
|
||||
shared.drop(stream_id);
|
||||
execute_still = shared.streams.is_empty();
|
||||
}
|
||||
} else {
|
||||
execute_still = true;
|
||||
}
|
||||
|
||||
if execute_still {
|
||||
let state = self.shared_tensors.remove(&tensor_id);
|
||||
self.shared_tensors_manual_drop.remove(&tensor_id);
|
||||
|
||||
return match state {
|
||||
Some(val) => {
|
||||
let streams = val.streams.keys().copied().collect();
|
||||
SharedTensorDropAction::ForceDrop(streams)
|
||||
}
|
||||
None => SharedTensorDropAction::ForceDrop(Vec::new()),
|
||||
};
|
||||
}
|
||||
|
||||
SharedTensorDropAction::Skip
|
||||
}
|
||||
|
||||
/// Function to call when one or many operations were executed on the stream.
|
||||
///
|
||||
/// Returns the tensor id that can be cleared with [Self::clear_tensors]
|
||||
pub fn on_executed_ops<R: FusionRuntime>(
|
||||
&mut self,
|
||||
id: StreamId,
|
||||
stream: &mut Stream<R>,
|
||||
) -> Vec<TensorId> {
|
||||
let mut cleared = Vec::new();
|
||||
for (tensor_id, state) in self.shared_tensors.iter_mut() {
|
||||
match state.update(id, stream) {
|
||||
SharedTensorUpdate::RemovedFromStream(no_more_stream) => {
|
||||
stream.queue.variables.remove(tensor_id);
|
||||
|
||||
if no_more_stream {
|
||||
cleared.push(*tensor_id);
|
||||
}
|
||||
}
|
||||
SharedTensorUpdate::ReadyForCleanup => {
|
||||
cleared.push(*tensor_id);
|
||||
}
|
||||
SharedTensorUpdate::NoChange => {}
|
||||
}
|
||||
}
|
||||
cleared
|
||||
}
|
||||
|
||||
/// Clear the provided tensors and returns the list of tensors that can be manually dropped.
|
||||
pub fn clear_tensors(&mut self, tensors: Vec<TensorId>) -> Vec<TensorIr> {
|
||||
let mut to_drop = Vec::new();
|
||||
for id in tensors {
|
||||
self.shared_tensors.remove(&id);
|
||||
|
||||
if let Some(tensor) = self.shared_tensors_manual_drop.remove(&id) {
|
||||
to_drop.push(tensor);
|
||||
}
|
||||
}
|
||||
|
||||
self.register_manual_drop(to_drop)
|
||||
}
|
||||
|
||||
/// Analyses the current tensor and updates its state.
|
||||
pub fn analyse<R: FusionRuntime>(
|
||||
&mut self,
|
||||
id: StreamId,
|
||||
node: &TensorIr,
|
||||
streams_op: &OperationStreams,
|
||||
streams: &HashMap<StreamId, Stream<R>>,
|
||||
) -> SharedTensorAnalysis {
|
||||
let stream_id = match streams_op.streams.get(&node.id) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
return match self.shared_tensors.contains_key(&node.id) {
|
||||
true => SharedTensorAnalysis::SharedFromCurrentStream,
|
||||
false => SharedTensorAnalysis::NotShared,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
if stream_id == &id {
|
||||
return match self.shared_tensors.contains_key(&node.id) {
|
||||
true => SharedTensorAnalysis::SharedFromCurrentStream,
|
||||
false => SharedTensorAnalysis::NotShared,
|
||||
};
|
||||
}
|
||||
|
||||
// Here the node is tagged as newly shared.
|
||||
let stream_current = streams.get(&id);
|
||||
let stream = streams.get(stream_id);
|
||||
|
||||
let state = match self.shared_tensors.get_mut(&node.id) {
|
||||
Some(state) => state,
|
||||
None => {
|
||||
self.shared_tensors.insert(node.id, SharedTensor::default());
|
||||
self.shared_tensors.get_mut(&node.id).unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
state.register_new_stream(id, stream_current);
|
||||
match state.register_new_stream(*stream_id, stream) {
|
||||
Some(origin) => SharedTensorAnalysis::SharedFromExistingStream {
|
||||
stream_id: *stream_id,
|
||||
original_cursor: origin,
|
||||
},
|
||||
None => SharedTensorAnalysis::SharedFromNewStream {
|
||||
stream_id: *stream_id,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Tag the provided tensors as manually dropped.
|
||||
pub fn tag_manual_drop(&mut self, dropped: Vec<TensorIr>) {
|
||||
for tensor in dropped {
|
||||
self.shared_tensors_manual_drop.insert(tensor.id, tensor);
|
||||
}
|
||||
}
|
||||
|
||||
fn register_manual_drop(&mut self, mut tensors: Vec<TensorIr>) -> Vec<TensorIr> {
|
||||
if self.shared_tensors_manual_drop.is_empty() {
|
||||
return tensors;
|
||||
}
|
||||
|
||||
let mut to_drop = Vec::new();
|
||||
for id in self.shared_tensors_manual_drop.keys() {
|
||||
if !self.shared_tensors.contains_key(id) {
|
||||
to_drop.push(*id);
|
||||
}
|
||||
}
|
||||
|
||||
for id in to_drop {
|
||||
let entry = self.shared_tensors_manual_drop.remove(&id).unwrap();
|
||||
tensors.push(entry);
|
||||
}
|
||||
|
||||
tensors
|
||||
}
|
||||
}
|
||||
|
||||
/// The result from a [SharedTensor::update].
|
||||
pub enum SharedTensorUpdate {
|
||||
/// The tensor is removed from the current stream.
|
||||
///
|
||||
/// Also contains if the current stream is empty.
|
||||
RemovedFromStream(bool),
|
||||
/// If the tensor is shared across zero streams.
|
||||
ReadyForCleanup,
|
||||
/// If nothing has been done from the update.
|
||||
NoChange,
|
||||
}
|
||||
|
||||
impl SharedTensor {
|
||||
/// Register the tensor as also part of the given stream.
|
||||
///
|
||||
/// The stream might not exist yet when the current tensor is part of the first operation in
|
||||
/// the newly created stream.
|
||||
fn register_new_stream<R: FusionRuntime>(
|
||||
&mut self,
|
||||
id: StreamId,
|
||||
stream: Option<&Stream<R>>,
|
||||
) -> Option<u64> {
|
||||
let cursor_current = match stream {
|
||||
Some(stream) => stream.cursor + stream.queue.global.len() as u64,
|
||||
None => 1,
|
||||
};
|
||||
|
||||
match self.streams.get_mut(&id) {
|
||||
Some(s) => {
|
||||
s.cursor_current = cursor_current;
|
||||
Some(s.cursor_origin)
|
||||
}
|
||||
None => {
|
||||
let state = SharedTensorState {
|
||||
cursor_current,
|
||||
cursor_origin: cursor_current,
|
||||
};
|
||||
self.streams.insert(id, state);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the current shared tensor state on the given stream.
|
||||
///
|
||||
/// If the shared tensor is no longer needed on the stream, we will remove it from the list of
|
||||
/// shared streams.
|
||||
fn update<R: FusionRuntime>(&mut self, id: StreamId, stream: &Stream<R>) -> SharedTensorUpdate {
|
||||
let entry = match self.streams.remove(&id) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
return if self.streams.is_empty() {
|
||||
SharedTensorUpdate::ReadyForCleanup
|
||||
} else {
|
||||
SharedTensorUpdate::NoChange
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// We can only free the shared tensor if the latest cursor is executed.
|
||||
if entry.cursor_current <= stream.cursor {
|
||||
SharedTensorUpdate::RemovedFromStream(self.streams.is_empty())
|
||||
} else {
|
||||
self.streams.insert(id, entry);
|
||||
SharedTensorUpdate::NoChange
|
||||
}
|
||||
}
|
||||
|
||||
fn drop(&mut self, id: StreamId) {
|
||||
self.streams.remove(&id);
|
||||
}
|
||||
}
|
||||
|
||||
impl core::fmt::Debug for SharedTensors {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str("\n==== Shared Tensors ====\n")?;
|
||||
|
||||
for sh in self.shared_tensors.iter() {
|
||||
f.write_fmt(format_args!(" - Shared {}", sh.0))?;
|
||||
for (id, state) in sh.1.streams.iter() {
|
||||
f.write_fmt(format_args!(
|
||||
" [{}, cursor={}..{}] ",
|
||||
id, state.cursor_origin, state.cursor_current
|
||||
))?;
|
||||
}
|
||||
f.write_str("\n")?;
|
||||
}
|
||||
for sh in self.shared_tensors_manual_drop.iter() {
|
||||
f.write_fmt(format_args!(" - Manual Drop {}", sh.0))?;
|
||||
f.write_str("\n")?;
|
||||
}
|
||||
|
||||
f.write_str("========================\n")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::search::BlockOptimization;
|
||||
|
||||
use super::{ExecutionPlanIndex, InsertQuery, SearchQuery};
|
||||
use burn_ir::OperationIr;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// The store that contains all explorations done on a device.
|
||||
#[derive(Default)]
|
||||
pub(crate) struct ExecutionPlanStore<O> {
|
||||
plans: Vec<ExecutionPlan<O>>,
|
||||
index: ExecutionPlanIndex,
|
||||
}
|
||||
|
||||
/// How a list of operations should be executed.
|
||||
#[derive(PartialEq, Debug, Clone)]
|
||||
pub(crate) enum ExecutionStrategy<O> {
|
||||
/// An optimization was found, and therefore should be executed.
|
||||
Optimization { opt: O, ordering: Arc<Vec<usize>> },
|
||||
/// No optimization was found, each operation should be executed individually.
|
||||
Operations { ordering: Arc<Vec<usize>> },
|
||||
/// A composition of multiple execution strategies.
|
||||
Composed(Vec<Box<Self>>),
|
||||
}
|
||||
|
||||
/// The trigger that indicates when to stop exploring.
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub(crate) enum ExecutionTrigger {
|
||||
OnOperations(Vec<OperationIr>),
|
||||
OnSync,
|
||||
Always,
|
||||
}
|
||||
|
||||
/// The unique identifier for an exploration that was executed.
|
||||
pub(crate) type ExecutionPlanId = usize;
|
||||
|
||||
/// The outcome of an exploration that can be stored.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ExecutionPlan<O> {
|
||||
/// The operations on which the exploration is related to.
|
||||
pub(crate) operations: Vec<OperationIr>,
|
||||
/// The criteria that signal when this plan should be executed. Only one trigger is necessary.
|
||||
pub(crate) triggers: Vec<ExecutionTrigger>,
|
||||
/// The optimization that should be used when executing this plan.
|
||||
pub(crate) optimization: BlockOptimization<O>,
|
||||
}
|
||||
|
||||
impl<O: core::fmt::Debug> ExecutionPlanStore<O> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
plans: Vec::new(),
|
||||
index: ExecutionPlanIndex::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn find(&self, query: SearchQuery<'_>) -> Vec<ExecutionPlanId> {
|
||||
self.index.find(query)
|
||||
}
|
||||
|
||||
pub fn add(&mut self, exploration: ExecutionPlan<O>) -> ExecutionPlanId {
|
||||
if exploration.operations.is_empty() {
|
||||
panic!("Can't add an empty optimization.");
|
||||
}
|
||||
|
||||
let id = self.plans.len();
|
||||
|
||||
self.index.insert(InsertQuery::NewPlan {
|
||||
operations: &exploration.operations,
|
||||
id,
|
||||
});
|
||||
|
||||
self.plans.push(exploration);
|
||||
|
||||
id
|
||||
}
|
||||
|
||||
pub fn get_mut_unchecked(&mut self, id: ExecutionPlanId) -> &mut ExecutionPlan<O> {
|
||||
&mut self.plans[id]
|
||||
}
|
||||
|
||||
pub fn get_unchecked(&self, id: ExecutionPlanId) -> &ExecutionPlan<O> {
|
||||
&self.plans[id]
|
||||
}
|
||||
|
||||
/// Add a new end condition for an optimization.
|
||||
pub fn add_trigger(&mut self, id: ExecutionPlanId, trigger: ExecutionTrigger) {
|
||||
let criteria = &mut self.plans[id].triggers;
|
||||
|
||||
if !criteria.contains(&trigger) {
|
||||
criteria.push(trigger);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,293 @@
|
||||
use crate::stream::store::ExecutionPlanId;
|
||||
use burn_ir::OperationIr;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::{HashMap, hash_map::DefaultHasher},
|
||||
hash::{Hash, Hasher},
|
||||
};
|
||||
|
||||
/// Index used to search optimizations.
|
||||
#[derive(Default, Serialize, Deserialize, Clone)]
|
||||
pub struct ExecutionPlanIndex {
|
||||
/// We can't use `HashMap<OperationIr, Vec<ExecutionPlanId>>` since `OperationIr`
|
||||
/// doesn't implement [`Eq`](core::cmp::Eq).
|
||||
///
|
||||
/// `OperationIr` can't implement `Eq` since float types don't implement it.
|
||||
///
|
||||
/// We rely instead on [`PartialEq`](core::cmp::PartialEq) to manually handle hash collisions.
|
||||
/// This is OK because we use `relative` operations where any scalar values are set to zeros,
|
||||
/// see [`RelativeStreamConverter`](crate::stream::RelativeStreamConverter).
|
||||
///
|
||||
/// Map from the hash of the `OperationIr` to a list of `(OperationIr, index)` pairs,
|
||||
/// where `index` is the index of all the execution plans that start with the `OperationIr`
|
||||
/// in the `starters` list.
|
||||
mapping: HashMap<u64, Vec<(OperationIr, usize)>>,
|
||||
starters: Vec<Vec<ExecutionPlanId>>,
|
||||
}
|
||||
|
||||
pub enum SearchQuery<'a> {
|
||||
PlansStartingWith(&'a OperationIr),
|
||||
}
|
||||
|
||||
pub enum InsertQuery<'a> {
|
||||
NewPlan {
|
||||
operations: &'a [OperationIr],
|
||||
id: ExecutionPlanId,
|
||||
},
|
||||
}
|
||||
|
||||
impl ExecutionPlanIndex {
|
||||
/// Search optimizations with the given [query](SearchQuery).
|
||||
pub fn find(&self, query: SearchQuery<'_>) -> Vec<ExecutionPlanId> {
|
||||
match query {
|
||||
SearchQuery::PlansStartingWith(ops) => self.find_starting_with(ops),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a new optimization with the given [query](InsertQuery).
|
||||
pub fn insert(&mut self, query: InsertQuery<'_>) {
|
||||
match query {
|
||||
InsertQuery::NewPlan { operations, id } => {
|
||||
if let Some(operation) = operations.first() {
|
||||
self.insert_new_operation(operation, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Find execution plans starting with the `OperationIr`
|
||||
fn find_starting_with(&self, operation: &OperationIr) -> Vec<ExecutionPlanId> {
|
||||
let key = self.operation_key(operation);
|
||||
let values = match self.mapping.get(&key) {
|
||||
Some(val) => val,
|
||||
None => return Vec::new(),
|
||||
};
|
||||
|
||||
if values.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let (_, index) = match values.iter().find(|value| &value.0 == operation) {
|
||||
Some(val) => val,
|
||||
None => return Vec::new(),
|
||||
};
|
||||
|
||||
match self.starters.get(*index) {
|
||||
Some(value) => value.clone(),
|
||||
None => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the index for an execution plan starting with operation `ops`
|
||||
fn insert_new_operation(&mut self, ops: &OperationIr, new_id: ExecutionPlanId) {
|
||||
let key = self.operation_key(ops);
|
||||
let values = match self.mapping.get_mut(&key) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
// New starter ops.
|
||||
let index = self.starters.len();
|
||||
self.starters.push(vec![new_id]);
|
||||
self.mapping.insert(key, vec![(ops.clone(), index)]);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
let (_, index) = match values.iter_mut().find(|value| &value.0 == ops) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
// New with hash collision.
|
||||
let index = self.starters.len();
|
||||
self.starters.push(vec![new_id]);
|
||||
values.push((ops.clone(), index));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// New optimization for an existing starter.
|
||||
self.starters
|
||||
.get_mut(*index)
|
||||
.expect("Should exist")
|
||||
.push(new_id);
|
||||
}
|
||||
|
||||
// Hash the value of the first operation in a list.
|
||||
fn operation_key(&self, ops: &OperationIr) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
ops.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn_backend::{DType, Shape};
|
||||
use burn_ir::{
|
||||
BinaryOpIr, NumericOperationIr, ScalarIr, ScalarOpIr, TensorId, TensorIr, TensorStatus,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn should_find_optimization_id_based_on_tensor_ops() {
|
||||
let mut index = ExecutionPlanIndex::default();
|
||||
let stream_1 = [ops_1()];
|
||||
let optimization_id_1 = 0;
|
||||
|
||||
index.insert(InsertQuery::NewPlan {
|
||||
operations: &stream_1,
|
||||
id: optimization_id_1,
|
||||
});
|
||||
|
||||
let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));
|
||||
|
||||
assert_eq!(found, vec![optimization_id_1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_multiple_optimization_ids_with_same_starting_ops() {
|
||||
let mut index = ExecutionPlanIndex::default();
|
||||
let stream_1 = [ops_1(), ops_2(), ops_1()];
|
||||
let stream_2 = [ops_1(), ops_1(), ops_2()];
|
||||
let optimization_id_1 = 0;
|
||||
let optimization_id_2 = 1;
|
||||
|
||||
index.insert(InsertQuery::NewPlan {
|
||||
operations: &stream_1,
|
||||
id: optimization_id_1,
|
||||
});
|
||||
index.insert(InsertQuery::NewPlan {
|
||||
operations: &stream_2,
|
||||
id: optimization_id_2,
|
||||
});
|
||||
|
||||
let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));
|
||||
|
||||
assert_eq!(found, vec![optimization_id_1, optimization_id_2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_only_find_optimization_with_correct_starting_ops() {
|
||||
let mut index = ExecutionPlanIndex::default();
|
||||
let stream_1 = [ops_1(), ops_1()];
|
||||
let stream_2 = [ops_2(), ops_1()];
|
||||
let optimization_id_1 = 0;
|
||||
let optimization_id_2 = 1;
|
||||
|
||||
index.insert(InsertQuery::NewPlan {
|
||||
operations: &stream_1,
|
||||
id: optimization_id_1,
|
||||
});
|
||||
index.insert(InsertQuery::NewPlan {
|
||||
operations: &stream_2,
|
||||
id: optimization_id_2,
|
||||
});
|
||||
|
||||
let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));
|
||||
|
||||
assert_eq!(found, vec![optimization_id_1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_handle_hash_collisions() {
|
||||
let mut index = ExecutionPlanIndex::default();
|
||||
let stream_1 = [ops_1(), ops_1()];
|
||||
let stream_2 = [ops_3(), ops_1()];
|
||||
let optimization_id_1 = 0;
|
||||
let optimization_id_2 = 1;
|
||||
|
||||
let stream_1_key = index.operation_key(&stream_1[0]);
|
||||
let stream_2_key = index.operation_key(&stream_2[0]);
|
||||
|
||||
assert_ne!(
|
||||
stream_1_key, stream_2_key,
|
||||
"Ops 1 and Ops 3 should not have the same hash"
|
||||
); // ops 1 and 3 have different variants, so the hash differs
|
||||
assert_ne!(stream_1[0], stream_2[0], "Ops 1 and Ops 3 are different.");
|
||||
|
||||
index.insert(InsertQuery::NewPlan {
|
||||
operations: &stream_1,
|
||||
id: optimization_id_1,
|
||||
});
|
||||
index.insert(InsertQuery::NewPlan {
|
||||
operations: &stream_2,
|
||||
id: optimization_id_2,
|
||||
});
|
||||
|
||||
let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));
|
||||
|
||||
assert_eq!(found, vec![optimization_id_1]);
|
||||
}
|
||||
|
||||
fn ops_1() -> OperationIr {
|
||||
OperationIr::NumericFloat(
|
||||
DType::F32,
|
||||
NumericOperationIr::Add(BinaryOpIr {
|
||||
lhs: TensorIr {
|
||||
id: TensorId::new(0),
|
||||
shape: Shape::new([32, 32]),
|
||||
status: TensorStatus::ReadOnly,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
rhs: TensorIr {
|
||||
id: TensorId::new(1),
|
||||
shape: Shape::new([32, 32]),
|
||||
status: TensorStatus::ReadOnly,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
out: TensorIr {
|
||||
id: TensorId::new(2),
|
||||
shape: Shape::new([32, 32]),
|
||||
status: TensorStatus::NotInit,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn ops_2() -> OperationIr {
|
||||
OperationIr::NumericFloat(
|
||||
DType::F32,
|
||||
NumericOperationIr::AddScalar(ScalarOpIr {
|
||||
lhs: TensorIr {
|
||||
id: TensorId::new(0),
|
||||
shape: Shape::new([32, 32]),
|
||||
status: TensorStatus::ReadOnly,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
rhs: ScalarIr::Float(5.0),
|
||||
out: TensorIr {
|
||||
id: TensorId::new(2),
|
||||
shape: Shape::new([32, 32]),
|
||||
status: TensorStatus::NotInit,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn ops_3() -> OperationIr {
|
||||
OperationIr::NumericFloat(
|
||||
DType::F32,
|
||||
NumericOperationIr::Sub(BinaryOpIr {
|
||||
lhs: TensorIr {
|
||||
id: TensorId::new(0),
|
||||
shape: Shape::new([32, 32]),
|
||||
status: TensorStatus::ReadOnly,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
rhs: TensorIr {
|
||||
id: TensorId::new(1),
|
||||
shape: Shape::new([32, 32]),
|
||||
status: TensorStatus::ReadOnly,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
out: TensorIr {
|
||||
id: TensorId::new(2),
|
||||
shape: Shape::new([32, 32]),
|
||||
status: TensorStatus::NotInit,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
mod base;
|
||||
mod index;
|
||||
|
||||
pub(crate) use base::*;
|
||||
pub(super) use index::*;
|
||||
@@ -0,0 +1,232 @@
|
||||
use crate::{
|
||||
Client, FusionBackend, FusionRuntime,
|
||||
stream::{Operation, OperationStreams, StreamId},
|
||||
};
|
||||
use burn_backend::{
|
||||
DType, ExecutionError, QTensorPrimitive, Shape, TensorData, TensorMetadata,
|
||||
quantization::QuantScheme,
|
||||
};
|
||||
use burn_ir::{OperationIr, TensorId, TensorIr, TensorStatus};
|
||||
use std::sync::{
|
||||
Arc,
|
||||
atomic::{AtomicU32, Ordering},
|
||||
};
|
||||
|
||||
/// Tensor primitive for the [fusion backend](crate::FusionBackend) for all kind.
|
||||
pub struct FusionTensor<R: FusionRuntime> {
|
||||
/// Tensor id.
|
||||
pub id: TensorId,
|
||||
/// The shape of the tensor.
|
||||
pub shape: Shape,
|
||||
/// The fusion client.
|
||||
pub client: Client<R>,
|
||||
/// The datatype of the tensor.
|
||||
pub dtype: DType,
|
||||
/// The current stream id this tensor is on.
|
||||
pub stream: StreamId,
|
||||
pub(crate) count: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
impl<R: FusionRuntime> Clone for FusionTensor<R> {
|
||||
fn clone(&self) -> Self {
|
||||
self.count.fetch_add(1, Ordering::Acquire);
|
||||
|
||||
Self {
|
||||
id: self.id,
|
||||
shape: self.shape.clone(),
|
||||
client: self.client.clone(),
|
||||
dtype: self.dtype,
|
||||
stream: self.stream,
|
||||
count: self.count.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: FusionRuntime> core::fmt::Debug for FusionTensor<R> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(
|
||||
format!(
|
||||
"{{ id: {:?}, shape: {:?}, device: {:?} }}",
|
||||
self.id,
|
||||
self.shape,
|
||||
self.client.device().clone(),
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: FusionRuntime> TensorMetadata for FusionTensor<R> {
|
||||
fn dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn shape(&self) -> Shape {
|
||||
self.shape.clone()
|
||||
}
|
||||
|
||||
fn rank(&self) -> usize {
|
||||
self.shape.num_dims()
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: FusionRuntime> FusionTensor<R> {
|
||||
pub(crate) fn new(
|
||||
id: TensorId,
|
||||
shape: Shape,
|
||||
dtype: DType,
|
||||
client: Client<R>,
|
||||
stream: StreamId,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
shape,
|
||||
client,
|
||||
dtype,
|
||||
stream,
|
||||
count: Arc::new(AtomicU32::new(1)),
|
||||
}
|
||||
}
|
||||
|
||||
fn status(&self, count: u32) -> TensorStatus {
|
||||
if count <= 1 {
|
||||
TensorStatus::ReadWrite
|
||||
} else {
|
||||
TensorStatus::ReadOnly
|
||||
}
|
||||
}
|
||||
|
||||
/// Intermediate representation to be used when using an uninitialized tensor as output.
|
||||
pub fn to_ir_out(&self) -> TensorIr {
|
||||
TensorIr {
|
||||
status: TensorStatus::NotInit,
|
||||
shape: self.shape.clone(),
|
||||
id: self.id,
|
||||
dtype: self.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
/// Intermediate representation to be used when using an initialized tensor used as input.
|
||||
pub fn into_ir(mut self) -> TensorIr {
|
||||
let count = self.count.load(Ordering::Acquire);
|
||||
let status = self.status(count);
|
||||
|
||||
let mut shape_out = Shape::from(Vec::<usize>::new());
|
||||
core::mem::swap(&mut self.shape, &mut shape_out);
|
||||
|
||||
if let TensorStatus::ReadWrite = status {
|
||||
// Avoids an unwanted drop on the same thread.
|
||||
//
|
||||
// Since `drop` is called after `into_ir`, we must not register a drop if the tensor
|
||||
// was consumed with a `ReadWrite` status.
|
||||
self.count.fetch_add(1, Ordering::Acquire);
|
||||
}
|
||||
|
||||
TensorIr {
|
||||
status,
|
||||
shape: shape_out,
|
||||
id: self.id,
|
||||
dtype: self.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn into_data<B>(self) -> Result<TensorData, ExecutionError>
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
let id = self.stream;
|
||||
let client = self.client.clone();
|
||||
let desc = self.into_ir();
|
||||
client.read_tensor_float::<B>(desc, id).await
|
||||
}
|
||||
|
||||
pub(crate) async fn q_into_data<B>(self) -> Result<TensorData, ExecutionError>
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
if let DType::QFloat(_scheme) = self.dtype {
|
||||
let id = self.stream;
|
||||
let client = self.client.clone();
|
||||
let desc = self.into_ir();
|
||||
client.read_tensor_quantized::<B>(desc, id).await
|
||||
} else {
|
||||
panic!("Expected quantized float dtype, got {:?}", self.dtype)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn int_into_data<B>(self) -> Result<TensorData, ExecutionError>
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
let id = self.stream;
|
||||
let client = self.client.clone();
|
||||
let desc = self.into_ir();
|
||||
client.read_tensor_int::<B>(desc, id).await
|
||||
}
|
||||
|
||||
pub(crate) async fn bool_into_data<B>(self) -> Result<TensorData, ExecutionError>
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
let id = self.stream;
|
||||
let client = self.client.clone();
|
||||
let desc = self.into_ir();
|
||||
client.read_tensor_bool::<B>(desc, id).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
pub(crate) struct DropOp {
|
||||
pub(crate) id: TensorId,
|
||||
}
|
||||
|
||||
impl<RO: FusionRuntime> Operation<RO> for DropOp {
|
||||
fn execute(&self, handles: &mut burn_ir::HandleContainer<RO::FusionHandle>) {
|
||||
handles.remove_handle(self.id);
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: FusionRuntime> Drop for FusionTensor<R> {
|
||||
fn drop(&mut self) {
|
||||
let count = self.count.fetch_sub(1, Ordering::Acquire);
|
||||
|
||||
// Workaround to prevent segfaults when an operation panics
|
||||
if std::thread::panicking() {
|
||||
return;
|
||||
}
|
||||
|
||||
match self.status(count) {
|
||||
TensorStatus::ReadWrite => {
|
||||
let mut shape = Shape::from(Vec::<usize>::new());
|
||||
core::mem::swap(&mut shape, &mut self.shape);
|
||||
|
||||
let ir = TensorIr {
|
||||
id: self.id,
|
||||
shape,
|
||||
status: TensorStatus::ReadWrite,
|
||||
dtype: self.dtype,
|
||||
};
|
||||
let mut streams = OperationStreams::default();
|
||||
streams.tensor(self);
|
||||
|
||||
self.client
|
||||
.register(streams, OperationIr::Drop(ir), DropOp { id: self.id });
|
||||
}
|
||||
TensorStatus::ReadOnly => {}
|
||||
TensorStatus::NotInit => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: FusionRuntime> QTensorPrimitive for FusionTensor<R> {
|
||||
fn scheme(&self) -> &QuantScheme {
|
||||
if let DType::QFloat(scheme) = &self.dtype {
|
||||
scheme
|
||||
} else {
|
||||
panic!(
|
||||
"Quantization scheme is not valid for dtype {:?}",
|
||||
self.dtype,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user