feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,73 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science"]
description = "Backend extension for collective calculations."
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "collective"]
license.workspace = true
name = "burn-collective"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-collective"
documentation = "https://docs.rs/burn-collective"
version.workspace = true
[lints]
workspace = true
[features]
default = []
doc = []
tracing = [
"dep:tracing",
"burn-std/tracing",
"burn-tensor/tracing",
"burn-communication/tracing",
"burn-ndarray?/tracing",
"burn-wgpu?/tracing",
"burn-cuda?/tracing",
]
orchestrator = ["burn-communication/websocket"]
# Backends for testing
test-ndarray = ["burn-ndarray"]
test-wgpu = ["burn-wgpu", "burn-wgpu/webgpu"]
test-metal = ["burn-wgpu", "burn-wgpu/metal"]
test-vulkan = ["burn-wgpu", "burn-wgpu/vulkan"]
test-cuda = ["burn-cuda"]
[dependencies]
burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", default-features = true }
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = true }
log = { workspace = true }
burn-communication = { path = "../burn-communication", version = "=0.21.0-pre.2", features = [
"data-service",
"websocket",
] }
tokio = { workspace = true, features = [
"rt-multi-thread",
"sync",
"signal",
"time",
"tracing",
] }
serde = { workspace = true, features = ["derive"] }
rmp-serde = { workspace = true }
bytes = { workspace = true }
futures = { workspace = true }
tokio-util = { workspace = true }
tracing = { workspace = true, optional = true }
# Tests
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2", optional = true }
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true }
burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true }
[dev-dependencies]
serial_test = { workspace = true }
[package.metadata.docs.rs]
features = ["doc"]
rustdoc-args = ["--cfg", "docsrs"]

View File

@@ -0,0 +1,139 @@
# burn-collective
Collective operations on tensors
The following collective operation are implemented:
- `all-reduce`
Aggregates a tensor between all peers, and distributes the result to all peers.
Different strategies can be used on the local and global levels. The result can only be
returned when all peers have called the all-reduce.
- `reduce`
Aggregates a tensor from all peers onto one peer, called the "root"
- `broadcast`
Copies a tensor from one peer to all other peers in the collective.
Peers must call `register` before calling any other operation.
The total number of devices on the node, or nodes in the collective, must be known ahead of time.
In many libraries like NCCL and PyTorch, participating units are called "ranks".
This name is confusing in the context of tensors, so in burn-collective the participating units
are called "peers".
*`reduce` and `broadcast` are not yet implemented for multi-node contexts*
## Local and Global
Internally, there are two levels to the collective operations: local and global. Operations are done on the local level, then optionally on the global level.
| Local | Global |
|-----------------------------------------------|-----------------------------------------------|
| Intra-node (typically within one machine) | Inter-node (typically across machies) |
| Participants are threads (one per peer/GPU) | Participants are processes (one per node) |
| Communication depends on backend | Network peer-to-peer communication |
| Local server is launched automatically | Global coordinator must be launched manually |
| Local server does the aggregation | Nodes do the operations themselves |
For global operations (ie. with multiple nodes), there must be a global orchestrator available.
Start one easily with `burn_collective::start_global_orchestrator()`.
On the global level, nodes use the `burn_communication::data_service::TensorDataService` to
expose and download tensors in a peer-to-peer manner, in order to be independent.
## Components
The following are the important pieces of the collective operations system.
| Term | One per... | Meaning
|--------------------------------|---------------|----------------------------------------------------------
| Local Collective Client | Peer/thread | Requests operations to the Local Collective Server
| Local Collective Server | Node/process | Does local-level ops for threads in this process. In the case of global operations, passes operations on to the Global Collective Client.
| Global Collective Client | Node/process | Does global-level ops for this node. Registers and requests strategies from the Global Collective Orchestrator.
| Global Collective Orchestrator | Collective | Responds to the Global Collective Client from each node. Responsible for aggregation strategies.
## Strategies
Different strategies can be used on the local and global level.
### Centralized
An arbitrary peer is designated as the "root", and all others are transferred to the root's device.
The operation is done on that device.
The resulting tensor then sent to each peer.
### Tree
Tensors in groups of N are aggregated together. This is done recursively until only one tensor
remains. The strategy tries to put devices of the same type closer in the tree.
When N=2, this is like a binary tree reduce.
The resulting tensor then sent to each peer
### Ring
See this good explanation: <https://blog.dailydoseofds.com/p/all-reduce-and-ring-reduce-for-model>
The tensors are sliced into N parts, where N is the number of tensors to aggregate.
Then, the slices are sent around in a series of cycles and aggregated until every tensor's slices
is a sum of the other corresponding slices.
In the case where the tensors are too small to split into N slices, a fallback algorithm is used.
For now, the fallback is a binary tree.
(p=3, n=3)
o->o o
o o->o
o o o->
o 1->o
o o 1->
1->o o
o 1 2->
2->o 1
1 2->o
3 1 2
2 3 1
1 2 3
(This is essentially a reduce-scatter)
3->x x
x 3->x
x x 3->
3 3->x
x 3 3->
3->x 3
3 3 3->
3->3 3
3 3->3
3 3 3
3 3 3
3 3 3
(This is essentially an all-gather)
This is done so that every peer is both sending and receiving data at any moment.
This is an important part of this strategy's advantages.
The ring strategy takes full advantage of the bandwidth available. The latency scales with the
number of peers.
So when the tensors are very small, or when the number of peers is very large, the latency is more
important in the ring strategy, and a tree algorithm is better. Otherwise, the ring algorithm is
the better.
In multi-node contexts, use of the Ring strategy in the local level may be less
advantageous. With multiple nodes, the global all-reduce step is enabled, and its result
is redistributed to all devices.
The Ring strategy inherently distributes the result, which in this context would not be necessary.
It is recommended to use the Ring strategy at the global level
### Double binary tree
<https://developer.nvidia.com/blog/massively-scale-deep-learning-training-nccl-2-4/>

View File

@@ -0,0 +1,34 @@
[package]
name = "burn-collective-multinode-tests"
version.workspace = true
edition.workspace = true
license.workspace = true
[features]
default = ["ndarray"]
ndarray = ["burn/ndarray"]
[dependencies]
burn = { path = "../../burn", default-features = false, features = ["std"] }
burn-std = { path = "../../burn-std", default-features = false }
burn-collective = { path = "..", features = ["orchestrator"] }
burn-communication = { path = "../../burn-communication" }
tokio = { workspace = true, features = ["rt-multi-thread", "process"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
interprocess = "2.3.1"
rmp-serde = { workspace = true }
tokio-util = { workspace = true, features = ["codec"] }
tokio-serde = { version = "0.9.0", features = ["messagepack"] }
futures = { workspace = true }
[[bin]]
name = "global"
path = "src/bin/global.rs"
[[bin]]
name = "node"
path = "src/bin/node.rs"

View File

@@ -0,0 +1,32 @@
# Integration test for burn collective operations with multiple nodes and devices.
Run `cargo run --bin test_launcher`
There are 3 binaries:
## node.rs
Launches `n` threads each simulating a different device. Currently the backend is NdArray,
so everything is CPU. The program takes a file with configurations and input data.
## global.rs
Runs the global orchestrator, who is responsible for responding to global collective operation
requests. In the case of an all-reduce, the orchestrator responds with a strategy for reducing,
and the node can do the reduction independently.
## test_launcher.rs
Generates input data, calculates the expected results, and launches the nodes each with their
own inputs in a separate file.
The topology is [4, 4, 4, 4]. This means 4 nodes are launched,
each with 4 threads (for each device).
The global orchestrator (`global.rs`) is also launched.
## Output
The outputs and inputs for each node and the orchestrator are written to the `target/test_files` folder
If the nodes or orchestrator stall, there is a timeout.

View File

@@ -0,0 +1,19 @@
//! Global orchestrator
//!
//! Launches the orchestrator that responds to global collective operations for nodes for the
//! integration test
//!
//! This is necessary for any node who needs global collective operations
use std::env;
#[tokio::main]
/// Start the global orchestrator on the port given as first arg
pub async fn main() {
let args: Vec<String> = env::args().collect();
let port = args[1].parse::<u16>().expect("invalid port");
// Launch the global orchestrator, which will listen and respond to global collective op
// requests from nodes
burn_collective::start_global_orchestrator(port).await;
}

View File

@@ -0,0 +1,157 @@
use burn::{
backend::NdArray,
prelude::Backend,
tensor::{Tensor, TensorPrimitive, Tolerance},
};
use burn_collective::{
CollectiveConfig, PeerId, ReduceOperation, all_reduce, finish_collective, register,
reset_collective,
};
use burn_collective_multinode_tests::shared::{NodeTest, NodeTestResult, TENSOR_RANK};
use std::{
env,
sync::mpsc::SyncSender,
time::{Duration, Instant},
};
use tokio::net::TcpStream;
use futures::{SinkExt, StreamExt};
use std::thread::JoinHandle;
use tokio_serde::formats::MessagePack;
use tokio_util::codec::LengthDelimitedCodec;
type TestBackend = NdArray;
/// Framed TCP connection channel
type TestChannel = tokio_serde::Framed<
tokio_util::codec::Framed<tokio::net::TcpStream, LengthDelimitedCodec>,
NodeTest,
NodeTestResult,
MessagePack<NodeTest, NodeTestResult>,
>;
/// Start a node that will test all-reduce
/// Args are the following:
/// - launcher endpoint
#[tokio::main]
pub async fn main() {
let args: Vec<String> = env::args().collect();
let launcher_addr = args[1].clone();
let socket = TcpStream::connect(launcher_addr).await.unwrap();
let length_delimited = tokio_util::codec::Framed::new(socket, LengthDelimitedCodec::new());
let mut socket: TestChannel = tokio_serde::Framed::new(
length_delimited,
MessagePack::<NodeTest, NodeTestResult>::default(),
);
// Loop: receive, do test, send result
while let Some(Ok(test)) = socket.next().await {
println!("Received test: {test:?}");
let result = run_test::<NdArray>(&test);
// send the result back
socket.send(result).await.expect("failed to send Result");
}
println!("Server closed connection; exiting.");
}
/// Runs a test for one node
fn run_test<B: Backend>(test_input: &NodeTest) -> NodeTestResult {
reset_collective::<TestBackend>();
// Channel for results
let (result_send, result_recv) = std::sync::mpsc::sync_channel(32);
// Launch a thread for each "device"
let handles = launch_threads::<B>(test_input.clone(), result_send);
// Receive results
let mut durations = vec![];
let tol: Tolerance<f32> = Tolerance::balanced();
for _ in 0..test_input.device_count {
// Assert all results are equal to each other as well as expected result
let (tensor, duration) = result_recv.recv().unwrap();
test_input.expected.assert_approx_eq(&tensor.to_data(), tol);
durations.push(duration);
}
// Threads finish
for handle in handles {
let _ = handle.join();
}
NodeTestResult {
success: true,
durations,
}
}
/// Launch a thread for each device, and run the all-reduce
fn launch_threads<B: Backend>(
test_input: NodeTest,
result_send: SyncSender<(Tensor<B, TENSOR_RANK>, Duration)>,
) -> Vec<JoinHandle<()>> {
let mut handles = vec![];
for id in 0..test_input.device_count {
// Launch a thread to test
// Put all the parameters in the config
let config = CollectiveConfig::default()
.with_num_devices(test_input.device_count)
.with_global_address(test_input.global_address.clone())
.with_node_address(test_input.node_address.clone())
.with_data_service_port(test_input.data_service_port)
.with_num_nodes(test_input.node_count)
.with_global_all_reduce_strategy(test_input.global_strategy)
.with_local_all_reduce_strategy(test_input.local_strategy);
// Inputs and outputs for the test
let tensor_data = test_input.inputs[id].clone();
let tensor = Tensor::<B, TENSOR_RANK>::from_data(tensor_data, &B::Device::default());
let result_send = result_send.clone();
let handle = std::thread::spawn(move || {
run_peer::<B>(
id.into(),
config,
tensor,
result_send,
test_input.all_reduce_op,
)
});
handles.push(handle);
}
handles
}
/// Runs a thread in the all-reduce test.
pub fn run_peer<B: Backend>(
id: PeerId,
config: CollectiveConfig,
input: Tensor<B, TENSOR_RANK>,
output: SyncSender<(Tensor<B, TENSOR_RANK>, Duration)>,
all_reduce_op: ReduceOperation,
) {
// Register the device
register::<B>(id, input.device(), config).unwrap();
let start = Instant::now();
// All-reduce
let input = input.into_primitive().tensor();
let tensor = all_reduce::<B>(id, input, all_reduce_op).unwrap();
let tensor = Tensor::<B, TENSOR_RANK>::from_primitive(TensorPrimitive::Float(tensor));
let duration = start.elapsed();
// Send result
output.send((tensor, duration)).unwrap();
finish_collective::<B>(id).unwrap();
}

View File

@@ -0,0 +1,354 @@
use burn::tensor::TensorData;
use burn_communication::Address;
use futures::{SinkExt, StreamExt};
use std::{
fmt::Display,
fs::{self, File},
str::FromStr,
time::{Duration, Instant},
vec,
};
use tokio::net::TcpListener;
use tokio_serde::formats::MessagePack;
use tokio_util::codec::LengthDelimitedCodec;
use burn::{backend::NdArray, prelude::Backend, tensor::Tensor};
use burn_collective::{AllReduceStrategy, ReduceOperation};
use burn_collective_multinode_tests::shared::{NodeTest, NodeTestResult, TENSOR_RANK};
use burn_std::rand::{SeedableRng, StdRng};
use tokio::process::{Child, Command};
#[derive(Clone)]
struct AllReduceTest {
shape: [usize; TENSOR_RANK],
op: ReduceOperation,
local_strategy: AllReduceStrategy,
global_strategy: AllReduceStrategy,
}
impl Display for AllReduceTest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let op_str = match self.op {
ReduceOperation::Sum => "sum",
ReduceOperation::Mean => "mean",
};
let local_strategy_str = match self.local_strategy {
AllReduceStrategy::Centralized => "local_centralized",
AllReduceStrategy::Tree(n) => &format!("local_tree_{n}"),
AllReduceStrategy::Ring => "local_ring",
};
let global_strategy_str = match self.global_strategy {
AllReduceStrategy::Centralized => "global_centralized",
AllReduceStrategy::Tree(n) => &format!("global_tree_{n}"),
AllReduceStrategy::Ring => "global_ring",
};
write!(f, "{op_str}_{local_strategy_str}_{global_strategy_str}")
}
}
/// Framed TCP connection for sending tests and receiving results
type TestChannel = tokio_serde::Framed<
tokio_util::codec::Framed<tokio::net::TcpStream, LengthDelimitedCodec>,
NodeTestResult,
NodeTest,
MessagePack<NodeTestResult, NodeTest>,
>;
/// Handle for each node process
struct NodeProcessHandle {
process: Child,
channel: TestChannel,
}
/// Main function to run the multi-node all-reduce test.
/// Launches a orchestrator and multiple nodes based on the provided topology.
#[tokio::main(flavor = "multi_thread", worker_threads = 10)]
async fn main() {
let all_reduce_tests = vec![
AllReduceTest {
shape: [4, 64, 512],
op: ReduceOperation::Mean,
local_strategy: AllReduceStrategy::Tree(2),
global_strategy: AllReduceStrategy::Tree(2),
},
AllReduceTest {
shape: [4, 64, 512],
op: ReduceOperation::Mean,
local_strategy: AllReduceStrategy::Tree(2),
global_strategy: AllReduceStrategy::Ring,
},
AllReduceTest {
shape: [4, 64, 512],
op: ReduceOperation::Mean,
local_strategy: AllReduceStrategy::Centralized,
global_strategy: AllReduceStrategy::Centralized,
},
];
let test_files_dir = "target/test_files";
fs::create_dir_all(test_files_dir).expect("Couldn't create test_files directory");
let topology: Vec<usize> = vec![4; 4];
let mut orchestrator = launch_orchestrator(test_files_dir);
let launcher_endpoint = "127.0.0.1:4000";
// Build and run node processes
let mut all_tests_durations = vec![];
if let Ok(mut nodes) = launch_nodes(&topology, launcher_endpoint).await {
// Run one test
for test in all_reduce_tests.clone() {
let test_name = test.to_string();
let time =
test_all_reduce_centralized_no_collective::<NdArray>(&topology, test.clone());
println!(
"{test_name}: Benchmark (no collective, centralized, single-threaded): {} secs",
time.as_secs_f32()
);
match test_all_reduce(&topology, test, &mut nodes).await {
Err(node_idx) => {
println!("{test_name}: Node with index {node_idx} failed!");
// Kill other node processes
for mut node in nodes.drain(..) {
node.process.kill().await.unwrap();
node.process.wait().await.unwrap();
}
break;
}
Ok(durations) => {
all_tests_durations.append(&mut durations.clone());
let avg = durations.iter().map(|dur| dur.as_secs_f32()).sum::<f32>()
/ durations.len() as f32;
println!("{test_name}: Success in {avg} secs");
}
}
}
}
if !all_tests_durations.is_empty() {
let avg = all_tests_durations
.iter()
.map(|dur| dur.as_secs_f32())
.sum::<f32>()
/ all_tests_durations.len() as f32;
println!("Average for all tests: {avg} secs");
}
// Shutdown orchestrator
orchestrator.kill().await.unwrap();
orchestrator.wait().await.unwrap();
}
/// Launch a global orchestrator with an output file in the given directory.
/// Necessary for global collective operations
///
/// Server listens on localhost port 3000
fn launch_orchestrator(test_files_dir: &str) -> Child {
let out_path = format!("{test_files_dir}/orchestrator_out.txt");
let out = File::create(out_path).expect("Could't create orchestrator output file");
Command::new("cargo")
.args(["run", "--bin", "global", "--", "3000"])
.stdout(out.try_clone().unwrap())
.stderr(out)
.spawn()
.expect("failed to launch orchestrator")
}
/// Launch nodes for a all_reduce test
/// Each node will connect to the global orchestrator and run an all-reduce operation.
/// The topology is a vector where each element represents the number of devices in that node.
async fn launch_nodes(
topology: &[usize],
launcher_endpoint: &str,
) -> Result<Vec<NodeProcessHandle>, ()> {
println!(
"Launching {} nodes with topology: {:?}",
topology.len(),
topology
);
// Listen for node connections
let listener = TcpListener::bind(launcher_endpoint).await.unwrap();
println!("Server listening on {launcher_endpoint}");
let mut nodes = vec![];
for node_idx in 0..topology.len() {
// Create log file
let output_filename = format!("target/test_files/node_{}_log.txt", node_idx + 1);
let out = File::create(output_filename).expect("Could't open node log file");
// Start a process for each node. Pass on our feature flags
let node_process: Child = Command::new("cargo")
.args([
"run",
"--release",
"--features",
#[cfg(feature = "ndarray")]
"ndarray",
"--bin",
"node",
"--",
launcher_endpoint,
&node_idx.to_string(),
])
.stdout(out.try_clone().unwrap())
.stderr(out)
.spawn()
.expect("node failed");
// Wait for child to connect for io
let (socket, _peer_addr) = listener.accept().await.unwrap();
let length_delimited = tokio_util::codec::Framed::new(socket, LengthDelimitedCodec::new());
let channel: TestChannel = tokio_serde::Framed::new(
length_delimited,
MessagePack::<NodeTestResult, NodeTest>::default(),
);
nodes.push(NodeProcessHandle {
process: node_process,
channel,
});
}
Ok(nodes)
}
async fn test_all_reduce(
topology: &[usize],
test: AllReduceTest,
nodes: &mut [NodeProcessHandle],
) -> Result<Vec<Duration>, usize> {
dispatch_all_reduce_test(topology, test, nodes).await;
let mut all_durations = vec![];
for (idx, handle) in nodes.iter_mut().enumerate() {
match handle.channel.next().await {
Some(Ok(mut result)) => {
if !result.success {
return Err(idx);
}
all_durations.append(&mut result.durations);
}
_ => {
return Err(idx);
}
}
}
Ok(all_durations)
}
async fn dispatch_all_reduce_test(
topology: &[usize],
test: AllReduceTest,
nodes: &mut [NodeProcessHandle],
) {
let total_device_count: usize = topology.iter().sum();
let (mut all_inputs, expected) =
generate_random_input(test.shape, test.op, total_device_count, 42);
// URL for the global orchestrator on port 3000
let global_url = "ws://localhost:3000";
let global_address = Address::from_str(global_url).unwrap();
for (node_idx, &device_count) in topology.iter().enumerate() {
// Construct URL for node
// Ports 3001... are for each node
let data_service_port = node_idx as u16 + 3001;
let node_url = format!("ws://localhost:{data_service_port}");
let node_address = Address::from_str(&node_url).unwrap();
// take input tensors for each device
let inputs = all_inputs[0..device_count].to_vec();
all_inputs = all_inputs[device_count..].to_vec();
let test = NodeTest {
device_count,
node_id: node_idx.into(),
node_count: topology.len() as u32,
global_address: global_address.clone(),
node_address,
data_service_port,
all_reduce_op: test.op,
global_strategy: test.global_strategy,
local_strategy: test.local_strategy,
inputs,
expected: expected.clone(),
};
let handle = &mut nodes[node_idx];
handle.channel.send(test).await.unwrap();
}
assert!(
all_inputs.is_empty(),
"Not all inputs have been sent to tests"
);
}
/// Run the test sequentially with no collective operations to get the optimal single-threaded speed
fn test_all_reduce_centralized_no_collective<B: Backend>(
topology: &[usize],
test: AllReduceTest,
) -> Duration {
let total_device_count: usize = topology.iter().sum();
let (all_inputs, _expected) =
generate_random_input(test.shape, test.op, total_device_count, 42);
let mut all_inputs = all_inputs
.into_iter()
.map(|data| Tensor::<B, 3>::from_data(data, &B::Device::default()));
let start = Instant::now();
// Sequential test
let mut result = all_inputs.next().unwrap();
for other in all_inputs {
result = result.add(other);
}
if test.op == ReduceOperation::Mean {
result.div_scalar(total_device_count as u32);
}
start.elapsed()
}
/// Generates random input tensors and expected output based on the provided shape and reduce kind.
fn generate_random_input(
shape: [usize; 3],
reduce_kind: ReduceOperation,
input_count: usize,
seed: u64,
) -> (Vec<TensorData>, TensorData) {
let mut rng = StdRng::seed_from_u64(seed);
// A random tensor for each device
let input: Vec<TensorData> = (0..input_count)
.map(|_| {
TensorData::random::<f32, _, _>(shape, burn::tensor::Distribution::Default, &mut rng)
})
.collect();
// Sum up the inputs
let device = <NdArray as Backend>::Device::default();
let mut expected_tensor = Tensor::<NdArray, TENSOR_RANK>::zeros(shape, &device);
for item in input.iter().take(input_count) {
let input_tensor = Tensor::<NdArray, TENSOR_RANK>::from_data(item.clone(), &device);
expected_tensor = expected_tensor.add(input_tensor);
}
if reduce_kind == ReduceOperation::Mean {
expected_tensor = expected_tensor.div_scalar(input_count as u32);
}
// All-Reduce results should have this value
let expected = expected_tensor.to_data();
(input, expected)
}

