feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,44 @@
[package]
authors = ["Guilhem Ané (@Cielbird)", "Nathaniel Simard (@nathanielsimard)"]
description = "Abstractions for network communication for Burn"
edition.workspace = true
license.workspace = true
name = "burn-communication"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-communication"
version.workspace = true
[lints]
workspace = true
[features]
tracing = [
"burn-std/tracing",
"burn-tensor?/tracing",
]
data-service = ["burn-tensor"]
websocket = ["axum", "tokio-tungstenite", "futures"]
[dependencies]
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = true }
bytes = { workspace = true }
derive-new = { workspace = true }
futures-util = { workspace = true }
log = { workspace = true }
rmp-serde = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_bytes = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread", "sync", "signal", "tracing"] }
tokio-util = { workspace = true }
tracing = { workspace = true, features = ["default"] }
tracing-core = { workspace = true, features = ["default"] }
tracing-subscriber = { workspace = true, features = ["default", "fmt", "env-filter"] }
# Tensor Data Service
burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", optional = true }
# Websocket
axum = { workspace = true, features = ["ws"], optional = true }
tokio-tungstenite = { workspace = true, optional = true }
futures = { workspace = true, optional = true }

View File

@@ -0,0 +1,15 @@
# Burn Communication
Abstractions for network communication
The Protocol trait defines how to communicate in a server/client style.
The server can set up routes with callbacks upon connection.
## WebSocket
Communication with WebSockets is implemented with the `websocket` feature.
## Tensor Data Service
The tensor data service provides easy utilities to share tensors peer-to-peer.
One peer can expose a tensor, and another can download it. Each peer is both a client and a server.

View File

