feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
@@ -0,0 +1,73 @@
|
||||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science"]
|
||||
description = "Backend extension for collective calculations."
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "collective"]
|
||||
license.workspace = true
|
||||
name = "burn-collective"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-collective"
|
||||
documentation = "https://docs.rs/burn-collective"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = []
|
||||
doc = []
|
||||
tracing = [
|
||||
"dep:tracing",
|
||||
"burn-std/tracing",
|
||||
"burn-tensor/tracing",
|
||||
"burn-communication/tracing",
|
||||
"burn-ndarray?/tracing",
|
||||
"burn-wgpu?/tracing",
|
||||
"burn-cuda?/tracing",
|
||||
]
|
||||
orchestrator = ["burn-communication/websocket"]
|
||||
|
||||
# Backends for testing
|
||||
test-ndarray = ["burn-ndarray"]
|
||||
test-wgpu = ["burn-wgpu", "burn-wgpu/webgpu"]
|
||||
test-metal = ["burn-wgpu", "burn-wgpu/metal"]
|
||||
test-vulkan = ["burn-wgpu", "burn-wgpu/vulkan"]
|
||||
test-cuda = ["burn-cuda"]
|
||||
|
||||
[dependencies]
|
||||
burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", default-features = true }
|
||||
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = true }
|
||||
|
||||
log = { workspace = true }
|
||||
|
||||
burn-communication = { path = "../burn-communication", version = "=0.21.0-pre.2", features = [
|
||||
"data-service",
|
||||
"websocket",
|
||||
] }
|
||||
tokio = { workspace = true, features = [
|
||||
"rt-multi-thread",
|
||||
"sync",
|
||||
"signal",
|
||||
"time",
|
||||
"tracing",
|
||||
] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
rmp-serde = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
tracing = { workspace = true, optional = true }
|
||||
|
||||
# Tests
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2", optional = true }
|
||||
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true }
|
||||
burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
serial_test = { workspace = true }
|
||||
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
@@ -0,0 +1,139 @@
|
||||
# burn-collective
|
||||
|
||||
Collective operations on tensors
|
||||
|
||||
The following collective operation are implemented:
|
||||
|
||||
- `all-reduce`
|
||||
Aggregates a tensor between all peers, and distributes the result to all peers.
|
||||
Different strategies can be used on the local and global levels. The result can only be
|
||||
returned when all peers have called the all-reduce.
|
||||
- `reduce`
|
||||
Aggregates a tensor from all peers onto one peer, called the "root"
|
||||
- `broadcast`
|
||||
Copies a tensor from one peer to all other peers in the collective.
|
||||
|
||||
Peers must call `register` before calling any other operation.
|
||||
The total number of devices on the node, or nodes in the collective, must be known ahead of time.
|
||||
|
||||
In many libraries like NCCL and PyTorch, participating units are called "ranks".
|
||||
This name is confusing in the context of tensors, so in burn-collective the participating units
|
||||
are called "peers".
|
||||
|
||||
*`reduce` and `broadcast` are not yet implemented for multi-node contexts*
|
||||
|
||||
## Local and Global
|
||||
|
||||
Internally, there are two levels to the collective operations: local and global. Operations are done on the local level, then optionally on the global level.
|
||||
|
||||
| Local | Global |
|
||||
|-----------------------------------------------|-----------------------------------------------|
|
||||
| Intra-node (typically within one machine) | Inter-node (typically across machies) |
|
||||
| Participants are threads (one per peer/GPU) | Participants are processes (one per node) |
|
||||
| Communication depends on backend | Network peer-to-peer communication |
|
||||
| Local server is launched automatically | Global coordinator must be launched manually |
|
||||
| Local server does the aggregation | Nodes do the operations themselves |
|
||||
|
||||
For global operations (ie. with multiple nodes), there must be a global orchestrator available.
|
||||
Start one easily with `burn_collective::start_global_orchestrator()`.
|
||||
|
||||
On the global level, nodes use the `burn_communication::data_service::TensorDataService` to
|
||||
expose and download tensors in a peer-to-peer manner, in order to be independent.
|
||||
|
||||
## Components
|
||||
|
||||
The following are the important pieces of the collective operations system.
|
||||
|
||||
| Term | One per... | Meaning
|
||||
|--------------------------------|---------------|----------------------------------------------------------
|
||||
| Local Collective Client | Peer/thread | Requests operations to the Local Collective Server
|
||||
| Local Collective Server | Node/process | Does local-level ops for threads in this process. In the case of global operations, passes operations on to the Global Collective Client.
|
||||
| Global Collective Client | Node/process | Does global-level ops for this node. Registers and requests strategies from the Global Collective Orchestrator.
|
||||
| Global Collective Orchestrator | Collective | Responds to the Global Collective Client from each node. Responsible for aggregation strategies.
|
||||
|
||||
## Strategies
|
||||
|
||||
Different strategies can be used on the local and global level.
|
||||
|
||||
### Centralized
|
||||
|
||||
An arbitrary peer is designated as the "root", and all others are transferred to the root's device.
|
||||
The operation is done on that device.
|
||||
The resulting tensor then sent to each peer.
|
||||
|
||||
### Tree
|
||||
|
||||
Tensors in groups of N are aggregated together. This is done recursively until only one tensor
|
||||
remains. The strategy tries to put devices of the same type closer in the tree.
|
||||
When N=2, this is like a binary tree reduce.
|
||||
The resulting tensor then sent to each peer
|
||||
|
||||
### Ring
|
||||
|
||||
See this good explanation: <https://blog.dailydoseofds.com/p/all-reduce-and-ring-reduce-for-model>
|
||||
|
||||
The tensors are sliced into N parts, where N is the number of tensors to aggregate.
|
||||
Then, the slices are sent around in a series of cycles and aggregated until every tensor's slices
|
||||
is a sum of the other corresponding slices.
|
||||
|
||||
In the case where the tensors are too small to split into N slices, a fallback algorithm is used.
|
||||
For now, the fallback is a binary tree.
|
||||
|
||||
(p=3, n=3)
|
||||
|
||||
o->o o
|
||||
o o->o
|
||||
o o o->
|
||||
|
||||
o 1->o
|
||||
o o 1->
|
||||
1->o o
|
||||
|
||||
o 1 2->
|
||||
2->o 1
|
||||
1 2->o
|
||||
|
||||
3 1 2
|
||||
2 3 1
|
||||
1 2 3
|
||||
|
||||
(This is essentially a reduce-scatter)
|
||||
|
||||
3->x x
|
||||
x 3->x
|
||||
x x 3->
|
||||
|
||||
3 3->x
|
||||
x 3 3->
|
||||
3->x 3
|
||||
|
||||
3 3 3->
|
||||
3->3 3
|
||||
3 3->3
|
||||
|
||||
3 3 3
|
||||
3 3 3
|
||||
3 3 3
|
||||
|
||||
(This is essentially an all-gather)
|
||||
|
||||
This is done so that every peer is both sending and receiving data at any moment.
|
||||
This is an important part of this strategy's advantages.
|
||||
|
||||
The ring strategy takes full advantage of the bandwidth available. The latency scales with the
|
||||
number of peers.
|
||||
|
||||
So when the tensors are very small, or when the number of peers is very large, the latency is more
|
||||
important in the ring strategy, and a tree algorithm is better. Otherwise, the ring algorithm is
|
||||
the better.
|
||||
|
||||
In multi-node contexts, use of the Ring strategy in the local level may be less
|
||||
advantageous. With multiple nodes, the global all-reduce step is enabled, and its result
|
||||
is redistributed to all devices.
|
||||
The Ring strategy inherently distributes the result, which in this context would not be necessary.
|
||||
|
||||
It is recommended to use the Ring strategy at the global level
|
||||
|
||||
### Double binary tree
|
||||
|
||||
<https://developer.nvidia.com/blog/massively-scale-deep-learning-training-nccl-2-4/>
|
||||
@@ -0,0 +1,34 @@
|
||||
[package]
|
||||
name = "burn-collective-multinode-tests"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["ndarray"]
|
||||
ndarray = ["burn/ndarray"]
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../../burn", default-features = false, features = ["std"] }
|
||||
burn-std = { path = "../../burn-std", default-features = false }
|
||||
burn-collective = { path = "..", features = ["orchestrator"] }
|
||||
burn-communication = { path = "../../burn-communication" }
|
||||
|
||||
tokio = { workspace = true, features = ["rt-multi-thread", "process"] }
|
||||
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
interprocess = "2.3.1"
|
||||
rmp-serde = { workspace = true }
|
||||
tokio-util = { workspace = true, features = ["codec"] }
|
||||
tokio-serde = { version = "0.9.0", features = ["messagepack"] }
|
||||
futures = { workspace = true }
|
||||
|
||||
|
||||
[[bin]]
|
||||
name = "global"
|
||||
path = "src/bin/global.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "node"
|
||||
path = "src/bin/node.rs"
|
||||
@@ -0,0 +1,32 @@
|
||||
# Integration test for burn collective operations with multiple nodes and devices.
|
||||
|
||||
Run `cargo run --bin test_launcher`
|
||||
|
||||
There are 3 binaries:
|
||||
|
||||
## node.rs
|
||||
|
||||
Launches `n` threads each simulating a different device. Currently the backend is NdArray,
|
||||
so everything is CPU. The program takes a file with configurations and input data.
|
||||
|
||||
## global.rs
|
||||
|
||||
Runs the global orchestrator, who is responsible for responding to global collective operation
|
||||
requests. In the case of an all-reduce, the orchestrator responds with a strategy for reducing,
|
||||
and the node can do the reduction independently.
|
||||
|
||||
## test_launcher.rs
|
||||
|
||||
Generates input data, calculates the expected results, and launches the nodes each with their
|
||||
own inputs in a separate file.
|
||||
|
||||
The topology is [4, 4, 4, 4]. This means 4 nodes are launched,
|
||||
each with 4 threads (for each device).
|
||||
|
||||
The global orchestrator (`global.rs`) is also launched.
|
||||
|
||||
## Output
|
||||
|
||||
The outputs and inputs for each node and the orchestrator are written to the `target/test_files` folder
|
||||
|
||||
If the nodes or orchestrator stall, there is a timeout.
|
||||
@@ -0,0 +1,19 @@
|
||||
//! Global orchestrator
|
||||
//!
|
||||
//! Launches the orchestrator that responds to global collective operations for nodes for the
|
||||
//! integration test
|
||||
//!
|
||||
//! This is necessary for any node who needs global collective operations
|
||||
|
||||
use std::env;
|
||||
|
||||
#[tokio::main]
|
||||
/// Start the global orchestrator on the port given as first arg
|
||||
pub async fn main() {
|
||||
let args: Vec<String> = env::args().collect();
|
||||
let port = args[1].parse::<u16>().expect("invalid port");
|
||||
|
||||
// Launch the global orchestrator, which will listen and respond to global collective op
|
||||
// requests from nodes
|
||||
burn_collective::start_global_orchestrator(port).await;
|
||||
}
|
||||
@@ -0,0 +1,157 @@
|
||||
use burn::{
|
||||
backend::NdArray,
|
||||
prelude::Backend,
|
||||
tensor::{Tensor, TensorPrimitive, Tolerance},
|
||||
};
|
||||
use burn_collective::{
|
||||
CollectiveConfig, PeerId, ReduceOperation, all_reduce, finish_collective, register,
|
||||
reset_collective,
|
||||
};
|
||||
use burn_collective_multinode_tests::shared::{NodeTest, NodeTestResult, TENSOR_RANK};
|
||||
use std::{
|
||||
env,
|
||||
sync::mpsc::SyncSender,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use std::thread::JoinHandle;
|
||||
use tokio_serde::formats::MessagePack;
|
||||
use tokio_util::codec::LengthDelimitedCodec;
|
||||
|
||||
type TestBackend = NdArray;
|
||||
|
||||
/// Framed TCP connection channel
|
||||
type TestChannel = tokio_serde::Framed<
|
||||
tokio_util::codec::Framed<tokio::net::TcpStream, LengthDelimitedCodec>,
|
||||
NodeTest,
|
||||
NodeTestResult,
|
||||
MessagePack<NodeTest, NodeTestResult>,
|
||||
>;
|
||||
|
||||
/// Start a node that will test all-reduce
|
||||
/// Args are the following:
|
||||
/// - launcher endpoint
|
||||
#[tokio::main]
|
||||
pub async fn main() {
|
||||
let args: Vec<String> = env::args().collect();
|
||||
|
||||
let launcher_addr = args[1].clone();
|
||||
|
||||
let socket = TcpStream::connect(launcher_addr).await.unwrap();
|
||||
let length_delimited = tokio_util::codec::Framed::new(socket, LengthDelimitedCodec::new());
|
||||
let mut socket: TestChannel = tokio_serde::Framed::new(
|
||||
length_delimited,
|
||||
MessagePack::<NodeTest, NodeTestResult>::default(),
|
||||
);
|
||||
|
||||
// Loop: receive, do test, send result
|
||||
while let Some(Ok(test)) = socket.next().await {
|
||||
println!("Received test: {test:?}");
|
||||
|
||||
let result = run_test::<NdArray>(&test);
|
||||
|
||||
// send the result back
|
||||
socket.send(result).await.expect("failed to send Result");
|
||||
}
|
||||
|
||||
println!("Server closed connection; exiting.");
|
||||
}
|
||||
|
||||
/// Runs a test for one node
|
||||
fn run_test<B: Backend>(test_input: &NodeTest) -> NodeTestResult {
|
||||
reset_collective::<TestBackend>();
|
||||
|
||||
// Channel for results
|
||||
let (result_send, result_recv) = std::sync::mpsc::sync_channel(32);
|
||||
|
||||
// Launch a thread for each "device"
|
||||
let handles = launch_threads::<B>(test_input.clone(), result_send);
|
||||
|
||||
// Receive results
|
||||
let mut durations = vec![];
|
||||
let tol: Tolerance<f32> = Tolerance::balanced();
|
||||
for _ in 0..test_input.device_count {
|
||||
// Assert all results are equal to each other as well as expected result
|
||||
let (tensor, duration) = result_recv.recv().unwrap();
|
||||
test_input.expected.assert_approx_eq(&tensor.to_data(), tol);
|
||||
|
||||
durations.push(duration);
|
||||
}
|
||||
|
||||
// Threads finish
|
||||
for handle in handles {
|
||||
let _ = handle.join();
|
||||
}
|
||||
|
||||
NodeTestResult {
|
||||
success: true,
|
||||
durations,
|
||||
}
|
||||
}
|
||||
|
||||
/// Launch a thread for each device, and run the all-reduce
|
||||
fn launch_threads<B: Backend>(
|
||||
test_input: NodeTest,
|
||||
result_send: SyncSender<(Tensor<B, TENSOR_RANK>, Duration)>,
|
||||
) -> Vec<JoinHandle<()>> {
|
||||
let mut handles = vec![];
|
||||
for id in 0..test_input.device_count {
|
||||
// Launch a thread to test
|
||||
|
||||
// Put all the parameters in the config
|
||||
let config = CollectiveConfig::default()
|
||||
.with_num_devices(test_input.device_count)
|
||||
.with_global_address(test_input.global_address.clone())
|
||||
.with_node_address(test_input.node_address.clone())
|
||||
.with_data_service_port(test_input.data_service_port)
|
||||
.with_num_nodes(test_input.node_count)
|
||||
.with_global_all_reduce_strategy(test_input.global_strategy)
|
||||
.with_local_all_reduce_strategy(test_input.local_strategy);
|
||||
|
||||
// Inputs and outputs for the test
|
||||
let tensor_data = test_input.inputs[id].clone();
|
||||
let tensor = Tensor::<B, TENSOR_RANK>::from_data(tensor_data, &B::Device::default());
|
||||
let result_send = result_send.clone();
|
||||
|
||||
let handle = std::thread::spawn(move || {
|
||||
run_peer::<B>(
|
||||
id.into(),
|
||||
config,
|
||||
tensor,
|
||||
result_send,
|
||||
test_input.all_reduce_op,
|
||||
)
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
handles
|
||||
}
|
||||
|
||||
/// Runs a thread in the all-reduce test.
|
||||
pub fn run_peer<B: Backend>(
|
||||
id: PeerId,
|
||||
config: CollectiveConfig,
|
||||
input: Tensor<B, TENSOR_RANK>,
|
||||
output: SyncSender<(Tensor<B, TENSOR_RANK>, Duration)>,
|
||||
all_reduce_op: ReduceOperation,
|
||||
) {
|
||||
// Register the device
|
||||
register::<B>(id, input.device(), config).unwrap();
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
// All-reduce
|
||||
let input = input.into_primitive().tensor();
|
||||
let tensor = all_reduce::<B>(id, input, all_reduce_op).unwrap();
|
||||
let tensor = Tensor::<B, TENSOR_RANK>::from_primitive(TensorPrimitive::Float(tensor));
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
// Send result
|
||||
output.send((tensor, duration)).unwrap();
|
||||
|
||||
finish_collective::<B>(id).unwrap();
|
||||
}
|
||||
@@ -0,0 +1,354 @@
|
||||
use burn::tensor::TensorData;
|
||||
use burn_communication::Address;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use std::{
|
||||
fmt::Display,
|
||||
fs::{self, File},
|
||||
str::FromStr,
|
||||
time::{Duration, Instant},
|
||||
vec,
|
||||
};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_serde::formats::MessagePack;
|
||||
use tokio_util::codec::LengthDelimitedCodec;
|
||||
|
||||
use burn::{backend::NdArray, prelude::Backend, tensor::Tensor};
|
||||
use burn_collective::{AllReduceStrategy, ReduceOperation};
|
||||
use burn_collective_multinode_tests::shared::{NodeTest, NodeTestResult, TENSOR_RANK};
|
||||
use burn_std::rand::{SeedableRng, StdRng};
|
||||
use tokio::process::{Child, Command};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AllReduceTest {
|
||||
shape: [usize; TENSOR_RANK],
|
||||
op: ReduceOperation,
|
||||
local_strategy: AllReduceStrategy,
|
||||
global_strategy: AllReduceStrategy,
|
||||
}
|
||||
|
||||
impl Display for AllReduceTest {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let op_str = match self.op {
|
||||
ReduceOperation::Sum => "sum",
|
||||
ReduceOperation::Mean => "mean",
|
||||
};
|
||||
let local_strategy_str = match self.local_strategy {
|
||||
AllReduceStrategy::Centralized => "local_centralized",
|
||||
AllReduceStrategy::Tree(n) => &format!("local_tree_{n}"),
|
||||
AllReduceStrategy::Ring => "local_ring",
|
||||
};
|
||||
let global_strategy_str = match self.global_strategy {
|
||||
AllReduceStrategy::Centralized => "global_centralized",
|
||||
AllReduceStrategy::Tree(n) => &format!("global_tree_{n}"),
|
||||
AllReduceStrategy::Ring => "global_ring",
|
||||
};
|
||||
|
||||
write!(f, "{op_str}_{local_strategy_str}_{global_strategy_str}")
|
||||
}
|
||||
}
|
||||
|
||||
/// Framed TCP connection for sending tests and receiving results
|
||||
type TestChannel = tokio_serde::Framed<
|
||||
tokio_util::codec::Framed<tokio::net::TcpStream, LengthDelimitedCodec>,
|
||||
NodeTestResult,
|
||||
NodeTest,
|
||||
MessagePack<NodeTestResult, NodeTest>,
|
||||
>;
|
||||
|
||||
/// Handle for each node process
|
||||
struct NodeProcessHandle {
|
||||
process: Child,
|
||||
channel: TestChannel,
|
||||
}
|
||||
|
||||
/// Main function to run the multi-node all-reduce test.
|
||||
/// Launches a orchestrator and multiple nodes based on the provided topology.
|
||||
#[tokio::main(flavor = "multi_thread", worker_threads = 10)]
|
||||
async fn main() {
|
||||
let all_reduce_tests = vec![
|
||||
AllReduceTest {
|
||||
shape: [4, 64, 512],
|
||||
op: ReduceOperation::Mean,
|
||||
local_strategy: AllReduceStrategy::Tree(2),
|
||||
global_strategy: AllReduceStrategy::Tree(2),
|
||||
},
|
||||
AllReduceTest {
|
||||
shape: [4, 64, 512],
|
||||
op: ReduceOperation::Mean,
|
||||
local_strategy: AllReduceStrategy::Tree(2),
|
||||
global_strategy: AllReduceStrategy::Ring,
|
||||
},
|
||||
AllReduceTest {
|
||||
shape: [4, 64, 512],
|
||||
op: ReduceOperation::Mean,
|
||||
local_strategy: AllReduceStrategy::Centralized,
|
||||
global_strategy: AllReduceStrategy::Centralized,
|
||||
},
|
||||
];
|
||||
|
||||
let test_files_dir = "target/test_files";
|
||||
fs::create_dir_all(test_files_dir).expect("Couldn't create test_files directory");
|
||||
|
||||
let topology: Vec<usize> = vec![4; 4];
|
||||
|
||||
let mut orchestrator = launch_orchestrator(test_files_dir);
|
||||
|
||||
let launcher_endpoint = "127.0.0.1:4000";
|
||||
|
||||
// Build and run node processes
|
||||
let mut all_tests_durations = vec![];
|
||||
if let Ok(mut nodes) = launch_nodes(&topology, launcher_endpoint).await {
|
||||
// Run one test
|
||||
for test in all_reduce_tests.clone() {
|
||||
let test_name = test.to_string();
|
||||
|
||||
let time =
|
||||
test_all_reduce_centralized_no_collective::<NdArray>(&topology, test.clone());
|
||||
println!(
|
||||
"{test_name}: Benchmark (no collective, centralized, single-threaded): {} secs",
|
||||
time.as_secs_f32()
|
||||
);
|
||||
|
||||
match test_all_reduce(&topology, test, &mut nodes).await {
|
||||
Err(node_idx) => {
|
||||
println!("{test_name}: Node with index {node_idx} failed!");
|
||||
// Kill other node processes
|
||||
for mut node in nodes.drain(..) {
|
||||
node.process.kill().await.unwrap();
|
||||
node.process.wait().await.unwrap();
|
||||
}
|
||||
break;
|
||||
}
|
||||
Ok(durations) => {
|
||||
all_tests_durations.append(&mut durations.clone());
|
||||
let avg = durations.iter().map(|dur| dur.as_secs_f32()).sum::<f32>()
|
||||
/ durations.len() as f32;
|
||||
println!("{test_name}: Success in {avg} secs");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !all_tests_durations.is_empty() {
|
||||
let avg = all_tests_durations
|
||||
.iter()
|
||||
.map(|dur| dur.as_secs_f32())
|
||||
.sum::<f32>()
|
||||
/ all_tests_durations.len() as f32;
|
||||
println!("Average for all tests: {avg} secs");
|
||||
}
|
||||
|
||||
// Shutdown orchestrator
|
||||
orchestrator.kill().await.unwrap();
|
||||
orchestrator.wait().await.unwrap();
|
||||
}
|
||||
|
||||
/// Launch a global orchestrator with an output file in the given directory.
|
||||
/// Necessary for global collective operations
|
||||
///
|
||||
/// Server listens on localhost port 3000
|
||||
fn launch_orchestrator(test_files_dir: &str) -> Child {
|
||||
let out_path = format!("{test_files_dir}/orchestrator_out.txt");
|
||||
let out = File::create(out_path).expect("Could't create orchestrator output file");
|
||||
|
||||
Command::new("cargo")
|
||||
.args(["run", "--bin", "global", "--", "3000"])
|
||||
.stdout(out.try_clone().unwrap())
|
||||
.stderr(out)
|
||||
.spawn()
|
||||
.expect("failed to launch orchestrator")
|
||||
}
|
||||
|
||||
/// Launch nodes for a all_reduce test
|
||||
/// Each node will connect to the global orchestrator and run an all-reduce operation.
|
||||
/// The topology is a vector where each element represents the number of devices in that node.
|
||||
async fn launch_nodes(
|
||||
topology: &[usize],
|
||||
launcher_endpoint: &str,
|
||||
) -> Result<Vec<NodeProcessHandle>, ()> {
|
||||
println!(
|
||||
"Launching {} nodes with topology: {:?}",
|
||||
topology.len(),
|
||||
topology
|
||||
);
|
||||
|
||||
// Listen for node connections
|
||||
let listener = TcpListener::bind(launcher_endpoint).await.unwrap();
|
||||
println!("Server listening on {launcher_endpoint}");
|
||||
|
||||
let mut nodes = vec![];
|
||||
|
||||
for node_idx in 0..topology.len() {
|
||||
// Create log file
|
||||
let output_filename = format!("target/test_files/node_{}_log.txt", node_idx + 1);
|
||||
let out = File::create(output_filename).expect("Could't open node log file");
|
||||
|
||||
// Start a process for each node. Pass on our feature flags
|
||||
let node_process: Child = Command::new("cargo")
|
||||
.args([
|
||||
"run",
|
||||
"--release",
|
||||
"--features",
|
||||
#[cfg(feature = "ndarray")]
|
||||
"ndarray",
|
||||
"--bin",
|
||||
"node",
|
||||
"--",
|
||||
launcher_endpoint,
|
||||
&node_idx.to_string(),
|
||||
])
|
||||
.stdout(out.try_clone().unwrap())
|
||||
.stderr(out)
|
||||
.spawn()
|
||||
.expect("node failed");
|
||||
|
||||
// Wait for child to connect for io
|
||||
let (socket, _peer_addr) = listener.accept().await.unwrap();
|
||||
let length_delimited = tokio_util::codec::Framed::new(socket, LengthDelimitedCodec::new());
|
||||
let channel: TestChannel = tokio_serde::Framed::new(
|
||||
length_delimited,
|
||||
MessagePack::<NodeTestResult, NodeTest>::default(),
|
||||
);
|
||||
|
||||
nodes.push(NodeProcessHandle {
|
||||
process: node_process,
|
||||
channel,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(nodes)
|
||||
}
|
||||
|
||||
async fn test_all_reduce(
|
||||
topology: &[usize],
|
||||
test: AllReduceTest,
|
||||
nodes: &mut [NodeProcessHandle],
|
||||
) -> Result<Vec<Duration>, usize> {
|
||||
dispatch_all_reduce_test(topology, test, nodes).await;
|
||||
|
||||
let mut all_durations = vec![];
|
||||
for (idx, handle) in nodes.iter_mut().enumerate() {
|
||||
match handle.channel.next().await {
|
||||
Some(Ok(mut result)) => {
|
||||
if !result.success {
|
||||
return Err(idx);
|
||||
}
|
||||
all_durations.append(&mut result.durations);
|
||||
}
|
||||
_ => {
|
||||
return Err(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(all_durations)
|
||||
}
|
||||
|
||||
async fn dispatch_all_reduce_test(
|
||||
topology: &[usize],
|
||||
test: AllReduceTest,
|
||||
nodes: &mut [NodeProcessHandle],
|
||||
) {
|
||||
let total_device_count: usize = topology.iter().sum();
|
||||
let (mut all_inputs, expected) =
|
||||
generate_random_input(test.shape, test.op, total_device_count, 42);
|
||||
|
||||
// URL for the global orchestrator on port 3000
|
||||
let global_url = "ws://localhost:3000";
|
||||
let global_address = Address::from_str(global_url).unwrap();
|
||||
|
||||
for (node_idx, &device_count) in topology.iter().enumerate() {
|
||||
// Construct URL for node
|
||||
// Ports 3001... are for each node
|
||||
let data_service_port = node_idx as u16 + 3001;
|
||||
let node_url = format!("ws://localhost:{data_service_port}");
|
||||
let node_address = Address::from_str(&node_url).unwrap();
|
||||
|
||||
// take input tensors for each device
|
||||
let inputs = all_inputs[0..device_count].to_vec();
|
||||
all_inputs = all_inputs[device_count..].to_vec();
|
||||
|
||||
let test = NodeTest {
|
||||
device_count,
|
||||
node_id: node_idx.into(),
|
||||
node_count: topology.len() as u32,
|
||||
global_address: global_address.clone(),
|
||||
node_address,
|
||||
data_service_port,
|
||||
all_reduce_op: test.op,
|
||||
global_strategy: test.global_strategy,
|
||||
local_strategy: test.local_strategy,
|
||||
inputs,
|
||||
expected: expected.clone(),
|
||||
};
|
||||
let handle = &mut nodes[node_idx];
|
||||
|
||||
handle.channel.send(test).await.unwrap();
|
||||
}
|
||||
|
||||
assert!(
|
||||
all_inputs.is_empty(),
|
||||
"Not all inputs have been sent to tests"
|
||||
);
|
||||
}
|
||||
|
||||
/// Run the test sequentially with no collective operations to get the optimal single-threaded speed
|
||||
fn test_all_reduce_centralized_no_collective<B: Backend>(
|
||||
topology: &[usize],
|
||||
test: AllReduceTest,
|
||||
) -> Duration {
|
||||
let total_device_count: usize = topology.iter().sum();
|
||||
let (all_inputs, _expected) =
|
||||
generate_random_input(test.shape, test.op, total_device_count, 42);
|
||||
|
||||
let mut all_inputs = all_inputs
|
||||
.into_iter()
|
||||
.map(|data| Tensor::<B, 3>::from_data(data, &B::Device::default()));
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
// Sequential test
|
||||
let mut result = all_inputs.next().unwrap();
|
||||
for other in all_inputs {
|
||||
result = result.add(other);
|
||||
}
|
||||
if test.op == ReduceOperation::Mean {
|
||||
result.div_scalar(total_device_count as u32);
|
||||
}
|
||||
|
||||
start.elapsed()
|
||||
}
|
||||
|
||||
/// Generates random input tensors and expected output based on the provided shape and reduce kind.
|
||||
fn generate_random_input(
|
||||
shape: [usize; 3],
|
||||
reduce_kind: ReduceOperation,
|
||||
input_count: usize,
|
||||
seed: u64,
|
||||
) -> (Vec<TensorData>, TensorData) {
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
|
||||
// A random tensor for each device
|
||||
let input: Vec<TensorData> = (0..input_count)
|
||||
.map(|_| {
|
||||
TensorData::random::<f32, _, _>(shape, burn::tensor::Distribution::Default, &mut rng)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sum up the inputs
|
||||
let device = <NdArray as Backend>::Device::default();
|
||||
let mut expected_tensor = Tensor::<NdArray, TENSOR_RANK>::zeros(shape, &device);
|
||||
for item in input.iter().take(input_count) {
|
||||
let input_tensor = Tensor::<NdArray, TENSOR_RANK>::from_data(item.clone(), &device);
|
||||
expected_tensor = expected_tensor.add(input_tensor);
|
||||
}
|
||||
|
||||
if reduce_kind == ReduceOperation::Mean {
|
||||
expected_tensor = expected_tensor.div_scalar(input_count as u32);
|
||||
}
|
||||
|
||||
// All-Reduce results should have this value
|
||||
let expected = expected_tensor.to_data();
|
||||
|
||||
(input, expected)
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
pub mod shared;
|
||||
@@ -0,0 +1,43 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use burn::tensor::TensorData;
|
||||
use burn_collective::{AllReduceStrategy, NodeId, ReduceOperation};
|
||||
use burn_communication::Address;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Ranks of inputs and outputs for all testing
|
||||
pub const TENSOR_RANK: usize = 3;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NodeTest {
|
||||
/// How many threads to start on this node
|
||||
pub device_count: usize,
|
||||
/// ID for this node
|
||||
pub node_id: NodeId,
|
||||
/// How many nodes in the cluster
|
||||
pub node_count: u32,
|
||||
/// Global server address
|
||||
pub global_address: Address,
|
||||
/// Node address
|
||||
pub node_address: Address,
|
||||
/// Node's data service port, for initializing the p2p tensor data service
|
||||
pub data_service_port: u16,
|
||||
/// What kind of all-reduce
|
||||
pub all_reduce_op: ReduceOperation,
|
||||
/// Node's data service port, for initializing the p2p tensor data service
|
||||
pub global_strategy: AllReduceStrategy,
|
||||
/// What kind of aggregation
|
||||
pub local_strategy: AllReduceStrategy,
|
||||
|
||||
/// Input data for test: all tensors are D=3
|
||||
pub inputs: Vec<TensorData>,
|
||||
/// Expected output for test
|
||||
pub expected: TensorData,
|
||||
}
|
||||
|
||||
/// Result sent back from each node for each test
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NodeTestResult {
|
||||
pub success: bool,
|
||||
pub durations: Vec<Duration>,
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
use crate::{
|
||||
CollectiveConfig, PeerId, ReduceOperation, global::shared::GlobalCollectiveError,
|
||||
local::server::get_collective_client,
|
||||
};
|
||||
|
||||
/// Errors from collective operations
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CollectiveError {
|
||||
/// The [config](CollectiveConfig) was invalid.
|
||||
/// Usually happens if only some global parameters have been defined
|
||||
InvalidConfig,
|
||||
/// Cannot un-register a node twice
|
||||
MultipleUnregister,
|
||||
/// Cannot register a node twice
|
||||
MultipleRegister,
|
||||
/// Trying to register a different way than is currently being done
|
||||
RegisterParamsMismatch,
|
||||
/// Trying to all-reduce tensors of different shapes: shape must match
|
||||
AllReduceShapeMismatch,
|
||||
/// Trying to all-reduce a different way than is currently being done: op must match
|
||||
AllReduceOperationMismatch,
|
||||
/// Trying to reduce tensors of different shapes: shape must match
|
||||
ReduceShapeMismatch,
|
||||
/// Trying to reduce a different way than is currently being done: op must match
|
||||
ReduceOperationMismatch,
|
||||
/// Trying to reduce with different roots
|
||||
ReduceRootMismatch,
|
||||
/// Trying to broadcast with different roots
|
||||
BroadcastRootMismatch,
|
||||
/// Trying to broadcast but no peer sent a tensor
|
||||
BroadcastNoTensor,
|
||||
/// Trying to broadcast but multiple peers sent a tensor
|
||||
BroadcastMultipleTensors,
|
||||
/// Local collective server couldn't respond
|
||||
LocalServerMissing,
|
||||
/// Another operation was called before Register
|
||||
RegisterNotFirstOperation,
|
||||
/// The global orchestrator had an error
|
||||
Global(GlobalCollectiveError),
|
||||
|
||||
#[allow(unused)]
|
||||
Other(String),
|
||||
}
|
||||
|
||||
/// Registers a device. `num_devices` must be the same for every register,
|
||||
/// and `device_id` must be unique.
|
||||
///
|
||||
/// * `id` - The peer id of the caller
|
||||
///
|
||||
/// With auto-diff backends, make sure to use the inner backend.
|
||||
pub fn register<B: Backend>(
|
||||
id: PeerId,
|
||||
device: B::Device,
|
||||
config: CollectiveConfig,
|
||||
) -> Result<(), CollectiveError> {
|
||||
log::info!("Registering peer {id} with config: {config}");
|
||||
let mut client = get_collective_client::<B>();
|
||||
client.register(id, device, config)
|
||||
}
|
||||
|
||||
/// Calls for an all-reduce operation with the given parameters, and returns the result.
|
||||
/// The `params` must be the same as the parameters passed by the other nodes.
|
||||
///
|
||||
/// * `id` - The peer id of the caller
|
||||
/// * `tensor` - The input tensor to reduce with the peers' tensors
|
||||
/// * `config` - Config of the collective operation, must be coherent with the other calls
|
||||
pub fn all_reduce<B: Backend>(
|
||||
id: PeerId,
|
||||
tensor: B::FloatTensorPrimitive,
|
||||
op: ReduceOperation,
|
||||
) -> Result<B::FloatTensorPrimitive, CollectiveError> {
|
||||
let client = get_collective_client::<B>();
|
||||
client.all_reduce(id, tensor, op)
|
||||
}
|
||||
|
||||
/// Broadcasts, or receives a broadcasted tensor.
|
||||
///
|
||||
/// * `id` - The peer id of the caller
|
||||
/// * `tensor` - If defined, this tensor will be broadcasted. Otherwise, this call will receive
|
||||
/// the broadcasted tensor.
|
||||
///
|
||||
/// Returns the broadcasted tensor.
|
||||
pub fn broadcast<B: Backend>(
|
||||
id: PeerId,
|
||||
tensor: Option<B::FloatTensorPrimitive>,
|
||||
) -> Result<B::FloatTensorPrimitive, CollectiveError> {
|
||||
let client = get_collective_client::<B>();
|
||||
client.broadcast(id, tensor)
|
||||
}
|
||||
|
||||
/// Reduces a tensor onto one device.
|
||||
///
|
||||
/// * `id` - The peer id of the caller
|
||||
/// * `tensor` - The tensor to send as input
|
||||
/// * `root` - The ID of the peer that will receive the result.
|
||||
///
|
||||
/// Returns Ok(None) if the root tensor is not the caller. Otherwise, returns the reduced tensor.
|
||||
pub fn reduce<B: Backend>(
|
||||
id: PeerId,
|
||||
tensor: B::FloatTensorPrimitive,
|
||||
op: ReduceOperation,
|
||||
root: PeerId,
|
||||
) -> Result<Option<B::FloatTensorPrimitive>, CollectiveError> {
|
||||
let client = get_collective_client::<B>();
|
||||
client.reduce(id, tensor, op, root)
|
||||
}
|
||||
|
||||
/// Closes the collective session, unregistering the device
|
||||
pub fn finish_collective<B: Backend>(id: PeerId) -> Result<(), CollectiveError> {
|
||||
let client = get_collective_client::<B>();
|
||||
client.finish(id)
|
||||
}
|
||||
|
||||
/// Resets the local collective server. All registered callers and ongoing operations are forgotten
|
||||
pub fn reset_collective<B: Backend>() {
|
||||
let client = get_collective_client::<B>();
|
||||
client.reset();
|
||||
}
|
||||
@@ -0,0 +1,337 @@
|
||||
use std::fmt::Display;
|
||||
|
||||
use burn_communication::Address;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Parameter struct for setting up and getting parameters for collective operations.
|
||||
/// Used in most collective api calls.
|
||||
/// This config is per-node. It is passed to [reduce](crate::register).
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct CollectiveConfig {
|
||||
pub(crate) num_devices: usize,
|
||||
pub(crate) local_all_reduce_strategy: AllReduceStrategy,
|
||||
pub(crate) local_reduce_strategy: ReduceStrategy,
|
||||
pub(crate) local_broadcast_strategy: BroadcastStrategy,
|
||||
|
||||
// Global parameters (all are optional, but if one is defined they should all be)
|
||||
pub(crate) num_nodes: Option<u32>,
|
||||
pub(crate) global_address: Option<Address>,
|
||||
pub(crate) node_address: Option<Address>,
|
||||
pub(crate) data_service_port: Option<u16>,
|
||||
|
||||
// These strategies may be defined when no other global params are defined
|
||||
pub(crate) global_all_reduce_strategy: Option<AllReduceStrategy>,
|
||||
pub(crate) global_reduce_strategy: Option<ReduceStrategy>,
|
||||
pub(crate) global_broadcast_strategy: Option<BroadcastStrategy>,
|
||||
}
|
||||
|
||||
impl Default for CollectiveConfig {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for CollectiveConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let num_devices = self.num_devices;
|
||||
let local_all_reduce_strategy = self.local_all_reduce_strategy;
|
||||
let local_reduce_strategy = self.local_reduce_strategy;
|
||||
let local_broadcast_strategy = self.local_broadcast_strategy;
|
||||
let num_nodes = self.num_nodes;
|
||||
let global_address = &self.global_address;
|
||||
let node_address = &self.node_address;
|
||||
let data_service_port = self.data_service_port;
|
||||
let global_all_reduce_strategy = self.global_all_reduce_strategy;
|
||||
let global_reduce_strategy = self.global_reduce_strategy;
|
||||
let global_broadcast_strategy = self.global_broadcast_strategy;
|
||||
|
||||
write!(
|
||||
f,
|
||||
r#"
|
||||
CollectiveConfig {{
|
||||
num_devices: {num_devices:?},
|
||||
local_all_reduce_strategy: {local_all_reduce_strategy:?},
|
||||
local_reduce_strategy: {local_reduce_strategy:?},
|
||||
local_broadcast_strategy: {local_broadcast_strategy:?},
|
||||
num_nodes: {num_nodes:?},
|
||||
global_address: {global_address:?},
|
||||
node_address: {node_address:?},
|
||||
data_service_port: {data_service_port:?},
|
||||
global_all_reduce_strategy: {global_all_reduce_strategy:?},
|
||||
global_reduce_strategy: {global_reduce_strategy:?},
|
||||
global_broadcast_strategy: {global_broadcast_strategy:?},
|
||||
}}
|
||||
"#
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl CollectiveConfig {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
num_devices: 1,
|
||||
local_all_reduce_strategy: AllReduceStrategy::Tree(2),
|
||||
local_reduce_strategy: ReduceStrategy::Tree(2),
|
||||
local_broadcast_strategy: BroadcastStrategy::Tree(2),
|
||||
|
||||
num_nodes: None,
|
||||
global_address: None,
|
||||
node_address: None,
|
||||
data_service_port: None,
|
||||
global_all_reduce_strategy: Some(AllReduceStrategy::Ring),
|
||||
global_reduce_strategy: Some(ReduceStrategy::Tree(2)),
|
||||
global_broadcast_strategy: Some(BroadcastStrategy::Tree(2)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Selects the number of devices (local peers) on the current node
|
||||
pub fn with_num_devices(mut self, num: usize) -> Self {
|
||||
self.num_devices = num;
|
||||
self
|
||||
}
|
||||
|
||||
/// Selects an all-reduce strategy to use on the local level.
|
||||
///
|
||||
/// In multi-node contexts, use of the Ring strategy in the local level may be less
|
||||
/// advantageous. With multiple nodes, the global all-reduce step is enabled, and its result
|
||||
/// is redistributed to all devices.
|
||||
/// The Ring strategy inherently distributes the result, which in this context would not be
|
||||
/// necessary.
|
||||
///
|
||||
/// It is recommended to use a tree strategy locally, and a ring strategy globally.
|
||||
pub fn with_local_all_reduce_strategy(mut self, strategy: AllReduceStrategy) -> Self {
|
||||
self.local_all_reduce_strategy = strategy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Selects a reduce strategy to use on the local level.
|
||||
pub fn with_local_reduce_strategy(mut self, strategy: ReduceStrategy) -> Self {
|
||||
self.local_reduce_strategy = strategy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Selects a broadcast strategy to use on the local level.
|
||||
pub fn with_local_broadcast_strategy(mut self, strategy: BroadcastStrategy) -> Self {
|
||||
self.local_broadcast_strategy = strategy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the number of nodes in the collective
|
||||
///
|
||||
/// This parameter is a global parameter and should only be set in multi-node contexts
|
||||
pub fn with_num_nodes(mut self, n: u32) -> Self {
|
||||
self.num_nodes = Some(n);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the network address of the Global Collective Orchestrator
|
||||
///
|
||||
/// This parameter is a global parameter and should only be set in multi-node contexts
|
||||
pub fn with_global_address(mut self, addr: Address) -> Self {
|
||||
self.global_address = Some(addr);
|
||||
self
|
||||
}
|
||||
|
||||
/// Define the address for this node
|
||||
///
|
||||
/// This parameter is a global parameter and should only be set in multi-node contexts
|
||||
pub fn with_node_address(mut self, addr: Address) -> Self {
|
||||
self.node_address = Some(addr);
|
||||
self
|
||||
}
|
||||
|
||||
/// Selects the network port on which to expose the tensor data service
|
||||
/// used for peer-to-peer tensor downloading.
|
||||
///
|
||||
/// This parameter is a global parameter and should only be set in multi-node contexts
|
||||
pub fn with_data_service_port(mut self, port: u16) -> Self {
|
||||
self.data_service_port = Some(port);
|
||||
self
|
||||
}
|
||||
|
||||
/// Selects an all-reduce strategy to use on the global level.
|
||||
///
|
||||
/// This parameter is a global parameter and should only be set in multi-node contexts.
|
||||
/// See [the local strategy](Self::with_local_all_reduce_strategy)
|
||||
pub fn with_global_all_reduce_strategy(mut self, strategy: AllReduceStrategy) -> Self {
|
||||
self.global_all_reduce_strategy = Some(strategy);
|
||||
self
|
||||
}
|
||||
|
||||
/// Selects an reduce strategy to use on the global level.
|
||||
///
|
||||
/// This parameter is a global parameter and should only be set in multi-node contexts.
|
||||
/// See [the local strategy](Self::with_local_reduce_strategy)
|
||||
pub fn with_global_reduce_strategy(mut self, strategy: ReduceStrategy) -> Self {
|
||||
self.global_reduce_strategy = Some(strategy);
|
||||
self
|
||||
}
|
||||
|
||||
/// Selects an broadcst strategy to use on the global level.
|
||||
///
|
||||
/// This parameter is a global parameter and should only be set in multi-node contexts.
|
||||
/// See [the local strategy](Self::with_local_broadcast_strategy)
|
||||
pub fn with_global_broadcast_strategy(mut self, strategy: BroadcastStrategy) -> Self {
|
||||
self.global_broadcast_strategy = Some(strategy);
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns whether the config is valid. If only some required global-level parameters are
|
||||
/// defined and others are not, the config is invalid.
|
||||
pub fn is_valid(&self) -> bool {
|
||||
match (
|
||||
self.num_nodes,
|
||||
&self.global_address,
|
||||
&self.node_address,
|
||||
self.data_service_port,
|
||||
) {
|
||||
(None, None, None, None) => true,
|
||||
(Some(_), Some(_), Some(_), Some(_)) => true,
|
||||
// Global parameters have only been partially defined!
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the global parameters for registering in a multi-node context.
|
||||
///
|
||||
/// If only some global parameters are defined, returns None. Use [is_valid](Self::is_valid) to check for
|
||||
/// validity in this case.
|
||||
pub(crate) fn global_register_params(&self) -> Option<GlobalRegisterParams> {
|
||||
match (
|
||||
self.num_nodes,
|
||||
&self.global_address,
|
||||
&self.node_address,
|
||||
self.data_service_port,
|
||||
) {
|
||||
// Only local collective
|
||||
(None, None, None, None) => None,
|
||||
// Local + global collective
|
||||
(Some(num_nodes), Some(global_addr), Some(node_addr), Some(data_service_port)) => {
|
||||
Some(GlobalRegisterParams {
|
||||
num_nodes,
|
||||
global_address: global_addr.clone(),
|
||||
node_address: node_addr.clone(),
|
||||
data_service_port,
|
||||
})
|
||||
}
|
||||
// Config is invalid!
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper struct for parameters in a multi-node register operation. Either they are all defined,
|
||||
/// or all not defined. Passed to the global client for registering on the global level and
|
||||
/// opening the p2p tensor service.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GlobalRegisterParams {
|
||||
/// The address for the connection to the global orchestrator.
|
||||
pub global_address: Address,
|
||||
/// The address for the connection to this node.
|
||||
pub node_address: Address,
|
||||
/// The port on which to open the tensor data service for peer-to-peer tensor transfers with
|
||||
/// other nodes. Should match the port given in the node url.
|
||||
pub data_service_port: u16,
|
||||
|
||||
/// The number of nodes globally. Should be the same between different nodes
|
||||
pub num_nodes: u32,
|
||||
}
|
||||
|
||||
/// Parameters for an all-reduce that should be the same between all devices
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub struct SharedAllReduceParams {
|
||||
pub op: ReduceOperation,
|
||||
pub local_strategy: AllReduceStrategy,
|
||||
pub global_strategy: Option<AllReduceStrategy>,
|
||||
}
|
||||
|
||||
/// Parameters for a reduce that should be the same between all devices
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub struct SharedReduceParams {}
|
||||
|
||||
/// Parameters for a broadcast that should be the same between all devices
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub struct SharedBroadcastParams {
|
||||
pub op: ReduceOperation,
|
||||
pub local_strategy: BroadcastStrategy,
|
||||
pub global_strategy: Option<BroadcastStrategy>,
|
||||
}
|
||||
|
||||
/// Reduce can be done different ways
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum ReduceOperation {
|
||||
Sum,
|
||||
Mean,
|
||||
}
|
||||
|
||||
/// All reduce can be implemented with different algorithms, which all have the same result.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum AllReduceStrategy {
|
||||
/// One device is the "central". The other devices, "peripherals", send their tensors to the
|
||||
/// central. The central does the reduction, and sends the result back to each peripheral.
|
||||
Centralized,
|
||||
|
||||
/// Devices are organized in a tree structure (with a given arity). Each node reduces its
|
||||
/// children's tensors with its own, and sends the result to its parent. Leaf nodes will
|
||||
/// simply send their tensors to their parents.
|
||||
/// When the root node calculates the result, it is propagated down the tree.
|
||||
Tree(u32),
|
||||
|
||||
/// Devices are organized in a ring. The tensors are split into N slices, where N is the
|
||||
/// number of devices participating. The slices are progressively sent around the ring until
|
||||
/// every device has one fully reduced slice of the tensor. Then, the resulting slices are sent
|
||||
/// around until every device has the full result.
|
||||
/// See `ring.rs` for details.
|
||||
Ring,
|
||||
}
|
||||
|
||||
/// Reduce can be implemented with different algorithms, which all have the same result.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum ReduceStrategy {
|
||||
/// See [all-reduce](AllReduceStrategy::Centralized)
|
||||
Centralized,
|
||||
|
||||
/// See [all-reduce](AllReduceStrategy::Tree)
|
||||
Tree(u32),
|
||||
}
|
||||
|
||||
/// Broadcast can be implemented with different algorithms, which all have the same result.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum BroadcastStrategy {
|
||||
/// See [all-reduce](AllReduceStrategy::Centralized)
|
||||
Centralized,
|
||||
|
||||
/// See [all-reduce](AllReduceStrategy::Tree)
|
||||
Tree(u32),
|
||||
}
|
||||
|
||||
/// A unique identifier for a peer in the context of collective operations.
|
||||
/// They must be unique, even in multi-node contexts.
|
||||
///
|
||||
/// This is like the rank in NCCL
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct PeerId(u32);
|
||||
|
||||
impl Display for PeerId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "PeerId({})", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u32> for PeerId {
|
||||
fn from(value: u32) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i32> for PeerId {
|
||||
fn from(value: i32) -> Self {
|
||||
Self(value as u32)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<usize> for PeerId {
|
||||
fn from(value: usize) -> Self {
|
||||
Self(value as u32)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Unique identifier for any node in the global collective.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub struct NodeId(u32);
|
||||
|
||||
impl From<u32> for NodeId {
|
||||
fn from(value: u32) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<usize> for NodeId {
|
||||
fn from(value: usize) -> Self {
|
||||
Self(value as u32)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i32> for NodeId {
|
||||
fn from(value: i32) -> Self {
|
||||
Self(value as u32)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
pub(crate) mod node;
|
||||
pub(crate) mod shared;
|
||||
|
||||
#[cfg(feature = "orchestrator")]
|
||||
pub mod orchestrator;
|
||||
#[cfg(feature = "orchestrator")]
|
||||
pub use orchestrator::*;
|
||||
|
||||
mod base;
|
||||
pub use base::*;
|
||||
@@ -0,0 +1,203 @@
|
||||
use burn_communication::Protocol;
|
||||
use burn_communication::data_service::TensorDataServer;
|
||||
use burn_communication::{Address, ProtocolServer, data_service::TensorDataService};
|
||||
use burn_tensor::backend::Backend;
|
||||
use std::collections::HashMap;
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::node::sync::SyncService;
|
||||
use crate::{
|
||||
AllReduceStrategy, BroadcastStrategy, GlobalRegisterParams, NodeId, PeerId, ReduceStrategy,
|
||||
};
|
||||
use crate::{
|
||||
ReduceOperation,
|
||||
global::{
|
||||
node::{
|
||||
centralized::centralized_all_reduce_sum, ring::ring_all_reduce_sum,
|
||||
tree::tree_all_reduce_sum, worker::GlobalClientWorker,
|
||||
},
|
||||
shared::{GlobalCollectiveError, RemoteRequest, RemoteResponse},
|
||||
},
|
||||
local::server::get_collective_server_runtime,
|
||||
};
|
||||
|
||||
/// Must be synchronized between all nodes for collective operations to work
|
||||
pub(crate) struct NodeState {
|
||||
pub node_id: NodeId,
|
||||
pub nodes: HashMap<NodeId, Address>,
|
||||
pub num_global_devices: u32,
|
||||
}
|
||||
|
||||
/// A node talks to the global orchestrator as well as other nodes with a peer-to-peer service
|
||||
pub(crate) struct Node<B, P>
|
||||
where
|
||||
B: Backend,
|
||||
P: Protocol,
|
||||
{
|
||||
// State is written during `register` and read during other operations,
|
||||
// sometimes by multiple threads (ex. syncing during an all-reduce)
|
||||
state: Arc<RwLock<Option<NodeState>>>,
|
||||
data_service: Arc<TensorDataService<B, P>>,
|
||||
sync_service: Arc<SyncService<P>>,
|
||||
worker: GlobalClientWorker<P::Client>,
|
||||
_n: PhantomData<P>,
|
||||
}
|
||||
|
||||
impl<B, P> Node<B, P>
|
||||
where
|
||||
B: Backend,
|
||||
P: Protocol,
|
||||
{
|
||||
pub fn new(global_address: &Address, comms_server: P::Server) -> Self {
|
||||
let state = Arc::new(tokio::sync::RwLock::new(None));
|
||||
let cancel_token = CancellationToken::new();
|
||||
let data_service = Arc::new(TensorDataService::new(cancel_token.clone()));
|
||||
let sync_service = Arc::new(SyncService::new(state.clone()));
|
||||
|
||||
let runtime = get_collective_server_runtime();
|
||||
let server = comms_server
|
||||
.route_tensor_data_service(data_service.clone())
|
||||
.route("/sync", {
|
||||
let sync_service = sync_service.clone();
|
||||
async move |channel: <P::Server as ProtocolServer>::Channel| {
|
||||
sync_service.handle_sync_connection(channel).await;
|
||||
}
|
||||
})
|
||||
.serve({
|
||||
let cancel_token = cancel_token.clone();
|
||||
async move { cancel_token.cancelled().await }
|
||||
});
|
||||
|
||||
runtime.spawn(server);
|
||||
|
||||
let worker = GlobalClientWorker::new(&runtime, cancel_token.clone(), global_address);
|
||||
|
||||
Self {
|
||||
state,
|
||||
data_service,
|
||||
sync_service,
|
||||
worker,
|
||||
_n: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn register(
|
||||
&mut self,
|
||||
peers: Vec<PeerId>,
|
||||
global_params: GlobalRegisterParams,
|
||||
) -> Result<(), GlobalCollectiveError> {
|
||||
let req = RemoteRequest::Register {
|
||||
node_addr: global_params.node_address,
|
||||
num_nodes: global_params.num_nodes,
|
||||
peers,
|
||||
};
|
||||
match self.worker.request(req).await {
|
||||
RemoteResponse::Register {
|
||||
node_id,
|
||||
nodes,
|
||||
num_global_devices,
|
||||
} => {
|
||||
let mut state = self.state.write().await;
|
||||
*state = Some(NodeState {
|
||||
node_id,
|
||||
nodes,
|
||||
num_global_devices,
|
||||
});
|
||||
}
|
||||
RemoteResponse::Error(err) => {
|
||||
return Err(err);
|
||||
}
|
||||
resp => {
|
||||
log::error!("Response to a register request should be an ack, not {resp:?}");
|
||||
return Err(GlobalCollectiveError::WrongOrchestratorResponse);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Performs an all-reduce
|
||||
///
|
||||
/// Reads the NodeState
|
||||
pub async fn all_reduce(
|
||||
&self,
|
||||
tensor: B::FloatTensorPrimitive,
|
||||
strategy: AllReduceStrategy,
|
||||
op: ReduceOperation,
|
||||
) -> Result<B::FloatTensorPrimitive, GlobalCollectiveError> {
|
||||
let state = self.state.read().await;
|
||||
let Some(ref state) = *state else {
|
||||
return Err(GlobalCollectiveError::AllReduceBeforeRegister);
|
||||
};
|
||||
let node = state.node_id;
|
||||
let nodes = &state.nodes;
|
||||
|
||||
let mut result = match strategy {
|
||||
AllReduceStrategy::Centralized => {
|
||||
centralized_all_reduce_sum(
|
||||
node,
|
||||
nodes,
|
||||
&self.data_service,
|
||||
self.sync_service.clone(),
|
||||
tensor,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
AllReduceStrategy::Tree(arity) => {
|
||||
tree_all_reduce_sum(
|
||||
node,
|
||||
nodes,
|
||||
self.data_service.clone(),
|
||||
self.sync_service.clone(),
|
||||
tensor,
|
||||
arity,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
AllReduceStrategy::Ring => {
|
||||
ring_all_reduce_sum(
|
||||
node,
|
||||
nodes,
|
||||
self.data_service.clone(),
|
||||
self.sync_service.clone(),
|
||||
tensor,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
|
||||
if op == ReduceOperation::Mean {
|
||||
result = B::float_div_scalar(result, (state.num_global_devices as f32).into());
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub async fn reduce(
|
||||
&self,
|
||||
_tensor: B::FloatTensorPrimitive,
|
||||
_strategy: ReduceStrategy,
|
||||
_root: PeerId,
|
||||
_op: ReduceOperation,
|
||||
) -> Result<Option<B::FloatTensorPrimitive>, GlobalCollectiveError> {
|
||||
unimplemented!("Global reduce unimplemented");
|
||||
}
|
||||
|
||||
pub async fn broadcast(
|
||||
&self,
|
||||
_tensor: Option<B::FloatTensorPrimitive>,
|
||||
_strategy: BroadcastStrategy,
|
||||
) -> Result<B::FloatTensorPrimitive, GlobalCollectiveError> {
|
||||
unimplemented!("Global broadcast unimplemented");
|
||||
}
|
||||
|
||||
pub async fn finish(&mut self) {
|
||||
let res = self.worker.close_connection().await;
|
||||
if let Err(err) = res {
|
||||
log::error!("Global collective client error: {err:?}");
|
||||
}
|
||||
self.data_service.close().await;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use crate::{NodeId, global::shared::GlobalCollectiveError, node::sync::SyncService};
|
||||
use burn_communication::data_service::TensorDataService;
|
||||
use burn_communication::{Address, Protocol};
|
||||
use burn_tensor::TensorMetadata;
|
||||
use burn_tensor::backend::Backend;
|
||||
use futures::StreamExt;
|
||||
use futures::stream::FuturesUnordered;
|
||||
|
||||
/// Global all-reduce, using a centralized strategy.
|
||||
///
|
||||
/// Returns the resulting tensor on the same device as the input tensor
|
||||
pub(crate) async fn centralized_all_reduce_sum<B, P>(
|
||||
node: NodeId,
|
||||
nodes: &HashMap<NodeId, Address>,
|
||||
data_service: &Arc<TensorDataService<B, P>>,
|
||||
sync_service: Arc<SyncService<P>>,
|
||||
tensor: B::FloatTensorPrimitive,
|
||||
) -> Result<B::FloatTensorPrimitive, GlobalCollectiveError>
|
||||
where
|
||||
B: Backend,
|
||||
P: Protocol,
|
||||
{
|
||||
let ids = nodes.keys().cloned().collect::<Vec<_>>();
|
||||
let central = get_central_node(ids.clone());
|
||||
|
||||
let shape = tensor.shape();
|
||||
let device = &B::float_device(&tensor);
|
||||
|
||||
let res = if central == node {
|
||||
// Transfer 1: download tensors from other nodes
|
||||
let mut futures = ids
|
||||
.iter()
|
||||
.filter(|id| **id != central) // Only non-central nodes
|
||||
.map(|id| {
|
||||
let address = nodes.get(id).unwrap();
|
||||
let device = device.clone();
|
||||
let data_service = data_service.clone();
|
||||
async move {
|
||||
let data = data_service
|
||||
.download_tensor((*address).clone(), 0.into())
|
||||
.await
|
||||
.expect("Couldn't find the tensor for transfer id 0");
|
||||
B::float_from_data(data, &device)
|
||||
}
|
||||
})
|
||||
.collect::<FuturesUnordered<_>>();
|
||||
|
||||
// Sum all downloads async
|
||||
let mut sum = tensor;
|
||||
while let Some(res) = futures.next().await {
|
||||
if shape != res.shape() {
|
||||
return Err(GlobalCollectiveError::PeerSentIncoherentTensor);
|
||||
}
|
||||
sum = B::float_add(sum, res);
|
||||
}
|
||||
|
||||
// Transfer 2: Expose result
|
||||
let other_nodes_count = ids.len() as u32 - 1;
|
||||
data_service
|
||||
.expose(sum.clone(), other_nodes_count, 1.into())
|
||||
.await;
|
||||
|
||||
sum
|
||||
} else {
|
||||
// Transfer 1: Expose input
|
||||
data_service.expose(tensor, 1, 0.into()).await;
|
||||
|
||||
// Transfer 2: Download result
|
||||
let central_addr = nodes.get(¢ral).unwrap().clone();
|
||||
let data = data_service
|
||||
.download_tensor(central_addr, 1.into())
|
||||
.await
|
||||
.expect("Couldn't find the tensor for transfer id 1");
|
||||
|
||||
let res = B::float_from_data(data, device);
|
||||
if shape != res.shape() {
|
||||
return Err(GlobalCollectiveError::PeerSentIncoherentTensor);
|
||||
}
|
||||
|
||||
res
|
||||
};
|
||||
|
||||
// Wait for all nodes to finish
|
||||
sync_service.sync().await;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Get the central node for a centralized all-reduce
|
||||
pub(crate) fn get_central_node(mut nodes: Vec<NodeId>) -> NodeId {
|
||||
nodes.sort();
|
||||
|
||||
*nodes.first().unwrap()
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
pub mod base;
|
||||
pub mod centralized;
|
||||
pub mod ring;
|
||||
pub mod sync;
|
||||
pub mod tree;
|
||||
pub mod worker;
|
||||
@@ -0,0 +1,216 @@
|
||||
//! Implements the collective ring all-reduce algorithm on the global level
|
||||
|
||||
use core::ops::Range;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use crate::{
|
||||
NodeId,
|
||||
global::shared::GlobalCollectiveError,
|
||||
local::{get_ring_reduce_slice_ranges, get_slice_dim},
|
||||
node::sync::SyncService,
|
||||
};
|
||||
use burn_communication::{Address, Protocol, data_service::TensorDataService};
|
||||
use burn_tensor::{Slice, TensorMetadata, backend::Backend};
|
||||
|
||||
// https://blog.dailydoseofds.com/p/all-reduce-and-ring-reduce-for-model
|
||||
|
||||
// Example: tensors=3, slices=3
|
||||
|
||||
// phase 1
|
||||
// o->o o
|
||||
// o o->o
|
||||
//>o o o->
|
||||
|
||||
// o 1->o
|
||||
//>o o 1->
|
||||
// 1->o o
|
||||
|
||||
// o 1 2
|
||||
// 2 o 1
|
||||
// 1 2 o
|
||||
|
||||
// phase 2
|
||||
//>o 1 2->
|
||||
// 2->o 1
|
||||
// 1 2->o
|
||||
|
||||
// 2->1 2
|
||||
// 2 2->1
|
||||
//>1 2 2->
|
||||
|
||||
// 2 2 2
|
||||
// 2 2 2
|
||||
// 2 2 2
|
||||
|
||||
/// Ring all-reduce algorithm with summation
|
||||
///
|
||||
/// * `node` - The id of the current node
|
||||
/// * `nodes` - Map of all nodes in the operation
|
||||
/// * `data_service` - The data service handles peer-to-peer tensor transfers
|
||||
/// * `sync_service` - The sync service handles syncing with peers
|
||||
/// * `tensor` - The tensor to reduce. At least one dimension size must be greater than the number
|
||||
/// of nodes
|
||||
pub(crate) async fn ring_all_reduce_sum<B, P>(
|
||||
node: NodeId,
|
||||
nodes: &HashMap<NodeId, Address>,
|
||||
data_service: Arc<TensorDataService<B, P>>,
|
||||
sync_service: Arc<SyncService<P>>,
|
||||
tensor: B::FloatTensorPrimitive,
|
||||
) -> Result<B::FloatTensorPrimitive, GlobalCollectiveError>
|
||||
where
|
||||
B: Backend,
|
||||
P: Protocol,
|
||||
{
|
||||
let shape = tensor.shape();
|
||||
|
||||
let device = &B::float_device(&tensor);
|
||||
// Slice tensors in N parts, N is node count
|
||||
let slice_dim = get_slice_dim(&shape);
|
||||
if shape[slice_dim] < nodes.len() {
|
||||
return Err(GlobalCollectiveError::RingReduceImpossible);
|
||||
}
|
||||
|
||||
let ring = get_ring_topology(nodes.keys().cloned().collect::<Vec<_>>());
|
||||
let slice_ranges = get_ring_reduce_slice_ranges(shape[slice_dim], ring.len());
|
||||
let mut slices = slice_tensor::<B>(tensor, slice_dim, slice_ranges);
|
||||
|
||||
let mut send_slice_idx = ring
|
||||
.iter()
|
||||
.position(|id| *id == node)
|
||||
.expect("Node is in ring");
|
||||
let prev_node_idx = (send_slice_idx + ring.len() - 1) % ring.len(); // +ring.len for overflow
|
||||
let prev_node = nodes.get(&ring[prev_node_idx]).unwrap();
|
||||
let mut transfer_counter: u64 = 0;
|
||||
|
||||
// Phase 1: add
|
||||
do_cycles::<B, P>(
|
||||
&mut slices,
|
||||
&mut transfer_counter,
|
||||
&mut send_slice_idx,
|
||||
true,
|
||||
prev_node.clone(),
|
||||
&data_service,
|
||||
device,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Phase 2: replace
|
||||
do_cycles::<B, P>(
|
||||
&mut slices,
|
||||
&mut transfer_counter,
|
||||
&mut send_slice_idx,
|
||||
false,
|
||||
prev_node.clone(),
|
||||
&data_service,
|
||||
device,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Wait for all nodes to finish
|
||||
sync_service.sync().await;
|
||||
|
||||
// merge slices
|
||||
Ok(B::float_cat(slices, slice_dim))
|
||||
}
|
||||
|
||||
/// Do N-1 cycles of ring-reduce
|
||||
///
|
||||
/// * `slices` - Slices of the original tensor, len equal to node count
|
||||
/// * `transfer_counter` - counter for each step (one send one receive)
|
||||
/// * `send_slice_idx` - counter for the index of each slice to send
|
||||
/// * `is_phase_one` - In phase 1, the tensors are aggregated. Otherwise, they are overridden
|
||||
/// * `data_service` - TensorDataService for peer-to-peer tensor transfers
|
||||
/// * `device` - The device on which all local tensors are stored. Should match `slices`
|
||||
async fn do_cycles<B, P>(
|
||||
slices: &mut [B::FloatTensorPrimitive],
|
||||
transfer_counter: &mut u64,
|
||||
send_slice_idx: &mut usize,
|
||||
is_phase_one: bool,
|
||||
prev_node: Address,
|
||||
data_service: &Arc<TensorDataService<B, P>>,
|
||||
device: &B::Device,
|
||||
) -> Result<(), GlobalCollectiveError>
|
||||
where
|
||||
B: Backend,
|
||||
P: Protocol,
|
||||
{
|
||||
let slice_count = slices.len();
|
||||
for _ in 0..(slice_count - 1) {
|
||||
let transfer_id = (*transfer_counter).into();
|
||||
// +slice_count to avoid overflow
|
||||
let recv_slice_idx = (*send_slice_idx + slice_count - 1) % slice_count;
|
||||
let slice_send = slices[*send_slice_idx].clone();
|
||||
|
||||
let upload = {
|
||||
let data_service = data_service.clone();
|
||||
tokio::spawn(async move {
|
||||
data_service
|
||||
.expose(slice_send.clone(), 1, transfer_id)
|
||||
.await
|
||||
})
|
||||
};
|
||||
let download = {
|
||||
let data_client = data_service.clone();
|
||||
let next_node = prev_node.clone();
|
||||
tokio::spawn(async move { data_client.download_tensor(next_node, transfer_id).await })
|
||||
};
|
||||
|
||||
upload.await.unwrap();
|
||||
let download = download.await.unwrap();
|
||||
if is_phase_one {
|
||||
let download = download.expect("Peer closed download connection");
|
||||
let tensor = B::float_from_data(download, device);
|
||||
slices[recv_slice_idx] = B::float_add(slices[recv_slice_idx].clone(), tensor);
|
||||
} else {
|
||||
let tensor = B::float_from_data(download.unwrap(), device);
|
||||
let old_shape = slices[recv_slice_idx].shape();
|
||||
if old_shape != tensor.shape() {
|
||||
return Err(GlobalCollectiveError::PeerSentIncoherentTensor);
|
||||
}
|
||||
slices[recv_slice_idx] = tensor;
|
||||
}
|
||||
|
||||
// Move slice index
|
||||
*send_slice_idx = recv_slice_idx;
|
||||
*transfer_counter += 1;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// But a tensor into even slices across a dimension
|
||||
///
|
||||
/// * `tensor` - the tensor to slice
|
||||
/// * `slice_dim` - the dimension to slice across
|
||||
/// * `slice_ranges` - The ranges of indices on `slice_dim` to use when slicing the tensor
|
||||
fn slice_tensor<B: Backend>(
|
||||
tensor: B::FloatTensorPrimitive,
|
||||
slice_dim: usize,
|
||||
slice_ranges: Vec<Range<usize>>,
|
||||
) -> Vec<B::FloatTensorPrimitive> {
|
||||
let shape = tensor.shape();
|
||||
// full range across all dims as Slice
|
||||
let full_range = shape
|
||||
.iter()
|
||||
.map(|dim| Slice::from(0..*dim))
|
||||
.collect::<Vec<Slice>>();
|
||||
|
||||
// Slice tensors
|
||||
let mut slices = vec![];
|
||||
for range in &slice_ranges {
|
||||
let mut all_ranges = full_range.clone();
|
||||
all_ranges[slice_dim] = Slice::from(range.clone());
|
||||
let slice = B::float_slice(tensor.clone(), &all_ranges);
|
||||
slices.push(slice);
|
||||
}
|
||||
|
||||
slices
|
||||
}
|
||||
|
||||
/// Get the ring topology
|
||||
fn get_ring_topology(mut nodes: Vec<NodeId>) -> Vec<NodeId> {
|
||||
// This ordering could be more sophisticated, using node proximities etc
|
||||
nodes.sort();
|
||||
|
||||
nodes
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
use std::{
|
||||
marker::PhantomData,
|
||||
sync::{Arc, Mutex},
|
||||
vec,
|
||||
};
|
||||
|
||||
use burn_communication::{CommunicationChannel, Message, Protocol, ProtocolClient};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::{Notify, RwLock};
|
||||
|
||||
use crate::{NodeId, node::base::NodeState};
|
||||
|
||||
/// Handles the status of sync requests from other nodes
|
||||
pub(crate) struct SyncService<P: Protocol> {
|
||||
/// Current node's state, shared with the thread that does aggregations
|
||||
node_state: Arc<RwLock<Option<NodeState>>>,
|
||||
/// The number of peers that have requested to sync with us since the last successful sync.
|
||||
syncing_peers: Mutex<Vec<NodeId>>,
|
||||
/// Notification on each incoming sync request
|
||||
sync_notif: Notify,
|
||||
|
||||
_p: PhantomData<P>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct SyncRequest(NodeId);
|
||||
|
||||
impl<P: Protocol> SyncService<P> {
|
||||
pub fn new(node_state: Arc<RwLock<Option<NodeState>>>) -> Self {
|
||||
Self {
|
||||
node_state,
|
||||
syncing_peers: Mutex::new(vec![]),
|
||||
sync_notif: Notify::new(),
|
||||
_p: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn add_syncing_peer(&self, peer: NodeId) {
|
||||
let mut syncing_peers = self.syncing_peers.lock().unwrap();
|
||||
syncing_peers.push(peer);
|
||||
}
|
||||
|
||||
/// Sync with all peers.
|
||||
pub async fn sync(&self) {
|
||||
// we can't sync while we register
|
||||
let node_state = self.node_state.read().await;
|
||||
let node_state = node_state
|
||||
.as_ref()
|
||||
.expect("Trying to sync a node before having registered to the orchestrator");
|
||||
|
||||
// this peer is syncing
|
||||
self.add_syncing_peer(node_state.node_id);
|
||||
for (id, addr) in &node_state.nodes {
|
||||
if *id == node_state.node_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut connection = P::Client::connect(addr.clone(), "sync")
|
||||
.await
|
||||
.expect("Couldn't connect to peer for sync");
|
||||
let msg = SyncRequest(node_state.node_id);
|
||||
let sync_bytes = rmp_serde::to_vec(&msg).unwrap();
|
||||
connection
|
||||
.send(Message::new(sync_bytes.into()))
|
||||
.await
|
||||
.expect("Peer closed connection unexpectedly");
|
||||
}
|
||||
loop {
|
||||
{
|
||||
// compare currently synced peers with list of all nodes
|
||||
let mut syncing_peers = self.syncing_peers.lock().unwrap().to_vec();
|
||||
syncing_peers.sort();
|
||||
|
||||
let mut all_node_ids = node_state.nodes.keys().cloned().collect::<Vec<_>>();
|
||||
all_node_ids.sort();
|
||||
|
||||
if syncing_peers == all_node_ids {
|
||||
// all nodes have synced
|
||||
syncing_peers.clear();
|
||||
return;
|
||||
}
|
||||
}
|
||||
// Wait for the next sync to come in
|
||||
self.sync_notif.notified().await
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_sync_connection<C: CommunicationChannel>(&self, mut channel: C) {
|
||||
let msg = channel.recv().await.unwrap();
|
||||
let Some(msg) = msg else {
|
||||
return;
|
||||
};
|
||||
|
||||
let msg = rmp_serde::from_slice::<SyncRequest>(&msg.data).unwrap();
|
||||
|
||||
self.add_syncing_peer(msg.0);
|
||||
|
||||
self.sync_notif.notify_waiters();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use crate::{NodeId, global::shared::GlobalCollectiveError, node::sync::SyncService};
|
||||
use burn_communication::{Address, Protocol, data_service::TensorDataService};
|
||||
use burn_tensor::{TensorMetadata, backend::Backend};
|
||||
use futures::{StreamExt, stream::FuturesUnordered};
|
||||
|
||||
struct TreeTopology {
|
||||
parents: HashMap<NodeId, NodeId>,
|
||||
children: HashMap<NodeId, Vec<NodeId>>,
|
||||
}
|
||||
|
||||
/// Global all-reduce, using a b-tree strategy.
|
||||
///
|
||||
/// Returns the resulting tensor on the same device as the input tensor
|
||||
pub(crate) async fn tree_all_reduce_sum<B, P>(
|
||||
node: NodeId,
|
||||
nodes: &HashMap<NodeId, Address>,
|
||||
data_service: Arc<TensorDataService<B, P>>,
|
||||
sync_service: Arc<SyncService<P>>,
|
||||
tensor: B::FloatTensorPrimitive,
|
||||
arity: u32,
|
||||
) -> Result<B::FloatTensorPrimitive, GlobalCollectiveError>
|
||||
where
|
||||
B: Backend,
|
||||
P: Protocol,
|
||||
{
|
||||
let shape = tensor.shape();
|
||||
let device = &B::float_device(&tensor);
|
||||
|
||||
// Topology could be cached based on (nodes.keys().cloned(), arity)
|
||||
let strategy = get_tree_topology(nodes.keys().cloned().collect::<Vec<_>>(), arity);
|
||||
|
||||
// Transfer 1: Download and sum tensors from children
|
||||
let mut result = tensor;
|
||||
|
||||
if let Some(children) = strategy.children.get(&node) {
|
||||
let mut downloads = children
|
||||
.iter()
|
||||
.map(|child| {
|
||||
let child_addr = nodes.get(child).unwrap().clone();
|
||||
let data_service = data_service.clone();
|
||||
async move {
|
||||
let data = data_service
|
||||
.download_tensor(child_addr.clone(), 0.into())
|
||||
.await
|
||||
.ok_or(GlobalCollectiveError::PeerLost(*child))?;
|
||||
Ok::<B::FloatTensorPrimitive, GlobalCollectiveError>(B::float_from_data(
|
||||
data, device,
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<FuturesUnordered<_>>();
|
||||
|
||||
for _ in children {
|
||||
let res = downloads.next().await.unwrap().unwrap();
|
||||
if res.shape() != shape {
|
||||
return Err(GlobalCollectiveError::PeerSentIncoherentTensor);
|
||||
}
|
||||
result = B::float_add(result, res);
|
||||
}
|
||||
}
|
||||
|
||||
// Transfer 2: Expose result to parent and download final result if not root
|
||||
if let Some(parent) = strategy.parents.get(&node) {
|
||||
data_service.expose(result.clone(), 1, 0.into()).await;
|
||||
|
||||
let parent_addr = nodes.get(parent).unwrap().clone();
|
||||
|
||||
let data = data_service
|
||||
.download_tensor(parent_addr.clone(), 1.into())
|
||||
.await
|
||||
.ok_or(GlobalCollectiveError::PeerLost(*parent))?;
|
||||
|
||||
let parent_tensor = B::float_from_data(data, device);
|
||||
if parent_tensor.shape() != shape {
|
||||
return Err(GlobalCollectiveError::PeerSentIncoherentTensor);
|
||||
}
|
||||
result = parent_tensor;
|
||||
}
|
||||
|
||||
// Transfer 3: Expose final result to children (if any)
|
||||
if let Some(children) = strategy.children.get(&node)
|
||||
&& !children.is_empty()
|
||||
{
|
||||
data_service
|
||||
.expose(result.clone(), children.len() as u32, 1.into())
|
||||
.await;
|
||||
}
|
||||
|
||||
// Final barrier
|
||||
sync_service.sync().await;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get the tree topology.
|
||||
///
|
||||
/// * `nodes` - List of node ids. Order doesn't matter. Nodes must be unique.
|
||||
fn get_tree_topology(mut nodes: Vec<NodeId>, arity: u32) -> TreeTopology {
|
||||
assert!(arity >= 1, "Arity must be ≥ 1");
|
||||
|
||||
nodes.sort(); // Sort
|
||||
|
||||
let n = nodes.len();
|
||||
let k = arity as usize;
|
||||
|
||||
let mut parents: HashMap<_, _> = HashMap::with_capacity(n);
|
||||
let mut children: HashMap<_, _> = HashMap::with_capacity(n);
|
||||
|
||||
for (i, &parent_id) in nodes.iter().enumerate() {
|
||||
// compute the window [first_child, last_child)
|
||||
let first = i * k + 1;
|
||||
if first < n {
|
||||
let last = usize::min(first + k, n);
|
||||
let mut ch = Vec::with_capacity(last - first);
|
||||
for &child_id in &nodes[first..last] {
|
||||
parents.insert(child_id, parent_id);
|
||||
ch.push(child_id);
|
||||
}
|
||||
children.insert(parent_id, ch);
|
||||
} else {
|
||||
// leaf‐node: no children
|
||||
children.insert(parent_id, Vec::new());
|
||||
}
|
||||
}
|
||||
|
||||
TreeTopology { parents, children }
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Test the tree topology algorithm with arity 2 and 7 nodes
|
||||
#[test]
|
||||
fn test_get_tree_topology_arity2_size7() {
|
||||
let mut nodes = vec![];
|
||||
for i in 0..7 {
|
||||
nodes.push(i.into());
|
||||
}
|
||||
|
||||
let topology = get_tree_topology(nodes, 2);
|
||||
|
||||
// Root is 0, so it should have no parent
|
||||
assert!(!topology.parents.contains_key(&0.into()));
|
||||
|
||||
// Parents:
|
||||
// Node 1 and 2 → parent 0
|
||||
// Node 3 and 4 → parent 1
|
||||
// Node 5 and 6 → parent 2
|
||||
let expected_parents = [
|
||||
(1.into(), 0.into()),
|
||||
(2.into(), 0.into()),
|
||||
(3.into(), 1.into()),
|
||||
(4.into(), 1.into()),
|
||||
(5.into(), 2.into()),
|
||||
(6.into(), 2.into()),
|
||||
];
|
||||
for (child, parent) in &expected_parents {
|
||||
assert_eq!(
|
||||
topology.parents.get(child),
|
||||
Some(parent),
|
||||
"wrong parent for {child:?}"
|
||||
);
|
||||
}
|
||||
// There should be exactly 6 entries in parents
|
||||
assert_eq!(topology.parents.len(), expected_parents.len());
|
||||
|
||||
// Children:
|
||||
// 0 → [1, 2]
|
||||
// 1 → [3, 4]
|
||||
// 2 → [5, 6]
|
||||
// 3,4,5,6 → []
|
||||
assert_eq!(
|
||||
topology.children.get(&0.into()),
|
||||
Some(&vec![1.into(), 2.into()])
|
||||
);
|
||||
assert_eq!(
|
||||
topology.children.get(&1.into()),
|
||||
Some(&vec![3.into(), 4.into()])
|
||||
);
|
||||
assert_eq!(
|
||||
topology.children.get(&2.into()),
|
||||
Some(&vec![5.into(), 6.into()])
|
||||
);
|
||||
// Leaves
|
||||
for leaf in 3..7 {
|
||||
assert_eq!(
|
||||
topology.children.get(&leaf.into()),
|
||||
Some(&Vec::new()),
|
||||
"leaf {leaf:?} should have no children"
|
||||
);
|
||||
}
|
||||
// Ensure we have exactly 7 entries in children
|
||||
assert_eq!(topology.children.len(), 7);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,297 @@
|
||||
use std::{collections::HashMap, marker::PhantomData, sync::Arc, time::Duration};
|
||||
|
||||
use burn_communication::{Address, CommunicationChannel, Message, ProtocolClient};
|
||||
use tokio::{
|
||||
runtime::Runtime,
|
||||
sync::{
|
||||
Mutex,
|
||||
mpsc::{Receiver, Sender},
|
||||
},
|
||||
task::JoinHandle,
|
||||
};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::global::shared::{
|
||||
CollectiveMessage, CollectiveMessageResponse, GlobalCollectiveError, RemoteRequest,
|
||||
RemoteResponse, RequestId, SessionId,
|
||||
};
|
||||
|
||||
/// Worker that handles communication with the orchestrator for global collective operations.
|
||||
pub(crate) struct GlobalClientWorker<P: ProtocolClient> {
|
||||
handle: Option<JoinHandle<Result<(), GlobalCollectiveError>>>,
|
||||
cancel_token: CancellationToken,
|
||||
request_sender: Sender<ClientRequest>,
|
||||
_phantom_data: PhantomData<P>,
|
||||
}
|
||||
|
||||
// Rename
|
||||
struct GlobalClientWorkerState {
|
||||
requests: HashMap<RequestId, Sender<RemoteResponse>>,
|
||||
}
|
||||
|
||||
impl GlobalClientWorkerState {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
requests: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ClientRequest {
|
||||
pub request: RemoteRequest,
|
||||
pub callback: Sender<RemoteResponse>,
|
||||
}
|
||||
|
||||
impl ClientRequest {
|
||||
pub(crate) fn new(request: RemoteRequest, callback: Sender<RemoteResponse>) -> Self {
|
||||
Self { request, callback }
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ProtocolClient> GlobalClientWorker<C> {
|
||||
/// Create a new global client worker and start the tasks.
|
||||
pub(crate) fn new(
|
||||
runtime: &Runtime,
|
||||
cancel_token: CancellationToken,
|
||||
global_address: &Address,
|
||||
) -> Self {
|
||||
let (request_sender, request_recv) = tokio::sync::mpsc::channel::<ClientRequest>(10);
|
||||
|
||||
let state = Arc::new(Mutex::new(GlobalClientWorkerState::new()));
|
||||
|
||||
let handle = runtime.spawn(Self::start(
|
||||
state,
|
||||
cancel_token.clone(),
|
||||
global_address.clone(),
|
||||
request_recv,
|
||||
));
|
||||
|
||||
Self {
|
||||
handle: Some(handle),
|
||||
cancel_token,
|
||||
request_sender,
|
||||
_phantom_data: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the global client tasks
|
||||
async fn start(
|
||||
state: Arc<Mutex<GlobalClientWorkerState>>,
|
||||
cancel_token: CancellationToken,
|
||||
global_address: Address,
|
||||
request_recv: Receiver<ClientRequest>,
|
||||
) -> Result<(), GlobalCollectiveError> {
|
||||
// Init the connection.
|
||||
let (request, response) = Self::init_connection(&global_address).await?;
|
||||
|
||||
// Websocket async worker loading responses from the server.
|
||||
let response_handle = tokio::spawn(Self::response_loader(
|
||||
state.clone(),
|
||||
response,
|
||||
cancel_token.clone(),
|
||||
));
|
||||
|
||||
// Channel async worker sending operations to the server.
|
||||
let request_handle = tokio::spawn(Self::request_sender(
|
||||
request_recv,
|
||||
state,
|
||||
request,
|
||||
cancel_token.clone(),
|
||||
));
|
||||
|
||||
if let Err(e) = response_handle.await {
|
||||
log::error!("Response handler failed: {e:?}");
|
||||
}
|
||||
if let Err(e) = request_handle.await {
|
||||
log::error!("Request handler failed: {e:?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn init_connection(
|
||||
address: &Address,
|
||||
) -> Result<(C::Channel, C::Channel), GlobalCollectiveError> {
|
||||
let session_id = SessionId::new();
|
||||
|
||||
let stream_request = tokio::spawn(Self::connect_with_retry(
|
||||
address.clone(),
|
||||
"request",
|
||||
std::time::Duration::from_secs(1),
|
||||
None,
|
||||
session_id,
|
||||
));
|
||||
let stream_response = tokio::spawn(Self::connect_with_retry(
|
||||
address.clone(),
|
||||
"response",
|
||||
std::time::Duration::from_secs(1),
|
||||
None,
|
||||
session_id,
|
||||
));
|
||||
|
||||
let Ok(Some(request)) = stream_request.await else {
|
||||
return Err(GlobalCollectiveError::OrchestratorUnreachable);
|
||||
};
|
||||
let Ok(Some(response)) = stream_response.await else {
|
||||
return Err(GlobalCollectiveError::OrchestratorUnreachable);
|
||||
};
|
||||
|
||||
Ok((request, response))
|
||||
}
|
||||
|
||||
/// Connect with websocket with retries.
|
||||
async fn connect_with_retry(
|
||||
address: Address,
|
||||
route: &str,
|
||||
retry_pause: Duration,
|
||||
retry_max: Option<u32>,
|
||||
session_id: SessionId,
|
||||
) -> Option<C::Channel> {
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
if let Some(max) = retry_max
|
||||
&& retries >= max
|
||||
{
|
||||
log::warn!("Failed to connect to {address} after {max} retries.");
|
||||
return None;
|
||||
}
|
||||
|
||||
// Try to connect to the request address.
|
||||
println!("Connecting to {address} ...");
|
||||
let result = C::connect(address.clone(), route).await;
|
||||
|
||||
if let Some(mut stream) = result {
|
||||
let init_msg = CollectiveMessage::Init(session_id);
|
||||
let bytes: bytes::Bytes = rmp_serde::to_vec(&init_msg).unwrap().into();
|
||||
stream
|
||||
.send(Message::new(bytes))
|
||||
.await
|
||||
.expect("Can send the init message on the websocket.");
|
||||
return Some(stream);
|
||||
}
|
||||
|
||||
println!("Failed to connect to {address}, retrying... Attempt #{retries}");
|
||||
tokio::time::sleep(retry_pause).await;
|
||||
retries += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Unregister the worker and close the connection.
|
||||
pub(crate) async fn close_connection(&mut self) -> Result<(), GlobalCollectiveError> {
|
||||
if let Some(handle) = self.handle.take() {
|
||||
// Un-register from server
|
||||
let req = RemoteRequest::Finish;
|
||||
let resp = self.request(req).await;
|
||||
if resp != RemoteResponse::FinishAck {
|
||||
log::error!("Requested to finish, did not get FinishAck; got {resp:?}");
|
||||
return Err(GlobalCollectiveError::WrongOrchestratorResponse);
|
||||
}
|
||||
|
||||
self.cancel_token.cancel();
|
||||
|
||||
if let Err(e) = handle.await.unwrap() {
|
||||
log::error!("Connection error {e:?}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn response_loader(
|
||||
state: Arc<Mutex<GlobalClientWorkerState>>,
|
||||
mut stream_response: C::Channel,
|
||||
cancel_token: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Check if the cancel token is cancelled
|
||||
_ = cancel_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
// .. Or get a message from the websocket
|
||||
response = stream_response.recv() => {
|
||||
match response {
|
||||
Err(err) => {
|
||||
log::error!("Error receiving message from websocket: {err:?}");
|
||||
break;
|
||||
}
|
||||
Ok(response) => {
|
||||
let Some(response) = response else {
|
||||
log::warn!("Closed connection");
|
||||
break;
|
||||
};
|
||||
|
||||
let response: CollectiveMessageResponse = rmp_serde::from_slice(&response.data)
|
||||
.expect("Can deserialize messages from the websocket.");
|
||||
let state_resp = state.lock().await;
|
||||
let response_callback = state_resp
|
||||
.requests
|
||||
.get(&response.request_id)
|
||||
.expect("Got a response to an unknown request");
|
||||
response_callback.send(response.content).await.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("Worker closing connection");
|
||||
stream_response
|
||||
.close()
|
||||
.await
|
||||
.expect("Can close the websocket stream.");
|
||||
}
|
||||
|
||||
async fn request_sender(
|
||||
mut request_recv: Receiver<ClientRequest>,
|
||||
worker: Arc<Mutex<GlobalClientWorkerState>>,
|
||||
mut stream_request: C::Channel,
|
||||
cancel_token: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel_token.cancelled() => {
|
||||
break;
|
||||
},
|
||||
request = request_recv.recv() => {
|
||||
let Some(request) = request else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let id = RequestId::new();
|
||||
|
||||
// Register the callback if there is one
|
||||
{
|
||||
let mut state = worker.lock().await;
|
||||
state.requests.insert(id, request.callback);
|
||||
}
|
||||
|
||||
let request = CollectiveMessage::Request(id, request.request);
|
||||
|
||||
let bytes = rmp_serde::to_vec::<CollectiveMessage>(&request)
|
||||
.expect("Can serialize tasks to bytes.")
|
||||
.into();
|
||||
stream_request
|
||||
.send(Message::new(bytes))
|
||||
.await
|
||||
.expect("Can send the message on the websocket.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("Worker closing connection");
|
||||
stream_request
|
||||
.close()
|
||||
.await
|
||||
.expect("Can send the close message on the websocket.");
|
||||
}
|
||||
|
||||
pub(crate) async fn request(&self, req: RemoteRequest) -> RemoteResponse {
|
||||
let (callback, mut response_recv) = tokio::sync::mpsc::channel::<RemoteResponse>(10);
|
||||
let client_req = ClientRequest::new(req, callback);
|
||||
self.request_sender.send(client_req).await.unwrap();
|
||||
|
||||
response_recv.recv().await.unwrap()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
use std::fmt::Debug;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::global::{
|
||||
orchestrator::state::GlobalCollectiveState,
|
||||
shared::{CollectiveMessage, GlobalCollectiveError},
|
||||
};
|
||||
use burn_communication::{
|
||||
CommunicationChannel, Message, ProtocolServer, util::os_shutdown_signal, websocket::WsServer,
|
||||
};
|
||||
|
||||
/// The global collective state manages collective operations on the global level
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct GlobalOrchestrator {
|
||||
state: Arc<Mutex<GlobalCollectiveState>>,
|
||||
}
|
||||
|
||||
impl GlobalOrchestrator {
|
||||
/// Starts the comms server with two routes: "/request" and "/response"
|
||||
pub(crate) async fn start<F, S: ProtocolServer + Debug>(
|
||||
shutdown_signal: F,
|
||||
comms_server: S,
|
||||
) -> Result<(), GlobalCollectiveError>
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
let state = GlobalCollectiveState::new();
|
||||
let server = Self {
|
||||
state: Arc::new(tokio::sync::Mutex::new(state)),
|
||||
};
|
||||
|
||||
comms_server
|
||||
.route("/response", {
|
||||
let server = server.clone();
|
||||
async move |socket| {
|
||||
if let Err(err) = server.handle_socket_response::<S>(socket).await {
|
||||
log::error!("[Response Handler] Error: {err:?}")
|
||||
}
|
||||
}
|
||||
})
|
||||
.route("/request", {
|
||||
let server = server.clone();
|
||||
async move |socket| {
|
||||
if let Err(err) = server.handle_socket_request::<S>(socket).await {
|
||||
log::error!("[Request Handler] Error: {err:?}")
|
||||
}
|
||||
}
|
||||
})
|
||||
.serve(shutdown_signal)
|
||||
.await
|
||||
.map_err(|err| GlobalCollectiveError::Server(format!("{err:?}")))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_socket_response<S: ProtocolServer>(
|
||||
self,
|
||||
mut stream: S::Channel,
|
||||
) -> Result<(), GlobalCollectiveError> {
|
||||
log::info!("[Response Handler] On new connection.");
|
||||
|
||||
let msg = stream
|
||||
.recv()
|
||||
.await
|
||||
.map_err(|err| GlobalCollectiveError::Server(format!("{err:?}")))?;
|
||||
let Some(msg) = msg else {
|
||||
log::warn!("Response socket closed early!");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let msg = rmp_serde::from_slice::<CollectiveMessage>(&msg.data)
|
||||
.map_err(|_| GlobalCollectiveError::InvalidMessage)?;
|
||||
|
||||
let CollectiveMessage::Init(id) = msg else {
|
||||
return Err(GlobalCollectiveError::FirstMsgNotInit);
|
||||
};
|
||||
|
||||
let mut receiver = {
|
||||
let mut state = self.state.lock().await;
|
||||
state.get_session_responder(id)
|
||||
};
|
||||
|
||||
while let Some(response) = receiver.recv().await {
|
||||
let bytes = rmp_serde::to_vec(&response).unwrap();
|
||||
|
||||
stream.send(Message::new(bytes.into())).await?;
|
||||
}
|
||||
|
||||
log::info!("[Response Handler] Closing connection.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_socket_request<S: ProtocolServer>(
|
||||
self,
|
||||
mut stream: S::Channel,
|
||||
) -> Result<(), GlobalCollectiveError> {
|
||||
log::info!("[Request Handler] On new connection.");
|
||||
|
||||
let mut session_id = None;
|
||||
|
||||
loop {
|
||||
let packet = stream.recv().await?;
|
||||
let Some(msg) = packet else {
|
||||
log::info!("Peer closed the connection");
|
||||
break;
|
||||
};
|
||||
|
||||
let mut state = self.state.lock().await;
|
||||
|
||||
let msg = rmp_serde::from_slice::<CollectiveMessage>(&msg.data)
|
||||
.map_err(|_| GlobalCollectiveError::InvalidMessage)?;
|
||||
match msg {
|
||||
CollectiveMessage::Init(id) => {
|
||||
state.init_session(id);
|
||||
session_id = Some(id);
|
||||
}
|
||||
CollectiveMessage::Request(request_id, remote_request) => {
|
||||
let session_id = session_id.ok_or(GlobalCollectiveError::FirstMsgNotInit)?;
|
||||
state
|
||||
.process_request(session_id, request_id, remote_request)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a global orchestrator with WebSocket on the given port
|
||||
pub async fn start_global_orchestrator(port: u16) {
|
||||
let server = WsServer::new(port);
|
||||
let res = GlobalOrchestrator::start(os_shutdown_signal(), server).await;
|
||||
if let Err(err) = res {
|
||||
log::error!("Global Collective Orchestrator error: {err:?}");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
pub(crate) mod base;
|
||||
pub(crate) mod state;
|
||||
|
||||
pub use base::start_global_orchestrator;
|
||||
@@ -0,0 +1,219 @@
|
||||
use crate::{
|
||||
PeerId,
|
||||
global::{
|
||||
NodeId,
|
||||
shared::{
|
||||
CollectiveMessageResponse, GlobalCollectiveError, RemoteRequest, RemoteResponse,
|
||||
RequestId, SessionId,
|
||||
},
|
||||
},
|
||||
};
|
||||
use burn_communication::Address;
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::mpsc::{Receiver, Sender};
|
||||
|
||||
pub(crate) struct Session {
|
||||
response_sender: Sender<CollectiveMessageResponse>,
|
||||
response_receiver: Option<Receiver<CollectiveMessageResponse>>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
fn new() -> Self {
|
||||
let (response_sender, recv) = tokio::sync::mpsc::channel::<CollectiveMessageResponse>(1);
|
||||
Self {
|
||||
response_sender,
|
||||
response_receiver: Some(recv),
|
||||
}
|
||||
}
|
||||
|
||||
async fn respond(&mut self, response: CollectiveMessageResponse) {
|
||||
self.response_sender.send(response).await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct GlobalCollectiveState {
|
||||
/// The ids passed to each register so far, and their addresses
|
||||
registered_nodes: HashMap<SessionId, NodeId>,
|
||||
/// Address for each node
|
||||
node_addresses: HashMap<NodeId, Address>,
|
||||
/// Peer on each node
|
||||
node_peers: HashMap<NodeId, Vec<PeerId>>,
|
||||
|
||||
/// How many total nodes for the current register operation, as defined by the first caller
|
||||
cur_num_nodes: Option<u32>,
|
||||
/// How many peers have registered total
|
||||
num_global_peers: u32,
|
||||
|
||||
register_requests: Vec<(SessionId, RequestId, NodeId)>,
|
||||
|
||||
sessions: HashMap<SessionId, Session>,
|
||||
}
|
||||
|
||||
impl GlobalCollectiveState {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
registered_nodes: HashMap::new(),
|
||||
node_addresses: HashMap::new(),
|
||||
node_peers: HashMap::new(),
|
||||
cur_num_nodes: None,
|
||||
num_global_peers: 0,
|
||||
register_requests: Vec::new(),
|
||||
sessions: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn init_session(&mut self, id: SessionId) {
|
||||
if self.sessions.contains_key(&id) {
|
||||
return;
|
||||
}
|
||||
self.sessions.insert(id, Session::new());
|
||||
}
|
||||
|
||||
/// Create the session with given id if necessary, and get the response receiver
|
||||
pub(crate) fn get_session_responder(
|
||||
&mut self,
|
||||
id: SessionId,
|
||||
) -> Receiver<CollectiveMessageResponse> {
|
||||
self.init_session(id);
|
||||
let session = self.sessions.get_mut(&id).unwrap();
|
||||
let response_recv = session.response_receiver.take();
|
||||
|
||||
response_recv.unwrap()
|
||||
}
|
||||
|
||||
pub(crate) async fn respond(
|
||||
&mut self,
|
||||
session_id: SessionId,
|
||||
response: CollectiveMessageResponse,
|
||||
) {
|
||||
let session = self.sessions.get_mut(&session_id).unwrap();
|
||||
session.respond(response).await;
|
||||
}
|
||||
|
||||
/// Process an incoming node's request
|
||||
pub(crate) async fn process_request(
|
||||
&mut self,
|
||||
session_id: SessionId,
|
||||
request_id: RequestId,
|
||||
request: RemoteRequest,
|
||||
) {
|
||||
if let Err(err) = match request {
|
||||
RemoteRequest::Register {
|
||||
node_addr,
|
||||
num_nodes,
|
||||
peers,
|
||||
} => {
|
||||
self.register(session_id, request_id, node_addr, num_nodes, peers)
|
||||
.await
|
||||
}
|
||||
RemoteRequest::Finish => self.finish(session_id, request_id).await,
|
||||
} {
|
||||
// Error occurred, send it as response
|
||||
let content = RemoteResponse::Error(err);
|
||||
self.respond(
|
||||
session_id,
|
||||
CollectiveMessageResponse {
|
||||
request_id,
|
||||
content,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Un-register a node. Any pending requests will be cancelled, returning error responses.
|
||||
async fn finish(
|
||||
&mut self,
|
||||
session_id: SessionId,
|
||||
request_id: RequestId,
|
||||
) -> Result<(), GlobalCollectiveError> {
|
||||
let node_id = self
|
||||
.registered_nodes
|
||||
.remove(&session_id)
|
||||
.ok_or(GlobalCollectiveError::NotRegisteredOnFinish)?;
|
||||
self.node_addresses.remove(&node_id);
|
||||
self.node_peers.remove(&node_id);
|
||||
self.num_global_peers = 0;
|
||||
|
||||
let mut register_requests = vec![];
|
||||
core::mem::swap(&mut register_requests, &mut self.register_requests);
|
||||
for (session, req, node_id) in register_requests {
|
||||
if session == session_id {
|
||||
// Send a response if we are finishing a session with a pending register request
|
||||
let content = RemoteResponse::Error(GlobalCollectiveError::PendingRegisterOnFinish);
|
||||
let response = CollectiveMessageResponse {
|
||||
request_id: req,
|
||||
content,
|
||||
};
|
||||
self.respond(session_id, response).await;
|
||||
} else {
|
||||
// keep the register request
|
||||
self.register_requests.push((session, req, node_id));
|
||||
}
|
||||
}
|
||||
|
||||
self.respond(
|
||||
session_id,
|
||||
CollectiveMessageResponse {
|
||||
request_id,
|
||||
content: RemoteResponse::FinishAck,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn register(
|
||||
&mut self,
|
||||
session_id: SessionId,
|
||||
request_id: RequestId,
|
||||
node_addr: Address,
|
||||
num_nodes: u32,
|
||||
peers: Vec<PeerId>,
|
||||
) -> Result<(), GlobalCollectiveError> {
|
||||
match &self.cur_num_nodes {
|
||||
Some(cur_num_nodes) => {
|
||||
if *cur_num_nodes != num_nodes {
|
||||
return Err(GlobalCollectiveError::RegisterParamsMismatch);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
self.cur_num_nodes = Some(num_nodes);
|
||||
}
|
||||
}
|
||||
|
||||
self.num_global_peers += peers.len() as u32;
|
||||
|
||||
let node_id: NodeId = self.registered_nodes.len().into();
|
||||
self.registered_nodes.insert(session_id, node_id);
|
||||
if self.node_addresses.values().any(|addr| node_addr == *addr) {
|
||||
return Err(GlobalCollectiveError::DoubleRegister);
|
||||
}
|
||||
self.node_addresses.insert(node_id, node_addr);
|
||||
self.node_peers.insert(node_id, peers);
|
||||
|
||||
self.register_requests
|
||||
.push((session_id, request_id, node_id));
|
||||
|
||||
if self.registered_nodes.len() == num_nodes as usize {
|
||||
let mut callbacks = vec![];
|
||||
core::mem::swap(&mut callbacks, &mut self.register_requests);
|
||||
|
||||
for (session, request, node_id) in callbacks {
|
||||
let content = RemoteResponse::Register {
|
||||
node_id,
|
||||
nodes: self.node_addresses.clone(),
|
||||
num_global_devices: self.num_global_peers,
|
||||
};
|
||||
let resp = CollectiveMessageResponse {
|
||||
request_id: request,
|
||||
content,
|
||||
};
|
||||
self.respond(session, resp).await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
use std::{collections::HashMap, sync::atomic::AtomicU32};
|
||||
|
||||
use crate::{NodeId, PeerId};
|
||||
use burn_communication::{Address, CommunicationError};
|
||||
use burn_std::id::IdGenerator;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A unique identifier for each request made to a global orchestrator
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub(crate) struct RequestId(u32);
|
||||
|
||||
static REQ_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
|
||||
impl RequestId {
|
||||
pub(crate) fn new() -> Self {
|
||||
let id = REQ_ID_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
Self(id)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RequestId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Unique identifier that can represent a session between a node and the orchestrator.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub(crate) struct SessionId {
|
||||
id: u64,
|
||||
}
|
||||
|
||||
impl SessionId {
|
||||
/// Create a new [session id](SessionId).
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
id: IdGenerator::generate(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Requests sent from the client
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) enum CollectiveMessage {
|
||||
Init(SessionId),
|
||||
Request(RequestId, RemoteRequest),
|
||||
}
|
||||
|
||||
/// Responses sent to the client
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct CollectiveMessageResponse {
|
||||
pub request_id: RequestId,
|
||||
pub content: RemoteResponse,
|
||||
}
|
||||
|
||||
/// Requests made from a client to a server.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) enum RemoteRequest {
|
||||
// Register a node
|
||||
Register {
|
||||
/// Endpoint for this node
|
||||
node_addr: Address,
|
||||
/// Number of total nodes
|
||||
num_nodes: u32,
|
||||
/// List of peers on this node
|
||||
peers: Vec<PeerId>,
|
||||
},
|
||||
|
||||
/// Unregister node
|
||||
Finish,
|
||||
}
|
||||
|
||||
/// Responses for each server request
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub(crate) enum RemoteResponse {
|
||||
/// Response to a register request
|
||||
Register {
|
||||
/// The orchestrator gives the node its id
|
||||
node_id: NodeId,
|
||||
/// All the nodes in the collective: including self
|
||||
nodes: HashMap<NodeId, Address>,
|
||||
/// How many devices exist globally? For averaging values
|
||||
num_global_devices: u32,
|
||||
},
|
||||
|
||||
// Finish
|
||||
FinishAck,
|
||||
|
||||
// There was a server-side error
|
||||
Error(GlobalCollectiveError),
|
||||
}
|
||||
|
||||
/// Errors that occur during collective operations on the global level
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum GlobalCollectiveError {
|
||||
/// Operations that can't be done before registering
|
||||
AllReduceBeforeRegister,
|
||||
/// Ring all-reduce can't be done if all tensor dimensions are smaller than the number of nodes.
|
||||
RingReduceImpossible,
|
||||
|
||||
/// Either a node has unregistered twice, or a Finish has been called before a Register
|
||||
NotRegisteredOnFinish,
|
||||
/// Finish has been called before a Register operation was finished
|
||||
PendingRegisterOnFinish,
|
||||
/// Trying to register a different way than is currently being done
|
||||
RegisterParamsMismatch,
|
||||
/// Trying to register while already registered
|
||||
DoubleRegister,
|
||||
/// Trying to aggregate a different way than is currently being done
|
||||
AllReduceParamsMismatch,
|
||||
|
||||
/// First message on socket should be Message::Init
|
||||
FirstMsgNotInit,
|
||||
/// Messages should be rmp_serde serialized `Message` types
|
||||
InvalidMessage,
|
||||
/// A peer behaved unexpectedly
|
||||
PeerSentIncoherentTensor,
|
||||
/// Tried to download from a peer, but the peer closed or lost the connection
|
||||
PeerLost(NodeId),
|
||||
/// Error from the coordinator
|
||||
Server(String),
|
||||
|
||||
/// The node received an invalid response
|
||||
WrongOrchestratorResponse,
|
||||
/// Node couldn't connect to coordinator
|
||||
OrchestratorUnreachable,
|
||||
}
|
||||
|
||||
impl<E: CommunicationError> From<E> for GlobalCollectiveError {
|
||||
fn from(err: E) -> Self {
|
||||
Self::Server(format!("{err:?}"))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
mod global;
|
||||
pub use global::*;
|
||||
|
||||
mod config;
|
||||
pub use config::*;
|
||||
|
||||
mod api;
|
||||
pub use api::*;
|
||||
|
||||
mod local;
|
||||
|
||||
#[cfg(all(
|
||||
test,
|
||||
any(
|
||||
feature = "test-ndarray",
|
||||
feature = "test-wgpu",
|
||||
feature = "test-cuda",
|
||||
feature = "test-metal"
|
||||
)
|
||||
))]
|
||||
mod tests;
|
||||
@@ -0,0 +1,118 @@
|
||||
use crate::local::tensor_map::{CollectiveTensorMap, get_peer_devices};
|
||||
use crate::{
|
||||
AllReduceStrategy, CollectiveConfig, CollectiveError, ReduceOperation,
|
||||
local::{
|
||||
all_reduce_sum_centralized, all_reduce_sum_ring, all_reduce_sum_tree,
|
||||
broadcast_centralized, broadcast_tree, reduce_sum_centralized, reduce_sum_tree,
|
||||
},
|
||||
node::base::Node,
|
||||
};
|
||||
use burn_communication::Protocol;
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
use tracing::Instrument;
|
||||
|
||||
/// Perform an all-reduce with no multi-node operations (global ops)
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip(tensors, config))
|
||||
)]
|
||||
pub(crate) async fn all_reduce_local_only<B: Backend>(
|
||||
tensors: CollectiveTensorMap<B>,
|
||||
op: ReduceOperation,
|
||||
config: &CollectiveConfig,
|
||||
) -> Result<CollectiveTensorMap<B>, CollectiveError> {
|
||||
let local_strategy = &config.local_all_reduce_strategy;
|
||||
|
||||
let mut reduced_tensors = match local_strategy {
|
||||
AllReduceStrategy::Centralized => all_reduce_sum_centralized::<B>(tensors),
|
||||
AllReduceStrategy::Tree(arity) => all_reduce_sum_tree::<B>(tensors, *arity),
|
||||
AllReduceStrategy::Ring => all_reduce_sum_ring::<B>(tensors),
|
||||
};
|
||||
|
||||
if op == ReduceOperation::Mean {
|
||||
#[cfg(feature = "tracing")]
|
||||
let _span = tracing::info_span!("mean_reduction").entered();
|
||||
|
||||
// Apply mean division
|
||||
let div = (reduced_tensors.len() as f32).into();
|
||||
|
||||
reduced_tensors = reduced_tensors
|
||||
.into_iter()
|
||||
.map(|(id, t)| (id, B::float_div_scalar(t, div)))
|
||||
.collect();
|
||||
}
|
||||
Ok(reduced_tensors)
|
||||
}
|
||||
|
||||
/// Do an all-reduce in a multi-node context
|
||||
///
|
||||
/// With Tree and Centralized strategies, the all-reduce is split between a
|
||||
/// reduce (all tensors are reduced to one device), and a broadcast (the result is sent to all
|
||||
/// other devices). The all-reduce on the global level is done between both steps.
|
||||
/// Due to the nature of the Ring strategy, this separation can't be done.
|
||||
///
|
||||
/// For the Ring strategy, this isn't possible, because it is more like a
|
||||
/// reduce-scatter plus an all-gather, so using a Ring strategy locally in a multi-node
|
||||
/// setup may be unadvantageous.
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip(tensors, config, global_client))
|
||||
)]
|
||||
pub(crate) async fn all_reduce_with_global<B: Backend, P: Protocol>(
|
||||
tensors: CollectiveTensorMap<B>,
|
||||
op: ReduceOperation,
|
||||
config: &CollectiveConfig,
|
||||
global_client: &mut Node<B, P>,
|
||||
) -> Result<CollectiveTensorMap<B>, CollectiveError> {
|
||||
let peer_devices = get_peer_devices::<B>(&tensors);
|
||||
|
||||
// For Centralized and Tree, we only need to do a reduce here, we'll do a broadcast later
|
||||
let main_device = *tensors.keys().next().unwrap();
|
||||
|
||||
let mut main_tensor = match config.local_all_reduce_strategy {
|
||||
AllReduceStrategy::Centralized => reduce_sum_centralized::<B>(tensors, &main_device),
|
||||
AllReduceStrategy::Tree(arity) => reduce_sum_tree::<B>(tensors, &main_device, arity),
|
||||
AllReduceStrategy::Ring => all_reduce_sum_ring::<B>(tensors)
|
||||
.remove(&main_device)
|
||||
.unwrap(),
|
||||
};
|
||||
|
||||
// Do aggregation on global level with the main tensor
|
||||
main_tensor = {
|
||||
let fut = async {
|
||||
let global_strategy = config
|
||||
.global_all_reduce_strategy
|
||||
.expect("global_all_reduce_strategy must be set");
|
||||
|
||||
global_client
|
||||
.all_reduce(main_tensor, global_strategy, op)
|
||||
.await
|
||||
};
|
||||
#[cfg(feature = "tracing")]
|
||||
{
|
||||
fut.instrument(tracing::info_span!("global_all_reduce"))
|
||||
}
|
||||
#[cfg(not(feature = "tracing"))]
|
||||
{
|
||||
fut
|
||||
}
|
||||
}
|
||||
.await
|
||||
.map_err(CollectiveError::Global)?;
|
||||
|
||||
// Broadcast result to all devices
|
||||
let tensors = match config.local_all_reduce_strategy {
|
||||
AllReduceStrategy::Tree(arity) => {
|
||||
broadcast_tree::<B>(peer_devices, main_device, main_tensor, arity)
|
||||
}
|
||||
// If we chose the ring strategy and we must still broadcast the global result,
|
||||
// we use the centralized strategy for broadcasting, but the tree may be better.
|
||||
AllReduceStrategy::Centralized | AllReduceStrategy::Ring => {
|
||||
broadcast_centralized::<B>(peer_devices, main_device, main_tensor)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(tensors)
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
use crate::local::tensor_map::{CollectiveTensorMap, get_peer_devices};
|
||||
use crate::local::{broadcast_centralized, reduce_sum_centralized};
|
||||
|
||||
/// Perform an all-reduce operation by reducing all tensors on one device, and broadcasting the
|
||||
/// result to all other devices
|
||||
///
|
||||
/// Internally, this is just a call to `reduce` followed by a `broadcast`
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip(tensors))
|
||||
)]
|
||||
pub(crate) fn all_reduce_sum_centralized<B: Backend>(
|
||||
tensors: CollectiveTensorMap<B>,
|
||||
) -> CollectiveTensorMap<B> {
|
||||
// Get corresponding devices for each peer
|
||||
let peer_devices = get_peer_devices::<B>(&tensors);
|
||||
let central_device = *tensors.keys().next().unwrap();
|
||||
|
||||
// Reduce to central device
|
||||
let central_tensor = reduce_sum_centralized::<B>(tensors, ¢ral_device);
|
||||
|
||||
// Broadcast result to all
|
||||
broadcast_centralized::<B>(peer_devices, central_device, central_tensor)
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
mod base;
|
||||
mod centralized;
|
||||
mod op;
|
||||
mod ring;
|
||||
mod tree;
|
||||
|
||||
pub(crate) use base::*;
|
||||
pub(crate) use centralized::*;
|
||||
pub(crate) use op::*;
|
||||
pub(crate) use ring::*;
|
||||
pub(crate) use tree::*;
|
||||
@@ -0,0 +1,141 @@
|
||||
use crate::global::node::base::Node;
|
||||
use crate::local::tensor_map::CollectiveTensorMap;
|
||||
use crate::{CollectiveConfig, CollectiveError, PeerId, ReduceOperation, local};
|
||||
use burn_communication::Protocol;
|
||||
use burn_std::Shape;
|
||||
use burn_tensor::TensorMetadata;
|
||||
use burn_tensor::backend::Backend;
|
||||
use std::sync::mpsc::SyncSender;
|
||||
|
||||
/// An on-going all-reduce operation
|
||||
#[derive(Debug)]
|
||||
pub struct AllReduceOp<B: Backend> {
|
||||
/// all-reduce calls, one for each calling device
|
||||
calls: Vec<AllReduceOpCall<B>>,
|
||||
/// The reduce operation of the current all-reduce, as defined by the first caller
|
||||
op: ReduceOperation,
|
||||
/// The shape of the current all-reduce, as defined by the first caller
|
||||
shape: Shape,
|
||||
}
|
||||
|
||||
/// Struct for each device that calls an all-reduce operation
|
||||
#[derive(Debug)]
|
||||
pub struct AllReduceOpCall<B: Backend> {
|
||||
/// Id of the caller for this operation
|
||||
caller: PeerId,
|
||||
/// The tensor primitive passed as input
|
||||
input: B::FloatTensorPrimitive,
|
||||
/// Callback for the result of the all-reduce
|
||||
result_sender: SyncSender<AllReduceResult<B::FloatTensorPrimitive>>,
|
||||
}
|
||||
|
||||
/// Type sent to the collective client upon completion of a all-reduce aggregation
|
||||
pub(crate) type AllReduceResult<T> = Result<T, CollectiveError>;
|
||||
|
||||
impl<B: Backend> AllReduceOp<B> {
|
||||
pub fn new(shape: Shape, reduce_op: ReduceOperation) -> Self {
|
||||
Self {
|
||||
calls: vec![],
|
||||
op: reduce_op,
|
||||
shape,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a list of the peers.
|
||||
fn peers(&self) -> Vec<PeerId> {
|
||||
self.calls.iter().map(|c| c.caller).collect()
|
||||
}
|
||||
|
||||
/// Register a call to all-reduce in this operation.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `true` if enough peers have registered, and the all-reduce is ready
|
||||
pub fn register_call(
|
||||
&mut self,
|
||||
caller: PeerId,
|
||||
input: B::FloatTensorPrimitive,
|
||||
result_sender: SyncSender<AllReduceResult<B::FloatTensorPrimitive>>,
|
||||
op: ReduceOperation,
|
||||
peer_count: usize,
|
||||
) -> Result<bool, CollectiveError> {
|
||||
if self.shape != input.shape() {
|
||||
return Err(CollectiveError::AllReduceShapeMismatch);
|
||||
}
|
||||
if self.op != op {
|
||||
return Err(CollectiveError::AllReduceOperationMismatch);
|
||||
}
|
||||
|
||||
self.calls.push(AllReduceOpCall {
|
||||
caller,
|
||||
input,
|
||||
result_sender,
|
||||
});
|
||||
|
||||
Ok(self.calls.len() == peer_count)
|
||||
}
|
||||
|
||||
/// Runs the all-reduce if the operation is ready. Otherwise, do nothing
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(
|
||||
level = "trace",
|
||||
skip(self, config, global_client),
|
||||
fields(
|
||||
?self.op,
|
||||
?self.shape,
|
||||
self.peers = ?self.peers(),
|
||||
)
|
||||
))]
|
||||
pub async fn execute<P: Protocol>(
|
||||
mut self,
|
||||
config: &CollectiveConfig,
|
||||
global_client: &mut Option<Node<B, P>>,
|
||||
) {
|
||||
// all registered callers have sent a tensor to aggregate
|
||||
match self.all_reduce(config, global_client).await {
|
||||
Ok(mut tensors) => {
|
||||
// Return resulting tensors
|
||||
self.calls.iter().for_each(|call| {
|
||||
let result = tensors
|
||||
.remove(&call.caller)
|
||||
.expect("tensor/peer internal mismatch.");
|
||||
call.result_sender.send(Ok(result)).unwrap();
|
||||
});
|
||||
assert_eq!(tensors.len(), 0, "tensor/peer internal mismatch.");
|
||||
}
|
||||
Err(err) => {
|
||||
// Send error to all subscribers
|
||||
self.fail(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform an all-reduce operation.
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip(self, config, global_client))
|
||||
)]
|
||||
async fn all_reduce<P: Protocol>(
|
||||
&mut self,
|
||||
config: &CollectiveConfig,
|
||||
global_client: &mut Option<Node<B, P>>,
|
||||
) -> Result<CollectiveTensorMap<B>, CollectiveError> {
|
||||
let tensors = self
|
||||
.calls
|
||||
.iter()
|
||||
.map(|call| (call.caller, call.input.clone()))
|
||||
.collect();
|
||||
|
||||
if let Some(global_client) = global_client.as_mut() {
|
||||
local::all_reduce_with_global(tensors, self.op, config, global_client).await
|
||||
} else {
|
||||
local::all_reduce_local_only::<B>(tensors, self.op, config).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a collective error as result to operation caller
|
||||
pub fn fail(self, err: CollectiveError) {
|
||||
self.calls.iter().for_each(|op| {
|
||||
op.result_sender.send(Err(err.clone())).unwrap();
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,194 @@
|
||||
use super::tree::all_reduce_sum_tree;
|
||||
use crate::PeerId;
|
||||
use crate::local::tensor_map;
|
||||
use crate::local::tensor_map::CollectiveTensorMap;
|
||||
use burn_tensor::{Shape, Slice, TensorMetadata, backend::Backend};
|
||||
use std::{collections::HashMap, ops::Range};
|
||||
|
||||
/// Ring implementation of All-Reduce (Ring-Reduce)
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip(tensors))
|
||||
)]
|
||||
pub(crate) fn all_reduce_sum_ring<B: Backend>(
|
||||
tensors: CollectiveTensorMap<B>,
|
||||
) -> CollectiveTensorMap<B> {
|
||||
// https://blog.dailydoseofds.com/p/all-reduce-and-ring-reduce-for-model
|
||||
|
||||
// Example: tensors=3, slices=3
|
||||
|
||||
// phase 1
|
||||
// o->o o
|
||||
// o o->oå
|
||||
// o o o->
|
||||
|
||||
// o 1->o
|
||||
// o o 1->
|
||||
// 1->o o
|
||||
|
||||
// o 1 2
|
||||
// 2 o 1
|
||||
// 1 2 o
|
||||
|
||||
// phase 2
|
||||
// o 1 2->
|
||||
// 2->o 1
|
||||
// 1 2->o
|
||||
|
||||
// 2->1 2
|
||||
// 2 2->1
|
||||
// 1 2 2->
|
||||
|
||||
// 2 2 2
|
||||
// 2 2 2
|
||||
// 2 2 2
|
||||
|
||||
// Verify all shapes are the same
|
||||
let shape = tensor_map::get_common_shape::<B>(&tensors)
|
||||
.expect("Cannot aggregate tensors with different sizes");
|
||||
|
||||
// Chose an axis
|
||||
let slice_dim = get_slice_dim(&shape);
|
||||
|
||||
let slice_dim_size = shape[slice_dim];
|
||||
let tensor_count = tensors.len();
|
||||
if slice_dim_size < tensor_count {
|
||||
// Tensor cannot be split into N slices! Use a fallback algorithm: binary tree
|
||||
return all_reduce_sum_tree::<B>(tensors, 2);
|
||||
}
|
||||
|
||||
// Split tensors into slices
|
||||
let mut sliced_tensors = slice_tensors::<B>(tensors, shape, slice_dim);
|
||||
|
||||
// phase 1: aggregate in ring N-1 times (Reduce-Scatter)
|
||||
ring_cycles::<B>(&mut sliced_tensors, true);
|
||||
|
||||
// phase 2: share (overwrite) in a ring N-1 times (All-Gather)
|
||||
ring_cycles::<B>(&mut sliced_tensors, false);
|
||||
|
||||
// merge slices and put back in result
|
||||
sliced_tensors
|
||||
.into_iter()
|
||||
.map(|(id, slices)| (id, B::float_cat(slices, slice_dim)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get the dimension to slice across: the largest dimension of the shape
|
||||
pub(crate) fn get_slice_dim(shape: &Shape) -> usize {
|
||||
// get dimension with the greatest size.
|
||||
shape
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.cmp(b))
|
||||
.map(|(index, _)| index)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// With a ring of N tensors, send the tensors N-1 times, either for the first of second phase.
|
||||
/// During the first phase, the tensor slices are summed.
|
||||
/// During the second, the slices are replaced.
|
||||
fn ring_cycles<B: Backend>(
|
||||
sliced_tensors: &mut [(PeerId, Vec<B::FloatTensorPrimitive>)],
|
||||
is_phase_one: bool,
|
||||
) {
|
||||
let tensor_count = sliced_tensors.len();
|
||||
for cycle in 0..(tensor_count - 1) {
|
||||
for i in 0..tensor_count {
|
||||
let src_tensor_idx = i;
|
||||
let dest_tensor_idx = (i + 1) % tensor_count;
|
||||
|
||||
let slice_idx = if is_phase_one {
|
||||
(i + (tensor_count - 1) * cycle) % tensor_count
|
||||
} else {
|
||||
// in phase 2, the starting slice is different (see diagrams)
|
||||
(i + 1 + (tensor_count - 1) * cycle) % tensor_count
|
||||
};
|
||||
|
||||
let src_slice = sliced_tensors[src_tensor_idx].1.remove(slice_idx);
|
||||
let mut dest_slice = sliced_tensors[dest_tensor_idx].1.remove(slice_idx);
|
||||
|
||||
let dest_device = B::float_device(&dest_slice);
|
||||
let src_slice_on_dest = B::float_to_device(src_slice.clone(), &dest_device);
|
||||
if is_phase_one {
|
||||
dest_slice = B::float_add(dest_slice, src_slice_on_dest);
|
||||
} else {
|
||||
let slices: Vec<Slice> = dest_slice
|
||||
.shape()
|
||||
.iter()
|
||||
.map(|&d| Slice::new(0, Some(d as isize), 1))
|
||||
.collect();
|
||||
|
||||
// in phase 2, we don't sum the two slices, we replace with the new one.
|
||||
dest_slice =
|
||||
B::float_slice_assign(dest_slice, slices.as_slice(), src_slice_on_dest);
|
||||
}
|
||||
|
||||
sliced_tensors[src_tensor_idx]
|
||||
.1
|
||||
.insert(slice_idx, src_slice);
|
||||
sliced_tensors[dest_tensor_idx]
|
||||
.1
|
||||
.insert(slice_idx, dest_slice);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Slice a list of tensors the same way, evenly across a given dimension.
|
||||
/// The given `shape` should be the same for every tensor.
|
||||
fn slice_tensors<B: Backend>(
|
||||
mut tensors: HashMap<PeerId, B::FloatTensorPrimitive>,
|
||||
shape: Shape,
|
||||
slice_dim: usize,
|
||||
) -> Vec<(PeerId, Vec<<B as Backend>::FloatTensorPrimitive>)> {
|
||||
// Get slice index ranges
|
||||
let ranges = get_ring_reduce_slice_ranges(shape[slice_dim], tensors.len());
|
||||
|
||||
// Slice tensors
|
||||
let mut sliced_tensors = vec![];
|
||||
for (id, tensor) in tensors.drain() {
|
||||
let mut slices = vec![];
|
||||
for range in &ranges {
|
||||
let full_range = shape
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(dim_idx, dim)| {
|
||||
if dim_idx == slice_dim {
|
||||
Slice::from(range.clone())
|
||||
} else {
|
||||
Slice::from(0..*dim)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let slice = B::float_slice(tensor.clone(), &full_range);
|
||||
slices.push(slice);
|
||||
}
|
||||
sliced_tensors.push((id, slices));
|
||||
}
|
||||
|
||||
sliced_tensors
|
||||
}
|
||||
|
||||
/// Get the index ranges for the slices to split a tensor evently across a given axis.
|
||||
///
|
||||
/// * `slice_dim_size` - The size of the dim to slice on
|
||||
/// * `slice_count` - The number of slices
|
||||
///
|
||||
/// Returns a vector of index ranges for each slice.
|
||||
pub(crate) fn get_ring_reduce_slice_ranges(
|
||||
slice_dim_size: usize,
|
||||
slice_count: usize,
|
||||
) -> Vec<Range<usize>> {
|
||||
let mut ranges: Vec<Range<usize>> = vec![];
|
||||
|
||||
let slice_size = slice_dim_size.div_ceil(slice_count);
|
||||
|
||||
for i in 0..slice_count {
|
||||
let start = i * slice_size;
|
||||
let end = start + slice_size;
|
||||
|
||||
ranges.push(Range { start, end });
|
||||
}
|
||||
ranges.last_mut().unwrap().end = slice_dim_size;
|
||||
|
||||
ranges
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
use crate::PeerId;
|
||||
use crate::local::tensor_map::CollectiveTensorMap;
|
||||
use burn_tensor::backend::{Backend, DeviceOps};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Performs an all-reduce on the provided tensors in a b-tree structure with `arity`.
|
||||
/// Similar to [reduce_sum_tree](reduce_sum_tree), but this function broadcasts the result with
|
||||
/// the same tree algorithm.
|
||||
/// The returned tensors are on the same devices as the corresponding inputs
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip(tensors))
|
||||
)]
|
||||
pub(crate) fn all_reduce_sum_tree<B: Backend>(
|
||||
tensors: CollectiveTensorMap<B>,
|
||||
arity: u32,
|
||||
) -> CollectiveTensorMap<B> {
|
||||
let mut input = tensors.into_iter().collect::<Vec<_>>();
|
||||
|
||||
// Sort to put devices of the same type together
|
||||
input.sort_by(|a, b| {
|
||||
let dev_a = B::float_device(&a.1);
|
||||
let dev_b = B::float_device(&b.1);
|
||||
dev_a.id().cmp(&dev_b.id())
|
||||
});
|
||||
// Recursive all-reduce
|
||||
let out = all_reduce_sum_tree_inner::<B>(input, arity);
|
||||
|
||||
let mut tensors = HashMap::new();
|
||||
for (id, tensor) in out {
|
||||
tensors.insert(id, tensor);
|
||||
}
|
||||
tensors
|
||||
}
|
||||
|
||||
/// Recursive function that sums `tensors` and redistributes the result to the host devices
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip(tensors))
|
||||
)]
|
||||
fn all_reduce_sum_tree_inner<B: Backend>(
|
||||
mut tensors: Vec<(PeerId, B::FloatTensorPrimitive)>,
|
||||
arity: u32,
|
||||
) -> Vec<(PeerId, B::FloatTensorPrimitive)> {
|
||||
let mut parent_tensors = vec![];
|
||||
let mut children_groups = vec![];
|
||||
|
||||
// Phase 1: Sum tensors in groups of `arity` + 1
|
||||
while !tensors.is_empty() {
|
||||
// Maps ids to devices for each child of this parent
|
||||
let mut children = vec![];
|
||||
let (parent, mut parent_tensor) = tensors.remove(0);
|
||||
let parent_device = B::float_device(&parent_tensor);
|
||||
|
||||
for _ in 0..arity {
|
||||
if tensors.is_empty() {
|
||||
break;
|
||||
}
|
||||
let (child, mut child_tensor) = tensors.remove(0);
|
||||
let child_device = B::float_device(&child_tensor);
|
||||
children.push((child, child_device));
|
||||
child_tensor = B::float_to_device(child_tensor, &parent_device);
|
||||
parent_tensor = B::float_add(parent_tensor, child_tensor);
|
||||
}
|
||||
|
||||
parent_tensors.push((parent, parent_tensor));
|
||||
children_groups.push(children);
|
||||
}
|
||||
|
||||
if parent_tensors.len() > 1 {
|
||||
// Parents are not yet at the root, do the upper part of the tree
|
||||
parent_tensors = all_reduce_sum_tree_inner::<B>(parent_tensors, arity);
|
||||
}
|
||||
|
||||
// Phase 2: Redistribute result from each parent to the respective devices
|
||||
for (parent, parent_tensor) in parent_tensors {
|
||||
let children = children_groups.remove(0);
|
||||
for (child, child_device) in children {
|
||||
// replace child tensors with result
|
||||
tensors.push((
|
||||
child,
|
||||
B::float_to_device(parent_tensor.clone(), &child_device),
|
||||
));
|
||||
}
|
||||
tensors.push((parent, parent_tensor));
|
||||
}
|
||||
|
||||
tensors
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::PeerId;
|
||||
use crate::local::tensor_map::{CollectiveTensorMap, PeerDeviceMap};
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
/// Broadcasts the tensor from one device in a map to all the others
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip(devices, tensor))
|
||||
)]
|
||||
pub(crate) fn broadcast_centralized<B: Backend>(
|
||||
mut devices: PeerDeviceMap<B>,
|
||||
central: PeerId,
|
||||
tensor: B::FloatTensorPrimitive,
|
||||
) -> CollectiveTensorMap<B> {
|
||||
let mut output = HashMap::new();
|
||||
|
||||
devices
|
||||
.remove(¢ral)
|
||||
.expect("Central device id is in `devices`");
|
||||
for (dest, dest_device) in devices {
|
||||
let tensor = B::float_to_device(tensor.clone(), &dest_device);
|
||||
output.insert(dest, tensor);
|
||||
}
|
||||
output.insert(central, tensor);
|
||||
|
||||
output
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
mod centralized;
|
||||
mod op;
|
||||
mod tree;
|
||||
|
||||
pub(crate) use centralized::*;
|
||||
pub(crate) use op::*;
|
||||
pub(crate) use tree::*;
|
||||
@@ -0,0 +1,172 @@
|
||||
use crate::local::tensor_map::{CollectiveTensorMap, PeerDeviceMap};
|
||||
use crate::{
|
||||
BroadcastStrategy, CollectiveConfig, CollectiveError, PeerId,
|
||||
local::{broadcast_centralized, broadcast_tree},
|
||||
node::base::Node,
|
||||
};
|
||||
use burn_communication::Protocol;
|
||||
#[allow(unused_imports)] // TensorMetadata is used by tracing::instrument.
|
||||
use burn_tensor::TensorMetadata;
|
||||
use burn_tensor::backend::Backend;
|
||||
use std::sync::mpsc::SyncSender;
|
||||
|
||||
/// An on-going broadcast operation
|
||||
pub struct BroadcastOp<B: Backend> {
|
||||
/// broadcast calls, one for each calling device
|
||||
calls: Vec<BroadcastOpCall<B>>,
|
||||
/// The tensor to broadcast, as defined by the root. Should be defined before all
|
||||
/// peers call the operation.
|
||||
tensor: Option<B::FloatTensorPrimitive>,
|
||||
|
||||
/// ID of the root (or use the first call's peer).
|
||||
root: Option<PeerId>,
|
||||
}
|
||||
|
||||
/// Struct for each device that calls an broadcast operation
|
||||
pub struct BroadcastOpCall<B: Backend> {
|
||||
/// Id of the caller of the operation
|
||||
caller: PeerId,
|
||||
/// Device of the calling peer
|
||||
device: B::Device,
|
||||
/// Callback for the result of the broadcast
|
||||
result_sender: SyncSender<BroadcastResult<B::FloatTensorPrimitive>>,
|
||||
}
|
||||
|
||||
/// Type sent to the collective client upon completion of a broadcast op
|
||||
pub(crate) type BroadcastResult<T> = Result<T, CollectiveError>;
|
||||
|
||||
impl<B: Backend> BroadcastOp<B> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
calls: vec![],
|
||||
tensor: None,
|
||||
root: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the effective root of the broadcast operation.
|
||||
/// If the root is set, return it. Otherwise, return the first caller's peer.
|
||||
pub fn effective_root(&self) -> PeerId {
|
||||
self.root.unwrap_or(self.calls.first().unwrap().caller)
|
||||
}
|
||||
|
||||
pub fn peers(&self) -> Vec<PeerId> {
|
||||
self.calls.iter().map(|c| c.caller).collect()
|
||||
}
|
||||
|
||||
fn peer_devices(&self) -> PeerDeviceMap<B> {
|
||||
self.calls
|
||||
.iter()
|
||||
.map(|call| (call.caller, call.device.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Register a call to reduce in this operation.
|
||||
/// When the last caller registers a reduce, the operation is executed.
|
||||
pub fn register_call(
|
||||
&mut self,
|
||||
caller: PeerId,
|
||||
input: Option<B::FloatTensorPrimitive>,
|
||||
result_sender: SyncSender<BroadcastResult<B::FloatTensorPrimitive>>,
|
||||
device: B::Device,
|
||||
peer_count: usize,
|
||||
) -> Result<bool, CollectiveError> {
|
||||
if input.is_some() {
|
||||
if self.tensor.is_some() {
|
||||
return Err(CollectiveError::BroadcastMultipleTensors);
|
||||
}
|
||||
self.tensor = input;
|
||||
}
|
||||
|
||||
self.calls.push(BroadcastOpCall {
|
||||
caller,
|
||||
device,
|
||||
result_sender,
|
||||
});
|
||||
|
||||
Ok(self.calls.len() == peer_count)
|
||||
}
|
||||
|
||||
/// Runs the broadcast if the operation is ready. Otherwise, do nothing
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(
|
||||
level="trace",
|
||||
skip(self, config, global_client),
|
||||
fields(
|
||||
self.peers = ?self.peers(),
|
||||
self.shape = ?self.tensor.as_ref().map(|t| t.shape()),
|
||||
self.dtype = ?self.tensor.as_ref().map(|t| t.dtype()),
|
||||
)
|
||||
))]
|
||||
pub async fn execute<P: Protocol>(
|
||||
mut self,
|
||||
config: &CollectiveConfig,
|
||||
global_client: &mut Option<Node<B, P>>,
|
||||
) {
|
||||
// all registered callers have sent a tensor to aggregate
|
||||
match self.broadcast(config, global_client).await {
|
||||
Ok(mut tensors) => {
|
||||
// Return resulting tensors
|
||||
self.calls.iter().for_each(|call| {
|
||||
let result = tensors
|
||||
.remove(&call.caller)
|
||||
.expect("tensor/peer internal mismatch.");
|
||||
call.result_sender.send(Ok(result)).unwrap();
|
||||
});
|
||||
assert_eq!(tensors.len(), 0, "tensor/peer internal mismatch.");
|
||||
}
|
||||
Err(err) => {
|
||||
// Send error to all subscribers
|
||||
self.fail(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip(self, config, global_client))
|
||||
)]
|
||||
async fn broadcast<P: Protocol>(
|
||||
&mut self,
|
||||
config: &CollectiveConfig,
|
||||
global_client: &mut Option<Node<B, P>>,
|
||||
) -> Result<CollectiveTensorMap<B>, CollectiveError> {
|
||||
// Do broadcast on global level with the main tensor
|
||||
if let Some(global_client) = &global_client {
|
||||
let strategy = config
|
||||
.global_broadcast_strategy
|
||||
.expect("global_broadcast_strategy not defined");
|
||||
|
||||
self.tensor = Some(
|
||||
global_client
|
||||
.broadcast(self.tensor.clone(), strategy)
|
||||
.await
|
||||
.map_err(CollectiveError::Global)?,
|
||||
)
|
||||
}
|
||||
|
||||
// At this point tensor must be defined
|
||||
let Some(tensor) = self.tensor.take() else {
|
||||
return Err(CollectiveError::BroadcastNoTensor);
|
||||
};
|
||||
|
||||
let root = self.effective_root();
|
||||
let peer_devices = self.peer_devices();
|
||||
|
||||
// Broadcast locally
|
||||
Ok(match config.local_broadcast_strategy {
|
||||
BroadcastStrategy::Tree(arity) => {
|
||||
broadcast_tree::<B>(peer_devices, root, tensor, arity)
|
||||
}
|
||||
BroadcastStrategy::Centralized => {
|
||||
broadcast_centralized::<B>(peer_devices, root, tensor)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Send a collective error as result to operation caller
|
||||
pub fn fail(self, err: CollectiveError) {
|
||||
self.calls.iter().for_each(|call| {
|
||||
call.result_sender.send(Err(err.clone())).unwrap();
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
use burn_tensor::backend::{Backend, DeviceOps};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::PeerId;
|
||||
use crate::local::tensor_map::{CollectiveTensorMap, PeerDeviceMap};
|
||||
|
||||
/// Performs a broadcast on the provided tensors in a b-tree structure with `arity`.
|
||||
///
|
||||
/// Tensor must be on the device in the `devices` map corresponding to the `root` key.
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip(devices, tensor))
|
||||
)]
|
||||
pub(crate) fn broadcast_tree<B: Backend>(
|
||||
mut devices: PeerDeviceMap<B>,
|
||||
root: PeerId,
|
||||
tensor: B::FloatTensorPrimitive,
|
||||
arity: u32,
|
||||
) -> CollectiveTensorMap<B> {
|
||||
// Convert hash map to vector of key-value pairs because order matters
|
||||
let mut devices_vec = vec![];
|
||||
let root_device = devices.remove(&root).unwrap();
|
||||
for (id, tensor) in devices.drain() {
|
||||
devices_vec.push((id, tensor));
|
||||
}
|
||||
|
||||
// Sort to put devices of the same type together
|
||||
devices_vec.sort_by(|a, b| {
|
||||
let dev_a = &a.1;
|
||||
let dev_b = &b.1;
|
||||
dev_a.id().cmp(&dev_b.id())
|
||||
});
|
||||
|
||||
// put the root first
|
||||
devices_vec.insert(0, (root, root_device));
|
||||
|
||||
// Recursive broadcast
|
||||
let out = broadcast_tree_inner::<B>(tensor, devices_vec, arity);
|
||||
|
||||
// put results in a hash map
|
||||
let mut tensors = HashMap::new();
|
||||
for (id, tensor) in out {
|
||||
tensors.insert(id, tensor);
|
||||
}
|
||||
|
||||
tensors
|
||||
}
|
||||
|
||||
/// Recursive function that broadcasts tensor across the other devices. Tensor should be on the
|
||||
/// first device of the list
|
||||
///
|
||||
/// Broadcasts the tensor across the devices in the tree in a pre-order traversal.
|
||||
fn broadcast_tree_inner<B: Backend>(
|
||||
tensor: B::FloatTensorPrimitive,
|
||||
mut all_devices: Vec<(PeerId, B::Device)>,
|
||||
arity: u32,
|
||||
) -> Vec<(PeerId, B::FloatTensorPrimitive)> {
|
||||
let mut parents = vec![];
|
||||
let mut children_groups = vec![];
|
||||
|
||||
// Put devices in groups of `arity` + the parent
|
||||
while !all_devices.is_empty() {
|
||||
let mut children = vec![];
|
||||
let parent = all_devices.remove(0);
|
||||
|
||||
for _ in 0..arity {
|
||||
if all_devices.is_empty() {
|
||||
break;
|
||||
}
|
||||
children.push(all_devices.remove(0));
|
||||
}
|
||||
|
||||
parents.push(parent);
|
||||
children_groups.push(children);
|
||||
}
|
||||
|
||||
let mut parents = if parents.len() > 1 {
|
||||
broadcast_tree_inner::<B>(tensor, parents, arity)
|
||||
} else {
|
||||
let root = parents.first().unwrap();
|
||||
// `tensor` should already be on the root's device, no need to call B::float_to_device
|
||||
vec![(root.0, tensor)]
|
||||
};
|
||||
|
||||
// Redistribute result from each parent to the respective devices
|
||||
let mut tensors = vec![];
|
||||
for children in children_groups {
|
||||
let parent = parents.remove(0);
|
||||
for (child_id, child_device) in children {
|
||||
// replace child's tensor with parent's
|
||||
let child_tensor = B::float_to_device(parent.1.clone(), &child_device);
|
||||
tensors.push((child_id, child_tensor));
|
||||
}
|
||||
tensors.push(parent);
|
||||
}
|
||||
|
||||
tensors
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
use crate::local::all_reduce::AllReduceResult;
|
||||
use crate::{
|
||||
CollectiveConfig, CollectiveError, PeerId, ReduceOperation,
|
||||
local::{
|
||||
BroadcastResult, ReduceResult,
|
||||
server::{FinishResult, Message, RegisterResult},
|
||||
},
|
||||
};
|
||||
use burn_tensor::backend::Backend;
|
||||
use std::sync::mpsc::{Receiver, SyncSender};
|
||||
|
||||
/// Local client to communicate with the local server. Each thread has a client.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct LocalCollectiveClient<B: Backend> {
|
||||
pub channel: SyncSender<Message<B>>,
|
||||
}
|
||||
|
||||
/// A pending operation that can be waited on.
|
||||
pub(crate) struct PendingCollectiveOperation<T> {
|
||||
rx: Receiver<Result<T, CollectiveError>>,
|
||||
}
|
||||
|
||||
impl<T> From<PendingCollectiveOperation<T>> for Receiver<Result<T, CollectiveError>> {
|
||||
fn from(value: PendingCollectiveOperation<T>) -> Self {
|
||||
value.rx
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> PendingCollectiveOperation<T> {
|
||||
/// Wait on the operation.
|
||||
///
|
||||
/// Given a `Receiver<Result<T, CollectiveError>>`, this function will wait:
|
||||
/// - Unwraps `Ok(Result<T, CollectiveError>)` into `Result<T, CollectiveError>`;
|
||||
/// - maps `Err(RecvError)` to `Err(CollectiveError::LocalServerMissing)`.
|
||||
pub(crate) fn wait(self) -> Result<T, CollectiveError> {
|
||||
let tensor = self
|
||||
.rx
|
||||
.recv()
|
||||
.unwrap_or(Err(CollectiveError::LocalServerMissing))?;
|
||||
|
||||
Ok(tensor)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> LocalCollectiveClient<B> {
|
||||
/// Common logic for starting a collective operation.
|
||||
///
|
||||
/// - Allocates `(callback, recv)` channels,
|
||||
/// - Passes the `callback` to the `Message<B>` builder,
|
||||
/// - Sends the message through the collective channel,
|
||||
/// - Returns the `recv`.
|
||||
pub(crate) fn start_operation<T, F>(&self, builder: F) -> PendingCollectiveOperation<T>
|
||||
where
|
||||
F: FnOnce(SyncSender<Result<T, CollectiveError>>) -> Message<B>,
|
||||
{
|
||||
let (tx, rx) = std::sync::mpsc::sync_channel(1);
|
||||
self.channel.send((builder)(tx)).unwrap();
|
||||
PendingCollectiveOperation { rx }
|
||||
}
|
||||
|
||||
/// Common logic for starting a collective operation, with validation.
|
||||
///
|
||||
/// When `valid` is `Err`, this function returns a `Receiver<Result<T, CollectiveError>>` that
|
||||
/// immediately returns `Err(valid)`;
|
||||
/// otherwise, it behaves like [`LocalCollectiveClient::start_operation`].
|
||||
pub(crate) fn start_valid_operation<T, F>(
|
||||
&self,
|
||||
valid: Result<(), CollectiveError>,
|
||||
builder: F,
|
||||
) -> PendingCollectiveOperation<T>
|
||||
where
|
||||
F: FnOnce(SyncSender<Result<T, CollectiveError>>) -> Message<B>,
|
||||
{
|
||||
match valid {
|
||||
Err(e) => {
|
||||
let (tx, rx) = std::sync::mpsc::sync_channel(1);
|
||||
tx.send(Err(e)).unwrap();
|
||||
PendingCollectiveOperation { rx }
|
||||
}
|
||||
_ => self.start_operation(builder),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn reset(&self) {
|
||||
self.channel.send(Message::Reset).unwrap();
|
||||
}
|
||||
|
||||
pub(crate) fn register(
|
||||
&mut self,
|
||||
id: PeerId,
|
||||
device: B::Device,
|
||||
config: CollectiveConfig,
|
||||
) -> RegisterResult {
|
||||
self.register_start(id, device, config).wait()
|
||||
}
|
||||
|
||||
pub(crate) fn register_start(
|
||||
&mut self,
|
||||
id: PeerId,
|
||||
device: B::Device,
|
||||
config: CollectiveConfig,
|
||||
) -> PendingCollectiveOperation<()> {
|
||||
self.start_valid_operation(
|
||||
match config.is_valid() {
|
||||
true => Ok(()),
|
||||
false => Err(CollectiveError::InvalidConfig),
|
||||
},
|
||||
|callback| Message::Register {
|
||||
device_id: id,
|
||||
device,
|
||||
config,
|
||||
callback,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Calls for an all-reduce operation with the given parameters and returns the result.
|
||||
/// The `params` must be the same as the parameters passed by the other nodes.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `id` - The peer id of the caller
|
||||
/// * `tensor` - The input tensor to reduce with the peers' tensors
|
||||
/// * `config` - Config of the collective operation. Must be coherent with the other calls.
|
||||
///
|
||||
/// # Result
|
||||
/// - `Ok(tensor)` if the operation was successful
|
||||
/// - `Err(CollectiveError)` on error.
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip(self, tensor))
|
||||
)]
|
||||
pub fn all_reduce(
|
||||
&self,
|
||||
id: PeerId,
|
||||
tensor: B::FloatTensorPrimitive,
|
||||
op: ReduceOperation,
|
||||
) -> AllReduceResult<B::FloatTensorPrimitive> {
|
||||
self.all_reduce_start(id, tensor, op).wait()
|
||||
}
|
||||
|
||||
/// Starts an all-reduce operation with the given parameters.
|
||||
///
|
||||
/// The `params` must be the same as the parameters passed by the other nodes.
|
||||
///
|
||||
/// This receiver can be waited on using [`LocalCollectiveClient::operation_wait`].
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `id` - The peer id of the caller
|
||||
/// * `tensor` - The input tensor to reduce with the peers' tensors
|
||||
/// * `config` - Config of the collective operation. Must be coherent with the other calls.
|
||||
///
|
||||
/// # Result
|
||||
///
|
||||
/// A `Receiver<>` that will yield:
|
||||
/// - `Ok(AllReduceResult<B::FloatTensorPrimitive>)` if the operation was successful
|
||||
/// - `Err(SendError)` if the channel was dropped.
|
||||
pub(crate) fn all_reduce_start(
|
||||
&self,
|
||||
id: PeerId,
|
||||
tensor: B::FloatTensorPrimitive,
|
||||
op: ReduceOperation,
|
||||
) -> PendingCollectiveOperation<B::FloatTensorPrimitive> {
|
||||
self.start_operation(|callback| Message::AllReduce {
|
||||
device_id: id,
|
||||
tensor,
|
||||
op,
|
||||
callback,
|
||||
})
|
||||
}
|
||||
|
||||
/// Reduces a tensor onto one device.
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `id` - The peer id of the caller.
|
||||
/// - `tensor` - The tensor to send as input.
|
||||
/// - `op` - The reduce operation to apply.
|
||||
/// - `root` - The ID of the peer that will receive the result.
|
||||
///
|
||||
/// Returns Ok(None) if the root tensor is not the caller. Otherwise, returns the reduced tensor.
|
||||
pub fn reduce(
|
||||
&self,
|
||||
id: PeerId,
|
||||
tensor: B::FloatTensorPrimitive,
|
||||
op: ReduceOperation,
|
||||
root: PeerId,
|
||||
) -> ReduceResult<B::FloatTensorPrimitive> {
|
||||
self.reduce_start(id, tensor, op, root).wait()
|
||||
}
|
||||
|
||||
/// Starts a reduce operation on a tensor onto one device.
|
||||
///
|
||||
/// This receiver can be waited on using [`LocalCollectiveClient::operation_wait`].
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `id` - The peer id of the caller.
|
||||
/// - `tensor` - The tensor to send as input.
|
||||
/// - `op` - The reduce operation to apply.
|
||||
/// - `root` - The ID of the peer that will receive the result.
|
||||
///
|
||||
/// # Result
|
||||
///
|
||||
/// A `Receiver<>` that will yield:
|
||||
/// - `Ok(ReduceResult<B::FloatTensorPrimitive>)` if the operation was successful
|
||||
/// - `Err(SendError)` if the channel was dropped.
|
||||
pub(crate) fn reduce_start(
|
||||
&self,
|
||||
id: PeerId,
|
||||
tensor: B::FloatTensorPrimitive,
|
||||
op: ReduceOperation,
|
||||
root: PeerId,
|
||||
) -> PendingCollectiveOperation<Option<B::FloatTensorPrimitive>> {
|
||||
self.start_operation(|callback| Message::Reduce {
|
||||
device_id: id,
|
||||
tensor,
|
||||
op,
|
||||
root,
|
||||
callback,
|
||||
})
|
||||
}
|
||||
|
||||
/// Broadcasts, or receives a broadcasted tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `id` - The peer id of the caller
|
||||
/// - `tensor` - If defined, this tensor will be broadcasted.
|
||||
/// Otherwise, this call will receive the broadcasted tensor.
|
||||
///
|
||||
/// # Result
|
||||
/// Synchronously waits on the broadcasted tensor.
|
||||
pub fn broadcast(
|
||||
&self,
|
||||
id: PeerId,
|
||||
tensor: Option<B::FloatTensorPrimitive>,
|
||||
) -> BroadcastResult<B::FloatTensorPrimitive> {
|
||||
self.broadcast_start(id, tensor).wait()
|
||||
}
|
||||
|
||||
/// Starts a Broadcast, or receives a broadcasted tensor.
|
||||
///
|
||||
/// This receiver can be waited on using [`LocalCollectiveClient::operation_wait`].
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `id` - The peer id of the caller
|
||||
/// - `tensor` - If defined, this tensor will be broadcasted. Otherwise, this call will receive
|
||||
/// the broadcasted tensor.
|
||||
///
|
||||
/// # Result
|
||||
///
|
||||
/// A `Receiver<>` that will yield:
|
||||
/// - `Ok(BroadcastResult<B::FloatTensorPrimitive>)` if the operation was successful
|
||||
/// - `Err(SendError)` if the channel was dropped.
|
||||
pub(crate) fn broadcast_start(
|
||||
&self,
|
||||
id: PeerId,
|
||||
tensor: Option<B::FloatTensorPrimitive>,
|
||||
) -> PendingCollectiveOperation<B::FloatTensorPrimitive> {
|
||||
self.start_operation(|callback| Message::Broadcast {
|
||||
device_id: id,
|
||||
tensor,
|
||||
callback,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn finish(&self, id: PeerId) -> FinishResult {
|
||||
self.finish_start(id).wait()
|
||||
}
|
||||
|
||||
pub(crate) fn finish_start(&self, id: PeerId) -> PendingCollectiveOperation<()> {
|
||||
self.start_operation(|callback| Message::Finish { id, callback })
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
mod all_reduce;
|
||||
mod broadcast;
|
||||
mod reduce;
|
||||
|
||||
pub(crate) mod tensor_map;
|
||||
|
||||
pub(crate) use all_reduce::*;
|
||||
pub(crate) use broadcast::*;
|
||||
pub(crate) use reduce::*;
|
||||
|
||||
pub(crate) mod client;
|
||||
pub(crate) mod server;
|
||||
@@ -0,0 +1,30 @@
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
use crate::PeerId;
|
||||
use crate::local::tensor_map::CollectiveTensorMap;
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
use crate::local::tensor_map::get_common_shape;
|
||||
|
||||
/// Sums the tensors on one device and returns the result
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(
|
||||
level="trace",
|
||||
skip(tensors),
|
||||
fields(shape = ?get_common_shape::<B>(&tensors).unwrap())
|
||||
))]
|
||||
pub(crate) fn reduce_sum_centralized<B: Backend>(
|
||||
mut tensors: CollectiveTensorMap<B>,
|
||||
central: &PeerId,
|
||||
) -> B::FloatTensorPrimitive {
|
||||
let mut central_tensor = tensors
|
||||
.remove(central)
|
||||
.expect("Source device id is in the map");
|
||||
let central_device = B::float_device(¢ral_tensor);
|
||||
|
||||
for (_, tensor) in tensors {
|
||||
let rhs = B::float_to_device(tensor.clone(), ¢ral_device);
|
||||
central_tensor = B::float_add(central_tensor, rhs);
|
||||
}
|
||||
|
||||
central_tensor
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
mod centralized;
|
||||
mod op;
|
||||
mod tree;
|
||||
|
||||
pub(crate) use centralized::*;
|
||||
pub(crate) use op::*;
|
||||
pub(crate) use tree::*;
|
||||
@@ -0,0 +1,163 @@
|
||||
use burn_communication::Protocol;
|
||||
use burn_tensor::{Shape, TensorMetadata, backend::Backend};
|
||||
use std::sync::mpsc::SyncSender;
|
||||
|
||||
use crate::{
|
||||
CollectiveConfig, CollectiveError, PeerId, ReduceOperation, ReduceStrategy,
|
||||
local::{reduce_sum_centralized, reduce_sum_tree},
|
||||
node::base::Node,
|
||||
};
|
||||
|
||||
/// An on-going reduce operation
|
||||
pub struct ReduceOp<B: Backend> {
|
||||
/// reduce calls, one for each calling device
|
||||
calls: Vec<ReduceOpCall<B>>,
|
||||
/// The reduce operation, as defined by the first caller
|
||||
op: ReduceOperation,
|
||||
/// The peer that receives the reduce result, as defined by the first caller
|
||||
root: PeerId,
|
||||
/// The shape of the tensor to reduce, as defined by the first caller
|
||||
shape: Shape,
|
||||
}
|
||||
|
||||
/// Struct for each device that calls an reduce operation
|
||||
pub struct ReduceOpCall<B: Backend> {
|
||||
/// Id of the caller of the operation
|
||||
caller: PeerId,
|
||||
/// The tensor primitive passed as input
|
||||
input: B::FloatTensorPrimitive,
|
||||
/// Callback for the result of the reduce
|
||||
result_sender: SyncSender<ReduceResult<B::FloatTensorPrimitive>>,
|
||||
}
|
||||
|
||||
/// Type sent to the collective client upon completion of a reduce aggregation
|
||||
pub(crate) type ReduceResult<T> = Result<Option<T>, CollectiveError>;
|
||||
|
||||
impl<B: Backend> ReduceOp<B> {
|
||||
pub fn new(shape: Shape, reduce_op: ReduceOperation, root: PeerId) -> Self {
|
||||
Self {
|
||||
calls: vec![],
|
||||
op: reduce_op,
|
||||
root,
|
||||
shape,
|
||||
}
|
||||
}
|
||||
|
||||
fn peers(&self) -> Vec<PeerId> {
|
||||
self.calls.iter().map(|c| c.caller).collect()
|
||||
}
|
||||
|
||||
/// Register a call to reduce in this operation.
|
||||
/// When the last caller registers a reduce, the operation is executed.
|
||||
pub fn register_call(
|
||||
&mut self,
|
||||
caller: PeerId,
|
||||
input: B::FloatTensorPrimitive,
|
||||
result_sender: SyncSender<ReduceResult<B::FloatTensorPrimitive>>,
|
||||
op: ReduceOperation,
|
||||
root: PeerId,
|
||||
peer_count: usize,
|
||||
) -> Result<bool, CollectiveError> {
|
||||
if self.shape != input.shape() {
|
||||
return Err(CollectiveError::ReduceShapeMismatch);
|
||||
}
|
||||
if self.op != op {
|
||||
return Err(CollectiveError::ReduceOperationMismatch);
|
||||
}
|
||||
if self.root != root {
|
||||
return Err(CollectiveError::ReduceRootMismatch);
|
||||
}
|
||||
|
||||
self.calls.push(ReduceOpCall {
|
||||
caller,
|
||||
input,
|
||||
result_sender,
|
||||
});
|
||||
|
||||
Ok(self.calls.len() == peer_count)
|
||||
}
|
||||
|
||||
/// Runs the all-reduce if the operation is ready. Otherwise, do nothing
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(
|
||||
level="trace",
|
||||
skip(self, config, global_client),
|
||||
fields(
|
||||
?self.op,
|
||||
?self.shape,
|
||||
self.peers = ?self.peers(),
|
||||
)
|
||||
))]
|
||||
pub async fn execute<P: Protocol>(
|
||||
mut self,
|
||||
root: PeerId,
|
||||
config: &CollectiveConfig,
|
||||
global_client: &mut Option<Node<B, P>>,
|
||||
) {
|
||||
match self.reduce(config, global_client).await {
|
||||
Ok(mut result) => {
|
||||
// Return resulting tensor to root, None to others
|
||||
self.calls.iter().for_each(|op| {
|
||||
let msg = if op.caller == root {
|
||||
Ok(result.take())
|
||||
} else {
|
||||
Ok(None)
|
||||
};
|
||||
op.result_sender.send(msg).unwrap();
|
||||
});
|
||||
}
|
||||
Err(err) => {
|
||||
self.fail(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip(self, config, global_client))
|
||||
)]
|
||||
async fn reduce<P: Protocol>(
|
||||
&mut self,
|
||||
config: &CollectiveConfig,
|
||||
global_client: &mut Option<Node<B, P>>,
|
||||
) -> Result<Option<B::FloatTensorPrimitive>, CollectiveError> {
|
||||
let tensors = self
|
||||
.calls
|
||||
.iter()
|
||||
.map(|call| (call.caller, call.input.clone()))
|
||||
.collect();
|
||||
|
||||
// For Centralized and Tree, we only need to do a reduce here, we'll do a broadcast later
|
||||
let mut local_sum = match config.local_reduce_strategy {
|
||||
ReduceStrategy::Centralized => reduce_sum_centralized::<B>(tensors, &self.root),
|
||||
ReduceStrategy::Tree(arity) => reduce_sum_tree::<B>(tensors, &self.root, arity),
|
||||
};
|
||||
|
||||
// Do aggregation on a global level with the main tensor
|
||||
let result = if let Some(global_client) = global_client {
|
||||
let strategy = config
|
||||
.global_reduce_strategy
|
||||
.expect("global_reduce_strategy not defined");
|
||||
|
||||
global_client
|
||||
.reduce(local_sum, strategy, self.root, self.op)
|
||||
.await
|
||||
.map_err(CollectiveError::Global)?
|
||||
} else {
|
||||
// Mean division locally
|
||||
if self.op == ReduceOperation::Mean {
|
||||
let local_tensor_count = self.calls.len() as f32;
|
||||
local_sum = B::float_div_scalar(local_sum, local_tensor_count.into())
|
||||
}
|
||||
Some(local_sum)
|
||||
};
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Send a collective error as result to operation caller
|
||||
pub fn fail(self, err: CollectiveError) {
|
||||
self.calls.iter().for_each(|op| {
|
||||
op.result_sender.send(Err(err.clone())).unwrap();
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
use crate::PeerId;
|
||||
use crate::local::tensor_map::CollectiveTensorMap;
|
||||
use burn_tensor::backend::{Backend, DeviceOps};
|
||||
|
||||
/// Performs a reduce on the provided tensors in a b-tree structure with `arity`.
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip(tensors))
|
||||
)]
|
||||
pub(crate) fn reduce_sum_tree<B: Backend>(
|
||||
mut tensors: CollectiveTensorMap<B>,
|
||||
root: &PeerId,
|
||||
arity: u32,
|
||||
) -> B::FloatTensorPrimitive {
|
||||
// Convert hash map to vector of key-value pairs because order matters
|
||||
let mut input = vec![];
|
||||
let root_tensor = tensors.remove(root).unwrap();
|
||||
for (_, tensor) in tensors.drain() {
|
||||
input.push(tensor);
|
||||
}
|
||||
|
||||
// Sort to put devices of the same type together
|
||||
input.sort_by(|a, b| {
|
||||
let dev_a = B::float_device(a);
|
||||
let dev_b = B::float_device(b);
|
||||
dev_a.id().cmp(&dev_b.id())
|
||||
});
|
||||
|
||||
// put the root first
|
||||
input.insert(0, root_tensor);
|
||||
|
||||
reduce_sum_tree_inner::<B>(input, arity)
|
||||
}
|
||||
|
||||
/// Recursive function that sums `tensors`
|
||||
///
|
||||
/// Traverses `tensors` and reduces in a post-order traversal. The first tensor in the list is
|
||||
/// chosen as the root
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip(tensors))
|
||||
)]
|
||||
fn reduce_sum_tree_inner<B: Backend>(
|
||||
mut tensors: Vec<B::FloatTensorPrimitive>,
|
||||
arity: u32,
|
||||
) -> B::FloatTensorPrimitive {
|
||||
let mut parents = vec![];
|
||||
let mut children_groups = vec![];
|
||||
|
||||
// Sum tensors in groups of `arity` + 1
|
||||
while !tensors.is_empty() {
|
||||
let mut children = vec![];
|
||||
let mut parent_tensor = tensors.remove(0);
|
||||
let parent_device = B::float_device(&parent_tensor);
|
||||
|
||||
for _ in 0..arity {
|
||||
if tensors.is_empty() {
|
||||
break;
|
||||
}
|
||||
let child_tensor = tensors.remove(0);
|
||||
children.push(B::float_device(&child_tensor));
|
||||
let rhs = B::float_to_device(child_tensor, &parent_device);
|
||||
parent_tensor = B::float_add(parent_tensor, rhs);
|
||||
}
|
||||
|
||||
parents.push(parent_tensor);
|
||||
children_groups.push(children);
|
||||
}
|
||||
|
||||
if parents.len() > 1 {
|
||||
// Parents are not yet at the root, do the upper part of the tree
|
||||
reduce_sum_tree_inner::<B>(parents, arity)
|
||||
} else {
|
||||
// Root of tree
|
||||
parents.remove(0)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,495 @@
|
||||
use crate::{
|
||||
CollectiveConfig, CollectiveError, PeerId, ReduceOperation,
|
||||
global::node::base::Node,
|
||||
local::{
|
||||
AllReduceOp, AllReduceResult, BroadcastOp, BroadcastResult, ReduceOp, ReduceResult,
|
||||
client::LocalCollectiveClient,
|
||||
},
|
||||
};
|
||||
use burn_communication::websocket::{WebSocket, WsServer};
|
||||
use burn_tensor::{TensorMetadata, backend::Backend};
|
||||
use std::sync::{MutexGuard, OnceLock};
|
||||
use std::{
|
||||
any::{Any, TypeId},
|
||||
collections::HashMap,
|
||||
fmt::Debug,
|
||||
sync::{
|
||||
Arc, Mutex,
|
||||
mpsc::{Receiver, SyncSender},
|
||||
},
|
||||
};
|
||||
use tokio::runtime::{Builder, Runtime};
|
||||
|
||||
/// Define the client/server communication on the network
|
||||
type Network = WebSocket;
|
||||
/// Type sent to the collective client upon completion of a register request
|
||||
pub(crate) type RegisterResult = Result<(), CollectiveError>;
|
||||
/// Type sent to the collective client upon completion of a finish request
|
||||
pub(crate) type FinishResult = Result<(), CollectiveError>;
|
||||
|
||||
/// The local collective server that manages all the collective aggregation operations
|
||||
/// (like all-reduce) between local threads.
|
||||
/// This thread takes in messages from different clients. The clients must register, than they can
|
||||
/// send an aggregate message. They must all use the same parameters for the same aggregate
|
||||
/// operation.
|
||||
pub(crate) struct LocalCollectiveServer<B: Backend> {
|
||||
/// Channel receiver for messages from clients
|
||||
message_rec: Receiver<Message<B>>,
|
||||
|
||||
/// The collective configuration. Must be the same by every peer when calling register
|
||||
config: Option<CollectiveConfig>,
|
||||
|
||||
/// The ids passed to each register so far
|
||||
peers: Vec<PeerId>,
|
||||
|
||||
/// Callbacks for when all registers are done
|
||||
callbacks_register: Vec<SyncSender<RegisterResult>>,
|
||||
|
||||
/// Map of each peer's id and its device
|
||||
devices: HashMap<PeerId, B::Device>,
|
||||
|
||||
/// Current uncompleted all-reduce operation
|
||||
all_reduce_op: Option<AllReduceOp<B>>,
|
||||
|
||||
/// Current uncompleted reduce call
|
||||
reduce_op: Option<ReduceOp<B>>,
|
||||
|
||||
/// Uncompleted broadcast calls, one for each calling device.
|
||||
broadcast_op: Option<BroadcastOp<B>>,
|
||||
|
||||
/// Client for global collective operations
|
||||
global_client: Option<Node<B, Network>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum Message<B: Backend> {
|
||||
/// Register a new peer with the collective.
|
||||
Register {
|
||||
device_id: PeerId,
|
||||
device: B::Device,
|
||||
config: CollectiveConfig,
|
||||
callback: SyncSender<RegisterResult>,
|
||||
},
|
||||
/// Perform an all-reduce operation.
|
||||
AllReduce {
|
||||
device_id: PeerId,
|
||||
tensor: B::FloatTensorPrimitive,
|
||||
op: ReduceOperation,
|
||||
callback: SyncSender<AllReduceResult<B::FloatTensorPrimitive>>,
|
||||
},
|
||||
/// Perform a reduce operation.
|
||||
Reduce {
|
||||
device_id: PeerId,
|
||||
tensor: B::FloatTensorPrimitive,
|
||||
op: ReduceOperation,
|
||||
root: PeerId,
|
||||
callback: SyncSender<ReduceResult<B::FloatTensorPrimitive>>,
|
||||
},
|
||||
/// Perform a broadcast operation (one-sender, many-receiver).
|
||||
Broadcast {
|
||||
device_id: PeerId,
|
||||
tensor: Option<B::FloatTensorPrimitive>,
|
||||
callback: SyncSender<BroadcastResult<B::FloatTensorPrimitive>>,
|
||||
},
|
||||
/// Reset the collective server.
|
||||
Reset,
|
||||
Finish {
|
||||
id: PeerId,
|
||||
callback: SyncSender<FinishResult>,
|
||||
},
|
||||
}
|
||||
|
||||
/// The type-erased box type for [`LocalCollectiveClient<B>`].
|
||||
type LocalClientBox = Box<dyn Any + Send + Sync>;
|
||||
|
||||
/// Global state map from [`Backend`] to boxed [`LocalCollectiveClient<B>`].
|
||||
static BACKEND_CLIENT_MAP: OnceLock<Mutex<HashMap<TypeId, LocalClientBox>>> = OnceLock::new();
|
||||
|
||||
/// Gets a locked mutable view of the `STATE_MAP`.
|
||||
pub(crate) fn get_backend_client_map() -> MutexGuard<'static, HashMap<TypeId, LocalClientBox>> {
|
||||
BACKEND_CLIENT_MAP
|
||||
.get_or_init(Default::default)
|
||||
.lock()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Get a [`LocalCollectiveClient`] for the given [`Backend`].
|
||||
///
|
||||
/// Will start the local collective client/server pair if necessary.
|
||||
pub(crate) fn get_collective_client<B: Backend>() -> LocalCollectiveClient<B> {
|
||||
let typeid = TypeId::of::<B>();
|
||||
let mut state_map = get_backend_client_map();
|
||||
match state_map.get(&typeid) {
|
||||
Some(val) => val.downcast_ref().cloned().unwrap(),
|
||||
None => {
|
||||
let client = LocalCollectiveServer::<B>::setup(LocalCollectiveClientConfig::default());
|
||||
state_map.insert(typeid, Box::new(client.clone()));
|
||||
client
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Global runtime.
|
||||
static SERVER_RUNTIME: OnceLock<Arc<Runtime>> = OnceLock::new();
|
||||
|
||||
/// Get the global [`Runtime`].
|
||||
pub(crate) fn get_collective_server_runtime() -> Arc<Runtime> {
|
||||
SERVER_RUNTIME
|
||||
.get_or_init(|| {
|
||||
Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.expect("Unable to initialize runtime")
|
||||
.into()
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
|
||||
/// Configuration for the local collective client/server pair.
|
||||
pub struct LocalCollectiveClientConfig {
|
||||
/// Channel capacity for the messaging queue from client to server.
|
||||
pub channel_capacity: usize,
|
||||
}
|
||||
|
||||
impl Default for LocalCollectiveClientConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
channel_capacity: 50,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<usize> for LocalCollectiveClientConfig {
|
||||
fn from(capacity: usize) -> Self {
|
||||
Self {
|
||||
channel_capacity: capacity,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> LocalCollectiveServer<B> {
|
||||
fn new(rec: Receiver<Message<B>>) -> Self {
|
||||
Self {
|
||||
message_rec: rec,
|
||||
config: None,
|
||||
peers: vec![],
|
||||
devices: HashMap::new(),
|
||||
all_reduce_op: None,
|
||||
reduce_op: None,
|
||||
broadcast_op: None,
|
||||
callbacks_register: vec![],
|
||||
global_client: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Setup a client/server pair with the given config.
|
||||
pub(crate) fn setup<C>(cfg: C) -> LocalCollectiveClient<B>
|
||||
where
|
||||
C: Into<LocalCollectiveClientConfig>,
|
||||
{
|
||||
let cfg = cfg.into();
|
||||
let (tx, rx) = std::sync::mpsc::sync_channel(cfg.channel_capacity);
|
||||
|
||||
get_collective_server_runtime().spawn(async {
|
||||
let typeid = TypeId::of::<B>();
|
||||
log::info!("Starting server for backend: {typeid:?}");
|
||||
let mut server = LocalCollectiveServer::new(rx);
|
||||
|
||||
loop {
|
||||
match server.message_rec.recv() {
|
||||
Ok(message) => server.process_message(message).await,
|
||||
Err(err) => {
|
||||
log::error!(
|
||||
"Error receiving message from local collective server: {err:?}"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
LocalCollectiveClient { channel: tx }
|
||||
}
|
||||
|
||||
async fn process_message(&mut self, message: Message<B>) {
|
||||
match message {
|
||||
Message::Register {
|
||||
device_id,
|
||||
device,
|
||||
config,
|
||||
callback,
|
||||
} => {
|
||||
self.process_register_message(device_id, device, config, &callback)
|
||||
.await
|
||||
}
|
||||
Message::AllReduce {
|
||||
device_id,
|
||||
tensor,
|
||||
op,
|
||||
callback,
|
||||
} => {
|
||||
self.process_all_reduce_message(device_id, tensor, op, callback)
|
||||
.await
|
||||
}
|
||||
Message::Reduce {
|
||||
device_id,
|
||||
tensor,
|
||||
op,
|
||||
root,
|
||||
callback,
|
||||
} => {
|
||||
self.process_reduce_message(device_id, tensor, op, root, callback)
|
||||
.await
|
||||
}
|
||||
Message::Broadcast {
|
||||
device_id,
|
||||
tensor,
|
||||
callback,
|
||||
} => {
|
||||
self.process_broadcast_message(device_id, tensor, callback)
|
||||
.await
|
||||
}
|
||||
Message::Reset => self.reset(),
|
||||
Message::Finish { id, callback } => self.process_finish_message(id, callback).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_register_message(
|
||||
&mut self,
|
||||
device_id: PeerId,
|
||||
device: B::Device,
|
||||
config: CollectiveConfig,
|
||||
callback: &SyncSender<RegisterResult>,
|
||||
) {
|
||||
if !config.is_valid() {
|
||||
callback.send(Err(CollectiveError::InvalidConfig)).unwrap();
|
||||
return;
|
||||
}
|
||||
if self.peers.contains(&device_id) {
|
||||
callback
|
||||
.send(Err(CollectiveError::MultipleRegister))
|
||||
.unwrap();
|
||||
return;
|
||||
}
|
||||
if self.peers.is_empty() || self.config.is_none() {
|
||||
self.config = Some(config);
|
||||
} else if let Some(cfg) = &self.config
|
||||
&& *cfg != config
|
||||
{
|
||||
callback
|
||||
.send(Err(CollectiveError::RegisterParamsMismatch))
|
||||
.unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
self.peers.push(device_id);
|
||||
self.callbacks_register.push(callback.clone());
|
||||
self.devices.insert(device_id, device);
|
||||
|
||||
let config = self.config.as_ref().unwrap();
|
||||
let global_params = config.global_register_params();
|
||||
if let Some(global_params) = &global_params
|
||||
&& self.global_client.is_none()
|
||||
{
|
||||
let server = WsServer::new(global_params.data_service_port);
|
||||
let client = Node::new(&global_params.global_address, server);
|
||||
self.global_client = Some(client)
|
||||
}
|
||||
|
||||
// All have registered, callback
|
||||
if self.peers.len() == config.num_devices {
|
||||
let mut register_result = Ok(());
|
||||
|
||||
// if an error occurs on the global register, it must be passed back to every local peer
|
||||
if let Some(global_params) = global_params {
|
||||
let client = self
|
||||
.global_client
|
||||
.as_mut()
|
||||
.expect("Global client should be initialized");
|
||||
|
||||
register_result = client
|
||||
.register(self.peers.clone(), global_params)
|
||||
.await
|
||||
.map_err(CollectiveError::Global);
|
||||
};
|
||||
|
||||
// Send results to all callbacks.
|
||||
self.callbacks_register
|
||||
.drain(..)
|
||||
.for_each(|tx| tx.send(register_result.clone()).unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
/// Processes an Message::AllReduce.
|
||||
async fn process_all_reduce_message(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
tensor: <B as Backend>::FloatTensorPrimitive,
|
||||
op: ReduceOperation,
|
||||
callback: SyncSender<AllReduceResult<B::FloatTensorPrimitive>>,
|
||||
) {
|
||||
if !self.peers.contains(&peer_id) {
|
||||
callback
|
||||
.send(Err(CollectiveError::RegisterNotFirstOperation))
|
||||
.unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
if self.all_reduce_op.is_none() {
|
||||
// First call to all-reduce
|
||||
self.all_reduce_op = Some(AllReduceOp::new(tensor.shape(), op));
|
||||
}
|
||||
// Take the operation, we'll put it back if we're not done
|
||||
let mut all_reduce_op = self.all_reduce_op.take().unwrap();
|
||||
|
||||
// On the last caller, the all-reduce is done here
|
||||
let res =
|
||||
all_reduce_op.register_call(peer_id, tensor, callback.clone(), op, self.peers.len());
|
||||
|
||||
// Upon an error or the last call, the all_reduce_op is dropped
|
||||
match res {
|
||||
Ok(is_ready) => {
|
||||
if is_ready {
|
||||
all_reduce_op
|
||||
.execute(self.config.as_ref().unwrap(), &mut self.global_client)
|
||||
.await;
|
||||
} else {
|
||||
// Put operation back, we're waiting for more calls
|
||||
self.all_reduce_op = Some(all_reduce_op)
|
||||
}
|
||||
}
|
||||
Err(err) => all_reduce_op.fail(err),
|
||||
}
|
||||
}
|
||||
|
||||
/// Processes a Message::Reduce.
|
||||
async fn process_reduce_message(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
tensor: <B as Backend>::FloatTensorPrimitive,
|
||||
op: ReduceOperation,
|
||||
root: PeerId,
|
||||
callback: SyncSender<ReduceResult<B::FloatTensorPrimitive>>,
|
||||
) {
|
||||
if !self.peers.contains(&root) {
|
||||
callback
|
||||
.send(Err(CollectiveError::RegisterNotFirstOperation))
|
||||
.unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
if self.reduce_op.is_none() {
|
||||
// First call to reduce
|
||||
self.reduce_op = Some(ReduceOp::new(tensor.shape(), op, root));
|
||||
}
|
||||
let mut reduce_op = self.reduce_op.take().unwrap();
|
||||
|
||||
// On the last caller, the all-reduce is done here
|
||||
let res = reduce_op.register_call(
|
||||
peer_id,
|
||||
tensor,
|
||||
callback.clone(),
|
||||
op,
|
||||
root,
|
||||
self.peers.len(),
|
||||
);
|
||||
|
||||
// Upon an error or the last call, the all_reduce_op is dropped
|
||||
match res {
|
||||
Ok(is_ready) => {
|
||||
if is_ready {
|
||||
reduce_op
|
||||
.execute(root, self.config.as_ref().unwrap(), &mut self.global_client)
|
||||
.await;
|
||||
} else {
|
||||
// Put operation back, we're waiting for more calls
|
||||
self.reduce_op = Some(reduce_op)
|
||||
}
|
||||
}
|
||||
Err(err) => reduce_op.fail(err),
|
||||
}
|
||||
}
|
||||
|
||||
/// Processes a Message::Broadcast.
|
||||
async fn process_broadcast_message(
|
||||
&mut self,
|
||||
caller: PeerId,
|
||||
tensor: Option<<B as Backend>::FloatTensorPrimitive>,
|
||||
callback: SyncSender<BroadcastResult<B::FloatTensorPrimitive>>,
|
||||
) {
|
||||
if self.config.is_none() {
|
||||
callback
|
||||
.send(Err(CollectiveError::RegisterNotFirstOperation))
|
||||
.unwrap();
|
||||
return;
|
||||
}
|
||||
if !self.peers.contains(&caller) {
|
||||
callback
|
||||
.send(Err(CollectiveError::RegisterNotFirstOperation))
|
||||
.unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
if self.broadcast_op.is_none() {
|
||||
// First call to broadcast
|
||||
self.broadcast_op = Some(BroadcastOp::new());
|
||||
}
|
||||
let device = self.devices.get(&caller).unwrap().clone();
|
||||
|
||||
let mut broadcast_op = self.broadcast_op.take().unwrap();
|
||||
|
||||
// On the last caller, the all-reduce is done here
|
||||
let res =
|
||||
broadcast_op.register_call(caller, tensor, callback.clone(), device, self.peers.len());
|
||||
|
||||
// Upon an error or the last call, the all_reduce_op is dropped
|
||||
match res {
|
||||
Ok(is_ready) => {
|
||||
if is_ready {
|
||||
broadcast_op
|
||||
.execute(self.config.as_ref().unwrap(), &mut self.global_client)
|
||||
.await;
|
||||
} else {
|
||||
// Put operation back, we're waiting for more calls
|
||||
self.broadcast_op = Some(broadcast_op)
|
||||
}
|
||||
}
|
||||
Err(err) => broadcast_op.fail(err),
|
||||
}
|
||||
}
|
||||
|
||||
/// Reinitializes the collective server
|
||||
fn reset(&mut self) {
|
||||
self.peers.clear();
|
||||
self.all_reduce_op = None;
|
||||
self.reduce_op = None;
|
||||
self.broadcast_op = None;
|
||||
}
|
||||
|
||||
/// Processes a Message::Finish.
|
||||
async fn process_finish_message(&mut self, id: PeerId, callback: SyncSender<RegisterResult>) {
|
||||
if self.config.is_none() {
|
||||
callback
|
||||
.send(Err(CollectiveError::RegisterNotFirstOperation))
|
||||
.unwrap();
|
||||
return;
|
||||
}
|
||||
if !self.peers.contains(&id) {
|
||||
callback
|
||||
.send(Err(CollectiveError::MultipleUnregister))
|
||||
.unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
// Remove registered with id
|
||||
self.peers.retain(|x| *x != id);
|
||||
|
||||
if self.peers.is_empty()
|
||||
&& let Some(mut global_client) = self.global_client.take()
|
||||
{
|
||||
global_client.finish().await;
|
||||
}
|
||||
|
||||
callback.send(Ok(())).unwrap();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
//! # Common Tensor Map for Local Collective Operations
|
||||
use crate::PeerId;
|
||||
use burn_std::Shape;
|
||||
use burn_tensor::TensorMetadata;
|
||||
use burn_tensor::backend::Backend;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub type CollectiveTensorMap<B> = HashMap<PeerId, <B as Backend>::FloatTensorPrimitive>;
|
||||
|
||||
pub type PeerDeviceMap<B> = HashMap<PeerId, <B as Backend>::Device>;
|
||||
|
||||
/// Get the shape of the tensors. They should all have the same shape, otherwise None is returned.
|
||||
pub fn get_common_shape<B: Backend>(tensors: &CollectiveTensorMap<B>) -> Option<Shape> {
|
||||
let mut it = tensors.values();
|
||||
if let Some(first) = it.next() {
|
||||
let shape = first.shape();
|
||||
for tensor in it {
|
||||
if tensor.shape() != shape {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
return Some(shape);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Get the `{ peer_id -> device }` mapping for the given tensors.
|
||||
pub fn get_peer_devices<B: Backend>(tensors: &CollectiveTensorMap<B>) -> PeerDeviceMap<B> {
|
||||
tensors
|
||||
.iter()
|
||||
.map(|(id, tensor)| (*id, B::float_device(tensor)))
|
||||
.collect()
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
mod tests {
|
||||
use std::sync::mpsc::SyncSender;
|
||||
|
||||
use burn_std::rand::get_seeded_rng;
|
||||
use burn_tensor::{Shape, Tensor, TensorData, TensorPrimitive, Tolerance, backend::Backend};
|
||||
|
||||
use serial_test::serial;
|
||||
|
||||
#[cfg(feature = "test-ndarray")]
|
||||
pub type TestBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
#[cfg(feature = "test-cuda")]
|
||||
pub type TestBackend = burn_cuda::Cuda<f32>;
|
||||
|
||||
#[cfg(feature = "test-wgpu")]
|
||||
pub type TestBackend = burn_wgpu::Wgpu<f32>;
|
||||
|
||||
#[cfg(feature = "test-metal")]
|
||||
pub type TestBackend = burn_wgpu::Wgpu<f32>;
|
||||
|
||||
#[cfg(feature = "test-vulkan")]
|
||||
pub type TestBackend = burn_wgpu::Wgpu<f32>;
|
||||
|
||||
use crate::{
|
||||
AllReduceStrategy, CollectiveConfig, PeerId, ReduceOperation, all_reduce, register,
|
||||
reset_collective,
|
||||
};
|
||||
|
||||
pub fn run_peer<B: Backend>(
|
||||
id: PeerId,
|
||||
config: CollectiveConfig,
|
||||
input: TensorData,
|
||||
op: ReduceOperation,
|
||||
output: SyncSender<Tensor<B, 1>>,
|
||||
) {
|
||||
let device = B::Device::default();
|
||||
|
||||
register::<B>(id, device.clone(), config).unwrap();
|
||||
|
||||
let tensor = Tensor::<B, 1>::from_data(input, &device);
|
||||
|
||||
let tensor = Tensor::from_primitive(TensorPrimitive::Float(
|
||||
all_reduce::<B>(id, tensor.into_primitive().tensor(), op).unwrap(),
|
||||
));
|
||||
|
||||
output.send(tensor).unwrap();
|
||||
}
|
||||
|
||||
fn generate_random_input(
|
||||
shape: Shape,
|
||||
op: ReduceOperation,
|
||||
thread_count: usize,
|
||||
) -> (Vec<TensorData>, TensorData) {
|
||||
let input: Vec<TensorData> = (0..thread_count)
|
||||
.map(|_| {
|
||||
TensorData::random::<f32, _, _>(
|
||||
shape.clone(),
|
||||
burn_tensor::Distribution::Default,
|
||||
&mut get_seeded_rng(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
|
||||
let mut expected_tensor = Tensor::<TestBackend, 1>::zeros(shape, &device);
|
||||
for item in input.iter().take(thread_count as usize) {
|
||||
let input_tensor = Tensor::<TestBackend, 1>::from_data(item.clone(), &device);
|
||||
expected_tensor = expected_tensor.add(input_tensor);
|
||||
}
|
||||
if op == ReduceOperation::Mean {
|
||||
expected_tensor = expected_tensor.div_scalar(thread_count as u32);
|
||||
}
|
||||
|
||||
let expected = expected_tensor.to_data();
|
||||
|
||||
(input, expected)
|
||||
}
|
||||
|
||||
fn test_all_reduce<B: Backend>(
|
||||
device_count: usize,
|
||||
op: ReduceOperation,
|
||||
strategy: AllReduceStrategy,
|
||||
tensor_size: usize,
|
||||
) {
|
||||
reset_collective::<TestBackend>();
|
||||
|
||||
let (send, recv) = std::sync::mpsc::sync_channel(32);
|
||||
|
||||
let shape = Shape {
|
||||
dims: vec![tensor_size],
|
||||
};
|
||||
|
||||
let (input, expected) = generate_random_input(shape, op, device_count);
|
||||
|
||||
let config = CollectiveConfig::default()
|
||||
.with_num_devices(device_count)
|
||||
.with_local_all_reduce_strategy(strategy);
|
||||
|
||||
for id in 0..device_count {
|
||||
let send = send.clone();
|
||||
let input = input[id as usize].clone();
|
||||
|
||||
std::thread::spawn({
|
||||
let config = config.clone();
|
||||
move || run_peer::<B>(id.into(), config, input, op, send)
|
||||
});
|
||||
}
|
||||
|
||||
let first = recv.recv().unwrap().to_data();
|
||||
for _ in 1..device_count {
|
||||
let tensor = recv.recv().unwrap();
|
||||
tensor.to_data().assert_eq(&first, true);
|
||||
}
|
||||
|
||||
let tol: Tolerance<f32> = Tolerance::balanced();
|
||||
expected.assert_approx_eq(&first, tol);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_all_reduce_centralized_sum() {
|
||||
test_all_reduce::<TestBackend>(4, ReduceOperation::Sum, AllReduceStrategy::Centralized, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_all_reduce_centralized_mean() {
|
||||
test_all_reduce::<TestBackend>(4, ReduceOperation::Mean, AllReduceStrategy::Centralized, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_all_reduce_binary_tree_sum() {
|
||||
test_all_reduce::<TestBackend>(4, ReduceOperation::Sum, AllReduceStrategy::Tree(2), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_all_reduce_binary_tree_mean() {
|
||||
test_all_reduce::<TestBackend>(4, ReduceOperation::Mean, AllReduceStrategy::Tree(2), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_all_reduce_5_tree_sum() {
|
||||
test_all_reduce::<TestBackend>(4, ReduceOperation::Sum, AllReduceStrategy::Tree(5), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_all_reduce_5_tree_mean() {
|
||||
test_all_reduce::<TestBackend>(4, ReduceOperation::Mean, AllReduceStrategy::Tree(5), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_all_reduce_ring_sum() {
|
||||
test_all_reduce::<TestBackend>(3, ReduceOperation::Sum, AllReduceStrategy::Ring, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_all_reduce_ring_mean() {
|
||||
test_all_reduce::<TestBackend>(3, ReduceOperation::Mean, AllReduceStrategy::Ring, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_all_reduce_ring_irregular_sum() {
|
||||
// this should trigger the fallback algorithm when the tensor is too small.
|
||||
test_all_reduce::<TestBackend>(4, ReduceOperation::Sum, AllReduceStrategy::Ring, 3);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
mod tests {
|
||||
use std::sync::mpsc::SyncSender;
|
||||
|
||||
use burn_std::rand::get_seeded_rng;
|
||||
use burn_tensor::{Shape, Tensor, TensorData, TensorPrimitive, Tolerance, backend::Backend};
|
||||
|
||||
use serial_test::serial;
|
||||
|
||||
#[cfg(feature = "test-ndarray")]
|
||||
pub type TestBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
#[cfg(feature = "test-cuda")]
|
||||
pub type TestBackend = burn_cuda::Cuda<f32>;
|
||||
|
||||
#[cfg(feature = "test-wgpu")]
|
||||
pub type TestBackend = burn_wgpu::Wgpu<f32>;
|
||||
|
||||
#[cfg(feature = "test-metal")]
|
||||
pub type TestBackend = burn_wgpu::Wgpu<f32>;
|
||||
|
||||
#[cfg(feature = "test-vulkan")]
|
||||
pub type TestBackend = burn_wgpu::Wgpu<f32>;
|
||||
|
||||
use crate::{
|
||||
BroadcastStrategy, CollectiveConfig, PeerId, broadcast, register, reset_collective,
|
||||
};
|
||||
|
||||
pub fn run_peer<B: Backend>(
|
||||
id: PeerId,
|
||||
config: CollectiveConfig,
|
||||
input: Option<TensorData>,
|
||||
output: SyncSender<Tensor<B, 1>>,
|
||||
) {
|
||||
let device = B::Device::default();
|
||||
|
||||
register::<B>(id, device.clone(), config).unwrap();
|
||||
|
||||
let tensor = input.map(|data| B::float_from_data(data, &device));
|
||||
let tensor = broadcast::<B>(id, tensor).unwrap();
|
||||
let tensor = Tensor::<B, 1>::from_primitive(TensorPrimitive::Float(tensor));
|
||||
|
||||
output.send(tensor).unwrap();
|
||||
}
|
||||
|
||||
fn generate_random_input(shape: Shape) -> TensorData {
|
||||
TensorData::random::<f32, _, _>(
|
||||
shape.clone(),
|
||||
burn_tensor::Distribution::Default,
|
||||
&mut get_seeded_rng(),
|
||||
)
|
||||
}
|
||||
|
||||
fn test_broadcast<B: Backend>(
|
||||
device_count: usize,
|
||||
strategy: BroadcastStrategy,
|
||||
tensor_size: usize,
|
||||
) {
|
||||
reset_collective::<TestBackend>();
|
||||
|
||||
let (send, recv) = std::sync::mpsc::sync_channel(32);
|
||||
|
||||
let shape = Shape {
|
||||
dims: vec![tensor_size],
|
||||
};
|
||||
|
||||
let input = generate_random_input(shape);
|
||||
|
||||
let config = CollectiveConfig::default()
|
||||
.with_num_devices(device_count)
|
||||
.with_local_broadcast_strategy(strategy);
|
||||
|
||||
for id in 0..device_count {
|
||||
// The peer #0 is the root: it sends the tensor
|
||||
let input = if id == 0 { Some(input.clone()) } else { None };
|
||||
|
||||
std::thread::spawn({
|
||||
let config = config.clone();
|
||||
let send = send.clone();
|
||||
move || run_peer::<B>(id.into(), config, input, send)
|
||||
});
|
||||
}
|
||||
|
||||
// Expect all peers to receive the input tensor
|
||||
let tol: Tolerance<f32> = Tolerance::balanced();
|
||||
for _ in 0..device_count {
|
||||
let tensor = recv.recv().unwrap().to_data();
|
||||
input.assert_approx_eq(&tensor, tol);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_broadcast_centralized_sum() {
|
||||
test_broadcast::<TestBackend>(4, BroadcastStrategy::Centralized, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_broadcast_centralized_mean() {
|
||||
test_broadcast::<TestBackend>(4, BroadcastStrategy::Centralized, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_broadcast_binary_tree_sum() {
|
||||
test_broadcast::<TestBackend>(4, BroadcastStrategy::Tree(2), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_broadcast_binary_tree_mean() {
|
||||
test_broadcast::<TestBackend>(4, BroadcastStrategy::Tree(2), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_broadcast_5_tree_sum() {
|
||||
test_broadcast::<TestBackend>(4, BroadcastStrategy::Tree(5), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_broadcast_5_tree_mean() {
|
||||
test_broadcast::<TestBackend>(4, BroadcastStrategy::Tree(5), 4);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
mod all_reduce;
|
||||
mod broadcast;
|
||||
mod reduce;
|
||||
@@ -0,0 +1,162 @@
|
||||
mod tests {
|
||||
use std::sync::mpsc::SyncSender;
|
||||
|
||||
use burn_std::rand::get_seeded_rng;
|
||||
use burn_tensor::{Shape, Tensor, TensorData, TensorPrimitive, Tolerance, backend::Backend};
|
||||
|
||||
use serial_test::serial;
|
||||
|
||||
#[cfg(feature = "test-ndarray")]
|
||||
pub type TestBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
#[cfg(feature = "test-cuda")]
|
||||
pub type TestBackend = burn_cuda::Cuda<f32>;
|
||||
|
||||
#[cfg(feature = "test-wgpu")]
|
||||
pub type TestBackend = burn_wgpu::Wgpu<f32>;
|
||||
|
||||
#[cfg(feature = "test-metal")]
|
||||
pub type TestBackend = burn_wgpu::Wgpu<f32>;
|
||||
|
||||
#[cfg(feature = "test-vulkan")]
|
||||
pub type TestBackend = burn_wgpu::Wgpu<f32>;
|
||||
|
||||
use crate::{
|
||||
CollectiveConfig, PeerId, ReduceOperation, ReduceStrategy, reduce, register,
|
||||
reset_collective,
|
||||
};
|
||||
|
||||
pub fn run_peer<B: Backend>(
|
||||
id: PeerId,
|
||||
config: CollectiveConfig,
|
||||
input: TensorData,
|
||||
op: ReduceOperation,
|
||||
root: PeerId,
|
||||
output: SyncSender<Option<Tensor<B, 1>>>,
|
||||
) {
|
||||
let device = B::Device::default();
|
||||
|
||||
register::<B>(id, device.clone(), config).unwrap();
|
||||
|
||||
let tensor = Tensor::<B, 1>::from_data(input, &device);
|
||||
|
||||
let tensor = tensor.into_primitive().tensor();
|
||||
let tensor = reduce::<B>(id, tensor, op, root).unwrap();
|
||||
let tensor = tensor.map(|t| Tensor::<B, 1>::from_primitive(TensorPrimitive::Float(t)));
|
||||
|
||||
output.send(tensor).unwrap();
|
||||
}
|
||||
|
||||
fn generate_random_input(
|
||||
shape: Shape,
|
||||
op: ReduceOperation,
|
||||
thread_count: usize,
|
||||
) -> (Vec<TensorData>, TensorData) {
|
||||
let input: Vec<TensorData> = (0..thread_count)
|
||||
.map(|_| {
|
||||
TensorData::random::<f32, _, _>(
|
||||
shape.clone(),
|
||||
burn_tensor::Distribution::Default,
|
||||
&mut get_seeded_rng(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
|
||||
let mut expected_tensor = Tensor::<TestBackend, 1>::zeros(shape, &device);
|
||||
for item in input.iter().take(thread_count) {
|
||||
let input_tensor = Tensor::<TestBackend, 1>::from_data(item.clone(), &device);
|
||||
expected_tensor = expected_tensor.add(input_tensor);
|
||||
}
|
||||
if op == ReduceOperation::Mean {
|
||||
expected_tensor = expected_tensor.div_scalar(thread_count as u32);
|
||||
}
|
||||
|
||||
let expected = expected_tensor.to_data();
|
||||
|
||||
(input, expected)
|
||||
}
|
||||
|
||||
fn test_reduce<B: Backend>(
|
||||
device_count: usize,
|
||||
op: ReduceOperation,
|
||||
strategy: ReduceStrategy,
|
||||
tensor_size: usize,
|
||||
) {
|
||||
reset_collective::<TestBackend>();
|
||||
|
||||
let (send, recv) = std::sync::mpsc::sync_channel(32);
|
||||
|
||||
let shape = Shape {
|
||||
dims: vec![tensor_size],
|
||||
};
|
||||
|
||||
let (input, expected) = generate_random_input(shape, op, device_count);
|
||||
|
||||
let config = CollectiveConfig::default()
|
||||
.with_num_devices(device_count)
|
||||
.with_local_reduce_strategy(strategy);
|
||||
|
||||
let root: PeerId = 0.into();
|
||||
for id in 0..device_count {
|
||||
let send = send.clone();
|
||||
let input = input[id as usize].clone();
|
||||
|
||||
std::thread::spawn({
|
||||
let config = config.clone();
|
||||
move || run_peer::<B>(id.into(), config, input, op, root, send)
|
||||
});
|
||||
}
|
||||
|
||||
let mut result = None;
|
||||
for _ in 0..device_count {
|
||||
let tensor = recv.recv().unwrap();
|
||||
if tensor.is_some() {
|
||||
if result.is_some() {
|
||||
panic!("Two peers received the result of an reduce!");
|
||||
}
|
||||
result = tensor.map(|t| t.to_data());
|
||||
}
|
||||
}
|
||||
|
||||
let tol: Tolerance<f32> = Tolerance::balanced();
|
||||
expected.assert_approx_eq(&result.expect("One peer has received the result"), tol);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_reduce_centralized_sum() {
|
||||
test_reduce::<TestBackend>(4, ReduceOperation::Sum, ReduceStrategy::Centralized, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_reduce_centralized_mean() {
|
||||
test_reduce::<TestBackend>(4, ReduceOperation::Mean, ReduceStrategy::Centralized, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_reduce_binary_tree_sum() {
|
||||
test_reduce::<TestBackend>(4, ReduceOperation::Sum, ReduceStrategy::Tree(2), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_reduce_binary_tree_mean() {
|
||||
test_reduce::<TestBackend>(4, ReduceOperation::Mean, ReduceStrategy::Tree(2), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_reduce_5_tree_sum() {
|
||||
test_reduce::<TestBackend>(4, ReduceOperation::Sum, ReduceStrategy::Tree(5), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
pub fn test_reduce_5_tree_mean() {
|
||||
test_reduce::<TestBackend>(4, ReduceOperation::Mean, ReduceStrategy::Tree(5), 4);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user