View File

@@ -0,0 +1 @@
pub mod shared;

View File

@@ -0,0 +1,43 @@
use std::time::Duration;
use burn::tensor::TensorData;
use burn_collective::{AllReduceStrategy, NodeId, ReduceOperation};
use burn_communication::Address;
use serde::{Deserialize, Serialize};
/// Ranks of inputs and outputs for all testing
pub const TENSOR_RANK: usize = 3;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeTest {
/// How many threads to start on this node
pub device_count: usize,
/// ID for this node
pub node_id: NodeId,
/// How many nodes in the cluster
pub node_count: u32,
/// Global server address
pub global_address: Address,
/// Node address
pub node_address: Address,
/// Node's data service port, for initializing the p2p tensor data service
pub data_service_port: u16,
/// What kind of all-reduce
pub all_reduce_op: ReduceOperation,
/// Node's data service port, for initializing the p2p tensor data service
pub global_strategy: AllReduceStrategy,
/// What kind of aggregation
pub local_strategy: AllReduceStrategy,
/// Input data for test: all tensors are D=3
pub inputs: Vec<TensorData>,
/// Expected output for test
pub expected: TensorData,
}
/// Result sent back from each node for each test
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeTestResult {
pub success: bool,
pub durations: Vec<Duration>,
}

View File

@@ -0,0 +1,121 @@
use burn_tensor::backend::Backend;
use crate::{
CollectiveConfig, PeerId, ReduceOperation, global::shared::GlobalCollectiveError,
local::server::get_collective_client,
};
/// Errors from collective operations
#[allow(unused)]
#[derive(Debug, Clone)]
pub enum CollectiveError {
/// The [config](CollectiveConfig) was invalid.
/// Usually happens if only some global parameters have been defined
InvalidConfig,
/// Cannot un-register a node twice
MultipleUnregister,
/// Cannot register a node twice
MultipleRegister,
/// Trying to register a different way than is currently being done
RegisterParamsMismatch,
/// Trying to all-reduce tensors of different shapes: shape must match
AllReduceShapeMismatch,
/// Trying to all-reduce a different way than is currently being done: op must match
AllReduceOperationMismatch,
/// Trying to reduce tensors of different shapes: shape must match
ReduceShapeMismatch,
/// Trying to reduce a different way than is currently being done: op must match
ReduceOperationMismatch,
/// Trying to reduce with different roots
ReduceRootMismatch,
/// Trying to broadcast with different roots
BroadcastRootMismatch,
/// Trying to broadcast but no peer sent a tensor
BroadcastNoTensor,
/// Trying to broadcast but multiple peers sent a tensor
BroadcastMultipleTensors,
/// Local collective server couldn't respond
LocalServerMissing,
/// Another operation was called before Register
RegisterNotFirstOperation,
/// The global orchestrator had an error
Global(GlobalCollectiveError),
#[allow(unused)]
Other(String),
}
/// Registers a device. `num_devices` must be the same for every register,
/// and `device_id` must be unique.
///
/// * `id` - The peer id of the caller
///
/// With auto-diff backends, make sure to use the inner backend.
pub fn register<B: Backend>(
id: PeerId,
device: B::Device,
config: CollectiveConfig,
) -> Result<(), CollectiveError> {
log::info!("Registering peer {id} with config: {config}");
let mut client = get_collective_client::<B>();
client.register(id, device, config)
}
/// Calls for an all-reduce operation with the given parameters, and returns the result.
/// The `params` must be the same as the parameters passed by the other nodes.
///
/// * `id` - The peer id of the caller
/// * `tensor` - The input tensor to reduce with the peers' tensors
/// * `config` - Config of the collective operation, must be coherent with the other calls
pub fn all_reduce<B: Backend>(
id: PeerId,
tensor: B::FloatTensorPrimitive,
op: ReduceOperation,
) -> Result<B::FloatTensorPrimitive, CollectiveError> {
let client = get_collective_client::<B>();
client.all_reduce(id, tensor, op)
}
/// Broadcasts, or receives a broadcasted tensor.
///
/// * `id` - The peer id of the caller
/// * `tensor` - If defined, this tensor will be broadcasted. Otherwise, this call will receive
/// the broadcasted tensor.
///
/// Returns the broadcasted tensor.
pub fn broadcast<B: Backend>(
id: PeerId,
tensor: Option<B::FloatTensorPrimitive>,
) -> Result<B::FloatTensorPrimitive, CollectiveError> {
let client = get_collective_client::<B>();
client.broadcast(id, tensor)
}
/// Reduces a tensor onto one device.
///
/// * `id` - The peer id of the caller
/// * `tensor` - The tensor to send as input
/// * `root` - The ID of the peer that will receive the result.
///
/// Returns Ok(None) if the root tensor is not the caller. Otherwise, returns the reduced tensor.
pub fn reduce<B: Backend>(
id: PeerId,
tensor: B::FloatTensorPrimitive,
op: ReduceOperation,
root: PeerId,
) -> Result<Option<B::FloatTensorPrimitive>, CollectiveError> {
let client = get_collective_client::<B>();
client.reduce(id, tensor, op, root)
}
/// Closes the collective session, unregistering the device
pub fn finish_collective<B: Backend>(id: PeerId) -> Result<(), CollectiveError> {
let client = get_collective_client::<B>();
client.finish(id)
}
/// Resets the local collective server. All registered callers and ongoing operations are forgotten
pub fn reset_collective<B: Backend>() {
let client = get_collective_client::<B>();
client.reset();
}

View File

@@ -0,0 +1,337 @@
use std::fmt::Display;
use burn_communication::Address;
use serde::{Deserialize, Serialize};
/// Parameter struct for setting up and getting parameters for collective operations.
/// Used in most collective api calls.
/// This config is per-node. It is passed to [reduce](crate::register).
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CollectiveConfig {
pub(crate) num_devices: usize,
pub(crate) local_all_reduce_strategy: AllReduceStrategy,
pub(crate) local_reduce_strategy: ReduceStrategy,
pub(crate) local_broadcast_strategy: BroadcastStrategy,
// Global parameters (all are optional, but if one is defined they should all be)
pub(crate) num_nodes: Option<u32>,
pub(crate) global_address: Option<Address>,
pub(crate) node_address: Option<Address>,
pub(crate) data_service_port: Option<u16>,
// These strategies may be defined when no other global params are defined
pub(crate) global_all_reduce_strategy: Option<AllReduceStrategy>,
pub(crate) global_reduce_strategy: Option<ReduceStrategy>,
pub(crate) global_broadcast_strategy: Option<BroadcastStrategy>,
}
impl Default for CollectiveConfig {
fn default() -> Self {
Self::new()
}
}
impl Display for CollectiveConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let num_devices = self.num_devices;
let local_all_reduce_strategy = self.local_all_reduce_strategy;
let local_reduce_strategy = self.local_reduce_strategy;
let local_broadcast_strategy = self.local_broadcast_strategy;
let num_nodes = self.num_nodes;
let global_address = &self.global_address;
let node_address = &self.node_address;
let data_service_port = self.data_service_port;
let global_all_reduce_strategy = self.global_all_reduce_strategy;
let global_reduce_strategy = self.global_reduce_strategy;
let global_broadcast_strategy = self.global_broadcast_strategy;
write!(
f,
r#"
CollectiveConfig {{
num_devices: {num_devices:?},
local_all_reduce_strategy: {local_all_reduce_strategy:?},
local_reduce_strategy: {local_reduce_strategy:?},
local_broadcast_strategy: {local_broadcast_strategy:?},
num_nodes: {num_nodes:?},
global_address: {global_address:?},
node_address: {node_address:?},
data_service_port: {data_service_port:?},
global_all_reduce_strategy: {global_all_reduce_strategy:?},
global_reduce_strategy: {global_reduce_strategy:?},
global_broadcast_strategy: {global_broadcast_strategy:?},
}}
"#
)
}
}
impl CollectiveConfig {
fn new() -> Self {
Self {
num_devices: 1,
local_all_reduce_strategy: AllReduceStrategy::Tree(2),
local_reduce_strategy: ReduceStrategy::Tree(2),
local_broadcast_strategy: BroadcastStrategy::Tree(2),
num_nodes: None,
global_address: None,
node_address: None,
data_service_port: None,
global_all_reduce_strategy: Some(AllReduceStrategy::Ring),
global_reduce_strategy: Some(ReduceStrategy::Tree(2)),
global_broadcast_strategy: Some(BroadcastStrategy::Tree(2)),
}
}
/// Selects the number of devices (local peers) on the current node
pub fn with_num_devices(mut self, num: usize) -> Self {
self.num_devices = num;
self
}
/// Selects an all-reduce strategy to use on the local level.
///
/// In multi-node contexts, use of the Ring strategy in the local level may be less
/// advantageous. With multiple nodes, the global all-reduce step is enabled, and its result
/// is redistributed to all devices.
/// The Ring strategy inherently distributes the result, which in this context would not be
/// necessary.
///
/// It is recommended to use a tree strategy locally, and a ring strategy globally.
pub fn with_local_all_reduce_strategy(mut self, strategy: AllReduceStrategy) -> Self {
self.local_all_reduce_strategy = strategy;
self
}
/// Selects a reduce strategy to use on the local level.
pub fn with_local_reduce_strategy(mut self, strategy: ReduceStrategy) -> Self {
self.local_reduce_strategy = strategy;
self
}
/// Selects a broadcast strategy to use on the local level.
pub fn with_local_broadcast_strategy(mut self, strategy: BroadcastStrategy) -> Self {
self.local_broadcast_strategy = strategy;
self
}
/// Set the number of nodes in the collective
///
/// This parameter is a global parameter and should only be set in multi-node contexts
pub fn with_num_nodes(mut self, n: u32) -> Self {
self.num_nodes = Some(n);
self
}
/// Set the network address of the Global Collective Orchestrator
///
/// This parameter is a global parameter and should only be set in multi-node contexts
pub fn with_global_address(mut self, addr: Address) -> Self {
self.global_address = Some(addr);
self
}
/// Define the address for this node
///
/// This parameter is a global parameter and should only be set in multi-node contexts
pub fn with_node_address(mut self, addr: Address) -> Self {
self.node_address = Some(addr);
self
}
/// Selects the network port on which to expose the tensor data service
/// used for peer-to-peer tensor downloading.
///
/// This parameter is a global parameter and should only be set in multi-node contexts
pub fn with_data_service_port(mut self, port: u16) -> Self {
self.data_service_port = Some(port);
self
}
/// Selects an all-reduce strategy to use on the global level.
///
/// This parameter is a global parameter and should only be set in multi-node contexts.
/// See [the local strategy](Self::with_local_all_reduce_strategy)
pub fn with_global_all_reduce_strategy(mut self, strategy: AllReduceStrategy) -> Self {
self.global_all_reduce_strategy = Some(strategy);
self
}
/// Selects an reduce strategy to use on the global level.
///
/// This parameter is a global parameter and should only be set in multi-node contexts.
/// See [the local strategy](Self::with_local_reduce_strategy)
pub fn with_global_reduce_strategy(mut self, strategy: ReduceStrategy) -> Self {
self.global_reduce_strategy = Some(strategy);
self
}
/// Selects an broadcst strategy to use on the global level.
///
/// This parameter is a global parameter and should only be set in multi-node contexts.
/// See [the local strategy](Self::with_local_broadcast_strategy)
pub fn with_global_broadcast_strategy(mut self, strategy: BroadcastStrategy) -> Self {
self.global_broadcast_strategy = Some(strategy);
self
}
/// Returns whether the config is valid. If only some required global-level parameters are
/// defined and others are not, the config is invalid.
pub fn is_valid(&self) -> bool {
match (
self.num_nodes,
&self.global_address,
&self.node_address,
self.data_service_port,
) {
(None, None, None, None) => true,
(Some(_), Some(_), Some(_), Some(_)) => true,
// Global parameters have only been partially defined!
_ => false,
}
}
/// Return the global parameters for registering in a multi-node context.
///
/// If only some global parameters are defined, returns None. Use [is_valid](Self::is_valid) to check for
/// validity in this case.
pub(crate) fn global_register_params(&self) -> Option<GlobalRegisterParams> {
match (
self.num_nodes,
&self.global_address,
&self.node_address,
self.data_service_port,
) {
// Only local collective
(None, None, None, None) => None,
// Local + global collective
(Some(num_nodes), Some(global_addr), Some(node_addr), Some(data_service_port)) => {
Some(GlobalRegisterParams {
num_nodes,
global_address: global_addr.clone(),
node_address: node_addr.clone(),
data_service_port,
})
}
// Config is invalid!
_ => None,
}
}
}
/// Helper struct for parameters in a multi-node register operation. Either they are all defined,
/// or all not defined. Passed to the global client for registering on the global level and
/// opening the p2p tensor service.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GlobalRegisterParams {
/// The address for the connection to the global orchestrator.
pub global_address: Address,
/// The address for the connection to this node.
pub node_address: Address,
/// The port on which to open the tensor data service for peer-to-peer tensor transfers with
/// other nodes. Should match the port given in the node url.
pub data_service_port: u16,
/// The number of nodes globally. Should be the same between different nodes
pub num_nodes: u32,
}
/// Parameters for an all-reduce that should be the same between all devices
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct SharedAllReduceParams {
pub op: ReduceOperation,
pub local_strategy: AllReduceStrategy,
pub global_strategy: Option<AllReduceStrategy>,
}
/// Parameters for a reduce that should be the same between all devices
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct SharedReduceParams {}
/// Parameters for a broadcast that should be the same between all devices
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct SharedBroadcastParams {
pub op: ReduceOperation,
pub local_strategy: BroadcastStrategy,
pub global_strategy: Option<BroadcastStrategy>,
}
/// Reduce can be done different ways
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
pub enum ReduceOperation {
Sum,
Mean,
}
/// All reduce can be implemented with different algorithms, which all have the same result.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum AllReduceStrategy {
/// One device is the "central". The other devices, "peripherals", send their tensors to the
/// central. The central does the reduction, and sends the result back to each peripheral.
Centralized,
/// Devices are organized in a tree structure (with a given arity). Each node reduces its
/// children's tensors with its own, and sends the result to its parent. Leaf nodes will
/// simply send their tensors to their parents.
/// When the root node calculates the result, it is propagated down the tree.
Tree(u32),
/// Devices are organized in a ring. The tensors are split into N slices, where N is the
/// number of devices participating. The slices are progressively sent around the ring until
/// every device has one fully reduced slice of the tensor. Then, the resulting slices are sent
/// around until every device has the full result.
/// See `ring.rs` for details.
Ring,
}
/// Reduce can be implemented with different algorithms, which all have the same result.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum ReduceStrategy {
/// See [all-reduce](AllReduceStrategy::Centralized)
Centralized,
/// See [all-reduce](AllReduceStrategy::Tree)
Tree(u32),
}
/// Broadcast can be implemented with different algorithms, which all have the same result.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum BroadcastStrategy {
/// See [all-reduce](AllReduceStrategy::Centralized)
Centralized,
/// See [all-reduce](AllReduceStrategy::Tree)
Tree(u32),
}
/// A unique identifier for a peer in the context of collective operations.
/// They must be unique, even in multi-node contexts.
///
/// This is like the rank in NCCL
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct PeerId(u32);
impl Display for PeerId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "PeerId({})", self.0)
}
}
impl From<u32> for PeerId {
fn from(value: u32) -> Self {
Self(value)
}
}
impl From<i32> for PeerId {
fn from(value: i32) -> Self {
Self(value as u32)
}
}
impl From<usize> for PeerId {
fn from(value: usize) -> Self {
Self(value as u32)
}
}

