feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
@@ -0,0 +1,123 @@
|
||||
pub use super::RemoteDevice;
|
||||
use super::worker::{ClientRequest, ClientWorker};
|
||||
use crate::shared::{ComputeTask, ConnectionId, SessionId, Task, TaskResponseContent};
|
||||
use async_channel::{RecvError, SendError, Sender};
|
||||
use burn_communication::ProtocolClient;
|
||||
use burn_ir::TensorId;
|
||||
use burn_std::id::StreamId;
|
||||
use std::{
|
||||
future::Future,
|
||||
sync::{Arc, atomic::AtomicU64},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RemoteClient {
|
||||
pub(crate) device: RemoteDevice,
|
||||
pub(crate) sender: Arc<RemoteSender>,
|
||||
pub(crate) runtime: Arc<tokio::runtime::Runtime>,
|
||||
}
|
||||
|
||||
impl RemoteClient {
|
||||
pub fn init<C: ProtocolClient>(device: RemoteDevice) -> Self {
|
||||
ClientWorker::<C>::start(device)
|
||||
}
|
||||
|
||||
pub(crate) fn new(
|
||||
device: RemoteDevice,
|
||||
sender: Sender<ClientRequest>,
|
||||
runtime: Arc<tokio::runtime::Runtime>,
|
||||
session_id: SessionId,
|
||||
) -> Self {
|
||||
Self {
|
||||
device,
|
||||
runtime,
|
||||
sender: Arc::new(RemoteSender {
|
||||
sender,
|
||||
position_counter: AtomicU64::new(0),
|
||||
tensor_id_counter: AtomicU64::new(0),
|
||||
session_id,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct RemoteSender {
|
||||
sender: Sender<ClientRequest>,
|
||||
position_counter: AtomicU64,
|
||||
tensor_id_counter: AtomicU64,
|
||||
session_id: SessionId,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug)]
|
||||
pub enum RemoteSendError {
|
||||
SendError(SendError<ClientRequest>),
|
||||
RecvError(RecvError),
|
||||
}
|
||||
|
||||
impl RemoteSender {
|
||||
/// Generate a new unique (for this [`RemoteSender`] [`TensorId`].
|
||||
pub(crate) fn new_tensor_id(&self) -> TensorId {
|
||||
TensorId::new(
|
||||
self.tensor_id_counter
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
|
||||
)
|
||||
}
|
||||
|
||||
/// Give the next operation sequence number.
|
||||
fn next_position(&self) -> u64 {
|
||||
self.position_counter
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub(crate) fn send(&self, task: ComputeTask) {
|
||||
self.sender
|
||||
.send_blocking(ClientRequest::WithoutCallback(Task::Compute(
|
||||
task,
|
||||
ConnectionId::new(self.next_position(), StreamId::current()),
|
||||
)))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub(crate) fn send_async(
|
||||
&self,
|
||||
task: ComputeTask,
|
||||
) -> impl Future<Output = Result<TaskResponseContent, RemoteSendError>> + Send + use<> {
|
||||
let stream_id = StreamId::current();
|
||||
let position = self.next_position();
|
||||
let sender = self.sender.clone();
|
||||
|
||||
async move {
|
||||
let (tx, rx) = async_channel::bounded(1);
|
||||
|
||||
if let Err(e) = sender
|
||||
.send(ClientRequest::WithSyncCallback(
|
||||
Task::Compute(task, ConnectionId::new(position, stream_id)),
|
||||
tx,
|
||||
))
|
||||
.await
|
||||
{
|
||||
return Err(RemoteSendError::SendError(e));
|
||||
}
|
||||
|
||||
match rx.recv().await {
|
||||
Ok(response) => Ok(response),
|
||||
Err(e) => Err(RemoteSendError::RecvError(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn close(&mut self) {
|
||||
let sender = self.sender.clone();
|
||||
|
||||
let close_task = ClientRequest::WithoutCallback(Task::Close(self.session_id));
|
||||
|
||||
sender.send_blocking(close_task).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for RemoteSender {
|
||||
fn drop(&mut self) {
|
||||
self.close();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use burn_backend::Shape;
|
||||
use burn_communication::ProtocolClient;
|
||||
use burn_ir::TensorIr;
|
||||
use burn_router::{RouterTensor, RunnerChannel, get_client};
|
||||
|
||||
use super::{
|
||||
RemoteClient,
|
||||
runner::{RemoteBridge, RemoteDevice, RemoteTensorHandle},
|
||||
};
|
||||
|
||||
/// A local channel with direct connection to the backend runner clients.
|
||||
pub struct RemoteChannel<C: ProtocolClient> {
|
||||
_p: PhantomData<C>,
|
||||
}
|
||||
|
||||
impl<C: ProtocolClient> RunnerChannel for RemoteChannel<C> {
|
||||
type Device = RemoteDevice;
|
||||
type Bridge = RemoteBridge<C>;
|
||||
type Client = RemoteClient;
|
||||
|
||||
type FloatElem = f32;
|
||||
|
||||
type IntElem = i32;
|
||||
|
||||
type BoolElem = u32;
|
||||
|
||||
fn name(device: &Self::Device) -> String {
|
||||
format!("remote-{device:?}")
|
||||
}
|
||||
|
||||
fn init_client(device: &Self::Device) -> Self::Client {
|
||||
RemoteClient::init::<C>(device.clone())
|
||||
}
|
||||
|
||||
fn get_tensor_handle(tensor: &TensorIr, client: &Self::Client) -> RemoteTensorHandle<C> {
|
||||
RemoteTensorHandle {
|
||||
client: client.clone(),
|
||||
tensor: tensor.clone(),
|
||||
_p: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn register_tensor(
|
||||
_client: &Self::Client,
|
||||
_handle: RemoteTensorHandle<C>,
|
||||
_shape: Shape,
|
||||
_dtype: burn_backend::DType,
|
||||
) -> RouterTensor<Self::Client> {
|
||||
// This function is normally only used to move a tensor from a device to another.
|
||||
//
|
||||
// In other words, to change the client.
|
||||
panic!("Can't register manually a tensor on a remote channel.");
|
||||
}
|
||||
|
||||
fn change_client_backend(
|
||||
tensor: RouterTensor<Self::Client>,
|
||||
target_device: &Self::Device, // target device
|
||||
) -> RouterTensor<Self::Client> {
|
||||
// Get tensor handle from current client
|
||||
let original_client = tensor.client.clone();
|
||||
let desc = tensor.into_ir();
|
||||
let handle = Self::get_tensor_handle(&desc, &original_client);
|
||||
|
||||
let handle = handle.change_backend(target_device);
|
||||
|
||||
let id = handle.tensor.id;
|
||||
|
||||
let target_client = get_client::<Self>(target_device);
|
||||
let router_tensor: RouterTensor<RemoteClient> =
|
||||
RouterTensor::new(id, handle.tensor.shape, handle.tensor.dtype, target_client);
|
||||
|
||||
router_tensor
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ProtocolClient> Clone for RemoteChannel<C> {
|
||||
fn clone(&self) -> Self {
|
||||
RemoteChannel { _p: PhantomData }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
mod base;
|
||||
mod channel;
|
||||
mod runner;
|
||||
mod worker;
|
||||
|
||||
pub use base::*;
|
||||
pub use channel::*;
|
||||
pub use runner::RemoteDevice;
|
||||
@@ -0,0 +1,294 @@
|
||||
use super::{RemoteChannel, RemoteClient};
|
||||
use crate::shared::{ComputeTask, TaskResponseContent, TensorRemote};
|
||||
use burn_backend::{DeviceId, DeviceOps, ExecutionError, Shape, TensorData};
|
||||
use burn_communication::{Address, ProtocolClient, data_service::TensorTransferId};
|
||||
use burn_ir::TensorIr;
|
||||
use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
|
||||
use burn_std::{backtrace::BackTrace, future::DynFut};
|
||||
use std::sync::OnceLock;
|
||||
use std::{collections::HashMap, marker::PhantomData, str::FromStr, sync::Mutex};
|
||||
|
||||
// TODO: we should work with the parsed structure of Address, not the string.
|
||||
static ADDRESS_REGISTRY: OnceLock<Mutex<HashMap<String, u32>>> = OnceLock::new();
|
||||
|
||||
fn get_address_registry() -> &'static Mutex<HashMap<String, u32>> {
|
||||
ADDRESS_REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
|
||||
}
|
||||
|
||||
/// Map a string network address to a (local runtime) global unique u32.
|
||||
///
|
||||
/// Globally stable over the lifetime of the process, shared between threads,
|
||||
/// If the address has never been seen, a new id will be created.
|
||||
/// If the address has been seen, the previous id will be returned.
|
||||
pub fn address_to_id<S: AsRef<str>>(address: S) -> u32 {
|
||||
let registry = get_address_registry();
|
||||
let mut registry = registry.lock().unwrap();
|
||||
let next_id = registry.len() as u32;
|
||||
*registry
|
||||
.entry(address.as_ref().to_string())
|
||||
.or_insert_with(|| next_id)
|
||||
}
|
||||
|
||||
/// Look up an address by id.
|
||||
///
|
||||
/// Returns the same address given ids by [`address_to_id`].
|
||||
pub fn id_to_address(id: u32) -> Option<String> {
|
||||
let registry = get_address_registry();
|
||||
let registry = registry.lock().unwrap();
|
||||
for entry in registry.iter() {
|
||||
if entry.1 == &id {
|
||||
return Some(entry.0.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// It is very important to block on any request made with the sender, since ordering is crucial
|
||||
// when registering operation or creating tensors.
|
||||
//
|
||||
// The overhead is minimal, since we only wait for the task to be sent to the async
|
||||
// channel, but not sent to the server and even less processed by the server.
|
||||
impl RunnerClient for RemoteClient {
|
||||
type Device = RemoteDevice;
|
||||
|
||||
fn register_op(&self, op: burn_ir::OperationIr) {
|
||||
self.sender
|
||||
.send(ComputeTask::RegisterOperation(Box::new(op)));
|
||||
}
|
||||
|
||||
fn read_tensor_async(
|
||||
&self,
|
||||
tensor: burn_ir::TensorIr,
|
||||
) -> DynFut<Result<TensorData, ExecutionError>> {
|
||||
// Important for ordering to call the creation of the future sync.
|
||||
let fut = self.sender.send_async(ComputeTask::ReadTensor(tensor));
|
||||
|
||||
Box::pin(async move {
|
||||
match fut.await {
|
||||
Ok(response) => match response {
|
||||
TaskResponseContent::ReadTensor(res) => res,
|
||||
_ => panic!("Invalid message type"),
|
||||
},
|
||||
Err(e) => Err(ExecutionError::Generic {
|
||||
reason: format!("Failed to read tensor: {:?}", e),
|
||||
backtrace: BackTrace::capture(),
|
||||
}),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
|
||||
let id = self.sender.new_tensor_id();
|
||||
let shape = data.shape.clone();
|
||||
let dtype = data.dtype;
|
||||
|
||||
self.sender.send(ComputeTask::RegisterTensor(id, data));
|
||||
|
||||
RouterTensor::new(id, Shape::from(shape), dtype, self.clone())
|
||||
}
|
||||
|
||||
fn device(&self) -> Self::Device {
|
||||
self.device.clone()
|
||||
}
|
||||
|
||||
fn sync(&self) -> Result<(), ExecutionError> {
|
||||
// Important for ordering to call the creation of the future sync.
|
||||
let fut = self.sender.send_async(ComputeTask::SyncBackend);
|
||||
|
||||
match self.runtime.block_on(fut) {
|
||||
Ok(response) => match response {
|
||||
TaskResponseContent::SyncBackend(res) => res,
|
||||
_ => panic!("Invalid message type"),
|
||||
},
|
||||
Err(e) => Err(ExecutionError::Generic {
|
||||
reason: format!("Failed to sync: {:?}", e),
|
||||
backtrace: BackTrace::capture(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn seed(&self, seed: u64) {
|
||||
self.sender.send(ComputeTask::Seed(seed));
|
||||
}
|
||||
|
||||
fn create_empty_handle(&self) -> burn_ir::TensorId {
|
||||
self.sender.new_tensor_id()
|
||||
}
|
||||
|
||||
fn dtype_usage(&self, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {
|
||||
let fut = self.sender.send_async(ComputeTask::SupportsDType(dtype));
|
||||
|
||||
match self.runtime.block_on(fut) {
|
||||
Ok(_response) => panic!("Invalid message type"),
|
||||
Err(e) => panic!("Failed to check dtype support: {:?}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
/// The device contains the connection information of the server.
|
||||
pub struct RemoteDevice {
|
||||
pub(crate) address: Address,
|
||||
/// The id of the device in the local registry, see [`address_to_id`].
|
||||
pub(crate) id: u32,
|
||||
}
|
||||
|
||||
impl RemoteDevice {
|
||||
/// Create a device from an url.
|
||||
pub fn new(address: &str) -> Self {
|
||||
let id = address_to_id(address);
|
||||
Self {
|
||||
address: Address::from_str(address).unwrap(),
|
||||
id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RemoteDevice {
|
||||
fn default() -> Self {
|
||||
let address = match std::env::var("BURN_REMOTE_ADDRESS") {
|
||||
Ok(address) => address,
|
||||
Err(_) => String::from("ws://127.0.0.1:3000"),
|
||||
};
|
||||
|
||||
Self::new(&address)
|
||||
}
|
||||
}
|
||||
|
||||
impl burn_std::device::Device for RemoteDevice {
|
||||
fn from_id(device_id: DeviceId) -> Self {
|
||||
if device_id.type_id != 0 {
|
||||
panic!("Invalid device id: {device_id} (expected type 0)");
|
||||
}
|
||||
let address = id_to_address(device_id.index_id)
|
||||
.unwrap_or_else(|| panic!("Invalid device id: {device_id}"));
|
||||
Self::new(&address)
|
||||
}
|
||||
|
||||
fn to_id(&self) -> DeviceId {
|
||||
DeviceId {
|
||||
type_id: 0,
|
||||
index_id: self.id,
|
||||
}
|
||||
}
|
||||
|
||||
fn device_count(_type_id: u16) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
impl DeviceOps for RemoteDevice {}
|
||||
|
||||
pub struct RemoteBridge<C: ProtocolClient> {
|
||||
_p: PhantomData<C>,
|
||||
}
|
||||
|
||||
pub struct RemoteTensorHandle<C: ProtocolClient> {
|
||||
pub(crate) client: RemoteClient,
|
||||
pub(crate) tensor: TensorIr,
|
||||
pub(crate) _p: PhantomData<C>,
|
||||
}
|
||||
|
||||
static TRANSFER_COUNTER: Mutex<Option<TensorTransferId>> = Mutex::new(None);
|
||||
|
||||
fn get_next_transfer_id() -> TensorTransferId {
|
||||
let mut transfer_counter = TRANSFER_COUNTER.lock().unwrap();
|
||||
if transfer_counter.is_none() {
|
||||
*transfer_counter = Some(0.into());
|
||||
|
||||
transfer_counter.unwrap()
|
||||
} else {
|
||||
let mut transfer_counter = transfer_counter.unwrap();
|
||||
transfer_counter.next();
|
||||
|
||||
transfer_counter
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ProtocolClient> RemoteTensorHandle<C> {
|
||||
/// Changes the backend of the tensor via a dWebSocket.
|
||||
/// We ask the original server to expose the tensor, then ask the target server to fetch
|
||||
/// the tensor. The target server will open a new network connection to the original server
|
||||
/// to download the data.
|
||||
/// This way the client never sees the tensor's data, and we avoid a bottleneck.
|
||||
pub(crate) fn change_backend(mut self, target_device: &RemoteDevice) -> Self {
|
||||
let transfer_id = get_next_transfer_id();
|
||||
self.client.sender.send(ComputeTask::ExposeTensorRemote {
|
||||
tensor: self.tensor.clone(),
|
||||
count: 1,
|
||||
transfer_id,
|
||||
});
|
||||
|
||||
let target_client = get_client::<RemoteChannel<C>>(target_device);
|
||||
|
||||
let new_id = target_client.sender.new_tensor_id();
|
||||
|
||||
let remote_tensor = TensorRemote {
|
||||
transfer_id,
|
||||
address: self.client.device.address.clone(),
|
||||
};
|
||||
target_client
|
||||
.sender
|
||||
.send(ComputeTask::RegisterTensorRemote(remote_tensor, new_id));
|
||||
|
||||
self.tensor.id = new_id;
|
||||
self.client = target_client;
|
||||
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ProtocolClient> MultiBackendBridge for RemoteBridge<C> {
|
||||
type TensorHandle = RemoteTensorHandle<C>;
|
||||
type Device = RemoteDevice;
|
||||
|
||||
fn change_backend_float(
|
||||
tensor: Self::TensorHandle,
|
||||
_shape: burn_backend::Shape,
|
||||
target_device: &Self::Device,
|
||||
) -> Self::TensorHandle {
|
||||
tensor.change_backend(target_device)
|
||||
}
|
||||
|
||||
fn change_backend_int(
|
||||
tensor: Self::TensorHandle,
|
||||
_shape: burn_backend::Shape,
|
||||
target_device: &Self::Device,
|
||||
) -> Self::TensorHandle {
|
||||
tensor.change_backend(target_device)
|
||||
}
|
||||
|
||||
fn change_backend_bool(
|
||||
tensor: Self::TensorHandle,
|
||||
_shape: burn_backend::Shape,
|
||||
target_device: &Self::Device,
|
||||
) -> Self::TensorHandle {
|
||||
tensor.change_backend(target_device)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_address_to_id() {
|
||||
let address1 = "ws://127.0.0.1:3000";
|
||||
let address2 = "ws://127.0.0.1:3001";
|
||||
|
||||
let id1 = address_to_id(address1);
|
||||
let id2 = address_to_id(address2);
|
||||
|
||||
assert_ne!(id1, id2);
|
||||
|
||||
assert_eq!(address_to_id(address1), id1);
|
||||
assert_eq!(id_to_address(id1), Some(address1.to_string()));
|
||||
|
||||
assert_eq!(address_to_id(address2), id2);
|
||||
assert_eq!(id_to_address(id2), Some(address2.to_string()));
|
||||
|
||||
let unused_id = u32::MAX;
|
||||
|
||||
assert_eq!(id_to_address(unused_id), None);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
use super::{RemoteClient, runner::RemoteDevice};
|
||||
use crate::shared::{ConnectionId, SessionId, Task, TaskResponse, TaskResponseContent};
|
||||
use burn_communication::{CommunicationChannel, Message, ProtocolClient};
|
||||
use std::{collections::HashMap, marker::PhantomData, sync::Arc};
|
||||
|
||||
pub type CallbackSender = async_channel::Sender<TaskResponseContent>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ClientRequest {
|
||||
WithSyncCallback(Task, CallbackSender),
|
||||
WithoutCallback(Task),
|
||||
}
|
||||
|
||||
pub(crate) struct ClientWorker<C: ProtocolClient> {
|
||||
requests: HashMap<ConnectionId, CallbackSender>,
|
||||
_p: PhantomData<C>,
|
||||
}
|
||||
|
||||
impl<C: ProtocolClient> ClientWorker<C> {
|
||||
async fn on_response(&mut self, response: TaskResponse) {
|
||||
match self.requests.remove(&response.id) {
|
||||
Some(request) => {
|
||||
request.send(response.content).await.unwrap();
|
||||
}
|
||||
None => {
|
||||
panic!("Can't ignore message from the server.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_callback(&mut self, id: ConnectionId, callback: CallbackSender) {
|
||||
self.requests.insert(id, callback);
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ProtocolClient> ClientWorker<C> {
|
||||
pub fn start(device: RemoteDevice) -> RemoteClient {
|
||||
let runtime = Arc::new(
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_io()
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let (sender, rec) = async_channel::bounded(10);
|
||||
|
||||
let session_id = SessionId::new();
|
||||
let address = device.address.clone();
|
||||
|
||||
#[allow(deprecated)]
|
||||
runtime.spawn(async move {
|
||||
log::info!("Connecting to {} ...", address.clone());
|
||||
let mut stream_request = C::connect(address.clone(), "request")
|
||||
.await
|
||||
.expect("Server to be accessible");
|
||||
let mut stream_response = C::connect(address, "response")
|
||||
.await
|
||||
.expect("Server to be accessible");
|
||||
|
||||
let state = Arc::new(tokio::sync::Mutex::new(ClientWorker::<C>::default()));
|
||||
|
||||
// Init the connection.
|
||||
let bytes: bytes::Bytes = rmp_serde::to_vec(&Task::Init(session_id))
|
||||
.expect("Can serialize tasks to bytes.")
|
||||
.into();
|
||||
stream_request
|
||||
.send(Message::new(bytes.clone()))
|
||||
.await
|
||||
.expect("Can send the message over the comms channel.");
|
||||
stream_response
|
||||
.send(Message::new(bytes))
|
||||
.await
|
||||
.expect("Can send the message on the websocket.");
|
||||
|
||||
// Async worker loading callbacks from the server.
|
||||
let state_ws = state.clone();
|
||||
tokio::spawn(async move {
|
||||
while let Ok(msg) = stream_response.recv().await {
|
||||
let msg = match msg {
|
||||
Some(msg) => msg,
|
||||
None => {
|
||||
log::warn!("Closed connection");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let response: TaskResponse = rmp_serde::from_slice(&msg.data)
|
||||
.expect("Can deserialize messages from the websocket.");
|
||||
let mut state = state_ws.lock().await;
|
||||
state.on_response(response).await;
|
||||
}
|
||||
});
|
||||
|
||||
// Channel async worker sending operations to the server.
|
||||
tokio::spawn(async move {
|
||||
while let Ok(req) = rec.recv().await {
|
||||
let task = match req {
|
||||
ClientRequest::WithSyncCallback(task, callback) => {
|
||||
if let Task::Compute(_content, id) = &task {
|
||||
let mut state = state.lock().await;
|
||||
state.register_callback(*id, callback);
|
||||
}
|
||||
task
|
||||
}
|
||||
ClientRequest::WithoutCallback(task) => task,
|
||||
};
|
||||
let bytes = rmp_serde::to_vec(&task)
|
||||
.expect("Can serialize tasks to bytes.")
|
||||
.into();
|
||||
stream_request
|
||||
.send(Message::new(bytes))
|
||||
.await
|
||||
.expect("Can send the message on the websocket.");
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
RemoteClient::new(device, sender, runtime, session_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ProtocolClient> Default for ClientWorker<C> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
requests: Default::default(),
|
||||
_p: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
#[cfg(feature = "client")]
|
||||
pub(crate) mod client;
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
pub mod server;
|
||||
|
||||
pub(crate) mod shared;
|
||||
|
||||
#[cfg(feature = "client")]
|
||||
mod __client {
|
||||
use super::*;
|
||||
|
||||
use crate::{client::RemoteChannel, shared::RemoteProtocol};
|
||||
use burn_communication::Protocol;
|
||||
use burn_router::BackendRouter;
|
||||
|
||||
/// The remote backend allows you to run computation on a remote device.
|
||||
///
|
||||
/// Make sure there is a running server before trying to connect to it.
|
||||
///
|
||||
/// ```rust, ignore
|
||||
/// fn main() {
|
||||
/// let device = Default::default();
|
||||
/// let port = 3000;
|
||||
///
|
||||
/// // You need to activate the `server` feature flag to have access to this function.
|
||||
/// burn::server::start::<burn::backend::Wgpu>(device, port);
|
||||
/// }
|
||||
///```
|
||||
pub type RemoteBackend = BackendRouter<RemoteChannel<<RemoteProtocol as Protocol>::Client>>;
|
||||
|
||||
pub use client::RemoteDevice;
|
||||
}
|
||||
#[cfg(feature = "client")]
|
||||
pub use __client::*;
|
||||
|
||||
#[cfg(all(test, feature = "client", feature = "server"))]
|
||||
mod tests {
|
||||
use crate::RemoteBackend;
|
||||
use burn_ndarray::NdArray;
|
||||
use burn_tensor::{Distribution, Tensor};
|
||||
|
||||
#[test]
|
||||
pub fn test_to_device_over_websocket() {
|
||||
let rt = tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_io()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
rt.spawn(crate::server::start_websocket_async::<NdArray>(
|
||||
Default::default(),
|
||||
3000,
|
||||
));
|
||||
rt.spawn(crate::server::start_websocket_async::<NdArray>(
|
||||
Default::default(),
|
||||
3010,
|
||||
));
|
||||
|
||||
let remote_device_1 = super::RemoteDevice::new("ws://localhost:3000");
|
||||
let remote_device_2 = super::RemoteDevice::new("ws://localhost:3010");
|
||||
|
||||
// Some random input
|
||||
let input_shape = [1, 28, 28];
|
||||
let input = Tensor::<RemoteBackend, 3>::random(
|
||||
input_shape,
|
||||
Distribution::Default,
|
||||
&remote_device_1,
|
||||
);
|
||||
let numbers_expected: Vec<f32> = input.to_data().to_vec().unwrap();
|
||||
|
||||
// Move tensor to device 2
|
||||
let input = input.to_device(&remote_device_2);
|
||||
let numbers: Vec<f32> = input.to_data().to_vec().unwrap();
|
||||
assert_eq!(numbers, numbers_expected);
|
||||
|
||||
// Move tensor back to device 1
|
||||
let input = input.to_device(&remote_device_1);
|
||||
let numbers: Vec<f32> = input.to_data().to_vec().unwrap();
|
||||
assert_eq!(numbers, numbers_expected);
|
||||
|
||||
rt.shutdown_background();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
use burn_communication::{
|
||||
CommunicationChannel, Message, Protocol, ProtocolServer,
|
||||
data_service::{TensorDataServer, TensorDataService},
|
||||
util::os_shutdown_signal,
|
||||
websocket::{WebSocket, WsServer},
|
||||
};
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use burn_backend::tensor::Device;
|
||||
use burn_ir::BackendIr;
|
||||
|
||||
use crate::shared::{ComputeTask, Task};
|
||||
|
||||
use super::session::SessionManager;
|
||||
|
||||
pub struct RemoteServer<B, P>
|
||||
where
|
||||
B: BackendIr,
|
||||
P: Protocol,
|
||||
{
|
||||
_b: PhantomData<B>,
|
||||
_n: PhantomData<P>,
|
||||
}
|
||||
|
||||
impl<B, P> RemoteServer<B, P>
|
||||
where
|
||||
B: BackendIr,
|
||||
P: Protocol,
|
||||
{
|
||||
/// Start the server on the given address.
|
||||
pub async fn start(device: Device<B>, server: P::Server) {
|
||||
let cancel_token = CancellationToken::new();
|
||||
let data_service = Arc::new(TensorDataService::<B, P>::new(cancel_token));
|
||||
let session_manager = Arc::new(SessionManager::<B, P>::new(device, data_service.clone()));
|
||||
|
||||
let _server = server
|
||||
.route("/response", {
|
||||
let session_manager = session_manager.clone();
|
||||
move |stream| Self::handle_socket_response(session_manager, stream)
|
||||
})
|
||||
.route("/request", {
|
||||
let session_manager = session_manager.clone();
|
||||
move |stream| Self::handle_socket_request(session_manager, stream)
|
||||
})
|
||||
.route_tensor_data_service(data_service)
|
||||
.serve(os_shutdown_signal())
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn handle_socket_response(
|
||||
session_manager: Arc<SessionManager<B, P>>,
|
||||
mut socket: <P::Server as ProtocolServer>::Channel,
|
||||
) {
|
||||
log::info!("[Response Handler] On new connection.");
|
||||
|
||||
let packet = socket.recv().await;
|
||||
let msg = match packet {
|
||||
Ok(Some(msg)) => msg,
|
||||
Ok(None) => {
|
||||
log::info!("Response stream closed");
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
log::info!("Response stream error on init: {e:?}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let id = match rmp_serde::from_slice::<Task>(&msg.data) {
|
||||
Ok(Task::Init(session_id)) => session_id,
|
||||
msg => {
|
||||
log::error!("Message is not a valid initialization task {msg:?}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut receiver = session_manager.register_responder(id).await;
|
||||
|
||||
log::info!("Response handler connection active");
|
||||
|
||||
while let Some(mut callback) = receiver.recv().await {
|
||||
let response = callback.recv().await.unwrap();
|
||||
let bytes = rmp_serde::to_vec(&response).unwrap();
|
||||
|
||||
socket.send(Message::new(bytes.into())).await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_socket_request(
|
||||
session_manager: Arc<SessionManager<B, P>>,
|
||||
mut socket: <P::Server as ProtocolServer>::Channel,
|
||||
) {
|
||||
log::info!("[Request Handler] On new connection.");
|
||||
let mut session_id = None;
|
||||
|
||||
loop {
|
||||
let packet = socket.recv().await;
|
||||
let msg = match packet {
|
||||
Ok(Some(msg)) => msg,
|
||||
Ok(None) => {
|
||||
log::info!("Request stream closed");
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
log::info!("Request stream error: {e:?}, Closing.");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let task = match rmp_serde::from_slice::<Task>(&msg.data) {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
log::info!("Only bytes message in the json format are supported {err:?}");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if let Task::Close(id) = task {
|
||||
session_id = Some(id);
|
||||
break;
|
||||
}
|
||||
|
||||
let (stream, connection_id, task) =
|
||||
match session_manager.stream(&mut session_id, task).await {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
log::info!("Ops session activated {session_id:?}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match task {
|
||||
ComputeTask::RegisterOperation(op) => {
|
||||
stream.register_operation(op).await;
|
||||
}
|
||||
ComputeTask::RegisterTensor(id, data) => {
|
||||
stream.register_tensor(id, data).await;
|
||||
}
|
||||
ComputeTask::ReadTensor(tensor) => {
|
||||
stream.read_tensor(connection_id, tensor).await;
|
||||
}
|
||||
ComputeTask::SyncBackend => {
|
||||
stream.sync(connection_id).await;
|
||||
}
|
||||
ComputeTask::RegisterTensorRemote(tensor, new_id) => {
|
||||
stream.register_tensor_remote(tensor, new_id).await;
|
||||
}
|
||||
ComputeTask::ExposeTensorRemote {
|
||||
tensor,
|
||||
count,
|
||||
transfer_id,
|
||||
} => {
|
||||
stream
|
||||
.expose_tensor_remote(tensor, count, transfer_id)
|
||||
.await;
|
||||
}
|
||||
ComputeTask::Seed(seed) => {
|
||||
stream.seed(seed).await;
|
||||
}
|
||||
ComputeTask::SupportsDType(dtype) => {
|
||||
stream.supports_dtype(connection_id, dtype).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("Closing session {session_id:?}");
|
||||
session_manager.close(session_id).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the server on the given port and [device](Device).
|
||||
pub async fn start_websocket_async<B: BackendIr>(device: Device<B>, port: u16) {
|
||||
let server = WsServer::new(port);
|
||||
RemoteServer::<B, WebSocket>::start(device, server).await;
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
/// Start the server on the given port and [device](Device).
|
||||
pub async fn start_websocket<B: BackendIr>(device: Device<B>, port: u16) {
|
||||
start_websocket_async::<B>(device, port).await;
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
pub(crate) mod processor;
|
||||
pub(crate) mod session;
|
||||
pub(crate) mod stream;
|
||||
|
||||
mod base;
|
||||
|
||||
pub use base::{start_websocket, start_websocket_async};
|
||||
@@ -0,0 +1,132 @@
|
||||
use burn_backend::TensorData;
|
||||
use burn_communication::{
|
||||
Protocol,
|
||||
data_service::{TensorDataService, TensorTransferId},
|
||||
};
|
||||
use burn_ir::{BackendIr, OperationIr, TensorId, TensorIr};
|
||||
use burn_router::{Runner, RunnerClient};
|
||||
use burn_std::DType;
|
||||
use core::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
|
||||
use crate::shared::{ConnectionId, TaskResponse, TaskResponseContent, TensorRemote};
|
||||
|
||||
/// The goal of the processor is to asynchronously process compute tasks on it own thread.
|
||||
pub struct Processor<B, P>
|
||||
where
|
||||
B: BackendIr,
|
||||
P: Protocol,
|
||||
{
|
||||
p: PhantomData<B>,
|
||||
n: PhantomData<P>,
|
||||
}
|
||||
|
||||
pub type Callback<M> = Sender<M>;
|
||||
|
||||
pub enum ProcessorTask {
|
||||
RegisterOperation(Box<OperationIr>),
|
||||
RegisterTensor(TensorId, TensorData),
|
||||
RegisterTensorRemote(TensorRemote, TensorId),
|
||||
ExposeTensorRemote {
|
||||
tensor: TensorIr,
|
||||
transfer_id: TensorTransferId,
|
||||
count: u32,
|
||||
},
|
||||
ReadTensor(ConnectionId, TensorIr, Callback<TaskResponse>),
|
||||
Sync(ConnectionId, Callback<TaskResponse>),
|
||||
Seed(u64),
|
||||
SupportsDType(ConnectionId, DType, Callback<TaskResponse>),
|
||||
Close,
|
||||
}
|
||||
|
||||
impl<B: BackendIr, P> Processor<B, P>
|
||||
where
|
||||
B: BackendIr,
|
||||
P: Protocol,
|
||||
{
|
||||
pub async fn start(
|
||||
runner: Runner<B>,
|
||||
data_service: Arc<TensorDataService<B, P>>,
|
||||
) -> Sender<ProcessorTask> {
|
||||
// channel for tasks to execute
|
||||
let (task_sender, mut task_rec) = tokio::sync::mpsc::channel(1);
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(item) = task_rec.recv().await {
|
||||
match item {
|
||||
ProcessorTask::RegisterOperation(op) => {
|
||||
runner.register_op(*op);
|
||||
}
|
||||
ProcessorTask::Sync(id, callback) => {
|
||||
let result = runner.sync();
|
||||
callback
|
||||
.send(TaskResponse {
|
||||
content: TaskResponseContent::SyncBackend(result),
|
||||
id,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
ProcessorTask::RegisterTensor(id, data) => {
|
||||
runner.register_tensor_data_id(id, data);
|
||||
}
|
||||
ProcessorTask::RegisterTensorRemote(remote_tensor, new_id) => {
|
||||
log::info!(
|
||||
"Registering remote tensor...(id: {:?})",
|
||||
remote_tensor.transfer_id
|
||||
);
|
||||
let data = data_service
|
||||
.download_tensor(remote_tensor.address, remote_tensor.transfer_id)
|
||||
.await
|
||||
.expect("Can't download tensor: error"); // TODO all these panics should be server errors
|
||||
runner.register_tensor_data_id(new_id, data);
|
||||
}
|
||||
ProcessorTask::ExposeTensorRemote {
|
||||
tensor,
|
||||
transfer_id,
|
||||
count,
|
||||
} => {
|
||||
log::info!("Exposing tensor: (id: {transfer_id:?})");
|
||||
let data = runner.read_tensor_async(tensor).await;
|
||||
data_service
|
||||
.expose_data(data.unwrap(), count, transfer_id)
|
||||
.await;
|
||||
}
|
||||
ProcessorTask::ReadTensor(id, tensor, callback) => {
|
||||
let tensor = runner.read_tensor_async(tensor).await;
|
||||
callback
|
||||
.send(TaskResponse {
|
||||
content: TaskResponseContent::ReadTensor(tensor),
|
||||
id,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
ProcessorTask::Close => {
|
||||
let device = runner.device();
|
||||
runner.sync().unwrap();
|
||||
core::mem::drop(runner);
|
||||
B::sync(&device).unwrap();
|
||||
break;
|
||||
}
|
||||
ProcessorTask::Seed(seed) => runner.seed(seed),
|
||||
ProcessorTask::SupportsDType(id, dtype, callback) => {
|
||||
let _result = runner.dtype_usage(dtype);
|
||||
callback
|
||||
.send(TaskResponse {
|
||||
// content: TaskResponseContent::SupportsDType(result),
|
||||
// TODO: Update to result.
|
||||
content: TaskResponseContent::SupportsDType(()),
|
||||
id,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
task_sender
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,170 @@
|
||||
use burn_backend::tensor::Device;
|
||||
use burn_communication::{Protocol, data_service::TensorDataService};
|
||||
use burn_ir::BackendIr;
|
||||
use burn_router::Runner;
|
||||
use burn_std::id::StreamId;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::sync::{
|
||||
Mutex,
|
||||
mpsc::{Receiver, Sender},
|
||||
};
|
||||
|
||||
use crate::shared::{ComputeTask, ConnectionId, SessionId, Task, TaskResponse};
|
||||
|
||||
use super::stream::Stream;
|
||||
|
||||
/// A session manager control the creation of sessions.
|
||||
///
|
||||
/// Each session manages its own stream, spawning one thread per stream to mimic the same behavior
|
||||
/// a native backend would have.
|
||||
pub struct SessionManager<B, P>
|
||||
where
|
||||
B: BackendIr,
|
||||
P: Protocol,
|
||||
{
|
||||
runner: Runner<B>,
|
||||
sessions: Mutex<HashMap<SessionId, Session<B, P>>>,
|
||||
data_service: Arc<TensorDataService<B, P>>,
|
||||
}
|
||||
|
||||
struct Session<B, P>
|
||||
where
|
||||
B: BackendIr,
|
||||
P: Protocol,
|
||||
{
|
||||
runner: Runner<B>,
|
||||
streams: HashMap<StreamId, Stream<B, P>>,
|
||||
sender: Sender<Receiver<TaskResponse>>,
|
||||
receiver: Option<Receiver<Receiver<TaskResponse>>>,
|
||||
data_service: Arc<TensorDataService<B, P>>,
|
||||
}
|
||||
|
||||
impl<B, P> SessionManager<B, P>
|
||||
where
|
||||
B: BackendIr,
|
||||
P: Protocol,
|
||||
{
|
||||
pub fn new(device: Device<B>, data_service: Arc<TensorDataService<B, P>>) -> Self {
|
||||
Self {
|
||||
runner: Runner::new(device),
|
||||
sessions: Mutex::new(Default::default()),
|
||||
data_service,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a new responder for the session. Only one responder can exist for a session for
|
||||
/// now.
|
||||
pub async fn register_responder(
|
||||
&self,
|
||||
session_id: SessionId,
|
||||
) -> Receiver<Receiver<TaskResponse>> {
|
||||
log::info!("Register responder for session {session_id}");
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
self.register_session(&mut sessions, session_id);
|
||||
|
||||
let session = sessions.get_mut(&session_id).unwrap();
|
||||
session.init_responder()
|
||||
}
|
||||
|
||||
/// Get the stream for the current session and task.
|
||||
pub async fn stream(
|
||||
&self,
|
||||
session_id: &mut Option<SessionId>,
|
||||
task: Task,
|
||||
) -> Option<(Stream<B, P>, ConnectionId, ComputeTask)> {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
|
||||
let session_id = match session_id {
|
||||
Some(id) => *id,
|
||||
None => match task {
|
||||
Task::Init(id) => {
|
||||
log::info!("Init requester for session {id}");
|
||||
*session_id = Some(id);
|
||||
self.register_session(&mut sessions, id);
|
||||
return None;
|
||||
}
|
||||
_ => panic!("The first message should initialize the session"),
|
||||
},
|
||||
};
|
||||
|
||||
match sessions.get_mut(&session_id) {
|
||||
Some(session) => {
|
||||
let (task, connection_id) = match task {
|
||||
Task::Compute(task, connection_id) => (task, connection_id),
|
||||
_ => panic!("Only support compute tasks."),
|
||||
};
|
||||
let stream = session.select(connection_id.stream_id).await;
|
||||
Some((stream, connection_id, task))
|
||||
}
|
||||
None => panic!("To be initialized"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Close the session with the given id.
|
||||
pub async fn close(&self, session_id: Option<SessionId>) {
|
||||
if let Some(id) = session_id {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
if let Some(session) = sessions.get_mut(&id) {
|
||||
session.close().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_session(&self, sessions: &mut HashMap<SessionId, Session<B, P>>, id: SessionId) {
|
||||
sessions.entry(id).or_insert_with(|| {
|
||||
log::info!("Creating a new session {id}");
|
||||
|
||||
Session::new(self.runner.clone(), self.data_service.clone())
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, P> Session<B, P>
|
||||
where
|
||||
B: BackendIr,
|
||||
P: Protocol,
|
||||
{
|
||||
fn new(runner: Runner<B>, data_service: Arc<TensorDataService<B, P>>) -> Self {
|
||||
let (sender, receiver) = tokio::sync::mpsc::channel(1);
|
||||
|
||||
Self {
|
||||
runner,
|
||||
streams: Default::default(),
|
||||
sender,
|
||||
receiver: Some(receiver),
|
||||
data_service,
|
||||
}
|
||||
}
|
||||
|
||||
fn init_responder(&mut self) -> Receiver<Receiver<TaskResponse>> {
|
||||
let mut receiver = None;
|
||||
core::mem::swap(&mut receiver, &mut self.receiver);
|
||||
receiver.expect("Only one responder per session is possible.")
|
||||
}
|
||||
|
||||
/// Select the current [stream](Stream) based on the given task.
|
||||
async fn select(&mut self, stream_id: StreamId) -> Stream<B, P> {
|
||||
// We return the stream.
|
||||
match self.streams.get(&stream_id) {
|
||||
Some(stream) => stream.clone(),
|
||||
None => {
|
||||
let stream = Stream::<B, P>::new(
|
||||
self.runner.clone(),
|
||||
self.sender.clone(),
|
||||
self.data_service.clone(),
|
||||
)
|
||||
.await;
|
||||
self.streams.insert(stream_id, stream.clone());
|
||||
stream
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close all streams created in the session.
|
||||
async fn close(&mut self) {
|
||||
for (id, stream) in self.streams.drain() {
|
||||
log::info!("Closing stream {id}");
|
||||
stream.close().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
use core::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::shared::{ConnectionId, TaskResponse, TensorRemote};
|
||||
|
||||
use super::processor::{Processor, ProcessorTask};
|
||||
use burn_backend::TensorData;
|
||||
use burn_communication::{
|
||||
Protocol,
|
||||
data_service::{TensorDataService, TensorTransferId},
|
||||
};
|
||||
use burn_ir::{BackendIr, OperationIr, TensorId, TensorIr};
|
||||
use burn_router::Runner;
|
||||
use burn_std::DType;
|
||||
use tokio::sync::mpsc::{Receiver, Sender};
|
||||
|
||||
/// A stream makes sure all operations registered are executed in the order they were sent to the
|
||||
/// server, potentially waiting to reconstruct consistency.
|
||||
#[derive(Clone)]
|
||||
pub struct Stream<B, P>
|
||||
where
|
||||
B: BackendIr,
|
||||
P: Protocol,
|
||||
{
|
||||
compute_sender: Sender<ProcessorTask>,
|
||||
writer_sender: Sender<Receiver<TaskResponse>>,
|
||||
_p: PhantomData<B>,
|
||||
_n: PhantomData<P>,
|
||||
}
|
||||
|
||||
impl<B, P> Stream<B, P>
|
||||
where
|
||||
B: BackendIr,
|
||||
P: Protocol,
|
||||
{
|
||||
pub async fn new(
|
||||
runner: Runner<B>,
|
||||
writer_sender: Sender<Receiver<TaskResponse>>,
|
||||
data_service: Arc<TensorDataService<B, P>>,
|
||||
) -> Self {
|
||||
let sender = Processor::<B, P>::start(runner, data_service).await;
|
||||
|
||||
Self {
|
||||
compute_sender: sender,
|
||||
writer_sender,
|
||||
_p: PhantomData,
|
||||
_n: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn register_operation(&self, op: Box<OperationIr>) {
|
||||
self.compute_sender
|
||||
.send(ProcessorTask::RegisterOperation(op))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub async fn register_tensor(&self, tensor_id: TensorId, data: TensorData) {
|
||||
self.compute_sender
|
||||
.send(ProcessorTask::RegisterTensor(tensor_id, data))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub async fn register_tensor_remote(&self, tensor: TensorRemote, new_id: TensorId) {
|
||||
self.compute_sender
|
||||
.send(ProcessorTask::RegisterTensorRemote(tensor, new_id))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub async fn expose_tensor_remote(
|
||||
&self,
|
||||
tensor: TensorIr,
|
||||
count: u32,
|
||||
transfer_id: TensorTransferId,
|
||||
) {
|
||||
self.compute_sender
|
||||
.send(ProcessorTask::ExposeTensorRemote {
|
||||
tensor,
|
||||
count,
|
||||
transfer_id,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub async fn read_tensor(&self, id: ConnectionId, desc: TensorIr) {
|
||||
let (callback_sender, callback_rec) = tokio::sync::mpsc::channel(1);
|
||||
|
||||
self.compute_sender
|
||||
.send(ProcessorTask::ReadTensor(id, desc, callback_sender))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
self.writer_sender.send(callback_rec).await.unwrap();
|
||||
}
|
||||
|
||||
pub async fn sync(&self, id: ConnectionId) {
|
||||
let (callback_sender, callback_rec) = tokio::sync::mpsc::channel(1);
|
||||
|
||||
self.compute_sender
|
||||
.send(ProcessorTask::Sync(id, callback_sender))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
self.writer_sender.send(callback_rec).await.unwrap();
|
||||
}
|
||||
|
||||
pub async fn close(&self) {
|
||||
self.compute_sender
|
||||
.send(ProcessorTask::Close)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub async fn seed(&self, seed: u64) {
|
||||
self.compute_sender
|
||||
.send(ProcessorTask::Seed(seed))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub async fn supports_dtype(&self, id: ConnectionId, dtype: DType) {
|
||||
let (callback_sender, callback_rec) = tokio::sync::mpsc::channel(1);
|
||||
|
||||
self.compute_sender
|
||||
.send(ProcessorTask::SupportsDType(id, dtype, callback_sender))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
self.writer_sender.send(callback_rec).await.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
mod task;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
pub(crate) use task::*;
|
||||
|
||||
/// We define the communication protocol here
|
||||
pub(crate) type RemoteProtocol = burn_communication::websocket::WebSocket;
|
||||
@@ -0,0 +1,87 @@
|
||||
use burn_backend::{ExecutionError, TensorData};
|
||||
use burn_communication::{Address, data_service::TensorTransferId};
|
||||
use burn_ir::{OperationIr, TensorId, TensorIr};
|
||||
use burn_std::{
|
||||
DType,
|
||||
id::{IdGenerator, StreamId},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Display;
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(new, Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]
|
||||
pub struct ConnectionId {
|
||||
pub position: u64,
|
||||
pub stream_id: StreamId,
|
||||
}
|
||||
|
||||
/// Unique identifier that can represent a session.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub struct SessionId {
|
||||
id: u64,
|
||||
}
|
||||
|
||||
impl Display for SessionId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
writeln!(f, "SessionId({})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl SessionId {
|
||||
/// Create a new [session id](SessionId).
|
||||
#[allow(dead_code)]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
id: IdGenerator::generate(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub enum Task {
|
||||
Compute(ComputeTask, ConnectionId),
|
||||
Init(SessionId),
|
||||
Close(SessionId),
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct TensorRemote {
|
||||
pub transfer_id: TensorTransferId,
|
||||
pub address: Address,
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub enum ComputeTask {
|
||||
Seed(u64),
|
||||
RegisterOperation(Box<OperationIr>),
|
||||
RegisterTensor(TensorId, TensorData),
|
||||
RegisterTensorRemote(TensorRemote, TensorId),
|
||||
ExposeTensorRemote {
|
||||
tensor: TensorIr,
|
||||
count: u32,
|
||||
transfer_id: TensorTransferId,
|
||||
},
|
||||
ReadTensor(TensorIr),
|
||||
SyncBackend,
|
||||
SupportsDType(DType),
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct TaskResponse {
|
||||
pub content: TaskResponseContent,
|
||||
pub id: ConnectionId,
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub enum TaskResponseContent {
|
||||
ReadTensor(Result<TensorData, ExecutionError>),
|
||||
SyncBackend(Result<(), ExecutionError>),
|
||||
// SupportsDType(DTypeUsageSet),
|
||||
// TODO: Update to `DTypeUsageSet` when it implements `serde`.
|
||||
SupportsDType(()),
|
||||
}
|
||||
Reference in New Issue
Block a user