feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,123 @@
pub use super::RemoteDevice;
use super::worker::{ClientRequest, ClientWorker};
use crate::shared::{ComputeTask, ConnectionId, SessionId, Task, TaskResponseContent};
use async_channel::{RecvError, SendError, Sender};
use burn_communication::ProtocolClient;
use burn_ir::TensorId;
use burn_std::id::StreamId;
use std::{
future::Future,
sync::{Arc, atomic::AtomicU64},
};
#[derive(Clone)]
pub struct RemoteClient {
pub(crate) device: RemoteDevice,
pub(crate) sender: Arc<RemoteSender>,
pub(crate) runtime: Arc<tokio::runtime::Runtime>,
}
impl RemoteClient {
pub fn init<C: ProtocolClient>(device: RemoteDevice) -> Self {
ClientWorker::<C>::start(device)
}
pub(crate) fn new(
device: RemoteDevice,
sender: Sender<ClientRequest>,
runtime: Arc<tokio::runtime::Runtime>,
session_id: SessionId,
) -> Self {
Self {
device,
runtime,
sender: Arc::new(RemoteSender {
sender,
position_counter: AtomicU64::new(0),
tensor_id_counter: AtomicU64::new(0),
session_id,
}),
}
}
}
pub(crate) struct RemoteSender {
sender: Sender<ClientRequest>,
position_counter: AtomicU64,
tensor_id_counter: AtomicU64,
session_id: SessionId,
}
#[allow(unused)]
#[derive(Debug)]
pub enum RemoteSendError {
SendError(SendError<ClientRequest>),
RecvError(RecvError),
}
impl RemoteSender {
/// Generate a new unique (for this [`RemoteSender`] [`TensorId`].
pub(crate) fn new_tensor_id(&self) -> TensorId {
TensorId::new(
self.tensor_id_counter
.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
)
}
/// Give the next operation sequence number.
fn next_position(&self) -> u64 {
self.position_counter
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}
pub(crate) fn send(&self, task: ComputeTask) {
self.sender
.send_blocking(ClientRequest::WithoutCallback(Task::Compute(
task,
ConnectionId::new(self.next_position(), StreamId::current()),
)))
.unwrap();
}
pub(crate) fn send_async(
&self,
task: ComputeTask,
) -> impl Future<Output = Result<TaskResponseContent, RemoteSendError>> + Send + use<> {
let stream_id = StreamId::current();
let position = self.next_position();
let sender = self.sender.clone();
async move {
let (tx, rx) = async_channel::bounded(1);
if let Err(e) = sender
.send(ClientRequest::WithSyncCallback(
Task::Compute(task, ConnectionId::new(position, stream_id)),
tx,
))
.await
{
return Err(RemoteSendError::SendError(e));
}
match rx.recv().await {
Ok(response) => Ok(response),
Err(e) => Err(RemoteSendError::RecvError(e)),
}
}
}
pub(crate) fn close(&mut self) {
let sender = self.sender.clone();
let close_task = ClientRequest::WithoutCallback(Task::Close(self.session_id));
sender.send_blocking(close_task).unwrap();
}
}
impl Drop for RemoteSender {
fn drop(&mut self) {
self.close();
}
}

View File

@@ -0,0 +1,82 @@
use std::marker::PhantomData;
use burn_backend::Shape;
use burn_communication::ProtocolClient;
use burn_ir::TensorIr;
use burn_router::{RouterTensor, RunnerChannel, get_client};
use super::{
RemoteClient,
runner::{RemoteBridge, RemoteDevice, RemoteTensorHandle},
};
/// A local channel with direct connection to the backend runner clients.
pub struct RemoteChannel<C: ProtocolClient> {
_p: PhantomData<C>,
}
impl<C: ProtocolClient> RunnerChannel for RemoteChannel<C> {
type Device = RemoteDevice;
type Bridge = RemoteBridge<C>;
type Client = RemoteClient;
type FloatElem = f32;
type IntElem = i32;
type BoolElem = u32;
fn name(device: &Self::Device) -> String {
format!("remote-{device:?}")
}
fn init_client(device: &Self::Device) -> Self::Client {
RemoteClient::init::<C>(device.clone())
}
fn get_tensor_handle(tensor: &TensorIr, client: &Self::Client) -> RemoteTensorHandle<C> {
RemoteTensorHandle {
client: client.clone(),
tensor: tensor.clone(),
_p: PhantomData,
}
}
fn register_tensor(
_client: &Self::Client,
_handle: RemoteTensorHandle<C>,
_shape: Shape,
_dtype: burn_backend::DType,
) -> RouterTensor<Self::Client> {
// This function is normally only used to move a tensor from a device to another.
//
// In other words, to change the client.
panic!("Can't register manually a tensor on a remote channel.");
}
fn change_client_backend(
tensor: RouterTensor<Self::Client>,
target_device: &Self::Device, // target device
) -> RouterTensor<Self::Client> {
// Get tensor handle from current client
let original_client = tensor.client.clone();
let desc = tensor.into_ir();
let handle = Self::get_tensor_handle(&desc, &original_client);
let handle = handle.change_backend(target_device);
let id = handle.tensor.id;
let target_client = get_client::<Self>(target_device);
let router_tensor: RouterTensor<RemoteClient> =
RouterTensor::new(id, handle.tensor.shape, handle.tensor.dtype, target_client);
router_tensor
}
}
impl<C: ProtocolClient> Clone for RemoteChannel<C> {
fn clone(&self) -> Self {
RemoteChannel { _p: PhantomData }
}
}