View File

@@ -0,0 +1,23 @@
use serde::{Deserialize, Serialize};
/// Unique identifier for any node in the global collective.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub struct NodeId(u32);
impl From<u32> for NodeId {
fn from(value: u32) -> Self {
Self(value)
}
}
impl From<usize> for NodeId {
fn from(value: usize) -> Self {
Self(value as u32)
}
}
impl From<i32> for NodeId {
fn from(value: i32) -> Self {
Self(value as u32)
}
}

View File

@@ -0,0 +1,10 @@
pub(crate) mod node;
pub(crate) mod shared;
#[cfg(feature = "orchestrator")]
pub mod orchestrator;
#[cfg(feature = "orchestrator")]
pub use orchestrator::*;
mod base;
pub use base::*;

View File

@@ -0,0 +1,203 @@
use burn_communication::Protocol;
use burn_communication::data_service::TensorDataServer;
use burn_communication::{Address, ProtocolServer, data_service::TensorDataService};
use burn_tensor::backend::Backend;
use std::collections::HashMap;
use std::{marker::PhantomData, sync::Arc};
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use crate::node::sync::SyncService;
use crate::{
AllReduceStrategy, BroadcastStrategy, GlobalRegisterParams, NodeId, PeerId, ReduceStrategy,
};
use crate::{
ReduceOperation,
global::{
node::{
centralized::centralized_all_reduce_sum, ring::ring_all_reduce_sum,
tree::tree_all_reduce_sum, worker::GlobalClientWorker,
},
shared::{GlobalCollectiveError, RemoteRequest, RemoteResponse},
},
local::server::get_collective_server_runtime,
};
/// Must be synchronized between all nodes for collective operations to work
pub(crate) struct NodeState {
pub node_id: NodeId,
pub nodes: HashMap<NodeId, Address>,
pub num_global_devices: u32,
}
/// A node talks to the global orchestrator as well as other nodes with a peer-to-peer service
pub(crate) struct Node<B, P>
where
B: Backend,
P: Protocol,
{
// State is written during `register` and read during other operations,
// sometimes by multiple threads (ex. syncing during an all-reduce)
state: Arc<RwLock<Option<NodeState>>>,
data_service: Arc<TensorDataService<B, P>>,
sync_service: Arc<SyncService<P>>,
worker: GlobalClientWorker<P::Client>,
_n: PhantomData<P>,
}
impl<B, P> Node<B, P>
where
B: Backend,
P: Protocol,
{
pub fn new(global_address: &Address, comms_server: P::Server) -> Self {
let state = Arc::new(tokio::sync::RwLock::new(None));
let cancel_token = CancellationToken::new();
let data_service = Arc::new(TensorDataService::new(cancel_token.clone()));
let sync_service = Arc::new(SyncService::new(state.clone()));
let runtime = get_collective_server_runtime();
let server = comms_server
.route_tensor_data_service(data_service.clone())
.route("/sync", {
let sync_service = sync_service.clone();
async move |channel: <P::Server as ProtocolServer>::Channel| {
sync_service.handle_sync_connection(channel).await;
}
})
.serve({
let cancel_token = cancel_token.clone();
async move { cancel_token.cancelled().await }
});
runtime.spawn(server);
let worker = GlobalClientWorker::new(&runtime, cancel_token.clone(), global_address);
Self {
state,
data_service,
sync_service,
worker,
_n: PhantomData,
}
}
pub async fn register(
&mut self,
peers: Vec<PeerId>,
global_params: GlobalRegisterParams,
) -> Result<(), GlobalCollectiveError> {
let req = RemoteRequest::Register {
node_addr: global_params.node_address,
num_nodes: global_params.num_nodes,
peers,
};
match self.worker.request(req).await {
RemoteResponse::Register {
node_id,
nodes,
num_global_devices,
} => {
let mut state = self.state.write().await;
*state = Some(NodeState {
node_id,
nodes,
num_global_devices,
});
}
RemoteResponse::Error(err) => {
return Err(err);
}
resp => {
log::error!("Response to a register request should be an ack, not {resp:?}");
return Err(GlobalCollectiveError::WrongOrchestratorResponse);
}
}
Ok(())
}
/// Performs an all-reduce
///
/// Reads the NodeState
pub async fn all_reduce(
&self,
tensor: B::FloatTensorPrimitive,
strategy: AllReduceStrategy,
op: ReduceOperation,
) -> Result<B::FloatTensorPrimitive, GlobalCollectiveError> {
let state = self.state.read().await;
let Some(ref state) = *state else {
return Err(GlobalCollectiveError::AllReduceBeforeRegister);
};
let node = state.node_id;
let nodes = &state.nodes;
let mut result = match strategy {
AllReduceStrategy::Centralized => {
centralized_all_reduce_sum(
node,
nodes,
&self.data_service,
self.sync_service.clone(),
tensor,
)
.await?
}
AllReduceStrategy::Tree(arity) => {
tree_all_reduce_sum(
node,
nodes,
self.data_service.clone(),
self.sync_service.clone(),
tensor,
arity,
)
.await?
}
AllReduceStrategy::Ring => {
ring_all_reduce_sum(
node,
nodes,
self.data_service.clone(),
self.sync_service.clone(),
tensor,
)
.await?
}
};
if op == ReduceOperation::Mean {
result = B::float_div_scalar(result, (state.num_global_devices as f32).into());
}
Ok(result)
}
pub async fn reduce(
&self,
_tensor: B::FloatTensorPrimitive,
_strategy: ReduceStrategy,
_root: PeerId,
_op: ReduceOperation,
) -> Result<Option<B::FloatTensorPrimitive>, GlobalCollectiveError> {
unimplemented!("Global reduce unimplemented");
}
pub async fn broadcast(
&self,
_tensor: Option<B::FloatTensorPrimitive>,
_strategy: BroadcastStrategy,
) -> Result<B::FloatTensorPrimitive, GlobalCollectiveError> {
unimplemented!("Global broadcast unimplemented");
}
pub async fn finish(&mut self) {
let res = self.worker.close_connection().await;
if let Err(err) = res {
log::error!("Global collective client error: {err:?}");
}
self.data_service.close().await;
}
}

View File

@@ -0,0 +1,96 @@
use std::{collections::HashMap, sync::Arc};
use crate::{NodeId, global::shared::GlobalCollectiveError, node::sync::SyncService};
use burn_communication::data_service::TensorDataService;
use burn_communication::{Address, Protocol};
use burn_tensor::TensorMetadata;
use burn_tensor::backend::Backend;
use futures::StreamExt;
use futures::stream::FuturesUnordered;
/// Global all-reduce, using a centralized strategy.
///
/// Returns the resulting tensor on the same device as the input tensor
pub(crate) async fn centralized_all_reduce_sum<B, P>(
node: NodeId,
nodes: &HashMap<NodeId, Address>,
data_service: &Arc<TensorDataService<B, P>>,
sync_service: Arc<SyncService<P>>,
tensor: B::FloatTensorPrimitive,
) -> Result<B::FloatTensorPrimitive, GlobalCollectiveError>
where
B: Backend,
P: Protocol,
{
let ids = nodes.keys().cloned().collect::<Vec<_>>();
let central = get_central_node(ids.clone());
let shape = tensor.shape();
let device = &B::float_device(&tensor);
let res = if central == node {
// Transfer 1: download tensors from other nodes
let mut futures = ids
.iter()
.filter(|id| **id != central) // Only non-central nodes
.map(|id| {
let address = nodes.get(id).unwrap();
let device = device.clone();
let data_service = data_service.clone();
async move {
let data = data_service
.download_tensor((*address).clone(), 0.into())
.await
.expect("Couldn't find the tensor for transfer id 0");
B::float_from_data(data, &device)
}
})
.collect::<FuturesUnordered<_>>();
// Sum all downloads async
let mut sum = tensor;
while let Some(res) = futures.next().await {
if shape != res.shape() {
return Err(GlobalCollectiveError::PeerSentIncoherentTensor);
}
sum = B::float_add(sum, res);
}
// Transfer 2: Expose result
let other_nodes_count = ids.len() as u32 - 1;
data_service
.expose(sum.clone(), other_nodes_count, 1.into())
.await;
sum
} else {
// Transfer 1: Expose input
data_service.expose(tensor, 1, 0.into()).await;
// Transfer 2: Download result
let central_addr = nodes.get(&central).unwrap().clone();
let data = data_service
.download_tensor(central_addr, 1.into())
.await
.expect("Couldn't find the tensor for transfer id 1");
let res = B::float_from_data(data, device);
if shape != res.shape() {
return Err(GlobalCollectiveError::PeerSentIncoherentTensor);
}
res
};
// Wait for all nodes to finish
sync_service.sync().await;
Ok(res)
}
/// Get the central node for a centralized all-reduce
pub(crate) fn get_central_node(mut nodes: Vec<NodeId>) -> NodeId {
nodes.sort();
*nodes.first().unwrap()
}

View File

@@ -0,0 +1,6 @@
pub mod base;
pub mod centralized;
pub mod ring;
pub mod sync;
pub mod tree;
pub mod worker;

View File

@@ -0,0 +1,216 @@
//! Implements the collective ring all-reduce algorithm on the global level
use core::ops::Range;
use std::{collections::HashMap, sync::Arc};
use crate::{
NodeId,
global::shared::GlobalCollectiveError,
local::{get_ring_reduce_slice_ranges, get_slice_dim},
node::sync::SyncService,
};
use burn_communication::{Address, Protocol, data_service::TensorDataService};
use burn_tensor::{Slice, TensorMetadata, backend::Backend};
// https://blog.dailydoseofds.com/p/all-reduce-and-ring-reduce-for-model
// Example: tensors=3, slices=3
// phase 1
// o->o o
// o o->o
//>o o o->
// o 1->o
//>o o 1->
// 1->o o
// o 1 2
// 2 o 1
// 1 2 o
// phase 2
//>o 1 2->
// 2->o 1
// 1 2->o
// 2->1 2
// 2 2->1
//>1 2 2->
// 2 2 2
// 2 2 2
// 2 2 2
/// Ring all-reduce algorithm with summation
///
/// * `node` - The id of the current node
/// * `nodes` - Map of all nodes in the operation
/// * `data_service` - The data service handles peer-to-peer tensor transfers
/// * `sync_service` - The sync service handles syncing with peers
/// * `tensor` - The tensor to reduce. At least one dimension size must be greater than the number
/// of nodes
pub(crate) async fn ring_all_reduce_sum<B, P>(
node: NodeId,
nodes: &HashMap<NodeId, Address>,
data_service: Arc<TensorDataService<B, P>>,
sync_service: Arc<SyncService<P>>,
tensor: B::FloatTensorPrimitive,
) -> Result<B::FloatTensorPrimitive, GlobalCollectiveError>
where
B: Backend,
P: Protocol,
{
let shape = tensor.shape();
let device = &B::float_device(&tensor);
// Slice tensors in N parts, N is node count
let slice_dim = get_slice_dim(&shape);
if shape[slice_dim] < nodes.len() {
return Err(GlobalCollectiveError::RingReduceImpossible);
}
let ring = get_ring_topology(nodes.keys().cloned().collect::<Vec<_>>());
let slice_ranges = get_ring_reduce_slice_ranges(shape[slice_dim], ring.len());
let mut slices = slice_tensor::<B>(tensor, slice_dim, slice_ranges);
let mut send_slice_idx = ring
.iter()
.position(|id| *id == node)
.expect("Node is in ring");
let prev_node_idx = (send_slice_idx + ring.len() - 1) % ring.len(); // +ring.len for overflow
let prev_node = nodes.get(&ring[prev_node_idx]).unwrap();
let mut transfer_counter: u64 = 0;
// Phase 1: add
do_cycles::<B, P>(
&mut slices,
&mut transfer_counter,
&mut send_slice_idx,
true,
prev_node.clone(),
&data_service,
device,
)
.await?;
// Phase 2: replace
do_cycles::<B, P>(
&mut slices,
&mut transfer_counter,
&mut send_slice_idx,
false,
prev_node.clone(),
&data_service,
device,
)
.await?;
// Wait for all nodes to finish
sync_service.sync().await;
// merge slices
Ok(B::float_cat(slices, slice_dim))
}
/// Do N-1 cycles of ring-reduce
///
/// * `slices` - Slices of the original tensor, len equal to node count
/// * `transfer_counter` - counter for each step (one send one receive)
/// * `send_slice_idx` - counter for the index of each slice to send
/// * `is_phase_one` - In phase 1, the tensors are aggregated. Otherwise, they are overridden
/// * `data_service` - TensorDataService for peer-to-peer tensor transfers
/// * `device` - The device on which all local tensors are stored. Should match `slices`
async fn do_cycles<B, P>(
slices: &mut [B::FloatTensorPrimitive],
transfer_counter: &mut u64,
send_slice_idx: &mut usize,
is_phase_one: bool,
prev_node: Address,
data_service: &Arc<TensorDataService<B, P>>,
device: &B::Device,
) -> Result<(), GlobalCollectiveError>
where
B: Backend,
P: Protocol,
{
let slice_count = slices.len();
for _ in 0..(slice_count - 1) {
let transfer_id = (*transfer_counter).into();
// +slice_count to avoid overflow
let recv_slice_idx = (*send_slice_idx + slice_count - 1) % slice_count;
let slice_send = slices[*send_slice_idx].clone();
let upload = {
let data_service = data_service.clone();
tokio::spawn(async move {
data_service
.expose(slice_send.clone(), 1, transfer_id)
.await
})
};
let download = {
let data_client = data_service.clone();
let next_node = prev_node.clone();
tokio::spawn(async move { data_client.download_tensor(next_node, transfer_id).await })
};
upload.await.unwrap();
let download = download.await.unwrap();
if is_phase_one {
let download = download.expect("Peer closed download connection");
let tensor = B::float_from_data(download, device);
slices[recv_slice_idx] = B::float_add(slices[recv_slice_idx].clone(), tensor);
} else {
let tensor = B::float_from_data(download.unwrap(), device);
let old_shape = slices[recv_slice_idx].shape();
if old_shape != tensor.shape() {
return Err(GlobalCollectiveError::PeerSentIncoherentTensor);
}
slices[recv_slice_idx] = tensor;
}
// Move slice index
*send_slice_idx = recv_slice_idx;
*transfer_counter += 1;
}
Ok(())
}
/// But a tensor into even slices across a dimension
///
/// * `tensor` - the tensor to slice
/// * `slice_dim` - the dimension to slice across
/// * `slice_ranges` - The ranges of indices on `slice_dim` to use when slicing the tensor
fn slice_tensor<B: Backend>(
tensor: B::FloatTensorPrimitive,
slice_dim: usize,
slice_ranges: Vec<Range<usize>>,
) -> Vec<B::FloatTensorPrimitive> {
let shape = tensor.shape();
// full range across all dims as Slice
let full_range = shape
.iter()
.map(|dim| Slice::from(0..*dim))
.collect::<Vec<Slice>>();
// Slice tensors
let mut slices = vec![];
for range in &slice_ranges {
let mut all_ranges = full_range.clone();
all_ranges[slice_dim] = Slice::from(range.clone());
let slice = B::float_slice(tensor.clone(), &all_ranges);
slices.push(slice);
}
slices
}
/// Get the ring topology
fn get_ring_topology(mut nodes: Vec<NodeId>) -> Vec<NodeId> {
// This ordering could be more sophisticated, using node proximities etc
nodes.sort();
nodes
}

View File

@@ -0,0 +1,100 @@
use std::{
marker::PhantomData,
sync::{Arc, Mutex},
vec,
};
use burn_communication::{CommunicationChannel, Message, Protocol, ProtocolClient};
use serde::{Deserialize, Serialize};
use tokio::sync::{Notify, RwLock};
use crate::{NodeId, node::base::NodeState};
/// Handles the status of sync requests from other nodes
pub(crate) struct SyncService<P: Protocol> {
/// Current node's state, shared with the thread that does aggregations
node_state: Arc<RwLock<Option<NodeState>>>,
/// The number of peers that have requested to sync with us since the last successful sync.
syncing_peers: Mutex<Vec<NodeId>>,
/// Notification on each incoming sync request
sync_notif: Notify,
_p: PhantomData<P>,
}
#[derive(Debug, Serialize, Deserialize)]
struct SyncRequest(NodeId);
impl<P: Protocol> SyncService<P> {
pub fn new(node_state: Arc<RwLock<Option<NodeState>>>) -> Self {
Self {
node_state,
syncing_peers: Mutex::new(vec![]),
sync_notif: Notify::new(),
_p: PhantomData,
}
}
fn add_syncing_peer(&self, peer: NodeId) {
let mut syncing_peers = self.syncing_peers.lock().unwrap();
syncing_peers.push(peer);
}
/// Sync with all peers.
pub async fn sync(&self) {
// we can't sync while we register
let node_state = self.node_state.read().await;
let node_state = node_state
.as_ref()
.expect("Trying to sync a node before having registered to the orchestrator");
// this peer is syncing
self.add_syncing_peer(node_state.node_id);
for (id, addr) in &node_state.nodes {
if *id == node_state.node_id {
continue;
}
let mut connection = P::Client::connect(addr.clone(), "sync")
.await
.expect("Couldn't connect to peer for sync");
let msg = SyncRequest(node_state.node_id);
let sync_bytes = rmp_serde::to_vec(&msg).unwrap();
connection
.send(Message::new(sync_bytes.into()))
.await
.expect("Peer closed connection unexpectedly");
}
loop {
{
// compare currently synced peers with list of all nodes
let mut syncing_peers = self.syncing_peers.lock().unwrap().to_vec();
syncing_peers.sort();
let mut all_node_ids = node_state.nodes.keys().cloned().collect::<Vec<_>>();
all_node_ids.sort();
if syncing_peers == all_node_ids {
// all nodes have synced
syncing_peers.clear();
return;
}
}
// Wait for the next sync to come in
self.sync_notif.notified().await
}
}
pub async fn handle_sync_connection<C: CommunicationChannel>(&self, mut channel: C) {
let msg = channel.recv().await.unwrap();
let Some(msg) = msg else {
return;
};
let msg = rmp_serde::from_slice::<SyncRequest>(&msg.data).unwrap();
self.add_syncing_peer(msg.0);
self.sync_notif.notify_waiters();
}
}

View File