@@ -0,0 +1,104 @@
use burn_std::future::DynFut;
use serde::{Deserialize, Serialize};
use std::fmt::{Debug, Display};
use std::hash::Hash;
use std::str::FromStr;
/// Allows nodes to find each other
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug)]
pub struct Address {
pub(crate) inner: String,
}
impl FromStr for Address {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self {
inner: s.to_string(),
})
}
}
impl Display for Address {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.inner)
}
}
/// The protocol used for the communications.
pub trait Protocol: Clone + Send + Sync + 'static {
/// The client implementation for the current protocol.
type Client: ProtocolClient;
/// The server implementation for the current protocol.
type Server: ProtocolServer;
}
/// Error that happens during a communication.
pub trait CommunicationError: Debug + Send + 'static {}
/// The client is only used to create a [channel](CommunicationChannel), which should be use to
/// transmit information with the [server](ProtocolServer).
pub trait ProtocolClient: Send + Sync + 'static {
/// Channel used by this protocol.
type Channel: CommunicationChannel<Error = Self::Error>;
/// The error type.
type Error: CommunicationError;
/// Opens a new [channel](CommunicationChannel) with the current protocol at the given
/// [address](Address) and route.
///
/// * `address` - Address to connect to
/// * `route` - The name of the route (no slashes)
///
/// Returns None if the connection can't be done.
fn connect(address: Address, route: &str) -> DynFut<Option<Self::Channel>>;
}
/// Data sent and received by the client and server.
#[derive(new)]
pub struct Message {
/// The data is always encoded as bytes.
pub data: bytes::Bytes,
}
/// Defines how to create a server that respond to a [channel](CommunicationChannel).
pub trait ProtocolServer: Sized + Send + Sync + 'static {
/// Channel used by this protocol.
type Channel: CommunicationChannel<Error = Self::Error>;
/// The error type.
type Error: CommunicationError;
/// Defines an endpoint with the function that responds.
/// TODO Docs: does it need a slash?
fn route<C, Fut>(self, path: &str, callback: C) -> Self
where
C: FnOnce(Self::Channel) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static;
/// Start the server.
fn serve<F>(
self,
shutdown: F,
) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static
where
F: Future<Output = ()> + Send + 'static;
}
/// Handles communications.
pub trait CommunicationChannel: Send + 'static {
type Error: CommunicationError;
/// Send a [message](Message) on the channel.
fn send(
&mut self,
message: Message,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
/// Receive a [message](Message) on the channel and returns a new [response message](Message).
fn recv(
&mut self,
) -> impl std::future::Future<Output = Result<Option<Message>, Self::Error>> + Send;
fn close(&mut self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
}

View File

@@ -0,0 +1,258 @@
//! This module enables direct data transfer between servers without blocking the client or any server.
//!
//! It eliminates the need for intermediate data transfer through the client, avoiding the process of downloading data from one server and reuploading it to another.
//!
//! The module provides an optimized mechanism for servers to communicate directly, streamlining data movement between them without involving the client.
use crate::Message;
use crate::base::Protocol;
use crate::base::{Address, CommunicationChannel, ProtocolClient, ProtocolServer};
use burn_tensor::{TensorData, backend::Backend};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, marker::PhantomData, sync::Arc};
use tokio::sync::Mutex;
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TensorTransferId(u64);
impl From<u64> for TensorTransferId {
fn from(value: u64) -> Self {
Self(value)
}
}
impl TensorTransferId {
pub fn next(&mut self) {
self.0 += 1;
}
}
#[derive(Debug, Serialize, Deserialize)]
enum DataServiceMessage {
TensorRequest(TensorTransferId),
Tensor(TensorData),
}
type ClientChannelRef<C> = Arc<Mutex<<C as ProtocolClient>::Channel>>;
pub struct TensorDataService<B: Backend, P: Protocol<Client: ProtocolClient>> {
/// Maps tensor transfer IDs to their exposed state.
pub exposed_tensors: Mutex<HashMap<TensorTransferId, TensorExposeState>>,
/// Maps node addresses to their channels.
pub channels: Mutex<HashMap<Address, ClientChannelRef<P::Client>>>,
/// Notify when a new tensor is exposed.
pub new_tensor_notify: Arc<Notify>,
cancel_token: CancellationToken,
_phantom_data: PhantomData<B>,
}
pub struct TensorExposeState {
/// The bytes of the tensor data message. Message::Data(...) serialized with rmp_serde
pub bytes: bytes::Bytes,
/// How many times the tensor will be downloaded
pub max_downloads: u32,
/// How man times the tensor has been downloaded
pub cur_download_count: u32,
}
/// Provides a routing function for a tensor data service for a communications server
pub trait TensorDataServer<B: Backend, P: Protocol> {
/// Routes the tensor data service to the "/data" route
fn route_tensor_data_service(self, state: Arc<TensorDataService<B, P>>) -> Self;
}
impl<B: Backend, S: ProtocolServer + Sized, P: Protocol<Server = S> + 'static>
TensorDataServer<B, P> for S
{
fn route_tensor_data_service(self, state: Arc<TensorDataService<B, P>>) -> Self {
self.route("/data", async move |stream: S::Channel| {
state.handle_data_channel(stream).await;
})
}
}
impl<B: Backend, P: Protocol> TensorDataService<B, P> {
pub fn new(cancel_token: CancellationToken) -> Self {
Self {
exposed_tensors: Mutex::new(HashMap::new()),
channels: Mutex::new(HashMap::new()),
new_tensor_notify: Arc::new(Notify::new()),
cancel_token,
_phantom_data: PhantomData::<B>,
}
}
/// Exposes a tensor to the data server, allowing it to be downloaded by other nodes.
pub async fn expose(
&self,
tensor: B::FloatTensorPrimitive,
max_downloads: u32,
transfer_id: TensorTransferId,
) {
let data = B::float_into_data(tensor).await.unwrap();
self.expose_data(data, max_downloads, transfer_id).await
}
/// Exposes a tensor data to the data server, allowing it to be downloaded by other nodes.
pub async fn expose_data(
&self,
tensor_data: TensorData,
max_downloads: u32,
transfer_id: TensorTransferId,
) {
let bytes: bytes::Bytes = rmp_serde::to_vec(&DataServiceMessage::Tensor(tensor_data))
.unwrap()
.into();
let mut exposed_tensors = self.exposed_tensors.lock().await;
exposed_tensors.insert(
transfer_id,
TensorExposeState {
bytes,
max_downloads,
cur_download_count: 0,
},
);
core::mem::drop(exposed_tensors);
self.new_tensor_notify.notify_waiters();
}
pub async fn close(&self) {
// Send a closing message to every open WebSocket stream
let mut streams = self.channels.lock().await;
for (_, stream) in streams.drain() {
let mut stream = stream.lock().await;
stream
.close()
.await
.expect("Failed to close WebSocket stream");
}
}
/// Downloads a tensor that is exposed on another server. Requires a Tokio 1.x runtime
///
/// Returns None if the peer closes the connection
pub async fn download_tensor(
&self,
remote: Address,
transfer_id: TensorTransferId,
) -> Option<TensorData> {
log::info!("Downloading tensor from {remote:?}");
let stream = self.get_data_stream(remote).await;
let mut stream = stream.lock().await;
// Send the download request with the download id
let bytes: bytes::Bytes =
rmp_serde::to_vec(&DataServiceMessage::TensorRequest(transfer_id))
.unwrap()
.into();
stream
.send(Message::new(bytes))
.await
.expect("Failed to send download id");
if let Ok(msg) = stream.recv().await {
let Some(msg) = msg else {
log::warn!("Received None message from the websocket, closing connection.");
return None;
};
let DataServiceMessage::Tensor(data) = rmp_serde::from_slice(&msg.data)
.expect("Can deserialize messages from the websocket.")
else {
panic!("Message should have been TensorData")
};
return Some(data);
}
log::warn!("Closed connection");
None
}
/// Get the WebSocket stream for the given address, or create a new one if it doesn't exist.
async fn get_data_stream(
&self,
address: Address,
) -> Arc<Mutex<<P::Client as ProtocolClient>::Channel>> {
let mut streams = self.channels.lock().await;
match streams.get(&address) {
Some(stream) => stream.clone(),
None => {
// Open a new WebSocket connection to the address
let stream = P::Client::connect(address.clone(), "data").await;
let Some(stream) = stream else {
panic!("Failed to connect to data server at {address:?}");
};
let stream = Arc::new(Mutex::new(stream));
streams.insert(address.clone(), stream.clone());
stream
}
}
}
/// Get the requested exposed tensor data, and update download counter
async fn get_exposed_tensor_bytes(
&self,
transfer_id: TensorTransferId,
) -> Option<bytes::Bytes> {
loop {
{
let mut exposed_tensors = self.exposed_tensors.lock().await;
// take the tensor out of the hashmap while we download
if let Some(mut exposed_state) = exposed_tensors.remove(&transfer_id) {
exposed_state.cur_download_count += 1;
let bytes = if exposed_state.cur_download_count == exposed_state.max_downloads {
exposed_state.bytes
} else {
let bytes = exposed_state.bytes.clone();
exposed_tensors.insert(transfer_id, exposed_state);
bytes
};
return Some(bytes);
}
}
// No matching tensor, wait for a new one to come in.
self.new_tensor_notify.notified().await;
}
}
/// Handle incoming connections for downloading tensors.
pub(crate) async fn handle_data_channel(
&self,
mut channel: <P::Server as ProtocolServer>::Channel,
) {
log::info!("[Data Handler] New connection for download.");
while !self.cancel_token.is_cancelled() {
match channel.recv().await {
Ok(message) => {
if let Some(msg) = message {
let bytes = msg.data;
let msg: DataServiceMessage = rmp_serde::from_slice(&bytes)
.expect("Can deserialize messages from the websocket.");
let DataServiceMessage::TensorRequest(transfer_id) = msg else {
panic!("Received a message that wasn't a tensor request! {msg:?}");
};
let bytes = self.get_exposed_tensor_bytes(transfer_id).await.unwrap();
channel.send(Message::new(bytes)).await.unwrap();
} else {
log::info!("Closed connection");
return;
}
}
Err(err) => panic!("Failed to receive message from websocket: {err:?}"),
};
}
log::info!("[Data Service] Closing connection for download.");
}
}

