feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,123 @@
pub use super::RemoteDevice;
use super::worker::{ClientRequest, ClientWorker};
use crate::shared::{ComputeTask, ConnectionId, SessionId, Task, TaskResponseContent};
use async_channel::{RecvError, SendError, Sender};
use burn_communication::ProtocolClient;
use burn_ir::TensorId;
use burn_std::id::StreamId;
use std::{
future::Future,
sync::{Arc, atomic::AtomicU64},
};
#[derive(Clone)]
pub struct RemoteClient {
pub(crate) device: RemoteDevice,
pub(crate) sender: Arc<RemoteSender>,
pub(crate) runtime: Arc<tokio::runtime::Runtime>,
}
impl RemoteClient {
pub fn init<C: ProtocolClient>(device: RemoteDevice) -> Self {
ClientWorker::<C>::start(device)
}
pub(crate) fn new(
device: RemoteDevice,
sender: Sender<ClientRequest>,
runtime: Arc<tokio::runtime::Runtime>,
session_id: SessionId,
) -> Self {
Self {
device,
runtime,
sender: Arc::new(RemoteSender {
sender,
position_counter: AtomicU64::new(0),
tensor_id_counter: AtomicU64::new(0),
session_id,
}),
}
}
}
pub(crate) struct RemoteSender {
sender: Sender<ClientRequest>,
position_counter: AtomicU64,
tensor_id_counter: AtomicU64,
session_id: SessionId,
}
#[allow(unused)]
#[derive(Debug)]
pub enum RemoteSendError {
SendError(SendError<ClientRequest>),
RecvError(RecvError),
}
impl RemoteSender {
/// Generate a new unique (for this [`RemoteSender`] [`TensorId`].
pub(crate) fn new_tensor_id(&self) -> TensorId {
TensorId::new(
self.tensor_id_counter
.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
)
}
/// Give the next operation sequence number.
fn next_position(&self) -> u64 {
self.position_counter
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}
pub(crate) fn send(&self, task: ComputeTask) {
self.sender
.send_blocking(ClientRequest::WithoutCallback(Task::Compute(
task,
ConnectionId::new(self.next_position(), StreamId::current()),
)))
.unwrap();
}
pub(crate) fn send_async(
&self,
task: ComputeTask,
) -> impl Future<Output = Result<TaskResponseContent, RemoteSendError>> + Send + use<> {
let stream_id = StreamId::current();
let position = self.next_position();
let sender = self.sender.clone();
async move {
let (tx, rx) = async_channel::bounded(1);
if let Err(e) = sender
.send(ClientRequest::WithSyncCallback(
Task::Compute(task, ConnectionId::new(position, stream_id)),
tx,
))
.await
{
return Err(RemoteSendError::SendError(e));
}
match rx.recv().await {
Ok(response) => Ok(response),
Err(e) => Err(RemoteSendError::RecvError(e)),
}
}
}
pub(crate) fn close(&mut self) {
let sender = self.sender.clone();
let close_task = ClientRequest::WithoutCallback(Task::Close(self.session_id));
sender.send_blocking(close_task).unwrap();
}
}
impl Drop for RemoteSender {
fn drop(&mut self) {
self.close();
}
}

View File

@@ -0,0 +1,82 @@
use std::marker::PhantomData;
use burn_backend::Shape;
use burn_communication::ProtocolClient;
use burn_ir::TensorIr;
use burn_router::{RouterTensor, RunnerChannel, get_client};
use super::{
RemoteClient,
runner::{RemoteBridge, RemoteDevice, RemoteTensorHandle},
};
/// A local channel with direct connection to the backend runner clients.
pub struct RemoteChannel<C: ProtocolClient> {
_p: PhantomData<C>,
}
impl<C: ProtocolClient> RunnerChannel for RemoteChannel<C> {
type Device = RemoteDevice;
type Bridge = RemoteBridge<C>;
type Client = RemoteClient;
type FloatElem = f32;
type IntElem = i32;
type BoolElem = u32;
fn name(device: &Self::Device) -> String {
format!("remote-{device:?}")
}
fn init_client(device: &Self::Device) -> Self::Client {
RemoteClient::init::<C>(device.clone())
}
fn get_tensor_handle(tensor: &TensorIr, client: &Self::Client) -> RemoteTensorHandle<C> {
RemoteTensorHandle {
client: client.clone(),
tensor: tensor.clone(),
_p: PhantomData,
}
}
fn register_tensor(
_client: &Self::Client,
_handle: RemoteTensorHandle<C>,
_shape: Shape,
_dtype: burn_backend::DType,
) -> RouterTensor<Self::Client> {
// This function is normally only used to move a tensor from a device to another.
//
// In other words, to change the client.
panic!("Can't register manually a tensor on a remote channel.");
}
fn change_client_backend(
tensor: RouterTensor<Self::Client>,
target_device: &Self::Device, // target device
) -> RouterTensor<Self::Client> {
// Get tensor handle from current client
let original_client = tensor.client.clone();
let desc = tensor.into_ir();
let handle = Self::get_tensor_handle(&desc, &original_client);
let handle = handle.change_backend(target_device);
let id = handle.tensor.id;
let target_client = get_client::<Self>(target_device);
let router_tensor: RouterTensor<RemoteClient> =
RouterTensor::new(id, handle.tensor.shape, handle.tensor.dtype, target_client);
router_tensor
}
}
impl<C: ProtocolClient> Clone for RemoteChannel<C> {
fn clone(&self) -> Self {
RemoteChannel { _p: PhantomData }
}
}