@@ -0,0 +1,198 @@
use std::{collections::HashMap, sync::Arc};
use crate::{NodeId, global::shared::GlobalCollectiveError, node::sync::SyncService};
use burn_communication::{Address, Protocol, data_service::TensorDataService};
use burn_tensor::{TensorMetadata, backend::Backend};
use futures::{StreamExt, stream::FuturesUnordered};
struct TreeTopology {
parents: HashMap<NodeId, NodeId>,
children: HashMap<NodeId, Vec<NodeId>>,
}
/// Global all-reduce, using a b-tree strategy.
///
/// Returns the resulting tensor on the same device as the input tensor
pub(crate) async fn tree_all_reduce_sum<B, P>(
node: NodeId,
nodes: &HashMap<NodeId, Address>,
data_service: Arc<TensorDataService<B, P>>,
sync_service: Arc<SyncService<P>>,
tensor: B::FloatTensorPrimitive,
arity: u32,
) -> Result<B::FloatTensorPrimitive, GlobalCollectiveError>
where
B: Backend,
P: Protocol,
{
let shape = tensor.shape();
let device = &B::float_device(&tensor);
// Topology could be cached based on (nodes.keys().cloned(), arity)
let strategy = get_tree_topology(nodes.keys().cloned().collect::<Vec<_>>(), arity);
// Transfer 1: Download and sum tensors from children
let mut result = tensor;
if let Some(children) = strategy.children.get(&node) {
let mut downloads = children
.iter()
.map(|child| {
let child_addr = nodes.get(child).unwrap().clone();
let data_service = data_service.clone();
async move {
let data = data_service
.download_tensor(child_addr.clone(), 0.into())
.await
.ok_or(GlobalCollectiveError::PeerLost(*child))?;
Ok::<B::FloatTensorPrimitive, GlobalCollectiveError>(B::float_from_data(
data, device,
))
}
})
.collect::<FuturesUnordered<_>>();
for _ in children {
let res = downloads.next().await.unwrap().unwrap();
if res.shape() != shape {
return Err(GlobalCollectiveError::PeerSentIncoherentTensor);
}
result = B::float_add(result, res);
}
}
// Transfer 2: Expose result to parent and download final result if not root
if let Some(parent) = strategy.parents.get(&node) {
data_service.expose(result.clone(), 1, 0.into()).await;
let parent_addr = nodes.get(parent).unwrap().clone();
let data = data_service
.download_tensor(parent_addr.clone(), 1.into())
.await
.ok_or(GlobalCollectiveError::PeerLost(*parent))?;
let parent_tensor = B::float_from_data(data, device);
if parent_tensor.shape() != shape {
return Err(GlobalCollectiveError::PeerSentIncoherentTensor);
}
result = parent_tensor;
}
// Transfer 3: Expose final result to children (if any)
if let Some(children) = strategy.children.get(&node)
&& !children.is_empty()
{
data_service
.expose(result.clone(), children.len() as u32, 1.into())
.await;
}
// Final barrier
sync_service.sync().await;
Ok(result)
}
/// Get the tree topology.
///
/// * `nodes` - List of node ids. Order doesn't matter. Nodes must be unique.
fn get_tree_topology(mut nodes: Vec<NodeId>, arity: u32) -> TreeTopology {
assert!(arity >= 1, "Arity must be ≥ 1");
nodes.sort(); // Sort
let n = nodes.len();
let k = arity as usize;
let mut parents: HashMap<_, _> = HashMap::with_capacity(n);
let mut children: HashMap<_, _> = HashMap::with_capacity(n);
for (i, &parent_id) in nodes.iter().enumerate() {
// compute the window [first_child, last_child)
let first = i * k + 1;
if first < n {
let last = usize::min(first + k, n);
let mut ch = Vec::with_capacity(last - first);
for &child_id in &nodes[first..last] {
parents.insert(child_id, parent_id);
ch.push(child_id);
}
children.insert(parent_id, ch);
} else {
// leafnode: no children
children.insert(parent_id, Vec::new());
}
}
TreeTopology { parents, children }
}
#[cfg(test)]
mod tests {
use super::*;
/// Test the tree topology algorithm with arity 2 and 7 nodes
#[test]
fn test_get_tree_topology_arity2_size7() {
let mut nodes = vec![];
for i in 0..7 {
nodes.push(i.into());
}
let topology = get_tree_topology(nodes, 2);
// Root is 0, so it should have no parent
assert!(!topology.parents.contains_key(&0.into()));
// Parents:
// Node 1 and 2 → parent 0
// Node 3 and 4 → parent 1
// Node 5 and 6 → parent 2
let expected_parents = [
(1.into(), 0.into()),
(2.into(), 0.into()),
(3.into(), 1.into()),
(4.into(), 1.into()),
(5.into(), 2.into()),
(6.into(), 2.into()),
];
for (child, parent) in &expected_parents {
assert_eq!(
topology.parents.get(child),
Some(parent),
"wrong parent for {child:?}"
);
}
// There should be exactly 6 entries in parents
assert_eq!(topology.parents.len(), expected_parents.len());
// Children:
// 0 → [1, 2]
// 1 → [3, 4]
// 2 → [5, 6]
// 3,4,5,6 → []
assert_eq!(
topology.children.get(&0.into()),
Some(&vec![1.into(), 2.into()])
);
assert_eq!(
topology.children.get(&1.into()),
Some(&vec![3.into(), 4.into()])
);
assert_eq!(
topology.children.get(&2.into()),
Some(&vec![5.into(), 6.into()])
);
// Leaves
for leaf in 3..7 {
assert_eq!(
topology.children.get(&leaf.into()),
Some(&Vec::new()),
"leaf {leaf:?} should have no children"
);
}
// Ensure we have exactly 7 entries in children
assert_eq!(topology.children.len(), 7);
}
}

View File

@@ -0,0 +1,297 @@
use std::{collections::HashMap, marker::PhantomData, sync::Arc, time::Duration};
use burn_communication::{Address, CommunicationChannel, Message, ProtocolClient};
use tokio::{
runtime::Runtime,
sync::{
Mutex,
mpsc::{Receiver, Sender},
},
task::JoinHandle,
};
use tokio_util::sync::CancellationToken;
use crate::global::shared::{
CollectiveMessage, CollectiveMessageResponse, GlobalCollectiveError, RemoteRequest,
RemoteResponse, RequestId, SessionId,
};
/// Worker that handles communication with the orchestrator for global collective operations.
pub(crate) struct GlobalClientWorker<P: ProtocolClient> {
handle: Option<JoinHandle<Result<(), GlobalCollectiveError>>>,
cancel_token: CancellationToken,
request_sender: Sender<ClientRequest>,
_phantom_data: PhantomData<P>,
}
// Rename
struct GlobalClientWorkerState {
requests: HashMap<RequestId, Sender<RemoteResponse>>,
}
impl GlobalClientWorkerState {
fn new() -> Self {
Self {
requests: HashMap::new(),
}
}
}
#[derive(Debug)]
pub(crate) struct ClientRequest {
pub request: RemoteRequest,
pub callback: Sender<RemoteResponse>,
}
impl ClientRequest {
pub(crate) fn new(request: RemoteRequest, callback: Sender<RemoteResponse>) -> Self {
Self { request, callback }
}
}
impl<C: ProtocolClient> GlobalClientWorker<C> {
/// Create a new global client worker and start the tasks.
pub(crate) fn new(
runtime: &Runtime,
cancel_token: CancellationToken,
global_address: &Address,
) -> Self {
let (request_sender, request_recv) = tokio::sync::mpsc::channel::<ClientRequest>(10);
let state = Arc::new(Mutex::new(GlobalClientWorkerState::new()));
let handle = runtime.spawn(Self::start(
state,
cancel_token.clone(),
global_address.clone(),
request_recv,
));
Self {
handle: Some(handle),
cancel_token,
request_sender,
_phantom_data: PhantomData,
}
}
/// Start the global client tasks
async fn start(
state: Arc<Mutex<GlobalClientWorkerState>>,
cancel_token: CancellationToken,
global_address: Address,
request_recv: Receiver<ClientRequest>,
) -> Result<(), GlobalCollectiveError> {
// Init the connection.
let (request, response) = Self::init_connection(&global_address).await?;
// Websocket async worker loading responses from the server.
let response_handle = tokio::spawn(Self::response_loader(
state.clone(),
response,
cancel_token.clone(),
));
// Channel async worker sending operations to the server.
let request_handle = tokio::spawn(Self::request_sender(
request_recv,
state,
request,
cancel_token.clone(),
));
if let Err(e) = response_handle.await {
log::error!("Response handler failed: {e:?}");
}
if let Err(e) = request_handle.await {
log::error!("Request handler failed: {e:?}");
}
Ok(())
}
async fn init_connection(
address: &Address,
) -> Result<(C::Channel, C::Channel), GlobalCollectiveError> {
let session_id = SessionId::new();
let stream_request = tokio::spawn(Self::connect_with_retry(
address.clone(),
"request",
std::time::Duration::from_secs(1),
None,
session_id,
));
let stream_response = tokio::spawn(Self::connect_with_retry(
address.clone(),
"response",
std::time::Duration::from_secs(1),
None,
session_id,
));
let Ok(Some(request)) = stream_request.await else {
return Err(GlobalCollectiveError::OrchestratorUnreachable);
};
let Ok(Some(response)) = stream_response.await else {
return Err(GlobalCollectiveError::OrchestratorUnreachable);
};
Ok((request, response))
}
/// Connect with websocket with retries.
async fn connect_with_retry(
address: Address,
route: &str,
retry_pause: Duration,
retry_max: Option<u32>,
session_id: SessionId,
) -> Option<C::Channel> {
let mut retries = 0;
loop {
if let Some(max) = retry_max
&& retries >= max
{
log::warn!("Failed to connect to {address} after {max} retries.");
return None;
}
// Try to connect to the request address.
println!("Connecting to {address} ...");
let result = C::connect(address.clone(), route).await;
if let Some(mut stream) = result {
let init_msg = CollectiveMessage::Init(session_id);
let bytes: bytes::Bytes = rmp_serde::to_vec(&init_msg).unwrap().into();
stream
.send(Message::new(bytes))
.await
.expect("Can send the init message on the websocket.");
return Some(stream);
}
println!("Failed to connect to {address}, retrying... Attempt #{retries}");
tokio::time::sleep(retry_pause).await;
retries += 1;
}
}
/// Unregister the worker and close the connection.
pub(crate) async fn close_connection(&mut self) -> Result<(), GlobalCollectiveError> {
if let Some(handle) = self.handle.take() {
// Un-register from server
let req = RemoteRequest::Finish;
let resp = self.request(req).await;
if resp != RemoteResponse::FinishAck {
log::error!("Requested to finish, did not get FinishAck; got {resp:?}");
return Err(GlobalCollectiveError::WrongOrchestratorResponse);
}
self.cancel_token.cancel();
if let Err(e) = handle.await.unwrap() {
log::error!("Connection error {e:?}");
}
}
Ok(())
}
async fn response_loader(
state: Arc<Mutex<GlobalClientWorkerState>>,
mut stream_response: C::Channel,
cancel_token: CancellationToken,
) {
loop {
tokio::select! {
// Check if the cancel token is cancelled
_ = cancel_token.cancelled() => {
break;
}
// .. Or get a message from the websocket
response = stream_response.recv() => {
match response {
Err(err) => {
log::error!("Error receiving message from websocket: {err:?}");
break;
}
Ok(response) => {
let Some(response) = response else {
log::warn!("Closed connection");
break;
};
let response: CollectiveMessageResponse = rmp_serde::from_slice(&response.data)
.expect("Can deserialize messages from the websocket.");
let state_resp = state.lock().await;
let response_callback = state_resp
.requests
.get(&response.request_id)
.expect("Got a response to an unknown request");
response_callback.send(response.content).await.unwrap();
}
}
}
}
}
log::info!("Worker closing connection");
stream_response
.close()
.await
.expect("Can close the websocket stream.");
}
async fn request_sender(
mut request_recv: Receiver<ClientRequest>,
worker: Arc<Mutex<GlobalClientWorkerState>>,
mut stream_request: C::Channel,
cancel_token: CancellationToken,
) {
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
break;
},
request = request_recv.recv() => {
let Some(request) = request else {
continue;
};
let id = RequestId::new();
// Register the callback if there is one
{
let mut state = worker.lock().await;
state.requests.insert(id, request.callback);
}
let request = CollectiveMessage::Request(id, request.request);
let bytes = rmp_serde::to_vec::<CollectiveMessage>(&request)
.expect("Can serialize tasks to bytes.")
.into();
stream_request
.send(Message::new(bytes))
.await
.expect("Can send the message on the websocket.");
}
}
}
log::info!("Worker closing connection");
stream_request
.close()
.await
.expect("Can send the close message on the websocket.");
}
pub(crate) async fn request(&self, req: RemoteRequest) -> RemoteResponse {
let (callback, mut response_recv) = tokio::sync::mpsc::channel::<RemoteResponse>(10);
let client_req = ClientRequest::new(req, callback);
self.request_sender.send(client_req).await.unwrap();
response_recv.recv().await.unwrap()
}
}

View File

@@ -0,0 +1,138 @@
use std::fmt::Debug;
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::global::{
orchestrator::state::GlobalCollectiveState,
shared::{CollectiveMessage, GlobalCollectiveError},
};
use burn_communication::{
CommunicationChannel, Message, ProtocolServer, util::os_shutdown_signal, websocket::WsServer,
};
/// The global collective state manages collective operations on the global level
#[derive(Clone)]
pub(crate) struct GlobalOrchestrator {
state: Arc<Mutex<GlobalCollectiveState>>,
}
impl GlobalOrchestrator {
/// Starts the comms server with two routes: "/request" and "/response"
pub(crate) async fn start<F, S: ProtocolServer + Debug>(
shutdown_signal: F,
comms_server: S,
) -> Result<(), GlobalCollectiveError>
where
F: Future<Output = ()> + Send + 'static,
{
let state = GlobalCollectiveState::new();
let server = Self {
state: Arc::new(tokio::sync::Mutex::new(state)),
};
comms_server
.route("/response", {
let server = server.clone();
async move |socket| {
if let Err(err) = server.handle_socket_response::<S>(socket).await {
log::error!("[Response Handler] Error: {err:?}")
}
}
})
.route("/request", {
let server = server.clone();
async move |socket| {
if let Err(err) = server.handle_socket_request::<S>(socket).await {
log::error!("[Request Handler] Error: {err:?}")
}
}
})
.serve(shutdown_signal)
.await
.map_err(|err| GlobalCollectiveError::Server(format!("{err:?}")))?;
Ok(())
}
async fn handle_socket_response<S: ProtocolServer>(
self,
mut stream: S::Channel,
) -> Result<(), GlobalCollectiveError> {
log::info!("[Response Handler] On new connection.");
let msg = stream
.recv()
.await
.map_err(|err| GlobalCollectiveError::Server(format!("{err:?}")))?;
let Some(msg) = msg else {
log::warn!("Response socket closed early!");
return Ok(());
};
let msg = rmp_serde::from_slice::<CollectiveMessage>(&msg.data)
.map_err(|_| GlobalCollectiveError::InvalidMessage)?;
let CollectiveMessage::Init(id) = msg else {
return Err(GlobalCollectiveError::FirstMsgNotInit);
};
let mut receiver = {
let mut state = self.state.lock().await;
state.get_session_responder(id)
};
while let Some(response) = receiver.recv().await {
let bytes = rmp_serde::to_vec(&response).unwrap();
stream.send(Message::new(bytes.into())).await?;
}
log::info!("[Response Handler] Closing connection.");
Ok(())
}
async fn handle_socket_request<S: ProtocolServer>(
self,
mut stream: S::Channel,
) -> Result<(), GlobalCollectiveError> {
log::info!("[Request Handler] On new connection.");
let mut session_id = None;
loop {
let packet = stream.recv().await?;
let Some(msg) = packet else {
log::info!("Peer closed the connection");
break;
};
let mut state = self.state.lock().await;
let msg = rmp_serde::from_slice::<CollectiveMessage>(&msg.data)
.map_err(|_| GlobalCollectiveError::InvalidMessage)?;
match msg {
CollectiveMessage::Init(id) => {
state.init_session(id);
session_id = Some(id);
}
CollectiveMessage::Request(request_id, remote_request) => {
let session_id = session_id.ok_or(GlobalCollectiveError::FirstMsgNotInit)?;
state
.process_request(session_id, request_id, remote_request)
.await;
}
}
}
Ok(())
}
}
/// Start a global orchestrator with WebSocket on the given port
pub async fn start_global_orchestrator(port: u16) {
let server = WsServer::new(port);
let res = GlobalOrchestrator::start(os_shutdown_signal(), server).await;
if let Err(err) = res {
log::error!("Global Collective Orchestrator error: {err:?}");
}
}

View File

@@ -0,0 +1,4 @@
pub(crate) mod base;
pub(crate) mod state;
pub use base::start_global_orchestrator;

View File

@@ -0,0 +1,219 @@
use crate::{
PeerId,
global::{
NodeId,
shared::{
CollectiveMessageResponse, GlobalCollectiveError, RemoteRequest, RemoteResponse,
RequestId, SessionId,
},
},
};
use burn_communication::Address;
use std::collections::HashMap;
use tokio::sync::mpsc::{Receiver, Sender};
pub(crate) struct Session {
response_sender: Sender<CollectiveMessageResponse>,
response_receiver: Option<Receiver<CollectiveMessageResponse>>,
}
impl Session {
fn new() -> Self {
let (response_sender, recv) = tokio::sync::mpsc::channel::<CollectiveMessageResponse>(1);
Self {
response_sender,
response_receiver: Some(recv),
}
}
async fn respond(&mut self, response: CollectiveMessageResponse) {
self.response_sender.send(response).await.unwrap();
}
}
pub(crate) struct GlobalCollectiveState {
/// The ids passed to each register so far, and their addresses
registered_nodes: HashMap<SessionId, NodeId>,
/// Address for each node
node_addresses: HashMap<NodeId, Address>,
/// Peer on each node
node_peers: HashMap<NodeId, Vec<PeerId>>,
/// How many total nodes for the current register operation, as defined by the first caller
cur_num_nodes: Option<u32>,
/// How many peers have registered total
num_global_peers: u32,
register_requests: Vec<(SessionId, RequestId, NodeId)>,
sessions: HashMap<SessionId, Session>,
}
impl GlobalCollectiveState {
pub fn new() -> Self {
Self {
registered_nodes: HashMap::new(),
node_addresses: HashMap::new(),
node_peers: HashMap::new(),
cur_num_nodes: None,
num_global_peers: 0,
register_requests: Vec::new(),
sessions: HashMap::new(),
}
}
pub(crate) fn init_session(&mut self, id: SessionId) {
if self.sessions.contains_key(&id) {
return;
}
self.sessions.insert(id, Session::new());
}
/// Create the session with given id if necessary, and get the response receiver
pub(crate) fn get_session_responder(
&mut self,
id: SessionId,
) -> Receiver<CollectiveMessageResponse> {
self.init_session(id);
let session = self.sessions.get_mut(&id).unwrap();
let response_recv = session.response_receiver.take();
response_recv.unwrap()
}
pub(crate) async fn respond(
&mut self,
session_id: SessionId,
response: CollectiveMessageResponse,
) {
let session = self.sessions.get_mut(&session_id).unwrap();
session.respond(response).await;
}
/// Process an incoming node's request
pub(crate) async fn process_request(
&mut self,
session_id: SessionId,
request_id: RequestId,
request: RemoteRequest,
) {
if let Err(err) = match request {
RemoteRequest::Register {
node_addr,
num_nodes,
peers,
} => {
self.register(session_id, request_id, node_addr, num_nodes, peers)
.await
}
RemoteRequest::Finish => self.finish(session_id, request_id).await,
} {
// Error occurred, send it as response
let content = RemoteResponse::Error(err);
self.respond(
session_id,
CollectiveMessageResponse {
request_id,
content,
},
)
.await;
}
}
/// Un-register a node. Any pending requests will be cancelled, returning error responses.
async fn finish(
&mut self,
session_id: SessionId,
request_id: RequestId,
) -> Result<(), GlobalCollectiveError> {
let node_id = self
.registered_nodes
.remove(&session_id)
.ok_or(GlobalCollectiveError::NotRegisteredOnFinish)?;
self.node_addresses.remove(&node_id);
self.node_peers.remove(&node_id);
self.num_global_peers = 0;
let mut register_requests = vec![];
core::mem::swap(&mut register_requests, &mut self.register_requests);
for (session, req, node_id) in register_requests {
if session == session_id {
// Send a response if we are finishing a session with a pending register request
let content = RemoteResponse::Error(GlobalCollectiveError::PendingRegisterOnFinish);
let response = CollectiveMessageResponse {
request_id: req,
content,
};
self.respond(session_id, response).await;
} else {
// keep the register request
self.register_requests.push((session, req, node_id));
}
}
self.respond(
session_id,
CollectiveMessageResponse {
request_id,
content: RemoteResponse::FinishAck,
},
)
.await;
Ok(())
}
async fn register(
&mut self,
session_id: SessionId,
request_id: RequestId,
node_addr: Address,
num_nodes: u32,
peers: Vec<PeerId>,
) -> Result<(), GlobalCollectiveError> {
match &self.cur_num_nodes {
Some(cur_num_nodes) => {
if *cur_num_nodes != num_nodes {
return Err(GlobalCollectiveError::RegisterParamsMismatch);
}
}
None => {
self.cur_num_nodes = Some(num_nodes);
}
}
self.num_global_peers += peers.len() as u32;
let node_id: NodeId = self.registered_nodes.len().into();
self.registered_nodes.insert(session_id, node_id);
if self.node_addresses.values().any(|addr| node_addr == *addr) {
return Err(GlobalCollectiveError::DoubleRegister);
}
self.node_addresses.insert(node_id, node_addr);
self.node_peers.insert(node_id, peers);
self.register_requests
.push((session_id, request_id, node_id));
if self.registered_nodes.len() == num_nodes as usize {
let mut callbacks = vec![];
core::mem::swap(&mut callbacks, &mut self.register_requests);
for (session, request, node_id) in callbacks {
let content = RemoteResponse::Register {
node_id,
nodes: self.node_addresses.clone(),
num_global_devices: self.num_global_peers,
};
let resp = CollectiveMessageResponse {
request_id: request,
content,
};
self.respond(session, resp).await;
}
}
Ok(())
}
}