View File

@@ -0,0 +1,13 @@
#[macro_use]
extern crate derive_new;
mod base;
pub use base::*;
pub mod util;
#[cfg(feature = "websocket")]
pub mod websocket;
#[cfg(feature = "data-service")]
pub mod data_service;

View File

@@ -0,0 +1,46 @@
use tracing_core::{Level, LevelFilter};
use tracing_subscriber::{
Layer, filter::filter_fn, layer::SubscriberExt, registry, util::SubscriberInitExt,
};
/// Utilities to help handle communication termination.
pub async fn os_shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
}
pub(crate) fn init_logging() {
let layer = tracing_subscriber::fmt::layer()
.with_filter(LevelFilter::INFO)
.with_filter(filter_fn(|m| {
if let Some(path) = m.module_path() {
// The wgpu crate is logging too much, so we skip `info` level.
if path.starts_with("wgpu") && *m.level() >= Level::INFO {
return false;
}
}
true
}));
// If we start multiple servers in the same process, this will fail, it's ok
let _ = registry().with(layer).try_init();
}

View File

@@ -0,0 +1,35 @@
use crate::{
base::{Address, Protocol},
websocket::{client::WsClient, server::WsServer},
};
#[derive(Clone)]
/// A websocket implements a [communication protocol](Protocol) that can be used to communicate
/// over the internet.
pub struct WebSocket {}
impl Protocol for WebSocket {
type Client = WsClient;
type Server = WsServer;
}
/// Parse an address, add the ws:// prefix if needed, and return an error if the address is invalid
pub(crate) fn parse_ws_address(mut address: Address) -> Result<Address, String> {
let s = &address.inner;
let parts = s.split("://").collect::<Vec<&str>>();
let num_parts = parts.len();
let url = if num_parts == 2 {
if parts[0] == "ws" {
s.to_owned()
} else {
return Err(format!("Invalid prefix: {}", parts[0]));
}
} else if num_parts == 1 {
return Err(format!("ws://{s}"));
} else {
return Err(format!("Invalid url: {s}"));
};
address.inner = url;
Ok(address)
}

