feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
@@ -0,0 +1,44 @@
|
||||
[package]
|
||||
authors = ["Guilhem Ané (@Cielbird)", "Nathaniel Simard (@nathanielsimard)"]
|
||||
description = "Abstractions for network communication for Burn"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
name = "burn-communication"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-communication"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
tracing = [
|
||||
"burn-std/tracing",
|
||||
"burn-tensor?/tracing",
|
||||
]
|
||||
|
||||
data-service = ["burn-tensor"]
|
||||
websocket = ["axum", "tokio-tungstenite", "futures"]
|
||||
|
||||
[dependencies]
|
||||
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = true }
|
||||
bytes = { workspace = true }
|
||||
derive-new = { workspace = true }
|
||||
futures-util = { workspace = true }
|
||||
log = { workspace = true }
|
||||
rmp-serde = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_bytes = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt-multi-thread", "sync", "signal", "tracing"] }
|
||||
tokio-util = { workspace = true }
|
||||
tracing = { workspace = true, features = ["default"] }
|
||||
tracing-core = { workspace = true, features = ["default"] }
|
||||
tracing-subscriber = { workspace = true, features = ["default", "fmt", "env-filter"] }
|
||||
|
||||
# Tensor Data Service
|
||||
burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", optional = true }
|
||||
|
||||
# Websocket
|
||||
axum = { workspace = true, features = ["ws"], optional = true }
|
||||
tokio-tungstenite = { workspace = true, optional = true }
|
||||
futures = { workspace = true, optional = true }
|
||||
@@ -0,0 +1,15 @@
|
||||
# Burn Communication
|
||||
|
||||
Abstractions for network communication
|
||||
|
||||
The Protocol trait defines how to communicate in a server/client style.
|
||||
The server can set up routes with callbacks upon connection.
|
||||
|
||||
## WebSocket
|
||||
|
||||
Communication with WebSockets is implemented with the `websocket` feature.
|
||||
|
||||
## Tensor Data Service
|
||||
|
||||
The tensor data service provides easy utilities to share tensors peer-to-peer.
|
||||
One peer can expose a tensor, and another can download it. Each peer is both a client and a server.
|
||||
@@ -0,0 +1,104 @@
|
||||
use burn_std::future::DynFut;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::hash::Hash;
|
||||
use std::str::FromStr;
|
||||
|
||||
/// Allows nodes to find each other
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug)]
|
||||
pub struct Address {
|
||||
pub(crate) inner: String,
|
||||
}
|
||||
|
||||
impl FromStr for Address {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
Ok(Self {
|
||||
inner: s.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Address {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.inner)
|
||||
}
|
||||
}
|
||||
|
||||
/// The protocol used for the communications.
|
||||
pub trait Protocol: Clone + Send + Sync + 'static {
|
||||
/// The client implementation for the current protocol.
|
||||
type Client: ProtocolClient;
|
||||
/// The server implementation for the current protocol.
|
||||
type Server: ProtocolServer;
|
||||
}
|
||||
|
||||
/// Error that happens during a communication.
|
||||
pub trait CommunicationError: Debug + Send + 'static {}
|
||||
|
||||
/// The client is only used to create a [channel](CommunicationChannel), which should be use to
|
||||
/// transmit information with the [server](ProtocolServer).
|
||||
pub trait ProtocolClient: Send + Sync + 'static {
|
||||
/// Channel used by this protocol.
|
||||
type Channel: CommunicationChannel<Error = Self::Error>;
|
||||
/// The error type.
|
||||
type Error: CommunicationError;
|
||||
|
||||
/// Opens a new [channel](CommunicationChannel) with the current protocol at the given
|
||||
/// [address](Address) and route.
|
||||
///
|
||||
/// * `address` - Address to connect to
|
||||
/// * `route` - The name of the route (no slashes)
|
||||
///
|
||||
/// Returns None if the connection can't be done.
|
||||
fn connect(address: Address, route: &str) -> DynFut<Option<Self::Channel>>;
|
||||
}
|
||||
|
||||
/// Data sent and received by the client and server.
|
||||
#[derive(new)]
|
||||
pub struct Message {
|
||||
/// The data is always encoded as bytes.
|
||||
pub data: bytes::Bytes,
|
||||
}
|
||||
|
||||
/// Defines how to create a server that respond to a [channel](CommunicationChannel).
|
||||
pub trait ProtocolServer: Sized + Send + Sync + 'static {
|
||||
/// Channel used by this protocol.
|
||||
type Channel: CommunicationChannel<Error = Self::Error>;
|
||||
/// The error type.
|
||||
type Error: CommunicationError;
|
||||
|
||||
/// Defines an endpoint with the function that responds.
|
||||
/// TODO Docs: does it need a slash?
|
||||
fn route<C, Fut>(self, path: &str, callback: C) -> Self
|
||||
where
|
||||
C: FnOnce(Self::Channel) -> Fut + Clone + Send + Sync + 'static,
|
||||
Fut: Future<Output = ()> + Send + 'static;
|
||||
|
||||
/// Start the server.
|
||||
fn serve<F>(
|
||||
self,
|
||||
shutdown: F,
|
||||
) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static;
|
||||
}
|
||||
|
||||
/// Handles communications.
|
||||
pub trait CommunicationChannel: Send + 'static {
|
||||
type Error: CommunicationError;
|
||||
|
||||
/// Send a [message](Message) on the channel.
|
||||
fn send(
|
||||
&mut self,
|
||||
message: Message,
|
||||
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
|
||||
|
||||
/// Receive a [message](Message) on the channel and returns a new [response message](Message).
|
||||
fn recv(
|
||||
&mut self,
|
||||
) -> impl std::future::Future<Output = Result<Option<Message>, Self::Error>> + Send;
|
||||
|
||||
fn close(&mut self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
|
||||
}
|
||||
@@ -0,0 +1,258 @@
|
||||
//! This module enables direct data transfer between servers without blocking the client or any server.
|
||||
//!
|
||||
//! It eliminates the need for intermediate data transfer through the client, avoiding the process of downloading data from one server and reuploading it to another.
|
||||
//!
|
||||
//! The module provides an optimized mechanism for servers to communicate directly, streamlining data movement between them without involving the client.
|
||||
|
||||
use crate::Message;
|
||||
use crate::base::Protocol;
|
||||
use crate::base::{Address, CommunicationChannel, ProtocolClient, ProtocolServer};
|
||||
use burn_tensor::{TensorData, backend::Backend};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{collections::HashMap, marker::PhantomData, sync::Arc};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::Notify;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct TensorTransferId(u64);
|
||||
|
||||
impl From<u64> for TensorTransferId {
|
||||
fn from(value: u64) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorTransferId {
|
||||
pub fn next(&mut self) {
|
||||
self.0 += 1;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
enum DataServiceMessage {
|
||||
TensorRequest(TensorTransferId),
|
||||
Tensor(TensorData),
|
||||
}
|
||||
|
||||
type ClientChannelRef<C> = Arc<Mutex<<C as ProtocolClient>::Channel>>;
|
||||
|
||||
pub struct TensorDataService<B: Backend, P: Protocol<Client: ProtocolClient>> {
|
||||
/// Maps tensor transfer IDs to their exposed state.
|
||||
pub exposed_tensors: Mutex<HashMap<TensorTransferId, TensorExposeState>>,
|
||||
/// Maps node addresses to their channels.
|
||||
pub channels: Mutex<HashMap<Address, ClientChannelRef<P::Client>>>,
|
||||
/// Notify when a new tensor is exposed.
|
||||
pub new_tensor_notify: Arc<Notify>,
|
||||
|
||||
cancel_token: CancellationToken,
|
||||
|
||||
_phantom_data: PhantomData<B>,
|
||||
}
|
||||
|
||||
pub struct TensorExposeState {
|
||||
/// The bytes of the tensor data message. Message::Data(...) serialized with rmp_serde
|
||||
pub bytes: bytes::Bytes,
|
||||
/// How many times the tensor will be downloaded
|
||||
pub max_downloads: u32,
|
||||
/// How man times the tensor has been downloaded
|
||||
pub cur_download_count: u32,
|
||||
}
|
||||
|
||||
/// Provides a routing function for a tensor data service for a communications server
|
||||
pub trait TensorDataServer<B: Backend, P: Protocol> {
|
||||
/// Routes the tensor data service to the "/data" route
|
||||
fn route_tensor_data_service(self, state: Arc<TensorDataService<B, P>>) -> Self;
|
||||
}
|
||||
|
||||
impl<B: Backend, S: ProtocolServer + Sized, P: Protocol<Server = S> + 'static>
|
||||
TensorDataServer<B, P> for S
|
||||
{
|
||||
fn route_tensor_data_service(self, state: Arc<TensorDataService<B, P>>) -> Self {
|
||||
self.route("/data", async move |stream: S::Channel| {
|
||||
state.handle_data_channel(stream).await;
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, P: Protocol> TensorDataService<B, P> {
|
||||
pub fn new(cancel_token: CancellationToken) -> Self {
|
||||
Self {
|
||||
exposed_tensors: Mutex::new(HashMap::new()),
|
||||
channels: Mutex::new(HashMap::new()),
|
||||
new_tensor_notify: Arc::new(Notify::new()),
|
||||
cancel_token,
|
||||
_phantom_data: PhantomData::<B>,
|
||||
}
|
||||
}
|
||||
|
||||
/// Exposes a tensor to the data server, allowing it to be downloaded by other nodes.
|
||||
pub async fn expose(
|
||||
&self,
|
||||
tensor: B::FloatTensorPrimitive,
|
||||
max_downloads: u32,
|
||||
transfer_id: TensorTransferId,
|
||||
) {
|
||||
let data = B::float_into_data(tensor).await.unwrap();
|
||||
self.expose_data(data, max_downloads, transfer_id).await
|
||||
}
|
||||
|
||||
/// Exposes a tensor data to the data server, allowing it to be downloaded by other nodes.
|
||||
pub async fn expose_data(
|
||||
&self,
|
||||
tensor_data: TensorData,
|
||||
max_downloads: u32,
|
||||
transfer_id: TensorTransferId,
|
||||
) {
|
||||
let bytes: bytes::Bytes = rmp_serde::to_vec(&DataServiceMessage::Tensor(tensor_data))
|
||||
.unwrap()
|
||||
.into();
|
||||
let mut exposed_tensors = self.exposed_tensors.lock().await;
|
||||
exposed_tensors.insert(
|
||||
transfer_id,
|
||||
TensorExposeState {
|
||||
bytes,
|
||||
max_downloads,
|
||||
cur_download_count: 0,
|
||||
},
|
||||
);
|
||||
core::mem::drop(exposed_tensors);
|
||||
self.new_tensor_notify.notify_waiters();
|
||||
}
|
||||
|
||||
pub async fn close(&self) {
|
||||
// Send a closing message to every open WebSocket stream
|
||||
|
||||
let mut streams = self.channels.lock().await;
|
||||
for (_, stream) in streams.drain() {
|
||||
let mut stream = stream.lock().await;
|
||||
|
||||
stream
|
||||
.close()
|
||||
.await
|
||||
.expect("Failed to close WebSocket stream");
|
||||
}
|
||||
}
|
||||
|
||||
/// Downloads a tensor that is exposed on another server. Requires a Tokio 1.x runtime
|
||||
///
|
||||
/// Returns None if the peer closes the connection
|
||||
pub async fn download_tensor(
|
||||
&self,
|
||||
remote: Address,
|
||||
transfer_id: TensorTransferId,
|
||||
) -> Option<TensorData> {
|
||||
log::info!("Downloading tensor from {remote:?}");
|
||||
|
||||
let stream = self.get_data_stream(remote).await;
|
||||
let mut stream = stream.lock().await;
|
||||
|
||||
// Send the download request with the download id
|
||||
let bytes: bytes::Bytes =
|
||||
rmp_serde::to_vec(&DataServiceMessage::TensorRequest(transfer_id))
|
||||
.unwrap()
|
||||
.into();
|
||||
stream
|
||||
.send(Message::new(bytes))
|
||||
.await
|
||||
.expect("Failed to send download id");
|
||||
|
||||
if let Ok(msg) = stream.recv().await {
|
||||
let Some(msg) = msg else {
|
||||
log::warn!("Received None message from the websocket, closing connection.");
|
||||
return None;
|
||||
};
|
||||
|
||||
let DataServiceMessage::Tensor(data) = rmp_serde::from_slice(&msg.data)
|
||||
.expect("Can deserialize messages from the websocket.")
|
||||
else {
|
||||
panic!("Message should have been TensorData")
|
||||
};
|
||||
return Some(data);
|
||||
}
|
||||
log::warn!("Closed connection");
|
||||
None
|
||||
}
|
||||
|
||||
/// Get the WebSocket stream for the given address, or create a new one if it doesn't exist.
|
||||
async fn get_data_stream(
|
||||
&self,
|
||||
address: Address,
|
||||
) -> Arc<Mutex<<P::Client as ProtocolClient>::Channel>> {
|
||||
let mut streams = self.channels.lock().await;
|
||||
match streams.get(&address) {
|
||||
Some(stream) => stream.clone(),
|
||||
None => {
|
||||
// Open a new WebSocket connection to the address
|
||||
let stream = P::Client::connect(address.clone(), "data").await;
|
||||
|
||||
let Some(stream) = stream else {
|
||||
panic!("Failed to connect to data server at {address:?}");
|
||||
};
|
||||
|
||||
let stream = Arc::new(Mutex::new(stream));
|
||||
streams.insert(address.clone(), stream.clone());
|
||||
|
||||
stream
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the requested exposed tensor data, and update download counter
|
||||
async fn get_exposed_tensor_bytes(
|
||||
&self,
|
||||
transfer_id: TensorTransferId,
|
||||
) -> Option<bytes::Bytes> {
|
||||
loop {
|
||||
{
|
||||
let mut exposed_tensors = self.exposed_tensors.lock().await;
|
||||
// take the tensor out of the hashmap while we download
|
||||
if let Some(mut exposed_state) = exposed_tensors.remove(&transfer_id) {
|
||||
exposed_state.cur_download_count += 1;
|
||||
let bytes = if exposed_state.cur_download_count == exposed_state.max_downloads {
|
||||
exposed_state.bytes
|
||||
} else {
|
||||
let bytes = exposed_state.bytes.clone();
|
||||
exposed_tensors.insert(transfer_id, exposed_state);
|
||||
bytes
|
||||
};
|
||||
return Some(bytes);
|
||||
}
|
||||
}
|
||||
// No matching tensor, wait for a new one to come in.
|
||||
self.new_tensor_notify.notified().await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle incoming connections for downloading tensors.
|
||||
pub(crate) async fn handle_data_channel(
|
||||
&self,
|
||||
mut channel: <P::Server as ProtocolServer>::Channel,
|
||||
) {
|
||||
log::info!("[Data Handler] New connection for download.");
|
||||
|
||||
while !self.cancel_token.is_cancelled() {
|
||||
match channel.recv().await {
|
||||
Ok(message) => {
|
||||
if let Some(msg) = message {
|
||||
let bytes = msg.data;
|
||||
let msg: DataServiceMessage = rmp_serde::from_slice(&bytes)
|
||||
.expect("Can deserialize messages from the websocket.");
|
||||
let DataServiceMessage::TensorRequest(transfer_id) = msg else {
|
||||
panic!("Received a message that wasn't a tensor request! {msg:?}");
|
||||
};
|
||||
|
||||
let bytes = self.get_exposed_tensor_bytes(transfer_id).await.unwrap();
|
||||
|
||||
channel.send(Message::new(bytes)).await.unwrap();
|
||||
} else {
|
||||
log::info!("Closed connection");
|
||||
return;
|
||||
}
|
||||
}
|
||||
Err(err) => panic!("Failed to receive message from websocket: {err:?}"),
|
||||
};
|
||||
}
|
||||
log::info!("[Data Service] Closing connection for download.");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
mod base;
|
||||
pub use base::*;
|
||||
|
||||
pub mod util;
|
||||
|
||||
#[cfg(feature = "websocket")]
|
||||
pub mod websocket;
|
||||
|
||||
#[cfg(feature = "data-service")]
|
||||
pub mod data_service;
|
||||
@@ -0,0 +1,46 @@
|
||||
use tracing_core::{Level, LevelFilter};
|
||||
use tracing_subscriber::{
|
||||
Layer, filter::filter_fn, layer::SubscriberExt, registry, util::SubscriberInitExt,
|
||||
};
|
||||
|
||||
/// Utilities to help handle communication termination.
|
||||
pub async fn os_shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
tokio::signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install Ctrl+C handler");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||
.expect("failed to install signal handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
tokio::select! {
|
||||
_ = ctrl_c => {},
|
||||
_ = terminate => {},
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn init_logging() {
|
||||
let layer = tracing_subscriber::fmt::layer()
|
||||
.with_filter(LevelFilter::INFO)
|
||||
.with_filter(filter_fn(|m| {
|
||||
if let Some(path) = m.module_path() {
|
||||
// The wgpu crate is logging too much, so we skip `info` level.
|
||||
if path.starts_with("wgpu") && *m.level() >= Level::INFO {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}));
|
||||
|
||||
// If we start multiple servers in the same process, this will fail, it's ok
|
||||
let _ = registry().with(layer).try_init();
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
use crate::{
|
||||
base::{Address, Protocol},
|
||||
websocket::{client::WsClient, server::WsServer},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
/// A websocket implements a [communication protocol](Protocol) that can be used to communicate
|
||||
/// over the internet.
|
||||
pub struct WebSocket {}
|
||||
|
||||
impl Protocol for WebSocket {
|
||||
type Client = WsClient;
|
||||
type Server = WsServer;
|
||||
}
|
||||
|
||||
/// Parse an address, add the ws:// prefix if needed, and return an error if the address is invalid
|
||||
pub(crate) fn parse_ws_address(mut address: Address) -> Result<Address, String> {
|
||||
let s = &address.inner;
|
||||
let parts = s.split("://").collect::<Vec<&str>>();
|
||||
let num_parts = parts.len();
|
||||
let url = if num_parts == 2 {
|
||||
if parts[0] == "ws" {
|
||||
s.to_owned()
|
||||
} else {
|
||||
return Err(format!("Invalid prefix: {}", parts[0]));
|
||||
}
|
||||
} else if num_parts == 1 {
|
||||
return Err(format!("ws://{s}"));
|
||||
} else {
|
||||
return Err(format!("Invalid url: {s}"));
|
||||
};
|
||||
|
||||
address.inner = url;
|
||||
Ok(address)
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
use crate::{
|
||||
base::{Address, CommunicationChannel, CommunicationError, Message, ProtocolClient},
|
||||
websocket::base::parse_ws_address,
|
||||
};
|
||||
use burn_std::future::DynFut;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_tungstenite::{
|
||||
MaybeTlsStream, WebSocketStream, connect_async_with_config,
|
||||
tungstenite::{self, protocol::WebSocketConfig},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WsClient;
|
||||
|
||||
impl ProtocolClient for WsClient {
|
||||
type Channel = WsClientChannel;
|
||||
type Error = WsClientError;
|
||||
|
||||
fn connect(address: Address, route: &str) -> DynFut<Option<WsClientChannel>> {
|
||||
Box::pin(connect_ws(address, route.to_owned()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Open a new WebSocket connection to the address
|
||||
async fn connect_ws(address: Address, route: String) -> Option<WsClientChannel> {
|
||||
let address = parse_ws_address(address).ok()?;
|
||||
let address = format!("{address}/{route}");
|
||||
const MB: usize = 1024 * 1024;
|
||||
let (stream, _) = connect_async_with_config(
|
||||
address.clone(),
|
||||
Some(
|
||||
WebSocketConfig::default()
|
||||
.write_buffer_size(0)
|
||||
.max_message_size(None)
|
||||
.max_frame_size(Some(MB * 512))
|
||||
.accept_unmasked_frames(true)
|
||||
.read_buffer_size(64 * 1024), // 64 KiB (previous default)
|
||||
),
|
||||
true,
|
||||
)
|
||||
.await
|
||||
.ok()?;
|
||||
|
||||
Some(WsClientChannel { inner: stream })
|
||||
}
|
||||
pub struct WsClientChannel {
|
||||
inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
}
|
||||
|
||||
impl CommunicationChannel for WsClientChannel {
|
||||
type Error = WsClientError;
|
||||
|
||||
async fn send(&mut self, msg: Message) -> Result<(), WsClientError> {
|
||||
self.inner
|
||||
.send(tungstenite::Message::Binary(msg.data))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recv(&mut self) -> Result<Option<Message>, WsClientError> {
|
||||
match self.inner.next().await {
|
||||
Some(next) => match next {
|
||||
Ok(tungstenite::Message::Binary(data)) => Ok(Some(Message { data })),
|
||||
Ok(tungstenite::Message::Close(_close_frame)) => Ok(None),
|
||||
Err(err) => Err(WsClientError::Tungstenite(err)),
|
||||
msg => Err(WsClientError::UnknownMessage(format!("{msg:?}"))),
|
||||
},
|
||||
None => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<(), WsClientError> {
|
||||
let reason = "Peer is closing".to_string();
|
||||
|
||||
self.inner
|
||||
.send(tungstenite::Message::Close(Some(
|
||||
tungstenite::protocol::CloseFrame {
|
||||
code: tungstenite::protocol::frame::coding::CloseCode::Normal,
|
||||
reason: reason.clone().into(),
|
||||
},
|
||||
)))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum WsClientError {
|
||||
Io(std::io::Error),
|
||||
Tungstenite(tungstenite::Error),
|
||||
UnknownMessage(String),
|
||||
Other(String),
|
||||
}
|
||||
impl CommunicationError for WsClientError {}
|
||||
|
||||
impl From<std::io::Error> for WsClientError {
|
||||
fn from(err: std::io::Error) -> Self {
|
||||
Self::Io(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tungstenite::Error> for WsClientError {
|
||||
fn from(err: tungstenite::Error) -> Self {
|
||||
Self::Tungstenite(err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
mod base;
|
||||
mod client;
|
||||
mod server;
|
||||
|
||||
pub use base::*;
|
||||
pub use client::*;
|
||||
pub use server::*;
|
||||
@@ -0,0 +1,141 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use crate::{
|
||||
base::{CommunicationChannel, CommunicationError, Message, ProtocolServer},
|
||||
util::init_logging,
|
||||
};
|
||||
use axum::{
|
||||
Router,
|
||||
extract::{
|
||||
State, WebSocketUpgrade,
|
||||
ws::{self, WebSocket},
|
||||
},
|
||||
routing::get,
|
||||
};
|
||||
use futures::StreamExt;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct WsServer {
|
||||
port: u16,
|
||||
router: Router<()>,
|
||||
}
|
||||
|
||||
pub struct WsServerChannel {
|
||||
inner: WebSocket,
|
||||
}
|
||||
|
||||
impl WsServer {
|
||||
pub fn new(port: u16) -> Self {
|
||||
Self {
|
||||
port,
|
||||
router: Router::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProtocolServer for WsServer {
|
||||
type Channel = WsServerChannel;
|
||||
type Error = WsServerError;
|
||||
|
||||
async fn serve<F>(self, shutdown: F) -> Result<(), Self::Error>
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
init_logging();
|
||||
|
||||
let address = format!("0.0.0.0:{}", self.port);
|
||||
log::info!("Starting server {address}");
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(address).await?;
|
||||
|
||||
axum::serve(
|
||||
listener,
|
||||
self.router
|
||||
.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
)
|
||||
.with_graceful_shutdown(shutdown)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn route<C, Fut>(mut self, path: &str, callback: C) -> Self
|
||||
where
|
||||
C: FnOnce(WsServerChannel) -> Fut + Clone + Send + Sync + 'static,
|
||||
Fut: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
// Format path: should start with a /
|
||||
let path = if path.starts_with("/") {
|
||||
path.to_owned()
|
||||
} else {
|
||||
format!("/{path}")
|
||||
};
|
||||
|
||||
let method = get(|ws: WebSocketUpgrade, _: State<()>| async {
|
||||
ws.on_upgrade(async move |socket| {
|
||||
callback(WsServerChannel { inner: socket }).await;
|
||||
})
|
||||
});
|
||||
|
||||
self.router = self.router.route(&path, method);
|
||||
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl CommunicationChannel for WsServerChannel {
|
||||
type Error = WsServerError;
|
||||
|
||||
async fn send(&mut self, message: Message) -> Result<(), WsServerError> {
|
||||
self.inner.send(ws::Message::Binary(message.data)).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recv(&mut self) -> Result<Option<Message>, WsServerError> {
|
||||
match self.inner.next().await {
|
||||
Some(next) => match next {
|
||||
Ok(ws::Message::Binary(data)) => Ok(Some(Message { data })),
|
||||
Ok(ws::Message::Close(_close_frame)) => Ok(None),
|
||||
Err(err) => Err(WsServerError::Axum(err)),
|
||||
msg => Err(WsServerError::UnknownMessage(format!("{msg:?}"))),
|
||||
},
|
||||
None => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<(), WsServerError> {
|
||||
let reason = "Peer is closing".to_string();
|
||||
|
||||
self.inner
|
||||
.send(ws::Message::Close(Some(ws::CloseFrame {
|
||||
code: 1000, // code: Normal
|
||||
reason: reason.clone().into(),
|
||||
})))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum WsServerError {
|
||||
Io(std::io::Error),
|
||||
Axum(axum::Error),
|
||||
UnknownMessage(String),
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl CommunicationError for WsServerError {}
|
||||
|
||||
impl From<std::io::Error> for WsServerError {
|
||||
fn from(err: std::io::Error) -> Self {
|
||||
Self::Io(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<axum::Error> for WsServerError {
|
||||
fn from(err: axum::Error) -> Self {
|
||||
Self::Axum(err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user