View File

@@ -0,0 +1,132 @@
use std::{collections::HashMap, sync::atomic::AtomicU32};
use crate::{NodeId, PeerId};
use burn_communication::{Address, CommunicationError};
use burn_std::id::IdGenerator;
use serde::{Deserialize, Serialize};
/// A unique identifier for each request made to a global orchestrator
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub(crate) struct RequestId(u32);
static REQ_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
impl RequestId {
pub(crate) fn new() -> Self {
let id = REQ_ID_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Self(id)
}
}
impl Default for RequestId {
fn default() -> Self {
Self::new()
}
}
/// Unique identifier that can represent a session between a node and the orchestrator.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub(crate) struct SessionId {
id: u64,
}
impl SessionId {
/// Create a new [session id](SessionId).
pub(crate) fn new() -> Self {
Self {
id: IdGenerator::generate(),
}
}
}
/// Requests sent from the client
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) enum CollectiveMessage {
Init(SessionId),
Request(RequestId, RemoteRequest),
}
/// Responses sent to the client
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct CollectiveMessageResponse {
pub request_id: RequestId,
pub content: RemoteResponse,
}
/// Requests made from a client to a server.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) enum RemoteRequest {
// Register a node
Register {
/// Endpoint for this node
node_addr: Address,
/// Number of total nodes
num_nodes: u32,
/// List of peers on this node
peers: Vec<PeerId>,
},
/// Unregister node
Finish,
}
/// Responses for each server request
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub(crate) enum RemoteResponse {
/// Response to a register request
Register {
/// The orchestrator gives the node its id
node_id: NodeId,
/// All the nodes in the collective: including self
nodes: HashMap<NodeId, Address>,
/// How many devices exist globally? For averaging values
num_global_devices: u32,
},
// Finish
FinishAck,
// There was a server-side error
Error(GlobalCollectiveError),
}
/// Errors that occur during collective operations on the global level
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum GlobalCollectiveError {
/// Operations that can't be done before registering
AllReduceBeforeRegister,
/// Ring all-reduce can't be done if all tensor dimensions are smaller than the number of nodes.
RingReduceImpossible,
/// Either a node has unregistered twice, or a Finish has been called before a Register
NotRegisteredOnFinish,
/// Finish has been called before a Register operation was finished
PendingRegisterOnFinish,
/// Trying to register a different way than is currently being done
RegisterParamsMismatch,
/// Trying to register while already registered
DoubleRegister,
/// Trying to aggregate a different way than is currently being done
AllReduceParamsMismatch,
/// First message on socket should be Message::Init
FirstMsgNotInit,
/// Messages should be rmp_serde serialized `Message` types
InvalidMessage,
/// A peer behaved unexpectedly
PeerSentIncoherentTensor,
/// Tried to download from a peer, but the peer closed or lost the connection
PeerLost(NodeId),
/// Error from the coordinator
Server(String),
/// The node received an invalid response
WrongOrchestratorResponse,
/// Node couldn't connect to coordinator
OrchestratorUnreachable,
}
impl<E: CommunicationError> From<E> for GlobalCollectiveError {
fn from(err: E) -> Self {
Self::Server(format!("{err:?}"))
}
}

View File

@@ -0,0 +1,21 @@
mod global;
pub use global::*;
mod config;
pub use config::*;
mod api;
pub use api::*;
mod local;
#[cfg(all(
test,
any(
feature = "test-ndarray",
feature = "test-wgpu",
feature = "test-cuda",
feature = "test-metal"
)
))]
mod tests;

View File

@@ -0,0 +1,118 @@
use crate::local::tensor_map::{CollectiveTensorMap, get_peer_devices};
use crate::{
AllReduceStrategy, CollectiveConfig, CollectiveError, ReduceOperation,
local::{
all_reduce_sum_centralized, all_reduce_sum_ring, all_reduce_sum_tree,
broadcast_centralized, broadcast_tree, reduce_sum_centralized, reduce_sum_tree,
},
node::base::Node,
};
use burn_communication::Protocol;
use burn_tensor::backend::Backend;
#[cfg(feature = "tracing")]
use tracing::Instrument;
/// Perform an all-reduce with no multi-node operations (global ops)
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(tensors, config))
)]
pub(crate) async fn all_reduce_local_only<B: Backend>(
tensors: CollectiveTensorMap<B>,
op: ReduceOperation,
config: &CollectiveConfig,
) -> Result<CollectiveTensorMap<B>, CollectiveError> {
let local_strategy = &config.local_all_reduce_strategy;
let mut reduced_tensors = match local_strategy {
AllReduceStrategy::Centralized => all_reduce_sum_centralized::<B>(tensors),
AllReduceStrategy::Tree(arity) => all_reduce_sum_tree::<B>(tensors, *arity),
AllReduceStrategy::Ring => all_reduce_sum_ring::<B>(tensors),
};
if op == ReduceOperation::Mean {
#[cfg(feature = "tracing")]
let _span = tracing::info_span!("mean_reduction").entered();
// Apply mean division
let div = (reduced_tensors.len() as f32).into();
reduced_tensors = reduced_tensors
.into_iter()
.map(|(id, t)| (id, B::float_div_scalar(t, div)))
.collect();
}
Ok(reduced_tensors)
}
/// Do an all-reduce in a multi-node context
///
/// With Tree and Centralized strategies, the all-reduce is split between a
/// reduce (all tensors are reduced to one device), and a broadcast (the result is sent to all
/// other devices). The all-reduce on the global level is done between both steps.
/// Due to the nature of the Ring strategy, this separation can't be done.
///
/// For the Ring strategy, this isn't possible, because it is more like a
/// reduce-scatter plus an all-gather, so using a Ring strategy locally in a multi-node
/// setup may be unadvantageous.
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(tensors, config, global_client))
)]
pub(crate) async fn all_reduce_with_global<B: Backend, P: Protocol>(
tensors: CollectiveTensorMap<B>,
op: ReduceOperation,
config: &CollectiveConfig,
global_client: &mut Node<B, P>,
) -> Result<CollectiveTensorMap<B>, CollectiveError> {
let peer_devices = get_peer_devices::<B>(&tensors);
// For Centralized and Tree, we only need to do a reduce here, we'll do a broadcast later
let main_device = *tensors.keys().next().unwrap();
let mut main_tensor = match config.local_all_reduce_strategy {
AllReduceStrategy::Centralized => reduce_sum_centralized::<B>(tensors, &main_device),
AllReduceStrategy::Tree(arity) => reduce_sum_tree::<B>(tensors, &main_device, arity),
AllReduceStrategy::Ring => all_reduce_sum_ring::<B>(tensors)
.remove(&main_device)
.unwrap(),
};
// Do aggregation on global level with the main tensor
main_tensor = {
let fut = async {
let global_strategy = config
.global_all_reduce_strategy
.expect("global_all_reduce_strategy must be set");
global_client
.all_reduce(main_tensor, global_strategy, op)
.await
};
#[cfg(feature = "tracing")]
{
fut.instrument(tracing::info_span!("global_all_reduce"))
}
#[cfg(not(feature = "tracing"))]
{
fut
}
}
.await
.map_err(CollectiveError::Global)?;
// Broadcast result to all devices
let tensors = match config.local_all_reduce_strategy {
AllReduceStrategy::Tree(arity) => {
broadcast_tree::<B>(peer_devices, main_device, main_tensor, arity)
}
// If we chose the ring strategy and we must still broadcast the global result,
// we use the centralized strategy for broadcasting, but the tree may be better.
AllReduceStrategy::Centralized | AllReduceStrategy::Ring => {
broadcast_centralized::<B>(peer_devices, main_device, main_tensor)
}
};
Ok(tensors)
}

View File

@@ -0,0 +1,26 @@
use burn_tensor::backend::Backend;
use crate::local::tensor_map::{CollectiveTensorMap, get_peer_devices};
use crate::local::{broadcast_centralized, reduce_sum_centralized};
/// Perform an all-reduce operation by reducing all tensors on one device, and broadcasting the
/// result to all other devices
///
/// Internally, this is just a call to `reduce` followed by a `broadcast`
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(tensors))
)]
pub(crate) fn all_reduce_sum_centralized<B: Backend>(
tensors: CollectiveTensorMap<B>,
) -> CollectiveTensorMap<B> {
// Get corresponding devices for each peer
let peer_devices = get_peer_devices::<B>(&tensors);
let central_device = *tensors.keys().next().unwrap();
// Reduce to central device
let central_tensor = reduce_sum_centralized::<B>(tensors, &central_device);
// Broadcast result to all
broadcast_centralized::<B>(peer_devices, central_device, central_tensor)
}

View File

@@ -0,0 +1,11 @@
mod base;
mod centralized;
mod op;
mod ring;
mod tree;
pub(crate) use base::*;
pub(crate) use centralized::*;
pub(crate) use op::*;
pub(crate) use ring::*;
pub(crate) use tree::*;

View File

@@ -0,0 +1,141 @@
use crate::global::node::base::Node;
use crate::local::tensor_map::CollectiveTensorMap;
use crate::{CollectiveConfig, CollectiveError, PeerId, ReduceOperation, local};
use burn_communication::Protocol;
use burn_std::Shape;
use burn_tensor::TensorMetadata;
use burn_tensor::backend::Backend;
use std::sync::mpsc::SyncSender;
/// An on-going all-reduce operation
#[derive(Debug)]
pub struct AllReduceOp<B: Backend> {
/// all-reduce calls, one for each calling device
calls: Vec<AllReduceOpCall<B>>,
/// The reduce operation of the current all-reduce, as defined by the first caller
op: ReduceOperation,
/// The shape of the current all-reduce, as defined by the first caller
shape: Shape,
}
/// Struct for each device that calls an all-reduce operation
#[derive(Debug)]
pub struct AllReduceOpCall<B: Backend> {
/// Id of the caller for this operation
caller: PeerId,
/// The tensor primitive passed as input
input: B::FloatTensorPrimitive,
/// Callback for the result of the all-reduce
result_sender: SyncSender<AllReduceResult<B::FloatTensorPrimitive>>,
}
/// Type sent to the collective client upon completion of a all-reduce aggregation
pub(crate) type AllReduceResult<T> = Result<T, CollectiveError>;
impl<B: Backend> AllReduceOp<B> {
pub fn new(shape: Shape, reduce_op: ReduceOperation) -> Self {
Self {
calls: vec![],
op: reduce_op,
shape,
}
}
/// Get a list of the peers.
fn peers(&self) -> Vec<PeerId> {
self.calls.iter().map(|c| c.caller).collect()
}
/// Register a call to all-reduce in this operation.
///
/// # Returns
///
/// `true` if enough peers have registered, and the all-reduce is ready
pub fn register_call(
&mut self,
caller: PeerId,
input: B::FloatTensorPrimitive,
result_sender: SyncSender<AllReduceResult<B::FloatTensorPrimitive>>,
op: ReduceOperation,
peer_count: usize,
) -> Result<bool, CollectiveError> {
if self.shape != input.shape() {
return Err(CollectiveError::AllReduceShapeMismatch);
}
if self.op != op {
return Err(CollectiveError::AllReduceOperationMismatch);
}
self.calls.push(AllReduceOpCall {
caller,
input,
result_sender,
});
Ok(self.calls.len() == peer_count)
}
/// Runs the all-reduce if the operation is ready. Otherwise, do nothing
#[cfg_attr(feature = "tracing", tracing::instrument(
level = "trace",
skip(self, config, global_client),
fields(
?self.op,
?self.shape,
self.peers = ?self.peers(),
)
))]
pub async fn execute<P: Protocol>(
mut self,
config: &CollectiveConfig,
global_client: &mut Option<Node<B, P>>,
) {
// all registered callers have sent a tensor to aggregate
match self.all_reduce(config, global_client).await {
Ok(mut tensors) => {
// Return resulting tensors
self.calls.iter().for_each(|call| {
let result = tensors
.remove(&call.caller)
.expect("tensor/peer internal mismatch.");
call.result_sender.send(Ok(result)).unwrap();
});
assert_eq!(tensors.len(), 0, "tensor/peer internal mismatch.");
}
Err(err) => {
// Send error to all subscribers
self.fail(err);
}
}
}
/// Perform an all-reduce operation.
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(self, config, global_client))
)]
async fn all_reduce<P: Protocol>(
&mut self,
config: &CollectiveConfig,
global_client: &mut Option<Node<B, P>>,
) -> Result<CollectiveTensorMap<B>, CollectiveError> {
let tensors = self
.calls
.iter()
.map(|call| (call.caller, call.input.clone()))
.collect();
if let Some(global_client) = global_client.as_mut() {
local::all_reduce_with_global(tensors, self.op, config, global_client).await
} else {
local::all_reduce_local_only::<B>(tensors, self.op, config).await
}
}
/// Send a collective error as result to operation caller
pub fn fail(self, err: CollectiveError) {
self.calls.iter().for_each(|op| {
op.result_sender.send(Err(err.clone())).unwrap();
});
}
}

View File

@@ -0,0 +1,194 @@
use super::tree::all_reduce_sum_tree;
use crate::PeerId;
use crate::local::tensor_map;
use crate::local::tensor_map::CollectiveTensorMap;
use burn_tensor::{Shape, Slice, TensorMetadata, backend::Backend};
use std::{collections::HashMap, ops::Range};
/// Ring implementation of All-Reduce (Ring-Reduce)
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(tensors))
)]
pub(crate) fn all_reduce_sum_ring<B: Backend>(
tensors: CollectiveTensorMap<B>,
) -> CollectiveTensorMap<B> {
// https://blog.dailydoseofds.com/p/all-reduce-and-ring-reduce-for-model
// Example: tensors=3, slices=3
// phase 1
// o->o o
// o o->oå
// o o o->
// o 1->o
// o o 1->
// 1->o o
// o 1 2
// 2 o 1
// 1 2 o
// phase 2
// o 1 2->
// 2->o 1
// 1 2->o
// 2->1 2
// 2 2->1
// 1 2 2->
// 2 2 2
// 2 2 2
// 2 2 2
// Verify all shapes are the same
let shape = tensor_map::get_common_shape::<B>(&tensors)
.expect("Cannot aggregate tensors with different sizes");
// Chose an axis
let slice_dim = get_slice_dim(&shape);
let slice_dim_size = shape[slice_dim];
let tensor_count = tensors.len();
if slice_dim_size < tensor_count {
// Tensor cannot be split into N slices! Use a fallback algorithm: binary tree
return all_reduce_sum_tree::<B>(tensors, 2);
}
// Split tensors into slices
let mut sliced_tensors = slice_tensors::<B>(tensors, shape, slice_dim);
// phase 1: aggregate in ring N-1 times (Reduce-Scatter)
ring_cycles::<B>(&mut sliced_tensors, true);
// phase 2: share (overwrite) in a ring N-1 times (All-Gather)
ring_cycles::<B>(&mut sliced_tensors, false);
// merge slices and put back in result
sliced_tensors
.into_iter()
.map(|(id, slices)| (id, B::float_cat(slices, slice_dim)))
.collect()
}
/// Get the dimension to slice across: the largest dimension of the shape
pub(crate) fn get_slice_dim(shape: &Shape) -> usize {
// get dimension with the greatest size.
shape
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.cmp(b))
.map(|(index, _)| index)
.unwrap()
}
/// With a ring of N tensors, send the tensors N-1 times, either for the first of second phase.
/// During the first phase, the tensor slices are summed.
/// During the second, the slices are replaced.
fn ring_cycles<B: Backend>(
sliced_tensors: &mut [(PeerId, Vec<B::FloatTensorPrimitive>)],
is_phase_one: bool,
) {
let tensor_count = sliced_tensors.len();
for cycle in 0..(tensor_count - 1) {
for i in 0..tensor_count {
let src_tensor_idx = i;
let dest_tensor_idx = (i + 1) % tensor_count;
let slice_idx = if is_phase_one {
(i + (tensor_count - 1) * cycle) % tensor_count
} else {
// in phase 2, the starting slice is different (see diagrams)
(i + 1 + (tensor_count - 1) * cycle) % tensor_count
};
let src_slice = sliced_tensors[src_tensor_idx].1.remove(slice_idx);
let mut dest_slice = sliced_tensors[dest_tensor_idx].1.remove(slice_idx);
let dest_device = B::float_device(&dest_slice);
let src_slice_on_dest = B::float_to_device(src_slice.clone(), &dest_device);
if is_phase_one {
dest_slice = B::float_add(dest_slice, src_slice_on_dest);
} else {
let slices: Vec<Slice> = dest_slice
.shape()
.iter()
.map(|&d| Slice::new(0, Some(d as isize), 1))
.collect();
// in phase 2, we don't sum the two slices, we replace with the new one.
dest_slice =
B::float_slice_assign(dest_slice, slices.as_slice(), src_slice_on_dest);
}
sliced_tensors[src_tensor_idx]
.1
.insert(slice_idx, src_slice);
sliced_tensors[dest_tensor_idx]
.1
.insert(slice_idx, dest_slice);
}
}
}
/// Slice a list of tensors the same way, evenly across a given dimension.
/// The given `shape` should be the same for every tensor.
fn slice_tensors<B: Backend>(
mut tensors: HashMap<PeerId, B::FloatTensorPrimitive>,
shape: Shape,
slice_dim: usize,
) -> Vec<(PeerId, Vec<<B as Backend>::FloatTensorPrimitive>)> {
// Get slice index ranges
let ranges = get_ring_reduce_slice_ranges(shape[slice_dim], tensors.len());
// Slice tensors
let mut sliced_tensors = vec![];
for (id, tensor) in tensors.drain() {
let mut slices = vec![];
for range in &ranges {
let full_range = shape
.iter()
.enumerate()
.map(|(dim_idx, dim)| {
if dim_idx == slice_dim {
Slice::from(range.clone())
} else {
Slice::from(0..*dim)
}
})
.collect::<Vec<_>>();
let slice = B::float_slice(tensor.clone(), &full_range);
slices.push(slice);
}
sliced_tensors.push((id, slices));
}
sliced_tensors
}
/// Get the index ranges for the slices to split a tensor evently across a given axis.
///
/// * `slice_dim_size` - The size of the dim to slice on
/// * `slice_count` - The number of slices
///
/// Returns a vector of index ranges for each slice.
pub(crate) fn get_ring_reduce_slice_ranges(
slice_dim_size: usize,
slice_count: usize,
) -> Vec<Range<usize>> {
let mut ranges: Vec<Range<usize>> = vec![];
let slice_size = slice_dim_size.div_ceil(slice_count);
for i in 0..slice_count {
let start = i * slice_size;
let end = start + slice_size;
ranges.push(Range { start, end });
}
ranges.last_mut().unwrap().end = slice_dim_size;
ranges
}