View File

@@ -0,0 +1,8 @@
mod base;
mod channel;
mod runner;
mod worker;
pub use base::*;
pub use channel::*;
pub use runner::RemoteDevice;

View File

@@ -0,0 +1,294 @@
use super::{RemoteChannel, RemoteClient};
use crate::shared::{ComputeTask, TaskResponseContent, TensorRemote};
use burn_backend::{DeviceId, DeviceOps, ExecutionError, Shape, TensorData};
use burn_communication::{Address, ProtocolClient, data_service::TensorTransferId};
use burn_ir::TensorIr;
use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
use burn_std::{backtrace::BackTrace, future::DynFut};
use std::sync::OnceLock;
use std::{collections::HashMap, marker::PhantomData, str::FromStr, sync::Mutex};
// TODO: we should work with the parsed structure of Address, not the string.
static ADDRESS_REGISTRY: OnceLock<Mutex<HashMap<String, u32>>> = OnceLock::new();
fn get_address_registry() -> &'static Mutex<HashMap<String, u32>> {
ADDRESS_REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
}
/// Map a string network address to a (local runtime) global unique u32.
///
/// Globally stable over the lifetime of the process, shared between threads,
/// If the address has never been seen, a new id will be created.
/// If the address has been seen, the previous id will be returned.
pub fn address_to_id<S: AsRef<str>>(address: S) -> u32 {
let registry = get_address_registry();
let mut registry = registry.lock().unwrap();
let next_id = registry.len() as u32;
*registry
.entry(address.as_ref().to_string())
.or_insert_with(|| next_id)
}
/// Look up an address by id.
///
/// Returns the same address given ids by [`address_to_id`].
pub fn id_to_address(id: u32) -> Option<String> {
let registry = get_address_registry();
let registry = registry.lock().unwrap();
for entry in registry.iter() {
if entry.1 == &id {
return Some(entry.0.clone());
}
}
None
}
// It is very important to block on any request made with the sender, since ordering is crucial
// when registering operation or creating tensors.
//
// The overhead is minimal, since we only wait for the task to be sent to the async
// channel, but not sent to the server and even less processed by the server.
impl RunnerClient for RemoteClient {
type Device = RemoteDevice;
fn register_op(&self, op: burn_ir::OperationIr) {
self.sender
.send(ComputeTask::RegisterOperation(Box::new(op)));
}
fn read_tensor_async(
&self,
tensor: burn_ir::TensorIr,
) -> DynFut<Result<TensorData, ExecutionError>> {
// Important for ordering to call the creation of the future sync.
let fut = self.sender.send_async(ComputeTask::ReadTensor(tensor));
Box::pin(async move {
match fut.await {
Ok(response) => match response {
TaskResponseContent::ReadTensor(res) => res,
_ => panic!("Invalid message type"),
},
Err(e) => Err(ExecutionError::Generic {
reason: format!("Failed to read tensor: {:?}", e),
backtrace: BackTrace::capture(),
}),
}
})
}
fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
let id = self.sender.new_tensor_id();
let shape = data.shape.clone();
let dtype = data.dtype;
self.sender.send(ComputeTask::RegisterTensor(id, data));
RouterTensor::new(id, Shape::from(shape), dtype, self.clone())
}
fn device(&self) -> Self::Device {
self.device.clone()
}
fn sync(&self) -> Result<(), ExecutionError> {
// Important for ordering to call the creation of the future sync.
let fut = self.sender.send_async(ComputeTask::SyncBackend);
match self.runtime.block_on(fut) {
Ok(response) => match response {
TaskResponseContent::SyncBackend(res) => res,
_ => panic!("Invalid message type"),
},
Err(e) => Err(ExecutionError::Generic {
reason: format!("Failed to sync: {:?}", e),
backtrace: BackTrace::capture(),
}),
}
}
fn seed(&self, seed: u64) {
self.sender.send(ComputeTask::Seed(seed));
}
fn create_empty_handle(&self) -> burn_ir::TensorId {
self.sender.new_tensor_id()
}
fn dtype_usage(&self, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {
let fut = self.sender.send_async(ComputeTask::SupportsDType(dtype));
match self.runtime.block_on(fut) {
Ok(_response) => panic!("Invalid message type"),
Err(e) => panic!("Failed to check dtype support: {:?}", e),
}
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
/// The device contains the connection information of the server.
pub struct RemoteDevice {
pub(crate) address: Address,
/// The id of the device in the local registry, see [`address_to_id`].
pub(crate) id: u32,
}
impl RemoteDevice {
/// Create a device from an url.
pub fn new(address: &str) -> Self {
let id = address_to_id(address);
Self {
address: Address::from_str(address).unwrap(),
id,
}
}
}
impl Default for RemoteDevice {
fn default() -> Self {
let address = match std::env::var("BURN_REMOTE_ADDRESS") {
Ok(address) => address,
Err(_) => String::from("ws://127.0.0.1:3000"),
};
Self::new(&address)
}
}
impl burn_std::device::Device for RemoteDevice {
fn from_id(device_id: DeviceId) -> Self {
if device_id.type_id != 0 {
panic!("Invalid device id: {device_id} (expected type 0)");
}
let address = id_to_address(device_id.index_id)
.unwrap_or_else(|| panic!("Invalid device id: {device_id}"));
Self::new(&address)
}
fn to_id(&self) -> DeviceId {
DeviceId {
type_id: 0,
index_id: self.id,
}
}
fn device_count(_type_id: u16) -> usize {
1
}
}
impl DeviceOps for RemoteDevice {}
pub struct RemoteBridge<C: ProtocolClient> {
_p: PhantomData<C>,
}
pub struct RemoteTensorHandle<C: ProtocolClient> {
pub(crate) client: RemoteClient,
pub(crate) tensor: TensorIr,
pub(crate) _p: PhantomData<C>,
}
static TRANSFER_COUNTER: Mutex<Option<TensorTransferId>> = Mutex::new(None);
fn get_next_transfer_id() -> TensorTransferId {
let mut transfer_counter = TRANSFER_COUNTER.lock().unwrap();
if transfer_counter.is_none() {
*transfer_counter = Some(0.into());
transfer_counter.unwrap()
} else {
let mut transfer_counter = transfer_counter.unwrap();
transfer_counter.next();
transfer_counter
}
}
impl<C: ProtocolClient> RemoteTensorHandle<C> {
/// Changes the backend of the tensor via a dWebSocket.
/// We ask the original server to expose the tensor, then ask the target server to fetch
/// the tensor. The target server will open a new network connection to the original server
/// to download the data.
/// This way the client never sees the tensor's data, and we avoid a bottleneck.
pub(crate) fn change_backend(mut self, target_device: &RemoteDevice) -> Self {
let transfer_id = get_next_transfer_id();
self.client.sender.send(ComputeTask::ExposeTensorRemote {
tensor: self.tensor.clone(),
count: 1,
transfer_id,
});
let target_client = get_client::<RemoteChannel<C>>(target_device);
let new_id = target_client.sender.new_tensor_id();
let remote_tensor = TensorRemote {
transfer_id,
address: self.client.device.address.clone(),
};
target_client
.sender
.send(ComputeTask::RegisterTensorRemote(remote_tensor, new_id));
self.tensor.id = new_id;
self.client = target_client;
self
}
}
impl<C: ProtocolClient> MultiBackendBridge for RemoteBridge<C> {
type TensorHandle = RemoteTensorHandle<C>;
type Device = RemoteDevice;
fn change_backend_float(
tensor: Self::TensorHandle,
_shape: burn_backend::Shape,
target_device: &Self::Device,
) -> Self::TensorHandle {
tensor.change_backend(target_device)
}
fn change_backend_int(
tensor: Self::TensorHandle,
_shape: burn_backend::Shape,
target_device: &Self::Device,
) -> Self::TensorHandle {
tensor.change_backend(target_device)
}
fn change_backend_bool(
tensor: Self::TensorHandle,
_shape: burn_backend::Shape,
target_device: &Self::Device,
) -> Self::TensorHandle {
tensor.change_backend(target_device)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_address_to_id() {
let address1 = "ws://127.0.0.1:3000";
let address2 = "ws://127.0.0.1:3001";
let id1 = address_to_id(address1);
let id2 = address_to_id(address2);
assert_ne!(id1, id2);
assert_eq!(address_to_id(address1), id1);
assert_eq!(id_to_address(id1), Some(address1.to_string()));
assert_eq!(address_to_id(address2), id2);
assert_eq!(id_to_address(id2), Some(address2.to_string()));
let unused_id = u32::MAX;
assert_eq!(id_to_address(unused_id), None);
}
}

View File

@@ -0,0 +1,129 @@
use super::{RemoteClient, runner::RemoteDevice};
use crate::shared::{ConnectionId, SessionId, Task, TaskResponse, TaskResponseContent};
use burn_communication::{CommunicationChannel, Message, ProtocolClient};
use std::{collections::HashMap, marker::PhantomData, sync::Arc};
pub type CallbackSender = async_channel::Sender<TaskResponseContent>;
#[derive(Debug)]
pub enum ClientRequest {
WithSyncCallback(Task, CallbackSender),
WithoutCallback(Task),
}
pub(crate) struct ClientWorker<C: ProtocolClient> {
requests: HashMap<ConnectionId, CallbackSender>,
_p: PhantomData<C>,
}
impl<C: ProtocolClient> ClientWorker<C> {
async fn on_response(&mut self, response: TaskResponse) {
match self.requests.remove(&response.id) {
Some(request) => {
request.send(response.content).await.unwrap();
}
None => {
panic!("Can't ignore message from the server.");
}
}
}
fn register_callback(&mut self, id: ConnectionId, callback: CallbackSender) {
self.requests.insert(id, callback);
}
}
impl<C: ProtocolClient> ClientWorker<C> {
pub fn start(device: RemoteDevice) -> RemoteClient {
let runtime = Arc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_io()
.build()
.unwrap(),
);
let (sender, rec) = async_channel::bounded(10);
let session_id = SessionId::new();
let address = device.address.clone();
#[allow(deprecated)]
runtime.spawn(async move {
log::info!("Connecting to {} ...", address.clone());
let mut stream_request = C::connect(address.clone(), "request")
.await
.expect("Server to be accessible");
let mut stream_response = C::connect(address, "response")
.await
.expect("Server to be accessible");
let state = Arc::new(tokio::sync::Mutex::new(ClientWorker::<C>::default()));
// Init the connection.
let bytes: bytes::Bytes = rmp_serde::to_vec(&Task::Init(session_id))
.expect("Can serialize tasks to bytes.")
.into();
stream_request
.send(Message::new(bytes.clone()))
.await
.expect("Can send the message over the comms channel.");
stream_response
.send(Message::new(bytes))
.await
.expect("Can send the message on the websocket.");
// Async worker loading callbacks from the server.
let state_ws = state.clone();
tokio::spawn(async move {
while let Ok(msg) = stream_response.recv().await {
let msg = match msg {
Some(msg) => msg,
None => {
log::warn!("Closed connection");
return;
}
};
let response: TaskResponse = rmp_serde::from_slice(&msg.data)
.expect("Can deserialize messages from the websocket.");
let mut state = state_ws.lock().await;
state.on_response(response).await;
}
});
// Channel async worker sending operations to the server.
tokio::spawn(async move {
while let Ok(req) = rec.recv().await {
let task = match req {
ClientRequest::WithSyncCallback(task, callback) => {
if let Task::Compute(_content, id) = &task {
let mut state = state.lock().await;
state.register_callback(*id, callback);
}
task
}
ClientRequest::WithoutCallback(task) => task,
};
let bytes = rmp_serde::to_vec(&task)
.expect("Can serialize tasks to bytes.")
.into();
stream_request
.send(Message::new(bytes))
.await
.expect("Can send the message on the websocket.");
}
});
});
RemoteClient::new(device, sender, runtime, session_id)
}
}
impl<C: ProtocolClient> Default for ClientWorker<C> {
fn default() -> Self {
Self {
requests: Default::default(),
_p: PhantomData,
}
}
}