View File

@@ -0,0 +1,109 @@
use crate::{
base::{Address, CommunicationChannel, CommunicationError, Message, ProtocolClient},
websocket::base::parse_ws_address,
};
use burn_std::future::DynFut;
use futures::{SinkExt, StreamExt};
use tokio::net::TcpStream;
use tokio_tungstenite::{
MaybeTlsStream, WebSocketStream, connect_async_with_config,
tungstenite::{self, protocol::WebSocketConfig},
};
#[derive(Clone)]
pub struct WsClient;
impl ProtocolClient for WsClient {
type Channel = WsClientChannel;
type Error = WsClientError;
fn connect(address: Address, route: &str) -> DynFut<Option<WsClientChannel>> {
Box::pin(connect_ws(address, route.to_owned()))
}
}
/// Open a new WebSocket connection to the address
async fn connect_ws(address: Address, route: String) -> Option<WsClientChannel> {
let address = parse_ws_address(address).ok()?;
let address = format!("{address}/{route}");
const MB: usize = 1024 * 1024;
let (stream, _) = connect_async_with_config(
address.clone(),
Some(
WebSocketConfig::default()
.write_buffer_size(0)
.max_message_size(None)
.max_frame_size(Some(MB * 512))
.accept_unmasked_frames(true)
.read_buffer_size(64 * 1024), // 64 KiB (previous default)
),
true,
)
.await
.ok()?;
Some(WsClientChannel { inner: stream })
}
pub struct WsClientChannel {
inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
}
impl CommunicationChannel for WsClientChannel {
type Error = WsClientError;
async fn send(&mut self, msg: Message) -> Result<(), WsClientError> {
self.inner
.send(tungstenite::Message::Binary(msg.data))
.await?;
Ok(())
}
async fn recv(&mut self) -> Result<Option<Message>, WsClientError> {
match self.inner.next().await {
Some(next) => match next {
Ok(tungstenite::Message::Binary(data)) => Ok(Some(Message { data })),
Ok(tungstenite::Message::Close(_close_frame)) => Ok(None),
Err(err) => Err(WsClientError::Tungstenite(err)),
msg => Err(WsClientError::UnknownMessage(format!("{msg:?}"))),
},
None => todo!(),
}
}
async fn close(&mut self) -> Result<(), WsClientError> {
let reason = "Peer is closing".to_string();
self.inner
.send(tungstenite::Message::Close(Some(
tungstenite::protocol::CloseFrame {
code: tungstenite::protocol::frame::coding::CloseCode::Normal,
reason: reason.clone().into(),
},
)))
.await?;
Ok(())
}
}
#[derive(Debug)]
pub enum WsClientError {
Io(std::io::Error),
Tungstenite(tungstenite::Error),
UnknownMessage(String),
Other(String),
}
impl CommunicationError for WsClientError {}
impl From<std::io::Error> for WsClientError {
fn from(err: std::io::Error) -> Self {
Self::Io(err)
}
}
impl From<tungstenite::Error> for WsClientError {
fn from(err: tungstenite::Error) -> Self {
Self::Tungstenite(err)
}
}