View File

@@ -0,0 +1,89 @@
use crate::PeerId;
use crate::local::tensor_map::CollectiveTensorMap;
use burn_tensor::backend::{Backend, DeviceOps};
use std::collections::HashMap;
/// Performs an all-reduce on the provided tensors in a b-tree structure with `arity`.
/// Similar to [reduce_sum_tree](reduce_sum_tree), but this function broadcasts the result with
/// the same tree algorithm.
/// The returned tensors are on the same devices as the corresponding inputs
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(tensors))
)]
pub(crate) fn all_reduce_sum_tree<B: Backend>(
tensors: CollectiveTensorMap<B>,
arity: u32,
) -> CollectiveTensorMap<B> {
let mut input = tensors.into_iter().collect::<Vec<_>>();
// Sort to put devices of the same type together
input.sort_by(|a, b| {
let dev_a = B::float_device(&a.1);
let dev_b = B::float_device(&b.1);
dev_a.id().cmp(&dev_b.id())
});
// Recursive all-reduce
let out = all_reduce_sum_tree_inner::<B>(input, arity);
let mut tensors = HashMap::new();
for (id, tensor) in out {
tensors.insert(id, tensor);
}
tensors
}
/// Recursive function that sums `tensors` and redistributes the result to the host devices
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(tensors))
)]
fn all_reduce_sum_tree_inner<B: Backend>(
mut tensors: Vec<(PeerId, B::FloatTensorPrimitive)>,
arity: u32,
) -> Vec<(PeerId, B::FloatTensorPrimitive)> {
let mut parent_tensors = vec![];
let mut children_groups = vec![];
// Phase 1: Sum tensors in groups of `arity` + 1
while !tensors.is_empty() {
// Maps ids to devices for each child of this parent
let mut children = vec![];
let (parent, mut parent_tensor) = tensors.remove(0);
let parent_device = B::float_device(&parent_tensor);
for _ in 0..arity {
if tensors.is_empty() {
break;
}
let (child, mut child_tensor) = tensors.remove(0);
let child_device = B::float_device(&child_tensor);
children.push((child, child_device));
child_tensor = B::float_to_device(child_tensor, &parent_device);
parent_tensor = B::float_add(parent_tensor, child_tensor);
}
parent_tensors.push((parent, parent_tensor));
children_groups.push(children);
}
if parent_tensors.len() > 1 {
// Parents are not yet at the root, do the upper part of the tree
parent_tensors = all_reduce_sum_tree_inner::<B>(parent_tensors, arity);
}
// Phase 2: Redistribute result from each parent to the respective devices
for (parent, parent_tensor) in parent_tensors {
let children = children_groups.remove(0);
for (child, child_device) in children {
// replace child tensors with result
tensors.push((
child,
B::float_to_device(parent_tensor.clone(), &child_device),
));
}
tensors.push((parent, parent_tensor));
}
tensors
}

View File

@@ -0,0 +1,29 @@
use std::collections::HashMap;
use crate::PeerId;
use crate::local::tensor_map::{CollectiveTensorMap, PeerDeviceMap};
use burn_tensor::backend::Backend;
/// Broadcasts the tensor from one device in a map to all the others
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(devices, tensor))
)]
pub(crate) fn broadcast_centralized<B: Backend>(
mut devices: PeerDeviceMap<B>,
central: PeerId,
tensor: B::FloatTensorPrimitive,
) -> CollectiveTensorMap<B> {
let mut output = HashMap::new();
devices
.remove(&central)
.expect("Central device id is in `devices`");
for (dest, dest_device) in devices {
let tensor = B::float_to_device(tensor.clone(), &dest_device);
output.insert(dest, tensor);
}
output.insert(central, tensor);
output
}

View File

@@ -0,0 +1,7 @@
mod centralized;
mod op;
mod tree;
pub(crate) use centralized::*;
pub(crate) use op::*;
pub(crate) use tree::*;

View File

@@ -0,0 +1,172 @@
use crate::local::tensor_map::{CollectiveTensorMap, PeerDeviceMap};
use crate::{
BroadcastStrategy, CollectiveConfig, CollectiveError, PeerId,
local::{broadcast_centralized, broadcast_tree},
node::base::Node,
};
use burn_communication::Protocol;
#[allow(unused_imports)] // TensorMetadata is used by tracing::instrument.
use burn_tensor::TensorMetadata;
use burn_tensor::backend::Backend;
use std::sync::mpsc::SyncSender;
/// An on-going broadcast operation
pub struct BroadcastOp<B: Backend> {
/// broadcast calls, one for each calling device
calls: Vec<BroadcastOpCall<B>>,
/// The tensor to broadcast, as defined by the root. Should be defined before all
/// peers call the operation.
tensor: Option<B::FloatTensorPrimitive>,
/// ID of the root (or use the first call's peer).
root: Option<PeerId>,
}
/// Struct for each device that calls an broadcast operation
pub struct BroadcastOpCall<B: Backend> {
/// Id of the caller of the operation
caller: PeerId,
/// Device of the calling peer
device: B::Device,
/// Callback for the result of the broadcast
result_sender: SyncSender<BroadcastResult<B::FloatTensorPrimitive>>,
}
/// Type sent to the collective client upon completion of a broadcast op
pub(crate) type BroadcastResult<T> = Result<T, CollectiveError>;
impl<B: Backend> BroadcastOp<B> {
pub fn new() -> Self {
Self {
calls: vec![],
tensor: None,
root: None,
}
}
/// Get the effective root of the broadcast operation.
/// If the root is set, return it. Otherwise, return the first caller's peer.
pub fn effective_root(&self) -> PeerId {
self.root.unwrap_or(self.calls.first().unwrap().caller)
}
pub fn peers(&self) -> Vec<PeerId> {
self.calls.iter().map(|c| c.caller).collect()
}
fn peer_devices(&self) -> PeerDeviceMap<B> {
self.calls
.iter()
.map(|call| (call.caller, call.device.clone()))
.collect()
}
/// Register a call to reduce in this operation.
/// When the last caller registers a reduce, the operation is executed.
pub fn register_call(
&mut self,
caller: PeerId,
input: Option<B::FloatTensorPrimitive>,
result_sender: SyncSender<BroadcastResult<B::FloatTensorPrimitive>>,
device: B::Device,
peer_count: usize,
) -> Result<bool, CollectiveError> {
if input.is_some() {
if self.tensor.is_some() {
return Err(CollectiveError::BroadcastMultipleTensors);
}
self.tensor = input;
}
self.calls.push(BroadcastOpCall {
caller,
device,
result_sender,
});
Ok(self.calls.len() == peer_count)
}
/// Runs the broadcast if the operation is ready. Otherwise, do nothing
#[cfg_attr(feature = "tracing", tracing::instrument(
level="trace",
skip(self, config, global_client),
fields(
self.peers = ?self.peers(),
self.shape = ?self.tensor.as_ref().map(|t| t.shape()),
self.dtype = ?self.tensor.as_ref().map(|t| t.dtype()),
)
))]
pub async fn execute<P: Protocol>(
mut self,
config: &CollectiveConfig,
global_client: &mut Option<Node<B, P>>,
) {
// all registered callers have sent a tensor to aggregate
match self.broadcast(config, global_client).await {
Ok(mut tensors) => {
// Return resulting tensors
self.calls.iter().for_each(|call| {
let result = tensors
.remove(&call.caller)
.expect("tensor/peer internal mismatch.");
call.result_sender.send(Ok(result)).unwrap();
});
assert_eq!(tensors.len(), 0, "tensor/peer internal mismatch.");
}
Err(err) => {
// Send error to all subscribers
self.fail(err);
}
}
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(self, config, global_client))
)]
async fn broadcast<P: Protocol>(
&mut self,
config: &CollectiveConfig,
global_client: &mut Option<Node<B, P>>,
) -> Result<CollectiveTensorMap<B>, CollectiveError> {
// Do broadcast on global level with the main tensor
if let Some(global_client) = &global_client {
let strategy = config
.global_broadcast_strategy
.expect("global_broadcast_strategy not defined");
self.tensor = Some(
global_client
.broadcast(self.tensor.clone(), strategy)
.await
.map_err(CollectiveError::Global)?,
)
}
// At this point tensor must be defined
let Some(tensor) = self.tensor.take() else {
return Err(CollectiveError::BroadcastNoTensor);
};
let root = self.effective_root();
let peer_devices = self.peer_devices();
// Broadcast locally
Ok(match config.local_broadcast_strategy {
BroadcastStrategy::Tree(arity) => {
broadcast_tree::<B>(peer_devices, root, tensor, arity)
}
BroadcastStrategy::Centralized => {
broadcast_centralized::<B>(peer_devices, root, tensor)
}
})
}
/// Send a collective error as result to operation caller
pub fn fail(self, err: CollectiveError) {
self.calls.iter().for_each(|call| {
call.result_sender.send(Err(err.clone())).unwrap();
});
}
}

View File

@@ -0,0 +1,98 @@
use burn_tensor::backend::{Backend, DeviceOps};
use std::collections::HashMap;
use crate::PeerId;
use crate::local::tensor_map::{CollectiveTensorMap, PeerDeviceMap};
/// Performs a broadcast on the provided tensors in a b-tree structure with `arity`.
///
/// Tensor must be on the device in the `devices` map corresponding to the `root` key.
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(devices, tensor))
)]
pub(crate) fn broadcast_tree<B: Backend>(
mut devices: PeerDeviceMap<B>,
root: PeerId,
tensor: B::FloatTensorPrimitive,
arity: u32,
) -> CollectiveTensorMap<B> {
// Convert hash map to vector of key-value pairs because order matters
let mut devices_vec = vec![];
let root_device = devices.remove(&root).unwrap();
for (id, tensor) in devices.drain() {
devices_vec.push((id, tensor));
}
// Sort to put devices of the same type together
devices_vec.sort_by(|a, b| {
let dev_a = &a.1;
let dev_b = &b.1;
dev_a.id().cmp(&dev_b.id())
});
// put the root first
devices_vec.insert(0, (root, root_device));
// Recursive broadcast
let out = broadcast_tree_inner::<B>(tensor, devices_vec, arity);
// put results in a hash map
let mut tensors = HashMap::new();
for (id, tensor) in out {
tensors.insert(id, tensor);
}
tensors
}
/// Recursive function that broadcasts tensor across the other devices. Tensor should be on the
/// first device of the list
///
/// Broadcasts the tensor across the devices in the tree in a pre-order traversal.
fn broadcast_tree_inner<B: Backend>(
tensor: B::FloatTensorPrimitive,
mut all_devices: Vec<(PeerId, B::Device)>,
arity: u32,
) -> Vec<(PeerId, B::FloatTensorPrimitive)> {
let mut parents = vec![];
let mut children_groups = vec![];
// Put devices in groups of `arity` + the parent
while !all_devices.is_empty() {
let mut children = vec![];
let parent = all_devices.remove(0);
for _ in 0..arity {
if all_devices.is_empty() {
break;
}
children.push(all_devices.remove(0));
}
parents.push(parent);
children_groups.push(children);
}
let mut parents = if parents.len() > 1 {
broadcast_tree_inner::<B>(tensor, parents, arity)
} else {
let root = parents.first().unwrap();
// `tensor` should already be on the root's device, no need to call B::float_to_device
vec![(root.0, tensor)]
};
// Redistribute result from each parent to the respective devices
let mut tensors = vec![];
for children in children_groups {
let parent = parents.remove(0);
for (child_id, child_device) in children {
// replace child's tensor with parent's
let child_tensor = B::float_to_device(parent.1.clone(), &child_device);
tensors.push((child_id, child_tensor));
}
tensors.push(parent);
}
tensors
}

View File

@@ -0,0 +1,271 @@
use crate::local::all_reduce::AllReduceResult;
use crate::{
CollectiveConfig, CollectiveError, PeerId, ReduceOperation,
local::{
BroadcastResult, ReduceResult,
server::{FinishResult, Message, RegisterResult},
},
};
use burn_tensor::backend::Backend;
use std::sync::mpsc::{Receiver, SyncSender};
/// Local client to communicate with the local server. Each thread has a client.
#[derive(Clone)]
pub(crate) struct LocalCollectiveClient<B: Backend> {
pub channel: SyncSender<Message<B>>,
}
/// A pending operation that can be waited on.
pub(crate) struct PendingCollectiveOperation<T> {
rx: Receiver<Result<T, CollectiveError>>,
}
impl<T> From<PendingCollectiveOperation<T>> for Receiver<Result<T, CollectiveError>> {
fn from(value: PendingCollectiveOperation<T>) -> Self {
value.rx
}
}
impl<T> PendingCollectiveOperation<T> {
/// Wait on the operation.
///
/// Given a `Receiver<Result<T, CollectiveError>>`, this function will wait:
/// - Unwraps `Ok(Result<T, CollectiveError>)` into `Result<T, CollectiveError>`;
/// - maps `Err(RecvError)` to `Err(CollectiveError::LocalServerMissing)`.
pub(crate) fn wait(self) -> Result<T, CollectiveError> {
let tensor = self
.rx
.recv()
.unwrap_or(Err(CollectiveError::LocalServerMissing))?;
Ok(tensor)
}
}
impl<B: Backend> LocalCollectiveClient<B> {
/// Common logic for starting a collective operation.
///
/// - Allocates `(callback, recv)` channels,
/// - Passes the `callback` to the `Message<B>` builder,
/// - Sends the message through the collective channel,
/// - Returns the `recv`.
pub(crate) fn start_operation<T, F>(&self, builder: F) -> PendingCollectiveOperation<T>
where
F: FnOnce(SyncSender<Result<T, CollectiveError>>) -> Message<B>,
{
let (tx, rx) = std::sync::mpsc::sync_channel(1);
self.channel.send((builder)(tx)).unwrap();
PendingCollectiveOperation { rx }
}
/// Common logic for starting a collective operation, with validation.
///
/// When `valid` is `Err`, this function returns a `Receiver<Result<T, CollectiveError>>` that
/// immediately returns `Err(valid)`;
/// otherwise, it behaves like [`LocalCollectiveClient::start_operation`].
pub(crate) fn start_valid_operation<T, F>(
&self,
valid: Result<(), CollectiveError>,
builder: F,
) -> PendingCollectiveOperation<T>
where
F: FnOnce(SyncSender<Result<T, CollectiveError>>) -> Message<B>,
{
match valid {
Err(e) => {
let (tx, rx) = std::sync::mpsc::sync_channel(1);
tx.send(Err(e)).unwrap();
PendingCollectiveOperation { rx }
}
_ => self.start_operation(builder),
}
}
pub(crate) fn reset(&self) {
self.channel.send(Message::Reset).unwrap();
}
pub(crate) fn register(
&mut self,
id: PeerId,
device: B::Device,
config: CollectiveConfig,
) -> RegisterResult {
self.register_start(id, device, config).wait()
}
pub(crate) fn register_start(
&mut self,
id: PeerId,
device: B::Device,
config: CollectiveConfig,
) -> PendingCollectiveOperation<()> {
self.start_valid_operation(
match config.is_valid() {
true => Ok(()),
false => Err(CollectiveError::InvalidConfig),
},
|callback| Message::Register {
device_id: id,
device,
config,
callback,
},
)
}
/// Calls for an all-reduce operation with the given parameters and returns the result.
/// The `params` must be the same as the parameters passed by the other nodes.
///
/// # Arguments
/// * `id` - The peer id of the caller
/// * `tensor` - The input tensor to reduce with the peers' tensors
/// * `config` - Config of the collective operation. Must be coherent with the other calls.
///
/// # Result
/// - `Ok(tensor)` if the operation was successful
/// - `Err(CollectiveError)` on error.
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(self, tensor))
)]
pub fn all_reduce(
&self,
id: PeerId,
tensor: B::FloatTensorPrimitive,
op: ReduceOperation,
) -> AllReduceResult<B::FloatTensorPrimitive> {
self.all_reduce_start(id, tensor, op).wait()
}
/// Starts an all-reduce operation with the given parameters.
///
/// The `params` must be the same as the parameters passed by the other nodes.
///
/// This receiver can be waited on using [`LocalCollectiveClient::operation_wait`].
///
/// # Arguments
/// * `id` - The peer id of the caller
/// * `tensor` - The input tensor to reduce with the peers' tensors
/// * `config` - Config of the collective operation. Must be coherent with the other calls.
///
/// # Result
///
/// A `Receiver<>` that will yield:
/// - `Ok(AllReduceResult<B::FloatTensorPrimitive>)` if the operation was successful
/// - `Err(SendError)` if the channel was dropped.
pub(crate) fn all_reduce_start(
&self,
id: PeerId,
tensor: B::FloatTensorPrimitive,
op: ReduceOperation,
) -> PendingCollectiveOperation<B::FloatTensorPrimitive> {
self.start_operation(|callback| Message::AllReduce {
device_id: id,
tensor,
op,
callback,
})
}
/// Reduces a tensor onto one device.
///
/// # Arguments
/// - `id` - The peer id of the caller.
/// - `tensor` - The tensor to send as input.
/// - `op` - The reduce operation to apply.
/// - `root` - The ID of the peer that will receive the result.
///
/// Returns Ok(None) if the root tensor is not the caller. Otherwise, returns the reduced tensor.
pub fn reduce(
&self,
id: PeerId,
tensor: B::FloatTensorPrimitive,
op: ReduceOperation,
root: PeerId,
) -> ReduceResult<B::FloatTensorPrimitive> {
self.reduce_start(id, tensor, op, root).wait()
}
/// Starts a reduce operation on a tensor onto one device.
///
/// This receiver can be waited on using [`LocalCollectiveClient::operation_wait`].
///
/// # Arguments
/// - `id` - The peer id of the caller.
/// - `tensor` - The tensor to send as input.
/// - `op` - The reduce operation to apply.
/// - `root` - The ID of the peer that will receive the result.
///
/// # Result
///
/// A `Receiver<>` that will yield:
/// - `Ok(ReduceResult<B::FloatTensorPrimitive>)` if the operation was successful
/// - `Err(SendError)` if the channel was dropped.
pub(crate) fn reduce_start(
&self,
id: PeerId,
tensor: B::FloatTensorPrimitive,
op: ReduceOperation,
root: PeerId,
) -> PendingCollectiveOperation<Option<B::FloatTensorPrimitive>> {
self.start_operation(|callback| Message::Reduce {
device_id: id,
tensor,
op,
root,
callback,
})
}
/// Broadcasts, or receives a broadcasted tensor.
///
/// # Arguments
/// - `id` - The peer id of the caller
/// - `tensor` - If defined, this tensor will be broadcasted.
/// Otherwise, this call will receive the broadcasted tensor.
///
/// # Result
/// Synchronously waits on the broadcasted tensor.
pub fn broadcast(
&self,
id: PeerId,
tensor: Option<B::FloatTensorPrimitive>,
) -> BroadcastResult<B::FloatTensorPrimitive> {
self.broadcast_start(id, tensor).wait()
}
/// Starts a Broadcast, or receives a broadcasted tensor.
///
/// This receiver can be waited on using [`LocalCollectiveClient::operation_wait`].
///
/// # Arguments
/// - `id` - The peer id of the caller
/// - `tensor` - If defined, this tensor will be broadcasted. Otherwise, this call will receive
/// the broadcasted tensor.
///
/// # Result
///
/// A `Receiver<>` that will yield:
/// - `Ok(BroadcastResult<B::FloatTensorPrimitive>)` if the operation was successful
/// - `Err(SendError)` if the channel was dropped.
pub(crate) fn broadcast_start(
&self,
id: PeerId,
tensor: Option<B::FloatTensorPrimitive>,
) -> PendingCollectiveOperation<B::FloatTensorPrimitive> {
self.start_operation(|callback| Message::Broadcast {
device_id: id,
tensor,
callback,
})
}
pub(crate) fn finish(&self, id: PeerId) -> FinishResult {
self.finish_start(id).wait()
}
pub(crate) fn finish_start(&self, id: PeerId) -> PendingCollectiveOperation<()> {
self.start_operation(|callback| Message::Finish { id, callback })
}
}

