feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
@@ -0,0 +1,123 @@
|
||||
pub use super::RemoteDevice;
|
||||
use super::worker::{ClientRequest, ClientWorker};
|
||||
use crate::shared::{ComputeTask, ConnectionId, SessionId, Task, TaskResponseContent};
|
||||
use async_channel::{RecvError, SendError, Sender};
|
||||
use burn_communication::ProtocolClient;
|
||||
use burn_ir::TensorId;
|
||||
use burn_std::id::StreamId;
|
||||
use std::{
|
||||
future::Future,
|
||||
sync::{Arc, atomic::AtomicU64},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RemoteClient {
|
||||
pub(crate) device: RemoteDevice,
|
||||
pub(crate) sender: Arc<RemoteSender>,
|
||||
pub(crate) runtime: Arc<tokio::runtime::Runtime>,
|
||||
}
|
||||
|
||||
impl RemoteClient {
|
||||
pub fn init<C: ProtocolClient>(device: RemoteDevice) -> Self {
|
||||
ClientWorker::<C>::start(device)
|
||||
}
|
||||
|
||||
pub(crate) fn new(
|
||||
device: RemoteDevice,
|
||||
sender: Sender<ClientRequest>,
|
||||
runtime: Arc<tokio::runtime::Runtime>,
|
||||
session_id: SessionId,
|
||||
) -> Self {
|
||||
Self {
|
||||
device,
|
||||
runtime,
|
||||
sender: Arc::new(RemoteSender {
|
||||
sender,
|
||||
position_counter: AtomicU64::new(0),
|
||||
tensor_id_counter: AtomicU64::new(0),
|
||||
session_id,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct RemoteSender {
|
||||
sender: Sender<ClientRequest>,
|
||||
position_counter: AtomicU64,
|
||||
tensor_id_counter: AtomicU64,
|
||||
session_id: SessionId,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug)]
|
||||
pub enum RemoteSendError {
|
||||
SendError(SendError<ClientRequest>),
|
||||
RecvError(RecvError),
|
||||
}
|
||||
|
||||
impl RemoteSender {
|
||||
/// Generate a new unique (for this [`RemoteSender`] [`TensorId`].
|
||||
pub(crate) fn new_tensor_id(&self) -> TensorId {
|
||||
TensorId::new(
|
||||
self.tensor_id_counter
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
|
||||
)
|
||||
}
|
||||
|
||||
/// Give the next operation sequence number.
|
||||
fn next_position(&self) -> u64 {
|
||||
self.position_counter
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub(crate) fn send(&self, task: ComputeTask) {
|
||||
self.sender
|
||||
.send_blocking(ClientRequest::WithoutCallback(Task::Compute(
|
||||
task,
|
||||
ConnectionId::new(self.next_position(), StreamId::current()),
|
||||
)))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub(crate) fn send_async(
|
||||
&self,
|
||||
task: ComputeTask,
|
||||
) -> impl Future<Output = Result<TaskResponseContent, RemoteSendError>> + Send + use<> {
|
||||
let stream_id = StreamId::current();
|
||||
let position = self.next_position();
|
||||
let sender = self.sender.clone();
|
||||
|
||||
async move {
|
||||
let (tx, rx) = async_channel::bounded(1);
|
||||
|
||||
if let Err(e) = sender
|
||||
.send(ClientRequest::WithSyncCallback(
|
||||
Task::Compute(task, ConnectionId::new(position, stream_id)),
|
||||
tx,
|
||||
))
|
||||
.await
|
||||
{
|
||||
return Err(RemoteSendError::SendError(e));
|
||||
}
|
||||
|
||||
match rx.recv().await {
|
||||
Ok(response) => Ok(response),
|
||||
Err(e) => Err(RemoteSendError::RecvError(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn close(&mut self) {
|
||||
let sender = self.sender.clone();
|
||||
|
||||
let close_task = ClientRequest::WithoutCallback(Task::Close(self.session_id));
|
||||
|
||||
sender.send_blocking(close_task).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for RemoteSender {
|
||||
fn drop(&mut self) {
|
||||
self.close();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use burn_backend::Shape;
|
||||
use burn_communication::ProtocolClient;
|
||||
use burn_ir::TensorIr;
|
||||
use burn_router::{RouterTensor, RunnerChannel, get_client};
|
||||
|
||||
use super::{
|
||||
RemoteClient,
|
||||
runner::{RemoteBridge, RemoteDevice, RemoteTensorHandle},
|
||||
};
|
||||
|
||||
/// A local channel with direct connection to the backend runner clients.
|
||||
pub struct RemoteChannel<C: ProtocolClient> {
|
||||
_p: PhantomData<C>,
|
||||
}
|
||||
|
||||
impl<C: ProtocolClient> RunnerChannel for RemoteChannel<C> {
|
||||
type Device = RemoteDevice;
|
||||
type Bridge = RemoteBridge<C>;
|
||||
type Client = RemoteClient;
|
||||
|
||||
type FloatElem = f32;
|
||||
|
||||
type IntElem = i32;
|
||||
|
||||
type BoolElem = u32;
|
||||
|
||||
fn name(device: &Self::Device) -> String {
|
||||
format!("remote-{device:?}")
|
||||
}
|
||||
|
||||
fn init_client(device: &Self::Device) -> Self::Client {
|
||||
RemoteClient::init::<C>(device.clone())
|
||||
}
|
||||
|
||||
fn get_tensor_handle(tensor: &TensorIr, client: &Self::Client) -> RemoteTensorHandle<C> {
|
||||
RemoteTensorHandle {
|
||||
client: client.clone(),
|
||||
tensor: tensor.clone(),
|
||||
_p: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn register_tensor(
|
||||
_client: &Self::Client,
|
||||
_handle: RemoteTensorHandle<C>,
|
||||
_shape: Shape,
|
||||
_dtype: burn_backend::DType,
|
||||
) -> RouterTensor<Self::Client> {
|
||||
// This function is normally only used to move a tensor from a device to another.
|
||||
//
|
||||
// In other words, to change the client.
|
||||
panic!("Can't register manually a tensor on a remote channel.");
|
||||
}
|
||||
|
||||
fn change_client_backend(
|
||||
tensor: RouterTensor<Self::Client>,
|
||||
target_device: &Self::Device, // target device
|
||||
) -> RouterTensor<Self::Client> {
|
||||
// Get tensor handle from current client
|
||||
let original_client = tensor.client.clone();
|
||||
let desc = tensor.into_ir();
|
||||
let handle = Self::get_tensor_handle(&desc, &original_client);
|
||||
|
||||
let handle = handle.change_backend(target_device);
|
||||
|
||||
let id = handle.tensor.id;
|
||||
|
||||
let target_client = get_client::<Self>(target_device);
|
||||
let router_tensor: RouterTensor<RemoteClient> =
|
||||
RouterTensor::new(id, handle.tensor.shape, handle.tensor.dtype, target_client);
|
||||
|
||||
router_tensor
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ProtocolClient> Clone for RemoteChannel<C> {
|
||||
fn clone(&self) -> Self {
|
||||
RemoteChannel { _p: PhantomData }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
mod base;
|
||||
mod channel;
|
||||
mod runner;
|
||||
mod worker;
|
||||
|
||||
pub use base::*;
|
||||
pub use channel::*;
|
||||
pub use runner::RemoteDevice;
|
||||
@@ -0,0 +1,294 @@
|
||||
use super::{RemoteChannel, RemoteClient};
|
||||
use crate::shared::{ComputeTask, TaskResponseContent, TensorRemote};
|
||||
use burn_backend::{DeviceId, DeviceOps, ExecutionError, Shape, TensorData};
|
||||
use burn_communication::{Address, ProtocolClient, data_service::TensorTransferId};
|
||||
use burn_ir::TensorIr;
|
||||
use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
|
||||
use burn_std::{backtrace::BackTrace, future::DynFut};
|
||||
use std::sync::OnceLock;
|
||||
use std::{collections::HashMap, marker::PhantomData, str::FromStr, sync::Mutex};
|
||||
|
||||
// TODO: we should work with the parsed structure of Address, not the string.
|
||||
static ADDRESS_REGISTRY: OnceLock<Mutex<HashMap<String, u32>>> = OnceLock::new();
|
||||
|
||||
fn get_address_registry() -> &'static Mutex<HashMap<String, u32>> {
|
||||
ADDRESS_REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
|
||||
}
|
||||
|
||||
/// Map a string network address to a (local runtime) global unique u32.
|
||||
///
|
||||
/// Globally stable over the lifetime of the process, shared between threads,
|
||||
/// If the address has never been seen, a new id will be created.
|
||||
/// If the address has been seen, the previous id will be returned.
|
||||
pub fn address_to_id<S: AsRef<str>>(address: S) -> u32 {
|
||||
let registry = get_address_registry();
|
||||
let mut registry = registry.lock().unwrap();
|
||||
let next_id = registry.len() as u32;
|
||||
*registry
|
||||
.entry(address.as_ref().to_string())
|
||||
.or_insert_with(|| next_id)
|
||||
}
|
||||
|
||||
/// Look up an address by id.
|
||||
///
|
||||
/// Returns the same address given ids by [`address_to_id`].
|
||||
pub fn id_to_address(id: u32) -> Option<String> {
|
||||
let registry = get_address_registry();
|
||||
let registry = registry.lock().unwrap();
|
||||
for entry in registry.iter() {
|
||||
if entry.1 == &id {
|
||||
return Some(entry.0.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// It is very important to block on any request made with the sender, since ordering is crucial
|
||||
// when registering operation or creating tensors.
|
||||
//
|
||||
// The overhead is minimal, since we only wait for the task to be sent to the async
|
||||
// channel, but not sent to the server and even less processed by the server.
|
||||
impl RunnerClient for RemoteClient {
|
||||
type Device = RemoteDevice;
|
||||
|
||||
fn register_op(&self, op: burn_ir::OperationIr) {
|
||||
self.sender
|
||||
.send(ComputeTask::RegisterOperation(Box::new(op)));
|
||||
}
|
||||
|
||||
fn read_tensor_async(
|
||||
&self,
|
||||
tensor: burn_ir::TensorIr,
|
||||
) -> DynFut<Result<TensorData, ExecutionError>> {
|
||||
// Important for ordering to call the creation of the future sync.
|
||||
let fut = self.sender.send_async(ComputeTask::ReadTensor(tensor));
|
||||
|
||||
Box::pin(async move {
|
||||
match fut.await {
|
||||
Ok(response) => match response {
|
||||
TaskResponseContent::ReadTensor(res) => res,
|
||||
_ => panic!("Invalid message type"),
|
||||
},
|
||||
Err(e) => Err(ExecutionError::Generic {
|
||||
reason: format!("Failed to read tensor: {:?}", e),
|
||||
backtrace: BackTrace::capture(),
|
||||
}),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
|
||||
let id = self.sender.new_tensor_id();
|
||||
let shape = data.shape.clone();
|
||||
let dtype = data.dtype;
|
||||
|
||||
self.sender.send(ComputeTask::RegisterTensor(id, data));
|
||||
|
||||
RouterTensor::new(id, Shape::from(shape), dtype, self.clone())
|
||||
}
|
||||
|
||||
fn device(&self) -> Self::Device {
|
||||
self.device.clone()
|
||||
}
|
||||
|
||||
fn sync(&self) -> Result<(), ExecutionError> {
|
||||
// Important for ordering to call the creation of the future sync.
|
||||
let fut = self.sender.send_async(ComputeTask::SyncBackend);
|
||||
|
||||
match self.runtime.block_on(fut) {
|
||||
Ok(response) => match response {
|
||||
TaskResponseContent::SyncBackend(res) => res,
|
||||
_ => panic!("Invalid message type"),
|
||||
},
|
||||
Err(e) => Err(ExecutionError::Generic {
|
||||
reason: format!("Failed to sync: {:?}", e),
|
||||
backtrace: BackTrace::capture(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn seed(&self, seed: u64) {
|
||||
self.sender.send(ComputeTask::Seed(seed));
|
||||
}
|
||||
|
||||
fn create_empty_handle(&self) -> burn_ir::TensorId {
|
||||
self.sender.new_tensor_id()
|
||||
}
|
||||
|
||||
fn dtype_usage(&self, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {
|
||||
let fut = self.sender.send_async(ComputeTask::SupportsDType(dtype));
|
||||
|
||||
match self.runtime.block_on(fut) {
|
||||
Ok(_response) => panic!("Invalid message type"),
|
||||
Err(e) => panic!("Failed to check dtype support: {:?}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
/// The device contains the connection information of the server.
|
||||
pub struct RemoteDevice {
|
||||
pub(crate) address: Address,
|
||||
/// The id of the device in the local registry, see [`address_to_id`].
|
||||
pub(crate) id: u32,
|
||||
}
|
||||
|
||||
impl RemoteDevice {
|
||||
/// Create a device from an url.
|
||||
pub fn new(address: &str) -> Self {
|
||||
let id = address_to_id(address);
|
||||
Self {
|
||||
address: Address::from_str(address).unwrap(),
|
||||
id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RemoteDevice {
|
||||
fn default() -> Self {
|
||||
let address = match std::env::var("BURN_REMOTE_ADDRESS") {
|
||||
Ok(address) => address,
|
||||
Err(_) => String::from("ws://127.0.0.1:3000"),
|
||||
};
|
||||
|
||||
Self::new(&address)
|
||||
}
|
||||
}
|
||||
|
||||
impl burn_std::device::Device for RemoteDevice {
|
||||
fn from_id(device_id: DeviceId) -> Self {
|
||||
if device_id.type_id != 0 {
|
||||
panic!("Invalid device id: {device_id} (expected type 0)");
|
||||
}
|
||||
let address = id_to_address(device_id.index_id)
|
||||
.unwrap_or_else(|| panic!("Invalid device id: {device_id}"));
|
||||
Self::new(&address)
|
||||
}
|
||||
|
||||
fn to_id(&self) -> DeviceId {
|
||||
DeviceId {
|
||||
type_id: 0,
|
||||
index_id: self.id,
|
||||
}
|
||||
}
|
||||
|
||||
fn device_count(_type_id: u16) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
impl DeviceOps for RemoteDevice {}
|
||||
|
||||
pub struct RemoteBridge<C: ProtocolClient> {
|
||||
_p: PhantomData<C>,
|
||||
}
|
||||
|
||||
pub struct RemoteTensorHandle<C: ProtocolClient> {
|
||||
pub(crate) client: RemoteClient,
|
||||
pub(crate) tensor: TensorIr,
|
||||
pub(crate) _p: PhantomData<C>,
|
||||
}
|
||||
|
||||
static TRANSFER_COUNTER: Mutex<Option<TensorTransferId>> = Mutex::new(None);
|
||||
|
||||
fn get_next_transfer_id() -> TensorTransferId {
|
||||
let mut transfer_counter = TRANSFER_COUNTER.lock().unwrap();
|
||||
if transfer_counter.is_none() {
|
||||
*transfer_counter = Some(0.into());
|
||||
|
||||
transfer_counter.unwrap()
|
||||
} else {
|
||||
let mut transfer_counter = transfer_counter.unwrap();
|
||||
transfer_counter.next();
|
||||
|
||||
transfer_counter
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ProtocolClient> RemoteTensorHandle<C> {
|
||||
/// Changes the backend of the tensor via a dWebSocket.
|
||||
/// We ask the original server to expose the tensor, then ask the target server to fetch
|
||||
/// the tensor. The target server will open a new network connection to the original server
|
||||
/// to download the data.
|
||||
/// This way the client never sees the tensor's data, and we avoid a bottleneck.
|
||||
pub(crate) fn change_backend(mut self, target_device: &RemoteDevice) -> Self {
|
||||
let transfer_id = get_next_transfer_id();
|
||||
self.client.sender.send(ComputeTask::ExposeTensorRemote {
|
||||
tensor: self.tensor.clone(),
|
||||
count: 1,
|
||||
transfer_id,
|
||||
});
|
||||
|
||||
let target_client = get_client::<RemoteChannel<C>>(target_device);
|
||||
|
||||
let new_id = target_client.sender.new_tensor_id();
|
||||
|
||||
let remote_tensor = TensorRemote {
|
||||
transfer_id,
|
||||
address: self.client.device.address.clone(),
|
||||
};
|
||||
target_client
|
||||
.sender
|
||||
.send(ComputeTask::RegisterTensorRemote(remote_tensor, new_id));
|
||||
|
||||
self.tensor.id = new_id;
|
||||
self.client = target_client;
|
||||
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ProtocolClient> MultiBackendBridge for RemoteBridge<C> {
|
||||
type TensorHandle = RemoteTensorHandle<C>;
|
||||
type Device = RemoteDevice;
|
||||
|
||||
fn change_backend_float(
|
||||
tensor: Self::TensorHandle,
|
||||
_shape: burn_backend::Shape,
|
||||
target_device: &Self::Device,
|
||||
) -> Self::TensorHandle {
|
||||
tensor.change_backend(target_device)
|
||||
}
|
||||
|
||||
fn change_backend_int(
|
||||
tensor: Self::TensorHandle,
|
||||
_shape: burn_backend::Shape,
|
||||
target_device: &Self::Device,
|
||||
) -> Self::TensorHandle {
|
||||
tensor.change_backend(target_device)
|
||||
}
|
||||
|
||||
fn change_backend_bool(
|
||||
tensor: Self::TensorHandle,
|
||||
_shape: burn_backend::Shape,
|
||||
target_device: &Self::Device,
|
||||
) -> Self::TensorHandle {
|
||||
tensor.change_backend(target_device)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_address_to_id() {
|
||||
let address1 = "ws://127.0.0.1:3000";
|
||||
let address2 = "ws://127.0.0.1:3001";
|
||||
|
||||
let id1 = address_to_id(address1);
|
||||
let id2 = address_to_id(address2);
|
||||
|
||||
assert_ne!(id1, id2);
|
||||
|
||||
assert_eq!(address_to_id(address1), id1);
|
||||
assert_eq!(id_to_address(id1), Some(address1.to_string()));
|
||||
|
||||
assert_eq!(address_to_id(address2), id2);
|
||||
assert_eq!(id_to_address(id2), Some(address2.to_string()));
|
||||
|
||||
let unused_id = u32::MAX;
|
||||
|
||||
assert_eq!(id_to_address(unused_id), None);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
use super::{RemoteClient, runner::RemoteDevice};
|
||||
use crate::shared::{ConnectionId, SessionId, Task, TaskResponse, TaskResponseContent};
|
||||
use burn_communication::{CommunicationChannel, Message, ProtocolClient};
|
||||
use std::{collections::HashMap, marker::PhantomData, sync::Arc};
|
||||
|
||||
pub type CallbackSender = async_channel::Sender<TaskResponseContent>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ClientRequest {
|
||||
WithSyncCallback(Task, CallbackSender),
|
||||
WithoutCallback(Task),
|
||||
}
|
||||
|
||||
pub(crate) struct ClientWorker<C: ProtocolClient> {
|
||||
requests: HashMap<ConnectionId, CallbackSender>,
|
||||
_p: PhantomData<C>,
|
||||
}
|
||||
|
||||
impl<C: ProtocolClient> ClientWorker<C> {
|
||||
async fn on_response(&mut self, response: TaskResponse) {
|
||||
match self.requests.remove(&response.id) {
|
||||
Some(request) => {
|
||||
request.send(response.content).await.unwrap();
|
||||
}
|
||||
None => {
|
||||
panic!("Can't ignore message from the server.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_callback(&mut self, id: ConnectionId, callback: CallbackSender) {
|
||||
self.requests.insert(id, callback);
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ProtocolClient> ClientWorker<C> {
|
||||
pub fn start(device: RemoteDevice) -> RemoteClient {
|
||||
let runtime = Arc::new(
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_io()
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let (sender, rec) = async_channel::bounded(10);
|
||||
|
||||
let session_id = SessionId::new();
|
||||
let address = device.address.clone();
|
||||
|
||||
#[allow(deprecated)]
|
||||
runtime.spawn(async move {
|
||||
log::info!("Connecting to {} ...", address.clone());
|
||||
let mut stream_request = C::connect(address.clone(), "request")
|
||||
.await
|
||||
.expect("Server to be accessible");
|
||||
let mut stream_response = C::connect(address, "response")
|
||||
.await
|
||||
.expect("Server to be accessible");
|
||||
|
||||
let state = Arc::new(tokio::sync::Mutex::new(ClientWorker::<C>::default()));
|
||||
|
||||
// Init the connection.
|
||||
let bytes: bytes::Bytes = rmp_serde::to_vec(&Task::Init(session_id))
|
||||
.expect("Can serialize tasks to bytes.")
|
||||
.into();
|
||||
stream_request
|
||||
.send(Message::new(bytes.clone()))
|
||||
.await
|
||||
.expect("Can send the message over the comms channel.");
|
||||
stream_response
|
||||
.send(Message::new(bytes))
|
||||
.await
|
||||
.expect("Can send the message on the websocket.");
|
||||
|
||||
// Async worker loading callbacks from the server.
|
||||
let state_ws = state.clone();
|
||||
tokio::spawn(async move {
|
||||
while let Ok(msg) = stream_response.recv().await {
|
||||
let msg = match msg {
|
||||
Some(msg) => msg,
|
||||
None => {
|
||||
log::warn!("Closed connection");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let response: TaskResponse = rmp_serde::from_slice(&msg.data)
|
||||
.expect("Can deserialize messages from the websocket.");
|
||||
let mut state = state_ws.lock().await;
|
||||
state.on_response(response).await;
|
||||
}
|
||||
});
|
||||
|
||||
// Channel async worker sending operations to the server.
|
||||
tokio::spawn(async move {
|
||||
while let Ok(req) = rec.recv().await {
|
||||
let task = match req {
|
||||
ClientRequest::WithSyncCallback(task, callback) => {
|
||||
if let Task::Compute(_content, id) = &task {
|
||||
let mut state = state.lock().await;
|
||||
state.register_callback(*id, callback);
|
||||
}
|
||||
task
|
||||
}
|
||||
ClientRequest::WithoutCallback(task) => task,
|
||||
};
|
||||
let bytes = rmp_serde::to_vec(&task)
|
||||
.expect("Can serialize tasks to bytes.")
|
||||
.into();
|
||||
stream_request
|
||||
.send(Message::new(bytes))
|
||||
.await
|
||||
.expect("Can send the message on the websocket.");
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
RemoteClient::new(device, sender, runtime, session_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ProtocolClient> Default for ClientWorker<C> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
requests: Default::default(),
|
||||
_p: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user