View File

@@ -0,0 +1,7 @@
mod base;
mod client;
mod server;
pub use base::*;
pub use client::*;
pub use server::*;

View File

@@ -0,0 +1,141 @@
use std::net::SocketAddr;
use crate::{
base::{CommunicationChannel, CommunicationError, Message, ProtocolServer},
util::init_logging,
};
use axum::{
Router,
extract::{
State, WebSocketUpgrade,
ws::{self, WebSocket},
},
routing::get,
};
use futures::StreamExt;
#[derive(Clone, Debug)]
pub struct WsServer {
port: u16,
router: Router<()>,
}
pub struct WsServerChannel {
inner: WebSocket,
}
impl WsServer {
pub fn new(port: u16) -> Self {
Self {
port,
router: Router::new(),
}
}
}
impl ProtocolServer for WsServer {
type Channel = WsServerChannel;
type Error = WsServerError;
async fn serve<F>(self, shutdown: F) -> Result<(), Self::Error>
where
F: Future<Output = ()> + Send + 'static,
{
init_logging();
let address = format!("0.0.0.0:{}", self.port);
log::info!("Starting server {address}");
let listener = tokio::net::TcpListener::bind(address).await?;
axum::serve(
listener,
self.router
.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown)
.await?;
Ok(())
}
fn route<C, Fut>(mut self, path: &str, callback: C) -> Self
where
C: FnOnce(WsServerChannel) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
// Format path: should start with a /
let path = if path.starts_with("/") {
path.to_owned()
} else {
format!("/{path}")
};
let method = get(|ws: WebSocketUpgrade, _: State<()>| async {
ws.on_upgrade(async move |socket| {
callback(WsServerChannel { inner: socket }).await;
})
});
self.router = self.router.route(&path, method);
self
}
}
impl CommunicationChannel for WsServerChannel {
type Error = WsServerError;
async fn send(&mut self, message: Message) -> Result<(), WsServerError> {
self.inner.send(ws::Message::Binary(message.data)).await?;
Ok(())
}
async fn recv(&mut self) -> Result<Option<Message>, WsServerError> {
match self.inner.next().await {
Some(next) => match next {
Ok(ws::Message::Binary(data)) => Ok(Some(Message { data })),
Ok(ws::Message::Close(_close_frame)) => Ok(None),
Err(err) => Err(WsServerError::Axum(err)),
msg => Err(WsServerError::UnknownMessage(format!("{msg:?}"))),
},
None => todo!(),
}
}
async fn close(&mut self) -> Result<(), WsServerError> {
let reason = "Peer is closing".to_string();
self.inner
.send(ws::Message::Close(Some(ws::CloseFrame {
code: 1000, // code: Normal
reason: reason.clone().into(),
})))
.await?;
Ok(())
}
}
#[derive(Debug)]
pub enum WsServerError {
Io(std::io::Error),
Axum(axum::Error),
UnknownMessage(String),
Other(String),
}
impl CommunicationError for WsServerError {}
impl From<std::io::Error> for WsServerError {
fn from(err: std::io::Error) -> Self {
Self::Io(err)
}
}
impl From<axum::Error> for WsServerError {
fn from(err: axum::Error) -> Self {
Self::Axum(err)
}
}