View File

@@ -0,0 +1,12 @@
mod all_reduce;
mod broadcast;
mod reduce;
pub(crate) mod tensor_map;
pub(crate) use all_reduce::*;
pub(crate) use broadcast::*;
pub(crate) use reduce::*;
pub(crate) mod client;
pub(crate) mod server;

View File

@@ -0,0 +1,30 @@
use burn_tensor::backend::Backend;
use crate::PeerId;
use crate::local::tensor_map::CollectiveTensorMap;
#[cfg(feature = "tracing")]
use crate::local::tensor_map::get_common_shape;
/// Sums the tensors on one device and returns the result
#[cfg_attr(feature = "tracing", tracing::instrument(
level="trace",
skip(tensors),
fields(shape = ?get_common_shape::<B>(&tensors).unwrap())
))]
pub(crate) fn reduce_sum_centralized<B: Backend>(
mut tensors: CollectiveTensorMap<B>,
central: &PeerId,
) -> B::FloatTensorPrimitive {
let mut central_tensor = tensors
.remove(central)
.expect("Source device id is in the map");
let central_device = B::float_device(&central_tensor);
for (_, tensor) in tensors {
let rhs = B::float_to_device(tensor.clone(), &central_device);
central_tensor = B::float_add(central_tensor, rhs);
}
central_tensor
}

View File

@@ -0,0 +1,7 @@
mod centralized;
mod op;
mod tree;
pub(crate) use centralized::*;
pub(crate) use op::*;
pub(crate) use tree::*;

View File

@@ -0,0 +1,163 @@
use burn_communication::Protocol;
use burn_tensor::{Shape, TensorMetadata, backend::Backend};
use std::sync::mpsc::SyncSender;
use crate::{
CollectiveConfig, CollectiveError, PeerId, ReduceOperation, ReduceStrategy,
local::{reduce_sum_centralized, reduce_sum_tree},
node::base::Node,
};
/// An on-going reduce operation
pub struct ReduceOp<B: Backend> {
/// reduce calls, one for each calling device
calls: Vec<ReduceOpCall<B>>,
/// The reduce operation, as defined by the first caller
op: ReduceOperation,
/// The peer that receives the reduce result, as defined by the first caller
root: PeerId,
/// The shape of the tensor to reduce, as defined by the first caller
shape: Shape,
}
/// Struct for each device that calls an reduce operation
pub struct ReduceOpCall<B: Backend> {
/// Id of the caller of the operation
caller: PeerId,
/// The tensor primitive passed as input
input: B::FloatTensorPrimitive,
/// Callback for the result of the reduce
result_sender: SyncSender<ReduceResult<B::FloatTensorPrimitive>>,
}
/// Type sent to the collective client upon completion of a reduce aggregation
pub(crate) type ReduceResult<T> = Result<Option<T>, CollectiveError>;
impl<B: Backend> ReduceOp<B> {
pub fn new(shape: Shape, reduce_op: ReduceOperation, root: PeerId) -> Self {
Self {
calls: vec![],
op: reduce_op,
root,
shape,
}
}
fn peers(&self) -> Vec<PeerId> {
self.calls.iter().map(|c| c.caller).collect()
}
/// Register a call to reduce in this operation.
/// When the last caller registers a reduce, the operation is executed.
pub fn register_call(
&mut self,
caller: PeerId,
input: B::FloatTensorPrimitive,
result_sender: SyncSender<ReduceResult<B::FloatTensorPrimitive>>,
op: ReduceOperation,
root: PeerId,
peer_count: usize,
) -> Result<bool, CollectiveError> {
if self.shape != input.shape() {
return Err(CollectiveError::ReduceShapeMismatch);
}
if self.op != op {
return Err(CollectiveError::ReduceOperationMismatch);
}
if self.root != root {
return Err(CollectiveError::ReduceRootMismatch);
}
self.calls.push(ReduceOpCall {
caller,
input,
result_sender,
});
Ok(self.calls.len() == peer_count)
}
/// Runs the all-reduce if the operation is ready. Otherwise, do nothing
#[cfg_attr(feature = "tracing", tracing::instrument(
level="trace",
skip(self, config, global_client),
fields(
?self.op,
?self.shape,
self.peers = ?self.peers(),
)
))]
pub async fn execute<P: Protocol>(
mut self,
root: PeerId,
config: &CollectiveConfig,
global_client: &mut Option<Node<B, P>>,
) {
match self.reduce(config, global_client).await {
Ok(mut result) => {
// Return resulting tensor to root, None to others
self.calls.iter().for_each(|op| {
let msg = if op.caller == root {
Ok(result.take())
} else {
Ok(None)
};
op.result_sender.send(msg).unwrap();
});
}
Err(err) => {
self.fail(err);
}
}
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(self, config, global_client))
)]
async fn reduce<P: Protocol>(
&mut self,
config: &CollectiveConfig,
global_client: &mut Option<Node<B, P>>,
) -> Result<Option<B::FloatTensorPrimitive>, CollectiveError> {
let tensors = self
.calls
.iter()
.map(|call| (call.caller, call.input.clone()))
.collect();
// For Centralized and Tree, we only need to do a reduce here, we'll do a broadcast later
let mut local_sum = match config.local_reduce_strategy {
ReduceStrategy::Centralized => reduce_sum_centralized::<B>(tensors, &self.root),
ReduceStrategy::Tree(arity) => reduce_sum_tree::<B>(tensors, &self.root, arity),
};
// Do aggregation on a global level with the main tensor
let result = if let Some(global_client) = global_client {
let strategy = config
.global_reduce_strategy
.expect("global_reduce_strategy not defined");
global_client
.reduce(local_sum, strategy, self.root, self.op)
.await
.map_err(CollectiveError::Global)?
} else {
// Mean division locally
if self.op == ReduceOperation::Mean {
let local_tensor_count = self.calls.len() as f32;
local_sum = B::float_div_scalar(local_sum, local_tensor_count.into())
}
Some(local_sum)
};
Ok(result)
}
/// Send a collective error as result to operation caller
pub fn fail(self, err: CollectiveError) {
self.calls.iter().for_each(|op| {
op.result_sender.send(Err(err.clone())).unwrap();
});
}
}

View File

@@ -0,0 +1,77 @@
use crate::PeerId;
use crate::local::tensor_map::CollectiveTensorMap;
use burn_tensor::backend::{Backend, DeviceOps};
/// Performs a reduce on the provided tensors in a b-tree structure with `arity`.
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(tensors))
)]
pub(crate) fn reduce_sum_tree<B: Backend>(
mut tensors: CollectiveTensorMap<B>,
root: &PeerId,
arity: u32,
) -> B::FloatTensorPrimitive {
// Convert hash map to vector of key-value pairs because order matters
let mut input = vec![];
let root_tensor = tensors.remove(root).unwrap();
for (_, tensor) in tensors.drain() {
input.push(tensor);
}
// Sort to put devices of the same type together
input.sort_by(|a, b| {
let dev_a = B::float_device(a);
let dev_b = B::float_device(b);
dev_a.id().cmp(&dev_b.id())
});
// put the root first
input.insert(0, root_tensor);
reduce_sum_tree_inner::<B>(input, arity)
}
/// Recursive function that sums `tensors`
///
/// Traverses `tensors` and reduces in a post-order traversal. The first tensor in the list is
/// chosen as the root
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(tensors))
)]
fn reduce_sum_tree_inner<B: Backend>(
mut tensors: Vec<B::FloatTensorPrimitive>,
arity: u32,
) -> B::FloatTensorPrimitive {
let mut parents = vec![];
let mut children_groups = vec![];
// Sum tensors in groups of `arity` + 1
while !tensors.is_empty() {
let mut children = vec![];
let mut parent_tensor = tensors.remove(0);
let parent_device = B::float_device(&parent_tensor);
for _ in 0..arity {
if tensors.is_empty() {
break;
}
let child_tensor = tensors.remove(0);
children.push(B::float_device(&child_tensor));
let rhs = B::float_to_device(child_tensor, &parent_device);
parent_tensor = B::float_add(parent_tensor, rhs);
}
parents.push(parent_tensor);
children_groups.push(children);
}
if parents.len() > 1 {
// Parents are not yet at the root, do the upper part of the tree
reduce_sum_tree_inner::<B>(parents, arity)
} else {
// Root of tree
parents.remove(0)
}
}

View File

@@ -0,0 +1,495 @@
use crate::{
CollectiveConfig, CollectiveError, PeerId, ReduceOperation,
global::node::base::Node,
local::{
AllReduceOp, AllReduceResult, BroadcastOp, BroadcastResult, ReduceOp, ReduceResult,
client::LocalCollectiveClient,
},
};
use burn_communication::websocket::{WebSocket, WsServer};
use burn_tensor::{TensorMetadata, backend::Backend};
use std::sync::{MutexGuard, OnceLock};
use std::{
any::{Any, TypeId},
collections::HashMap,
fmt::Debug,
sync::{
Arc, Mutex,
mpsc::{Receiver, SyncSender},
},
};
use tokio::runtime::{Builder, Runtime};
/// Define the client/server communication on the network
type Network = WebSocket;
/// Type sent to the collective client upon completion of a register request
pub(crate) type RegisterResult = Result<(), CollectiveError>;
/// Type sent to the collective client upon completion of a finish request
pub(crate) type FinishResult = Result<(), CollectiveError>;
/// The local collective server that manages all the collective aggregation operations
/// (like all-reduce) between local threads.
/// This thread takes in messages from different clients. The clients must register, than they can
/// send an aggregate message. They must all use the same parameters for the same aggregate
/// operation.
pub(crate) struct LocalCollectiveServer<B: Backend> {
/// Channel receiver for messages from clients
message_rec: Receiver<Message<B>>,
/// The collective configuration. Must be the same by every peer when calling register
config: Option<CollectiveConfig>,
/// The ids passed to each register so far
peers: Vec<PeerId>,
/// Callbacks for when all registers are done
callbacks_register: Vec<SyncSender<RegisterResult>>,
/// Map of each peer's id and its device
devices: HashMap<PeerId, B::Device>,
/// Current uncompleted all-reduce operation
all_reduce_op: Option<AllReduceOp<B>>,
/// Current uncompleted reduce call
reduce_op: Option<ReduceOp<B>>,
/// Uncompleted broadcast calls, one for each calling device.
broadcast_op: Option<BroadcastOp<B>>,
/// Client for global collective operations
global_client: Option<Node<B, Network>>,
}
#[derive(Debug)]
pub(crate) enum Message<B: Backend> {
/// Register a new peer with the collective.
Register {
device_id: PeerId,
device: B::Device,
config: CollectiveConfig,
callback: SyncSender<RegisterResult>,
},
/// Perform an all-reduce operation.
AllReduce {
device_id: PeerId,
tensor: B::FloatTensorPrimitive,
op: ReduceOperation,
callback: SyncSender<AllReduceResult<B::FloatTensorPrimitive>>,
},
/// Perform a reduce operation.
Reduce {
device_id: PeerId,
tensor: B::FloatTensorPrimitive,
op: ReduceOperation,
root: PeerId,
callback: SyncSender<ReduceResult<B::FloatTensorPrimitive>>,
},
/// Perform a broadcast operation (one-sender, many-receiver).
Broadcast {
device_id: PeerId,
tensor: Option<B::FloatTensorPrimitive>,
callback: SyncSender<BroadcastResult<B::FloatTensorPrimitive>>,
},
/// Reset the collective server.
Reset,
Finish {
id: PeerId,
callback: SyncSender<FinishResult>,
},
}
/// The type-erased box type for [`LocalCollectiveClient<B>`].
type LocalClientBox = Box<dyn Any + Send + Sync>;
/// Global state map from [`Backend`] to boxed [`LocalCollectiveClient<B>`].
static BACKEND_CLIENT_MAP: OnceLock<Mutex<HashMap<TypeId, LocalClientBox>>> = OnceLock::new();
/// Gets a locked mutable view of the `STATE_MAP`.
pub(crate) fn get_backend_client_map() -> MutexGuard<'static, HashMap<TypeId, LocalClientBox>> {
BACKEND_CLIENT_MAP
.get_or_init(Default::default)
.lock()
.unwrap()
}
/// Get a [`LocalCollectiveClient`] for the given [`Backend`].
///
/// Will start the local collective client/server pair if necessary.
pub(crate) fn get_collective_client<B: Backend>() -> LocalCollectiveClient<B> {
let typeid = TypeId::of::<B>();
let mut state_map = get_backend_client_map();
match state_map.get(&typeid) {
Some(val) => val.downcast_ref().cloned().unwrap(),
None => {
let client = LocalCollectiveServer::<B>::setup(LocalCollectiveClientConfig::default());
state_map.insert(typeid, Box::new(client.clone()));
client
}
}
}
/// Global runtime.
static SERVER_RUNTIME: OnceLock<Arc<Runtime>> = OnceLock::new();
/// Get the global [`Runtime`].
pub(crate) fn get_collective_server_runtime() -> Arc<Runtime> {
SERVER_RUNTIME
.get_or_init(|| {
Builder::new_multi_thread()
.enable_all()
.build()
.expect("Unable to initialize runtime")
.into()
})
.clone()
}
/// Configuration for the local collective client/server pair.
pub struct LocalCollectiveClientConfig {
/// Channel capacity for the messaging queue from client to server.
pub channel_capacity: usize,
}
impl Default for LocalCollectiveClientConfig {
fn default() -> Self {
Self {
channel_capacity: 50,
}
}
}
impl From<usize> for LocalCollectiveClientConfig {
fn from(capacity: usize) -> Self {
Self {
channel_capacity: capacity,
}
}
}
impl<B: Backend> LocalCollectiveServer<B> {
fn new(rec: Receiver<Message<B>>) -> Self {
Self {
message_rec: rec,
config: None,
peers: vec![],
devices: HashMap::new(),
all_reduce_op: None,
reduce_op: None,
broadcast_op: None,
callbacks_register: vec![],
global_client: None,
}
}
/// Setup a client/server pair with the given config.
pub(crate) fn setup<C>(cfg: C) -> LocalCollectiveClient<B>
where
C: Into<LocalCollectiveClientConfig>,
{
let cfg = cfg.into();
let (tx, rx) = std::sync::mpsc::sync_channel(cfg.channel_capacity);
get_collective_server_runtime().spawn(async {
let typeid = TypeId::of::<B>();
log::info!("Starting server for backend: {typeid:?}");
let mut server = LocalCollectiveServer::new(rx);
loop {
match server.message_rec.recv() {
Ok(message) => server.process_message(message).await,
Err(err) => {
log::error!(
"Error receiving message from local collective server: {err:?}"
);
break;
}
}
}
});
LocalCollectiveClient { channel: tx }
}
async fn process_message(&mut self, message: Message<B>) {
match message {
Message::Register {
device_id,
device,
config,
callback,
} => {
self.process_register_message(device_id, device, config, &callback)
.await
}
Message::AllReduce {
device_id,
tensor,
op,
callback,
} => {
self.process_all_reduce_message(device_id, tensor, op, callback)
.await
}
Message::Reduce {
device_id,
tensor,
op,
root,
callback,
} => {
self.process_reduce_message(device_id, tensor, op, root, callback)
.await
}
Message::Broadcast {
device_id,
tensor,
callback,
} => {
self.process_broadcast_message(device_id, tensor, callback)
.await
}
Message::Reset => self.reset(),
Message::Finish { id, callback } => self.process_finish_message(id, callback).await,
}
}
async fn process_register_message(
&mut self,
device_id: PeerId,
device: B::Device,
config: CollectiveConfig,
callback: &SyncSender<RegisterResult>,
) {
if !config.is_valid() {
callback.send(Err(CollectiveError::InvalidConfig)).unwrap();
return;
}
if self.peers.contains(&device_id) {
callback
.send(Err(CollectiveError::MultipleRegister))
.unwrap();
return;
}
if self.peers.is_empty() || self.config.is_none() {
self.config = Some(config);
} else if let Some(cfg) = &self.config
&& *cfg != config
{
callback
.send(Err(CollectiveError::RegisterParamsMismatch))
.unwrap();
return;
}
self.peers.push(device_id);
self.callbacks_register.push(callback.clone());
self.devices.insert(device_id, device);
let config = self.config.as_ref().unwrap();
let global_params = config.global_register_params();
if let Some(global_params) = &global_params
&& self.global_client.is_none()
{
let server = WsServer::new(global_params.data_service_port);
let client = Node::new(&global_params.global_address, server);
self.global_client = Some(client)
}
// All have registered, callback
if self.peers.len() == config.num_devices {
let mut register_result = Ok(());
// if an error occurs on the global register, it must be passed back to every local peer
if let Some(global_params) = global_params {
let client = self
.global_client
.as_mut()
.expect("Global client should be initialized");
register_result = client
.register(self.peers.clone(), global_params)
.await
.map_err(CollectiveError::Global);
};
// Send results to all callbacks.
self.callbacks_register
.drain(..)
.for_each(|tx| tx.send(register_result.clone()).unwrap());
}
}
/// Processes an Message::AllReduce.
async fn process_all_reduce_message(
&mut self,
peer_id: PeerId,
tensor: <B as Backend>::FloatTensorPrimitive,
op: ReduceOperation,
callback: SyncSender<AllReduceResult<B::FloatTensorPrimitive>>,
) {
if !self.peers.contains(&peer_id) {
callback
.send(Err(CollectiveError::RegisterNotFirstOperation))
.unwrap();
return;
}
if self.all_reduce_op.is_none() {
// First call to all-reduce
self.all_reduce_op = Some(AllReduceOp::new(tensor.shape(), op));
}
// Take the operation, we'll put it back if we're not done
let mut all_reduce_op = self.all_reduce_op.take().unwrap();
// On the last caller, the all-reduce is done here
let res =
all_reduce_op.register_call(peer_id, tensor, callback.clone(), op, self.peers.len());
// Upon an error or the last call, the all_reduce_op is dropped
match res {
Ok(is_ready) => {
if is_ready {
all_reduce_op
.execute(self.config.as_ref().unwrap(), &mut self.global_client)
.await;
} else {
// Put operation back, we're waiting for more calls
self.all_reduce_op = Some(all_reduce_op)
}
}
Err(err) => all_reduce_op.fail(err),
}
}
/// Processes a Message::Reduce.
async fn process_reduce_message(
&mut self,
peer_id: PeerId,
tensor: <B as Backend>::FloatTensorPrimitive,
op: ReduceOperation,
root: PeerId,
callback: SyncSender<ReduceResult<B::FloatTensorPrimitive>>,
) {
if !self.peers.contains(&root) {
callback
.send(Err(CollectiveError::RegisterNotFirstOperation))
.unwrap();
return;
}
if self.reduce_op.is_none() {
// First call to reduce
self.reduce_op = Some(ReduceOp::new(tensor.shape(), op, root));
}
let mut reduce_op = self.reduce_op.take().unwrap();
// On the last caller, the all-reduce is done here
let res = reduce_op.register_call(
peer_id,
tensor,
callback.clone(),
op,
root,
self.peers.len(),
);
// Upon an error or the last call, the all_reduce_op is dropped
match res {
Ok(is_ready) => {
if is_ready {
reduce_op
.execute(root, self.config.as_ref().unwrap(), &mut self.global_client)
.await;
} else {
// Put operation back, we're waiting for more calls
self.reduce_op = Some(reduce_op)
}
}
Err(err) => reduce_op.fail(err),
}
}
/// Processes a Message::Broadcast.
async fn process_broadcast_message(
&mut self,
caller: PeerId,
tensor: Option<<B as Backend>::FloatTensorPrimitive>,
callback: SyncSender<BroadcastResult<B::FloatTensorPrimitive>>,
) {
if self.config.is_none() {
callback
.send(Err(CollectiveError::RegisterNotFirstOperation))
.unwrap();
return;
}
if !self.peers.contains(&caller) {
callback
.send(Err(CollectiveError::RegisterNotFirstOperation))
.unwrap();
return;
}
if self.broadcast_op.is_none() {
// First call to broadcast
self.broadcast_op = Some(BroadcastOp::new());
}
let device = self.devices.get(&caller).unwrap().clone();
let mut broadcast_op = self.broadcast_op.take().unwrap();
// On the last caller, the all-reduce is done here
let res =
broadcast_op.register_call(caller, tensor, callback.clone(), device, self.peers.len());
// Upon an error or the last call, the all_reduce_op is dropped
match res {
Ok(is_ready) => {
if is_ready {
broadcast_op
.execute(self.config.as_ref().unwrap(), &mut self.global_client)
.await;
} else {
// Put operation back, we're waiting for more calls
self.broadcast_op = Some(broadcast_op)
}
}
Err(err) => broadcast_op.fail(err),
}
}
/// Reinitializes the collective server
fn reset(&mut self) {
self.peers.clear();
self.all_reduce_op = None;
self.reduce_op = None;
self.broadcast_op = None;
}
/// Processes a Message::Finish.
async fn process_finish_message(&mut self, id: PeerId, callback: SyncSender<RegisterResult>) {
if self.config.is_none() {
callback
.send(Err(CollectiveError::RegisterNotFirstOperation))
.unwrap();
return;
}
if !self.peers.contains(&id) {
callback
.send(Err(CollectiveError::MultipleUnregister))
.unwrap();
return;
}
// Remove registered with id
self.peers.retain(|x| *x != id);
if self.peers.is_empty()
&& let Some(mut global_client) = self.global_client.take()
{
global_client.finish().await;
}
callback.send(Ok(())).unwrap();
}
}