View File

@@ -0,0 +1,8 @@
mod base;
mod channel;
mod runner;
mod worker;
pub use base::*;
pub use channel::*;
pub use runner::RemoteDevice;

View File

@@ -0,0 +1,294 @@
use super::{RemoteChannel, RemoteClient};
use crate::shared::{ComputeTask, TaskResponseContent, TensorRemote};
use burn_backend::{DeviceId, DeviceOps, ExecutionError, Shape, TensorData};
use burn_communication::{Address, ProtocolClient, data_service::TensorTransferId};
use burn_ir::TensorIr;
use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
use burn_std::{backtrace::BackTrace, future::DynFut};
use std::sync::OnceLock;
use std::{collections::HashMap, marker::PhantomData, str::FromStr, sync::Mutex};
// TODO: we should work with the parsed structure of Address, not the string.
static ADDRESS_REGISTRY: OnceLock<Mutex<HashMap<String, u32>>> = OnceLock::new();
fn get_address_registry() -> &'static Mutex<HashMap<String, u32>> {
ADDRESS_REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
}
/// Map a string network address to a (local runtime) global unique u32.
///
/// Globally stable over the lifetime of the process, shared between threads,
/// If the address has never been seen, a new id will be created.
/// If the address has been seen, the previous id will be returned.
pub fn address_to_id<S: AsRef<str>>(address: S) -> u32 {
let registry = get_address_registry();
let mut registry = registry.lock().unwrap();
let next_id = registry.len() as u32;
*registry
.entry(address.as_ref().to_string())
.or_insert_with(|| next_id)
}
/// Look up an address by id.
///
/// Returns the same address given ids by [`address_to_id`].
pub fn id_to_address(id: u32) -> Option<String> {
let registry = get_address_registry();
let registry = registry.lock().unwrap();
for entry in registry.iter() {
if entry.1 == &id {
return Some(entry.0.clone());
}
}
None
}
// It is very important to block on any request made with the sender, since ordering is crucial
// when registering operation or creating tensors.
//
// The overhead is minimal, since we only wait for the task to be sent to the async
// channel, but not sent to the server and even less processed by the server.
impl RunnerClient for RemoteClient {
type Device = RemoteDevice;
fn register_op(&self, op: burn_ir::OperationIr) {
self.sender
.send(ComputeTask::RegisterOperation(Box::new(op)));
}
fn read_tensor_async(
&self,
tensor: burn_ir::TensorIr,
) -> DynFut<Result<TensorData, ExecutionError>> {
// Important for ordering to call the creation of the future sync.
let fut = self.sender.send_async(ComputeTask::ReadTensor(tensor));
Box::pin(async move {
match fut.await {
Ok(response) => match response {
TaskResponseContent::ReadTensor(res) => res,
_ => panic!("Invalid message type"),
},
Err(e) => Err(ExecutionError::Generic {
reason: format!("Failed to read tensor: {:?}", e),
backtrace: BackTrace::capture(),
}),
}
})
}
fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
let id = self.sender.new_tensor_id();
let shape = data.shape.clone();
let dtype = data.dtype;
self.sender.send(ComputeTask::RegisterTensor(id, data));
RouterTensor::new(id, Shape::from(shape), dtype, self.clone())
}
fn device(&self) -> Self::Device {
self.device.clone()
}
fn sync(&self) -> Result<(), ExecutionError> {
// Important for ordering to call the creation of the future sync.
let fut = self.sender.send_async(ComputeTask::SyncBackend);
match self.runtime.block_on(fut) {
Ok(response) => match response {
TaskResponseContent::SyncBackend(res) => res,
_ => panic!("Invalid message type"),
},
Err(e) => Err(ExecutionError::Generic {
reason: format!("Failed to sync: {:?}", e),
backtrace: BackTrace::capture(),
}),
}
}
fn seed(&self, seed: u64) {
self.sender.send(ComputeTask::Seed(seed));
}
fn create_empty_handle(&self) -> burn_ir::TensorId {
self.sender.new_tensor_id()
}
fn dtype_usage(&self, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {
let fut = self.sender.send_async(ComputeTask::SupportsDType(dtype));
match self.runtime.block_on(fut) {
Ok(_response) => panic!("Invalid message type"),
Err(e) => panic!("Failed to check dtype support: {:?}", e),
}
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
/// The device contains the connection information of the server.
pub struct RemoteDevice {
pub(crate) address: Address,
/// The id of the device in the local registry, see [`address_to_id`].
pub(crate) id: u32,
}
impl RemoteDevice {
/// Create a device from an url.
pub fn new(address: &str) -> Self {
let id = address_to_id(address);
Self {
address: Address::from_str(address).unwrap(),
id,
}
}
}
impl Default for RemoteDevice {
fn default() -> Self {
let address = match std::env::var("BURN_REMOTE_ADDRESS") {
Ok(address) => address,
Err(_) => String::from("ws://127.0.0.1:3000"),
};
Self::new(&address)
}
}
impl burn_std::device::Device for RemoteDevice {
fn from_id(device_id: DeviceId) -> Self {
if device_id.type_id != 0 {
panic!("Invalid device id: {device_id} (expected type 0)");
}
let address = id_to_address(device_id.index_id)
.unwrap_or_else(|| panic!("Invalid device id: {device_id}"));
Self::new(&address)
}
fn to_id(&self) -> DeviceId {
DeviceId {
type_id: 0,
index_id: self.id,
}
}
fn device_count(_type_id: u16) -> usize {
1
}
}
impl DeviceOps for RemoteDevice {}
pub struct RemoteBridge<C: ProtocolClient> {
_p: PhantomData<C>,
}
pub struct RemoteTensorHandle<C: ProtocolClient> {
pub(crate) client: RemoteClient,
pub(crate) tensor: TensorIr,
pub(crate) _p: PhantomData<C>,
}
static TRANSFER_COUNTER: Mutex<Option<TensorTransferId>> = Mutex::new(None);
fn get_next_transfer_id() -> TensorTransferId {
let mut transfer_counter = TRANSFER_COUNTER.lock().unwrap();
if transfer_counter.is_none() {
*transfer_counter = Some(0.into());
transfer_counter.unwrap()
} else {
let mut transfer_counter = transfer_counter.unwrap();
transfer_counter.next();
transfer_counter
}
}
impl<C: ProtocolClient> RemoteTensorHandle<C> {
/// Changes the backend of the tensor via a dWebSocket.
/// We ask the original server to expose the tensor, then ask the target server to fetch
/// the tensor. The target server will open a new network connection to the original server
/// to download the data.
/// This way the client never sees the tensor's data, and we avoid a bottleneck.
pub(crate) fn change_backend(mut self, target_device: &RemoteDevice) -> Self {
let transfer_id = get_next_transfer_id();
self.client.sender.send(ComputeTask::ExposeTensorRemote {
tensor: self.tensor.clone(),
count: 1,
transfer_id,
});
let target_client = get_client::<RemoteChannel<C>>(target_device);
let new_id = target_client.sender.new_tensor_id();
let remote_tensor = TensorRemote {
transfer_id,
address: self.client.device.address.clone(),
};
target_client
.sender
.send(ComputeTask::RegisterTensorRemote(remote_tensor, new_id));
self.tensor.id = new_id;
self.client = target_client;
self
}
}
impl<C: ProtocolClient> MultiBackendBridge for RemoteBridge<C> {
type TensorHandle = RemoteTensorHandle<C>;
type Device = RemoteDevice;
fn change_backend_float(
tensor: Self::TensorHandle,
_shape: burn_backend::Shape,
target_device: &Self::Device,
) -> Self::TensorHandle {
tensor.change_backend(target_device)
}
fn change_backend_int(
tensor: Self::TensorHandle,
_shape: burn_backend::Shape,
target_device: &Self::Device,
) -> Self::TensorHandle {
tensor.change_backend(target_device)
}
fn change_backend_bool(
tensor: Self::TensorHandle,
_shape: burn_backend::Shape,
target_device: &Self::Device,
) -> Self::TensorHandle {
tensor.change_backend(target_device)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_address_to_id() {
let address1 = "ws://127.0.0.1:3000";
let address2 = "ws://127.0.0.1:3001";
let id1 = address_to_id(address1);
let id2 = address_to_id(address2);
assert_ne!(id1, id2);
assert_eq!(address_to_id(address1), id1);
assert_eq!(id_to_address(id1), Some(address1.to_string()));
assert_eq!(address_to_id(address2), id2);
assert_eq!(id_to_address(id2), Some(address2.to_string()));
let unused_id = u32::MAX;
assert_eq!(id_to_address(unused_id), None);
}
}

View File

@@ -0,0 +1,129 @@
use super::{RemoteClient, runner::RemoteDevice};
use crate::shared::{ConnectionId, SessionId, Task, TaskResponse, TaskResponseContent};
use burn_communication::{CommunicationChannel, Message, ProtocolClient};
use std::{collections::HashMap, marker::PhantomData, sync::Arc};
pub type CallbackSender = async_channel::Sender<TaskResponseContent>;
#[derive(Debug)]
pub enum ClientRequest {
WithSyncCallback(Task, CallbackSender),
WithoutCallback(Task),
}
pub(crate) struct ClientWorker<C: ProtocolClient> {
requests: HashMap<ConnectionId, CallbackSender>,
_p: PhantomData<C>,
}
impl<C: ProtocolClient> ClientWorker<C> {
async fn on_response(&mut self, response: TaskResponse) {
match self.requests.remove(&response.id) {
Some(request) => {
request.send(response.content).await.unwrap();
}
None => {
panic!("Can't ignore message from the server.");
}
}
}
fn register_callback(&mut self, id: ConnectionId, callback: CallbackSender) {
self.requests.insert(id, callback);
}
}
impl<C: ProtocolClient> ClientWorker<C> {
pub fn start(device: RemoteDevice) -> RemoteClient {
let runtime = Arc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_io()
.build()
.unwrap(),
);
let (sender, rec) = async_channel::bounded(10);
let session_id = SessionId::new();
let address = device.address.clone();
#[allow(deprecated)]
runtime.spawn(async move {
log::info!("Connecting to {} ...", address.clone());
let mut stream_request = C::connect(address.clone(), "request")
.await
.expect("Server to be accessible");
let mut stream_response = C::connect(address, "response")
.await
.expect("Server to be accessible");
let state = Arc::new(tokio::sync::Mutex::new(ClientWorker::<C>::default()));
// Init the connection.
let bytes: bytes::Bytes = rmp_serde::to_vec(&Task::Init(session_id))
.expect("Can serialize tasks to bytes.")
.into();
stream_request
.send(Message::new(bytes.clone()))
.await
.expect("Can send the message over the comms channel.");
stream_response
.send(Message::new(bytes))
.await
.expect("Can send the message on the websocket.");
// Async worker loading callbacks from the server.
let state_ws = state.clone();
tokio::spawn(async move {
while let Ok(msg) = stream_response.recv().await {
let msg = match msg {
Some(msg) => msg,
None => {
log::warn!("Closed connection");
return;
}
};
let response: TaskResponse = rmp_serde::from_slice(&msg.data)
.expect("Can deserialize messages from the websocket.");
let mut state = state_ws.lock().await;
state.on_response(response).await;
}
});
// Channel async worker sending operations to the server.
tokio::spawn(async move {
while let Ok(req) = rec.recv().await {
let task = match req {
ClientRequest::WithSyncCallback(task, callback) => {
if let Task::Compute(_content, id) = &task {
let mut state = state.lock().await;
state.register_callback(*id, callback);
}
task
}
ClientRequest::WithoutCallback(task) => task,
};
let bytes = rmp_serde::to_vec(&task)
.expect("Can serialize tasks to bytes.")
.into();
stream_request
.send(Message::new(bytes))
.await
.expect("Can send the message on the websocket.");
}
});
});
RemoteClient::new(device, sender, runtime, session_id)
}
}
impl<C: ProtocolClient> Default for ClientWorker<C> {
fn default() -> Self {
Self {
requests: Default::default(),
_p: PhantomData,
}
}
}

View File

@@ -0,0 +1,86 @@
#[macro_use]
extern crate derive_new;
#[cfg(feature = "client")]
pub(crate) mod client;
#[cfg(feature = "server")]
pub mod server;
pub(crate) mod shared;
#[cfg(feature = "client")]
mod __client {
use super::*;
use crate::{client::RemoteChannel, shared::RemoteProtocol};
use burn_communication::Protocol;
use burn_router::BackendRouter;
/// The remote backend allows you to run computation on a remote device.
///
/// Make sure there is a running server before trying to connect to it.
///
/// ```rust, ignore
/// fn main() {
/// let device = Default::default();
/// let port = 3000;
///
/// // You need to activate the `server` feature flag to have access to this function.
/// burn::server::start::<burn::backend::Wgpu>(device, port);
/// }
///```
pub type RemoteBackend = BackendRouter<RemoteChannel<<RemoteProtocol as Protocol>::Client>>;
pub use client::RemoteDevice;
}
#[cfg(feature = "client")]
pub use __client::*;
#[cfg(all(test, feature = "client", feature = "server"))]
mod tests {
use crate::RemoteBackend;
use burn_ndarray::NdArray;
use burn_tensor::{Distribution, Tensor};
#[test]
pub fn test_to_device_over_websocket() {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_io()
.build()
.unwrap();
rt.spawn(crate::server::start_websocket_async::<NdArray>(
Default::default(),
3000,
));
rt.spawn(crate::server::start_websocket_async::<NdArray>(
Default::default(),
3010,
));
let remote_device_1 = super::RemoteDevice::new("ws://localhost:3000");
let remote_device_2 = super::RemoteDevice::new("ws://localhost:3010");
// Some random input
let input_shape = [1, 28, 28];
let input = Tensor::<RemoteBackend, 3>::random(
input_shape,
Distribution::Default,
&remote_device_1,
);
let numbers_expected: Vec<f32> = input.to_data().to_vec().unwrap();
// Move tensor to device 2
let input = input.to_device(&remote_device_2);
let numbers: Vec<f32> = input.to_data().to_vec().unwrap();
assert_eq!(numbers, numbers_expected);
// Move tensor back to device 1
let input = input.to_device(&remote_device_1);
let numbers: Vec<f32> = input.to_data().to_vec().unwrap();
assert_eq!(numbers, numbers_expected);
rt.shutdown_background();
}
}

View File

@@ -0,0 +1,182 @@
use burn_communication::{
CommunicationChannel, Message, Protocol, ProtocolServer,
data_service::{TensorDataServer, TensorDataService},
util::os_shutdown_signal,
websocket::{WebSocket, WsServer},
};
use std::{marker::PhantomData, sync::Arc};
use tokio_util::sync::CancellationToken;
use burn_backend::tensor::Device;
use burn_ir::BackendIr;
use crate::shared::{ComputeTask, Task};
use super::session::SessionManager;
pub struct RemoteServer<B, P>
where
B: BackendIr,
P: Protocol,
{
_b: PhantomData<B>,
_n: PhantomData<P>,
}
impl<B, P> RemoteServer<B, P>
where
B: BackendIr,
P: Protocol,
{
/// Start the server on the given address.
pub async fn start(device: Device<B>, server: P::Server) {
let cancel_token = CancellationToken::new();
let data_service = Arc::new(TensorDataService::<B, P>::new(cancel_token));
let session_manager = Arc::new(SessionManager::<B, P>::new(device, data_service.clone()));
let _server = server
.route("/response", {
let session_manager = session_manager.clone();
move |stream| Self::handle_socket_response(session_manager, stream)
})
.route("/request", {
let session_manager = session_manager.clone();
move |stream| Self::handle_socket_request(session_manager, stream)
})
.route_tensor_data_service(data_service)
.serve(os_shutdown_signal())
.await;
}
async fn handle_socket_response(
session_manager: Arc<SessionManager<B, P>>,
mut socket: <P::Server as ProtocolServer>::Channel,
) {
log::info!("[Response Handler] On new connection.");
let packet = socket.recv().await;
let msg = match packet {
Ok(Some(msg)) => msg,
Ok(None) => {
log::info!("Response stream closed");
return;
}
Err(e) => {
log::info!("Response stream error on init: {e:?}");
return;
}
};
let id = match rmp_serde::from_slice::<Task>(&msg.data) {
Ok(Task::Init(session_id)) => session_id,
msg => {
log::error!("Message is not a valid initialization task {msg:?}");
return;
}
};
let mut receiver = session_manager.register_responder(id).await;
log::info!("Response handler connection active");
while let Some(mut callback) = receiver.recv().await {
let response = callback.recv().await.unwrap();
let bytes = rmp_serde::to_vec(&response).unwrap();
socket.send(Message::new(bytes.into())).await.unwrap();
}
}
async fn handle_socket_request(
session_manager: Arc<SessionManager<B, P>>,
mut socket: <P::Server as ProtocolServer>::Channel,
) {
log::info!("[Request Handler] On new connection.");
let mut session_id = None;
loop {
let packet = socket.recv().await;
let msg = match packet {
Ok(Some(msg)) => msg,
Ok(None) => {
log::info!("Request stream closed");
break;
}
Err(e) => {
log::info!("Request stream error: {e:?}, Closing.");
break;
}
};
let task = match rmp_serde::from_slice::<Task>(&msg.data) {
Ok(val) => val,
Err(err) => {
log::info!("Only bytes message in the json format are supported {err:?}");
break;
}
};
if let Task::Close(id) = task {
session_id = Some(id);
break;
}
let (stream, connection_id, task) =
match session_manager.stream(&mut session_id, task).await {
Some(val) => val,
None => {
log::info!("Ops session activated {session_id:?}");
continue;
}
};
match task {
ComputeTask::RegisterOperation(op) => {
stream.register_operation(op).await;
}
ComputeTask::RegisterTensor(id, data) => {
stream.register_tensor(id, data).await;
}
ComputeTask::ReadTensor(tensor) => {
stream.read_tensor(connection_id, tensor).await;
}
ComputeTask::SyncBackend => {
stream.sync(connection_id).await;
}
ComputeTask::RegisterTensorRemote(tensor, new_id) => {
stream.register_tensor_remote(tensor, new_id).await;
}
ComputeTask::ExposeTensorRemote {
tensor,
count,
transfer_id,
} => {
stream
.expose_tensor_remote(tensor, count, transfer_id)
.await;
}
ComputeTask::Seed(seed) => {
stream.seed(seed).await;
}
ComputeTask::SupportsDType(dtype) => {
stream.supports_dtype(connection_id, dtype).await
}
}
}
log::info!("Closing session {session_id:?}");
session_manager.close(session_id).await;
}
}
/// Start the server on the given port and [device](Device).
pub async fn start_websocket_async<B: BackendIr>(device: Device<B>, port: u16) {
let server = WsServer::new(port);
RemoteServer::<B, WebSocket>::start(device, server).await;
}
#[tokio::main]
/// Start the server on the given port and [device](Device).
pub async fn start_websocket<B: BackendIr>(device: Device<B>, port: u16) {
start_websocket_async::<B>(device, port).await;
}

View File

@@ -0,0 +1,7 @@
pub(crate) mod processor;
pub(crate) mod session;
pub(crate) mod stream;
mod base;
pub use base::{start_websocket, start_websocket_async};

View File

@@ -0,0 +1,132 @@
use burn_backend::TensorData;
use burn_communication::{
Protocol,
data_service::{TensorDataService, TensorTransferId},
};
use burn_ir::{BackendIr, OperationIr, TensorId, TensorIr};
use burn_router::{Runner, RunnerClient};
use burn_std::DType;
use core::marker::PhantomData;
use std::sync::Arc;
use tokio::sync::mpsc::Sender;
use crate::shared::{ConnectionId, TaskResponse, TaskResponseContent, TensorRemote};
/// The goal of the processor is to asynchronously process compute tasks on it own thread.
pub struct Processor<B, P>
where
B: BackendIr,
P: Protocol,
{
p: PhantomData<B>,
n: PhantomData<P>,
}
pub type Callback<M> = Sender<M>;
pub enum ProcessorTask {
RegisterOperation(Box<OperationIr>),
RegisterTensor(TensorId, TensorData),
RegisterTensorRemote(TensorRemote, TensorId),
ExposeTensorRemote {
tensor: TensorIr,
transfer_id: TensorTransferId,
count: u32,
},
ReadTensor(ConnectionId, TensorIr, Callback<TaskResponse>),
Sync(ConnectionId, Callback<TaskResponse>),
Seed(u64),
SupportsDType(ConnectionId, DType, Callback<TaskResponse>),
Close,
}
impl<B: BackendIr, P> Processor<B, P>
where
B: BackendIr,
P: Protocol,
{
pub async fn start(
runner: Runner<B>,
data_service: Arc<TensorDataService<B, P>>,
) -> Sender<ProcessorTask> {
// channel for tasks to execute
let (task_sender, mut task_rec) = tokio::sync::mpsc::channel(1);
tokio::spawn(async move {
while let Some(item) = task_rec.recv().await {
match item {
ProcessorTask::RegisterOperation(op) => {
runner.register_op(*op);
}
ProcessorTask::Sync(id, callback) => {
let result = runner.sync();
callback
.send(TaskResponse {
content: TaskResponseContent::SyncBackend(result),
id,
})
.await
.unwrap();
}
ProcessorTask::RegisterTensor(id, data) => {
runner.register_tensor_data_id(id, data);
}
ProcessorTask::RegisterTensorRemote(remote_tensor, new_id) => {
log::info!(
"Registering remote tensor...(id: {:?})",
remote_tensor.transfer_id
);
let data = data_service
.download_tensor(remote_tensor.address, remote_tensor.transfer_id)
.await
.expect("Can't download tensor: error"); // TODO all these panics should be server errors
runner.register_tensor_data_id(new_id, data);
}
ProcessorTask::ExposeTensorRemote {
tensor,
transfer_id,
count,
} => {
log::info!("Exposing tensor: (id: {transfer_id:?})");
let data = runner.read_tensor_async(tensor).await;
data_service
.expose_data(data.unwrap(), count, transfer_id)
.await;
}
ProcessorTask::ReadTensor(id, tensor, callback) => {
let tensor = runner.read_tensor_async(tensor).await;
callback
.send(TaskResponse {
content: TaskResponseContent::ReadTensor(tensor),
id,
})
.await
.unwrap();
}
ProcessorTask::Close => {
let device = runner.device();
runner.sync().unwrap();
core::mem::drop(runner);
B::sync(&device).unwrap();
break;
}
ProcessorTask::Seed(seed) => runner.seed(seed),
ProcessorTask::SupportsDType(id, dtype, callback) => {
let _result = runner.dtype_usage(dtype);
callback
.send(TaskResponse {
// content: TaskResponseContent::SupportsDType(result),
// TODO: Update to result.
content: TaskResponseContent::SupportsDType(()),
id,
})
.await
.unwrap();
}
}
}
});
task_sender
}
}

View File

@@ -0,0 +1,170 @@
use burn_backend::tensor::Device;
use burn_communication::{Protocol, data_service::TensorDataService};
use burn_ir::BackendIr;
use burn_router::Runner;
use burn_std::id::StreamId;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::{
Mutex,
mpsc::{Receiver, Sender},
};
use crate::shared::{ComputeTask, ConnectionId, SessionId, Task, TaskResponse};
use super::stream::Stream;
/// A session manager control the creation of sessions.
///
/// Each session manages its own stream, spawning one thread per stream to mimic the same behavior
/// a native backend would have.
pub struct SessionManager<B, P>
where
B: BackendIr,
P: Protocol,
{
runner: Runner<B>,
sessions: Mutex<HashMap<SessionId, Session<B, P>>>,
data_service: Arc<TensorDataService<B, P>>,
}
struct Session<B, P>
where
B: BackendIr,
P: Protocol,
{
runner: Runner<B>,
streams: HashMap<StreamId, Stream<B, P>>,
sender: Sender<Receiver<TaskResponse>>,
receiver: Option<Receiver<Receiver<TaskResponse>>>,
data_service: Arc<TensorDataService<B, P>>,
}
impl<B, P> SessionManager<B, P>
where
B: BackendIr,
P: Protocol,
{
pub fn new(device: Device<B>, data_service: Arc<TensorDataService<B, P>>) -> Self {
Self {
runner: Runner::new(device),
sessions: Mutex::new(Default::default()),
data_service,
}
}
/// Register a new responder for the session. Only one responder can exist for a session for
/// now.
pub async fn register_responder(
&self,
session_id: SessionId,
) -> Receiver<Receiver<TaskResponse>> {
log::info!("Register responder for session {session_id}");
let mut sessions = self.sessions.lock().await;
self.register_session(&mut sessions, session_id);
let session = sessions.get_mut(&session_id).unwrap();
session.init_responder()
}
/// Get the stream for the current session and task.
pub async fn stream(
&self,
session_id: &mut Option<SessionId>,
task: Task,
) -> Option<(Stream<B, P>, ConnectionId, ComputeTask)> {
let mut sessions = self.sessions.lock().await;
let session_id = match session_id {
Some(id) => *id,
None => match task {
Task::Init(id) => {
log::info!("Init requester for session {id}");
*session_id = Some(id);
self.register_session(&mut sessions, id);
return None;
}
_ => panic!("The first message should initialize the session"),
},
};
match sessions.get_mut(&session_id) {
Some(session) => {
let (task, connection_id) = match task {
Task::Compute(task, connection_id) => (task, connection_id),
_ => panic!("Only support compute tasks."),
};
let stream = session.select(connection_id.stream_id).await;
Some((stream, connection_id, task))
}
None => panic!("To be initialized"),
}
}
/// Close the session with the given id.
pub async fn close(&self, session_id: Option<SessionId>) {
if let Some(id) = session_id {
let mut sessions = self.sessions.lock().await;
if let Some(session) = sessions.get_mut(&id) {
session.close().await;
}
}
}
fn register_session(&self, sessions: &mut HashMap<SessionId, Session<B, P>>, id: SessionId) {
sessions.entry(id).or_insert_with(|| {
log::info!("Creating a new session {id}");
Session::new(self.runner.clone(), self.data_service.clone())
});
}
}
impl<B, P> Session<B, P>
where
B: BackendIr,
P: Protocol,
{
fn new(runner: Runner<B>, data_service: Arc<TensorDataService<B, P>>) -> Self {
let (sender, receiver) = tokio::sync::mpsc::channel(1);
Self {
runner,
streams: Default::default(),
sender,
receiver: Some(receiver),
data_service,
}
}
fn init_responder(&mut self) -> Receiver<Receiver<TaskResponse>> {
let mut receiver = None;
core::mem::swap(&mut receiver, &mut self.receiver);
receiver.expect("Only one responder per session is possible.")
}
/// Select the current [stream](Stream) based on the given task.
async fn select(&mut self, stream_id: StreamId) -> Stream<B, P> {
// We return the stream.
match self.streams.get(&stream_id) {
Some(stream) => stream.clone(),
None => {
let stream = Stream::<B, P>::new(
self.runner.clone(),
self.sender.clone(),
self.data_service.clone(),
)
.await;
self.streams.insert(stream_id, stream.clone());
stream
}
}
}
// Close all streams created in the session.
async fn close(&mut self) {
for (id, stream) in self.streams.drain() {
log::info!("Closing stream {id}");
stream.close().await;
}
}
}

View File

@@ -0,0 +1,134 @@
use core::marker::PhantomData;
use std::sync::Arc;
use crate::shared::{ConnectionId, TaskResponse, TensorRemote};
use super::processor::{Processor, ProcessorTask};
use burn_backend::TensorData;
use burn_communication::{
Protocol,
data_service::{TensorDataService, TensorTransferId},
};
use burn_ir::{BackendIr, OperationIr, TensorId, TensorIr};
use burn_router::Runner;
use burn_std::DType;
use tokio::sync::mpsc::{Receiver, Sender};
/// A stream makes sure all operations registered are executed in the order they were sent to the
/// server, potentially waiting to reconstruct consistency.
#[derive(Clone)]
pub struct Stream<B, P>
where
B: BackendIr,
P: Protocol,
{
compute_sender: Sender<ProcessorTask>,
writer_sender: Sender<Receiver<TaskResponse>>,
_p: PhantomData<B>,
_n: PhantomData<P>,
}
impl<B, P> Stream<B, P>
where
B: BackendIr,
P: Protocol,
{
pub async fn new(
runner: Runner<B>,
writer_sender: Sender<Receiver<TaskResponse>>,
data_service: Arc<TensorDataService<B, P>>,
) -> Self {
let sender = Processor::<B, P>::start(runner, data_service).await;
Self {
compute_sender: sender,
writer_sender,
_p: PhantomData,
_n: PhantomData,
}
}
pub async fn register_operation(&self, op: Box<OperationIr>) {
self.compute_sender
.send(ProcessorTask::RegisterOperation(op))
.await
.unwrap();
}
pub async fn register_tensor(&self, tensor_id: TensorId, data: TensorData) {
self.compute_sender
.send(ProcessorTask::RegisterTensor(tensor_id, data))
.await
.unwrap();
}
pub async fn register_tensor_remote(&self, tensor: TensorRemote, new_id: TensorId) {
self.compute_sender
.send(ProcessorTask::RegisterTensorRemote(tensor, new_id))
.await
.unwrap();
}
pub async fn expose_tensor_remote(
&self,
tensor: TensorIr,
count: u32,
transfer_id: TensorTransferId,
) {
self.compute_sender
.send(ProcessorTask::ExposeTensorRemote {
tensor,
count,
transfer_id,
})
.await
.unwrap();
}
pub async fn read_tensor(&self, id: ConnectionId, desc: TensorIr) {
let (callback_sender, callback_rec) = tokio::sync::mpsc::channel(1);
self.compute_sender
.send(ProcessorTask::ReadTensor(id, desc, callback_sender))
.await
.unwrap();
self.writer_sender.send(callback_rec).await.unwrap();
}
pub async fn sync(&self, id: ConnectionId) {
let (callback_sender, callback_rec) = tokio::sync::mpsc::channel(1);
self.compute_sender
.send(ProcessorTask::Sync(id, callback_sender))
.await
.unwrap();
self.writer_sender.send(callback_rec).await.unwrap();
}
pub async fn close(&self) {
self.compute_sender
.send(ProcessorTask::Close)
.await
.unwrap();
}
pub async fn seed(&self, seed: u64) {
self.compute_sender
.send(ProcessorTask::Seed(seed))
.await
.unwrap();
}
pub async fn supports_dtype(&self, id: ConnectionId, dtype: DType) {
let (callback_sender, callback_rec) = tokio::sync::mpsc::channel(1);
self.compute_sender
.send(ProcessorTask::SupportsDType(id, dtype, callback_sender))
.await
.unwrap();
self.writer_sender.send(callback_rec).await.unwrap();
}
}

View File

@@ -0,0 +1,7 @@
mod task;
#[allow(unused_imports)]
pub(crate) use task::*;
/// We define the communication protocol here
pub(crate) type RemoteProtocol = burn_communication::websocket::WebSocket;

View File

@@ -0,0 +1,87 @@
use burn_backend::{ExecutionError, TensorData};
use burn_communication::{Address, data_service::TensorTransferId};
use burn_ir::{OperationIr, TensorId, TensorIr};
use burn_std::{
DType,
id::{IdGenerator, StreamId},
};
use serde::{Deserialize, Serialize};
use std::fmt::Display;
#[allow(missing_docs)]
#[derive(new, Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]
pub struct ConnectionId {
pub position: u64,
pub stream_id: StreamId,
}
/// Unique identifier that can represent a session.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub struct SessionId {
id: u64,
}
impl Display for SessionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "SessionId({})", self.id)
}
}
impl SessionId {
/// Create a new [session id](SessionId).
#[allow(dead_code)]
pub fn new() -> Self {
Self {
id: IdGenerator::generate(),
}
}
}
#[allow(missing_docs)]
#[derive(Serialize, Deserialize, Debug)]
pub enum Task {
Compute(ComputeTask, ConnectionId),
Init(SessionId),
Close(SessionId),
}
#[allow(missing_docs)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct TensorRemote {
pub transfer_id: TensorTransferId,
pub address: Address,
}
#[allow(missing_docs)]
#[derive(Serialize, Deserialize, Debug)]
pub enum ComputeTask {
Seed(u64),
RegisterOperation(Box<OperationIr>),
RegisterTensor(TensorId, TensorData),
RegisterTensorRemote(TensorRemote, TensorId),
ExposeTensorRemote {
tensor: TensorIr,
count: u32,
transfer_id: TensorTransferId,
},
ReadTensor(TensorIr),
SyncBackend,
SupportsDType(DType),
}
#[allow(missing_docs)]
#[derive(Serialize, Deserialize, Debug)]
pub struct TaskResponse {
pub content: TaskResponseContent,
pub id: ConnectionId,
}
#[allow(missing_docs)]
#[derive(Serialize, Deserialize, Debug)]
pub enum TaskResponseContent {
ReadTensor(Result<TensorData, ExecutionError>),
SyncBackend(Result<(), ExecutionError>),
// SupportsDType(DTypeUsageSet),
// TODO: Update to `DTypeUsageSet` when it implements `serde`.
SupportsDType(()),
}