View File

@@ -0,0 +1,33 @@
//! # Common Tensor Map for Local Collective Operations
use crate::PeerId;
use burn_std::Shape;
use burn_tensor::TensorMetadata;
use burn_tensor::backend::Backend;
use std::collections::HashMap;
pub type CollectiveTensorMap<B> = HashMap<PeerId, <B as Backend>::FloatTensorPrimitive>;
pub type PeerDeviceMap<B> = HashMap<PeerId, <B as Backend>::Device>;
/// Get the shape of the tensors. They should all have the same shape, otherwise None is returned.
pub fn get_common_shape<B: Backend>(tensors: &CollectiveTensorMap<B>) -> Option<Shape> {
let mut it = tensors.values();
if let Some(first) = it.next() {
let shape = first.shape();
for tensor in it {
if tensor.shape() != shape {
return None;
}
}
return Some(shape);
}
None
}
/// Get the `{ peer_id -> device }` mapping for the given tensors.
pub fn get_peer_devices<B: Backend>(tensors: &CollectiveTensorMap<B>) -> PeerDeviceMap<B> {
tensors
.iter()
.map(|(id, tensor)| (*id, B::float_device(tensor)))
.collect()
}

View File

@@ -0,0 +1,174 @@
mod tests {
use std::sync::mpsc::SyncSender;
use burn_std::rand::get_seeded_rng;
use burn_tensor::{Shape, Tensor, TensorData, TensorPrimitive, Tolerance, backend::Backend};
use serial_test::serial;
#[cfg(feature = "test-ndarray")]
pub type TestBackend = burn_ndarray::NdArray<f32>;
#[cfg(feature = "test-cuda")]
pub type TestBackend = burn_cuda::Cuda<f32>;
#[cfg(feature = "test-wgpu")]
pub type TestBackend = burn_wgpu::Wgpu<f32>;
#[cfg(feature = "test-metal")]
pub type TestBackend = burn_wgpu::Wgpu<f32>;
#[cfg(feature = "test-vulkan")]
pub type TestBackend = burn_wgpu::Wgpu<f32>;
use crate::{
AllReduceStrategy, CollectiveConfig, PeerId, ReduceOperation, all_reduce, register,
reset_collective,
};
pub fn run_peer<B: Backend>(
id: PeerId,
config: CollectiveConfig,
input: TensorData,
op: ReduceOperation,
output: SyncSender<Tensor<B, 1>>,
) {
let device = B::Device::default();
register::<B>(id, device.clone(), config).unwrap();
let tensor = Tensor::<B, 1>::from_data(input, &device);
let tensor = Tensor::from_primitive(TensorPrimitive::Float(
all_reduce::<B>(id, tensor.into_primitive().tensor(), op).unwrap(),
));
output.send(tensor).unwrap();
}
fn generate_random_input(
shape: Shape,
op: ReduceOperation,
thread_count: usize,
) -> (Vec<TensorData>, TensorData) {
let input: Vec<TensorData> = (0..thread_count)
.map(|_| {
TensorData::random::<f32, _, _>(
shape.clone(),
burn_tensor::Distribution::Default,
&mut get_seeded_rng(),
)
})
.collect();
let device = <TestBackend as Backend>::Device::default();
let mut expected_tensor = Tensor::<TestBackend, 1>::zeros(shape, &device);
for item in input.iter().take(thread_count as usize) {
let input_tensor = Tensor::<TestBackend, 1>::from_data(item.clone(), &device);
expected_tensor = expected_tensor.add(input_tensor);
}
if op == ReduceOperation::Mean {
expected_tensor = expected_tensor.div_scalar(thread_count as u32);
}
let expected = expected_tensor.to_data();
(input, expected)
}
fn test_all_reduce<B: Backend>(
device_count: usize,
op: ReduceOperation,
strategy: AllReduceStrategy,
tensor_size: usize,
) {
reset_collective::<TestBackend>();
let (send, recv) = std::sync::mpsc::sync_channel(32);
let shape = Shape {
dims: vec![tensor_size],
};
let (input, expected) = generate_random_input(shape, op, device_count);
let config = CollectiveConfig::default()
.with_num_devices(device_count)
.with_local_all_reduce_strategy(strategy);
for id in 0..device_count {
let send = send.clone();
let input = input[id as usize].clone();
std::thread::spawn({
let config = config.clone();
move || run_peer::<B>(id.into(), config, input, op, send)
});
}
let first = recv.recv().unwrap().to_data();
for _ in 1..device_count {
let tensor = recv.recv().unwrap();
tensor.to_data().assert_eq(&first, true);
}
let tol: Tolerance<f32> = Tolerance::balanced();
expected.assert_approx_eq(&first, tol);
}
#[test]
#[serial]
pub fn test_all_reduce_centralized_sum() {
test_all_reduce::<TestBackend>(4, ReduceOperation::Sum, AllReduceStrategy::Centralized, 4);
}
#[test]
#[serial]
pub fn test_all_reduce_centralized_mean() {
test_all_reduce::<TestBackend>(4, ReduceOperation::Mean, AllReduceStrategy::Centralized, 4);
}
#[test]
#[serial]
pub fn test_all_reduce_binary_tree_sum() {
test_all_reduce::<TestBackend>(4, ReduceOperation::Sum, AllReduceStrategy::Tree(2), 4);
}
#[test]
#[serial]
pub fn test_all_reduce_binary_tree_mean() {
test_all_reduce::<TestBackend>(4, ReduceOperation::Mean, AllReduceStrategy::Tree(2), 4);
}
#[test]
#[serial]
pub fn test_all_reduce_5_tree_sum() {
test_all_reduce::<TestBackend>(4, ReduceOperation::Sum, AllReduceStrategy::Tree(5), 4);
}
#[test]
#[serial]
pub fn test_all_reduce_5_tree_mean() {
test_all_reduce::<TestBackend>(4, ReduceOperation::Mean, AllReduceStrategy::Tree(5), 4);
}
#[test]
#[serial]
pub fn test_all_reduce_ring_sum() {
test_all_reduce::<TestBackend>(3, ReduceOperation::Sum, AllReduceStrategy::Ring, 3);
}
#[test]
#[serial]
pub fn test_all_reduce_ring_mean() {
test_all_reduce::<TestBackend>(3, ReduceOperation::Mean, AllReduceStrategy::Ring, 3);
}
#[test]
#[serial]
pub fn test_all_reduce_ring_irregular_sum() {
// this should trigger the fallback algorithm when the tensor is too small.
test_all_reduce::<TestBackend>(4, ReduceOperation::Sum, AllReduceStrategy::Ring, 3);
}
}

View File

@@ -0,0 +1,126 @@
mod tests {
use std::sync::mpsc::SyncSender;
use burn_std::rand::get_seeded_rng;
use burn_tensor::{Shape, Tensor, TensorData, TensorPrimitive, Tolerance, backend::Backend};
use serial_test::serial;
#[cfg(feature = "test-ndarray")]
pub type TestBackend = burn_ndarray::NdArray<f32>;
#[cfg(feature = "test-cuda")]
pub type TestBackend = burn_cuda::Cuda<f32>;
#[cfg(feature = "test-wgpu")]
pub type TestBackend = burn_wgpu::Wgpu<f32>;
#[cfg(feature = "test-metal")]
pub type TestBackend = burn_wgpu::Wgpu<f32>;
#[cfg(feature = "test-vulkan")]
pub type TestBackend = burn_wgpu::Wgpu<f32>;
use crate::{
BroadcastStrategy, CollectiveConfig, PeerId, broadcast, register, reset_collective,
};
pub fn run_peer<B: Backend>(
id: PeerId,
config: CollectiveConfig,
input: Option<TensorData>,
output: SyncSender<Tensor<B, 1>>,
) {
let device = B::Device::default();
register::<B>(id, device.clone(), config).unwrap();
let tensor = input.map(|data| B::float_from_data(data, &device));
let tensor = broadcast::<B>(id, tensor).unwrap();
let tensor = Tensor::<B, 1>::from_primitive(TensorPrimitive::Float(tensor));
output.send(tensor).unwrap();
}
fn generate_random_input(shape: Shape) -> TensorData {
TensorData::random::<f32, _, _>(
shape.clone(),
burn_tensor::Distribution::Default,
&mut get_seeded_rng(),
)
}
fn test_broadcast<B: Backend>(
device_count: usize,
strategy: BroadcastStrategy,
tensor_size: usize,
) {
reset_collective::<TestBackend>();
let (send, recv) = std::sync::mpsc::sync_channel(32);
let shape = Shape {
dims: vec![tensor_size],
};
let input = generate_random_input(shape);
let config = CollectiveConfig::default()
.with_num_devices(device_count)
.with_local_broadcast_strategy(strategy);
for id in 0..device_count {
// The peer #0 is the root: it sends the tensor
let input = if id == 0 { Some(input.clone()) } else { None };
std::thread::spawn({
let config = config.clone();
let send = send.clone();
move || run_peer::<B>(id.into(), config, input, send)
});
}
// Expect all peers to receive the input tensor
let tol: Tolerance<f32> = Tolerance::balanced();
for _ in 0..device_count {
let tensor = recv.recv().unwrap().to_data();
input.assert_approx_eq(&tensor, tol);
}
}
#[test]
#[serial]
pub fn test_broadcast_centralized_sum() {
test_broadcast::<TestBackend>(4, BroadcastStrategy::Centralized, 4);
}
#[test]
#[serial]
pub fn test_broadcast_centralized_mean() {
test_broadcast::<TestBackend>(4, BroadcastStrategy::Centralized, 4);
}
#[test]
#[serial]
pub fn test_broadcast_binary_tree_sum() {
test_broadcast::<TestBackend>(4, BroadcastStrategy::Tree(2), 4);
}
#[test]
#[serial]
pub fn test_broadcast_binary_tree_mean() {
test_broadcast::<TestBackend>(4, BroadcastStrategy::Tree(2), 4);
}
#[test]
#[serial]
pub fn test_broadcast_5_tree_sum() {
test_broadcast::<TestBackend>(4, BroadcastStrategy::Tree(5), 4);
}
#[test]
#[serial]
pub fn test_broadcast_5_tree_mean() {
test_broadcast::<TestBackend>(4, BroadcastStrategy::Tree(5), 4);
}
}

View File

@@ -0,0 +1,3 @@
mod all_reduce;
mod broadcast;
mod reduce;

View File

@@ -0,0 +1,162 @@
mod tests {
use std::sync::mpsc::SyncSender;
use burn_std::rand::get_seeded_rng;
use burn_tensor::{Shape, Tensor, TensorData, TensorPrimitive, Tolerance, backend::Backend};
use serial_test::serial;
#[cfg(feature = "test-ndarray")]
pub type TestBackend = burn_ndarray::NdArray<f32>;
#[cfg(feature = "test-cuda")]
pub type TestBackend = burn_cuda::Cuda<f32>;
#[cfg(feature = "test-wgpu")]
pub type TestBackend = burn_wgpu::Wgpu<f32>;
#[cfg(feature = "test-metal")]
pub type TestBackend = burn_wgpu::Wgpu<f32>;
#[cfg(feature = "test-vulkan")]
pub type TestBackend = burn_wgpu::Wgpu<f32>;
use crate::{
CollectiveConfig, PeerId, ReduceOperation, ReduceStrategy, reduce, register,
reset_collective,
};
pub fn run_peer<B: Backend>(
id: PeerId,
config: CollectiveConfig,
input: TensorData,
op: ReduceOperation,
root: PeerId,
output: SyncSender<Option<Tensor<B, 1>>>,
) {
let device = B::Device::default();
register::<B>(id, device.clone(), config).unwrap();
let tensor = Tensor::<B, 1>::from_data(input, &device);
let tensor = tensor.into_primitive().tensor();
let tensor = reduce::<B>(id, tensor, op, root).unwrap();
let tensor = tensor.map(|t| Tensor::<B, 1>::from_primitive(TensorPrimitive::Float(t)));
output.send(tensor).unwrap();
}
fn generate_random_input(
shape: Shape,
op: ReduceOperation,
thread_count: usize,
) -> (Vec<TensorData>, TensorData) {
let input: Vec<TensorData> = (0..thread_count)
.map(|_| {
TensorData::random::<f32, _, _>(
shape.clone(),
burn_tensor::Distribution::Default,
&mut get_seeded_rng(),
)
})
.collect();
let device = <TestBackend as Backend>::Device::default();
let mut expected_tensor = Tensor::<TestBackend, 1>::zeros(shape, &device);
for item in input.iter().take(thread_count) {
let input_tensor = Tensor::<TestBackend, 1>::from_data(item.clone(), &device);
expected_tensor = expected_tensor.add(input_tensor);
}
if op == ReduceOperation::Mean {
expected_tensor = expected_tensor.div_scalar(thread_count as u32);
}
let expected = expected_tensor.to_data();
(input, expected)
}
fn test_reduce<B: Backend>(
device_count: usize,
op: ReduceOperation,
strategy: ReduceStrategy,
tensor_size: usize,
) {
reset_collective::<TestBackend>();
let (send, recv) = std::sync::mpsc::sync_channel(32);
let shape = Shape {
dims: vec![tensor_size],
};
let (input, expected) = generate_random_input(shape, op, device_count);
let config = CollectiveConfig::default()
.with_num_devices(device_count)
.with_local_reduce_strategy(strategy);
let root: PeerId = 0.into();
for id in 0..device_count {
let send = send.clone();
let input = input[id as usize].clone();
std::thread::spawn({
let config = config.clone();
move || run_peer::<B>(id.into(), config, input, op, root, send)
});
}
let mut result = None;
for _ in 0..device_count {
let tensor = recv.recv().unwrap();
if tensor.is_some() {
if result.is_some() {
panic!("Two peers received the result of an reduce!");
}
result = tensor.map(|t| t.to_data());
}
}
let tol: Tolerance<f32> = Tolerance::balanced();
expected.assert_approx_eq(&result.expect("One peer has received the result"), tol);
}
#[test]
#[serial]
pub fn test_reduce_centralized_sum() {
test_reduce::<TestBackend>(4, ReduceOperation::Sum, ReduceStrategy::Centralized, 4);
}
#[test]
#[serial]
pub fn test_reduce_centralized_mean() {
test_reduce::<TestBackend>(4, ReduceOperation::Mean, ReduceStrategy::Centralized, 4);
}
#[test]
#[serial]
pub fn test_reduce_binary_tree_sum() {
test_reduce::<TestBackend>(4, ReduceOperation::Sum, ReduceStrategy::Tree(2), 4);
}
#[test]
#[serial]
pub fn test_reduce_binary_tree_mean() {
test_reduce::<TestBackend>(4, ReduceOperation::Mean, ReduceStrategy::Tree(2), 4);
}
#[test]
#[serial]
pub fn test_reduce_5_tree_sum() {
test_reduce::<TestBackend>(4, ReduceOperation::Sum, ReduceStrategy::Tree(5), 4);
}
#[test]
#[serial]
pub fn test_reduce_5_tree_mean() {
test_reduce::<TestBackend>(4, ReduceOperation::Mean, ReduceStrategy::Tree(5), 4);
}
}