feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
@@ -0,0 +1,34 @@
|
||||
[package]
|
||||
name = "burn-collective-multinode-tests"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["ndarray"]
|
||||
ndarray = ["burn/ndarray"]
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../../burn", default-features = false, features = ["std"] }
|
||||
burn-std = { path = "../../burn-std", default-features = false }
|
||||
burn-collective = { path = "..", features = ["orchestrator"] }
|
||||
burn-communication = { path = "../../burn-communication" }
|
||||
|
||||
tokio = { workspace = true, features = ["rt-multi-thread", "process"] }
|
||||
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
interprocess = "2.3.1"
|
||||
rmp-serde = { workspace = true }
|
||||
tokio-util = { workspace = true, features = ["codec"] }
|
||||
tokio-serde = { version = "0.9.0", features = ["messagepack"] }
|
||||
futures = { workspace = true }
|
||||
|
||||
|
||||
[[bin]]
|
||||
name = "global"
|
||||
path = "src/bin/global.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "node"
|
||||
path = "src/bin/node.rs"
|
||||
@@ -0,0 +1,32 @@
|
||||
# Integration test for burn collective operations with multiple nodes and devices.
|
||||
|
||||
Run `cargo run --bin test_launcher`
|
||||
|
||||
There are 3 binaries:
|
||||
|
||||
## node.rs
|
||||
|
||||
Launches `n` threads each simulating a different device. Currently the backend is NdArray,
|
||||
so everything is CPU. The program takes a file with configurations and input data.
|
||||
|
||||
## global.rs
|
||||
|
||||
Runs the global orchestrator, who is responsible for responding to global collective operation
|
||||
requests. In the case of an all-reduce, the orchestrator responds with a strategy for reducing,
|
||||
and the node can do the reduction independently.
|
||||
|
||||
## test_launcher.rs
|
||||
|
||||
Generates input data, calculates the expected results, and launches the nodes each with their
|
||||
own inputs in a separate file.
|
||||
|
||||
The topology is [4, 4, 4, 4]. This means 4 nodes are launched,
|
||||
each with 4 threads (for each device).
|
||||
|
||||
The global orchestrator (`global.rs`) is also launched.
|
||||
|
||||
## Output
|
||||
|
||||
The outputs and inputs for each node and the orchestrator are written to the `target/test_files` folder
|
||||
|
||||
If the nodes or orchestrator stall, there is a timeout.
|
||||
@@ -0,0 +1,19 @@
|
||||
//! Global orchestrator
|
||||
//!
|
||||
//! Launches the orchestrator that responds to global collective operations for nodes for the
|
||||
//! integration test
|
||||
//!
|
||||
//! This is necessary for any node who needs global collective operations
|
||||
|
||||
use std::env;
|
||||
|
||||
#[tokio::main]
|
||||
/// Start the global orchestrator on the port given as first arg
|
||||
pub async fn main() {
|
||||
let args: Vec<String> = env::args().collect();
|
||||
let port = args[1].parse::<u16>().expect("invalid port");
|
||||
|
||||
// Launch the global orchestrator, which will listen and respond to global collective op
|
||||
// requests from nodes
|
||||
burn_collective::start_global_orchestrator(port).await;
|
||||
}
|
||||
@@ -0,0 +1,157 @@
|
||||
use burn::{
|
||||
backend::NdArray,
|
||||
prelude::Backend,
|
||||
tensor::{Tensor, TensorPrimitive, Tolerance},
|
||||
};
|
||||
use burn_collective::{
|
||||
CollectiveConfig, PeerId, ReduceOperation, all_reduce, finish_collective, register,
|
||||
reset_collective,
|
||||
};
|
||||
use burn_collective_multinode_tests::shared::{NodeTest, NodeTestResult, TENSOR_RANK};
|
||||
use std::{
|
||||
env,
|
||||
sync::mpsc::SyncSender,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use std::thread::JoinHandle;
|
||||
use tokio_serde::formats::MessagePack;
|
||||
use tokio_util::codec::LengthDelimitedCodec;
|
||||
|
||||
type TestBackend = NdArray;
|
||||
|
||||
/// Framed TCP connection channel
|
||||
type TestChannel = tokio_serde::Framed<
|
||||
tokio_util::codec::Framed<tokio::net::TcpStream, LengthDelimitedCodec>,
|
||||
NodeTest,
|
||||
NodeTestResult,
|
||||
MessagePack<NodeTest, NodeTestResult>,
|
||||
>;
|
||||
|
||||
/// Start a node that will test all-reduce
|
||||
/// Args are the following:
|
||||
/// - launcher endpoint
|
||||
#[tokio::main]
|
||||
pub async fn main() {
|
||||
let args: Vec<String> = env::args().collect();
|
||||
|
||||
let launcher_addr = args[1].clone();
|
||||
|
||||
let socket = TcpStream::connect(launcher_addr).await.unwrap();
|
||||
let length_delimited = tokio_util::codec::Framed::new(socket, LengthDelimitedCodec::new());
|
||||
let mut socket: TestChannel = tokio_serde::Framed::new(
|
||||
length_delimited,
|
||||
MessagePack::<NodeTest, NodeTestResult>::default(),
|
||||
);
|
||||
|
||||
// Loop: receive, do test, send result
|
||||
while let Some(Ok(test)) = socket.next().await {
|
||||
println!("Received test: {test:?}");
|
||||
|
||||
let result = run_test::<NdArray>(&test);
|
||||
|
||||
// send the result back
|
||||
socket.send(result).await.expect("failed to send Result");
|
||||
}
|
||||
|
||||
println!("Server closed connection; exiting.");
|
||||
}
|
||||
|
||||
/// Runs a test for one node
|
||||
fn run_test<B: Backend>(test_input: &NodeTest) -> NodeTestResult {
|
||||
reset_collective::<TestBackend>();
|
||||
|
||||
// Channel for results
|
||||
let (result_send, result_recv) = std::sync::mpsc::sync_channel(32);
|
||||
|
||||
// Launch a thread for each "device"
|
||||
let handles = launch_threads::<B>(test_input.clone(), result_send);
|
||||
|
||||
// Receive results
|
||||
let mut durations = vec![];
|
||||
let tol: Tolerance<f32> = Tolerance::balanced();
|
||||
for _ in 0..test_input.device_count {
|
||||
// Assert all results are equal to each other as well as expected result
|
||||
let (tensor, duration) = result_recv.recv().unwrap();
|
||||
test_input.expected.assert_approx_eq(&tensor.to_data(), tol);
|
||||
|
||||
durations.push(duration);
|
||||
}
|
||||
|
||||
// Threads finish
|
||||
for handle in handles {
|
||||
let _ = handle.join();
|
||||
}
|
||||
|
||||
NodeTestResult {
|
||||
success: true,
|
||||
durations,
|
||||
}
|
||||
}
|
||||
|
||||
/// Launch a thread for each device, and run the all-reduce
|
||||
fn launch_threads<B: Backend>(
|
||||
test_input: NodeTest,
|
||||
result_send: SyncSender<(Tensor<B, TENSOR_RANK>, Duration)>,
|
||||
) -> Vec<JoinHandle<()>> {
|
||||
let mut handles = vec![];
|
||||
for id in 0..test_input.device_count {
|
||||
// Launch a thread to test
|
||||
|
||||
// Put all the parameters in the config
|
||||
let config = CollectiveConfig::default()
|
||||
.with_num_devices(test_input.device_count)
|
||||
.with_global_address(test_input.global_address.clone())
|
||||
.with_node_address(test_input.node_address.clone())
|
||||
.with_data_service_port(test_input.data_service_port)
|
||||
.with_num_nodes(test_input.node_count)
|
||||
.with_global_all_reduce_strategy(test_input.global_strategy)
|
||||
.with_local_all_reduce_strategy(test_input.local_strategy);
|
||||
|
||||
// Inputs and outputs for the test
|
||||
let tensor_data = test_input.inputs[id].clone();
|
||||
let tensor = Tensor::<B, TENSOR_RANK>::from_data(tensor_data, &B::Device::default());
|
||||
let result_send = result_send.clone();
|
||||
|
||||
let handle = std::thread::spawn(move || {
|
||||
run_peer::<B>(
|
||||
id.into(),
|
||||
config,
|
||||
tensor,
|
||||
result_send,
|
||||
test_input.all_reduce_op,
|
||||
)
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
handles
|
||||
}
|
||||
|
||||
/// Runs a thread in the all-reduce test.
|
||||
pub fn run_peer<B: Backend>(
|
||||
id: PeerId,
|
||||
config: CollectiveConfig,
|
||||
input: Tensor<B, TENSOR_RANK>,
|
||||
output: SyncSender<(Tensor<B, TENSOR_RANK>, Duration)>,
|
||||
all_reduce_op: ReduceOperation,
|
||||
) {
|
||||
// Register the device
|
||||
register::<B>(id, input.device(), config).unwrap();
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
// All-reduce
|
||||
let input = input.into_primitive().tensor();
|
||||
let tensor = all_reduce::<B>(id, input, all_reduce_op).unwrap();
|
||||
let tensor = Tensor::<B, TENSOR_RANK>::from_primitive(TensorPrimitive::Float(tensor));
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
// Send result
|
||||
output.send((tensor, duration)).unwrap();
|
||||
|
||||
finish_collective::<B>(id).unwrap();
|
||||
}
|
||||
@@ -0,0 +1,354 @@
|
||||
use burn::tensor::TensorData;
|
||||
use burn_communication::Address;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use std::{
|
||||
fmt::Display,
|
||||
fs::{self, File},
|
||||
str::FromStr,
|
||||
time::{Duration, Instant},
|
||||
vec,
|
||||
};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_serde::formats::MessagePack;
|
||||
use tokio_util::codec::LengthDelimitedCodec;
|
||||
|
||||
use burn::{backend::NdArray, prelude::Backend, tensor::Tensor};
|
||||
use burn_collective::{AllReduceStrategy, ReduceOperation};
|
||||
use burn_collective_multinode_tests::shared::{NodeTest, NodeTestResult, TENSOR_RANK};
|
||||
use burn_std::rand::{SeedableRng, StdRng};
|
||||
use tokio::process::{Child, Command};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AllReduceTest {
|
||||
shape: [usize; TENSOR_RANK],
|
||||
op: ReduceOperation,
|
||||
local_strategy: AllReduceStrategy,
|
||||
global_strategy: AllReduceStrategy,
|
||||
}
|
||||
|
||||
impl Display for AllReduceTest {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let op_str = match self.op {
|
||||
ReduceOperation::Sum => "sum",
|
||||
ReduceOperation::Mean => "mean",
|
||||
};
|
||||
let local_strategy_str = match self.local_strategy {
|
||||
AllReduceStrategy::Centralized => "local_centralized",
|
||||
AllReduceStrategy::Tree(n) => &format!("local_tree_{n}"),
|
||||
AllReduceStrategy::Ring => "local_ring",
|
||||
};
|
||||
let global_strategy_str = match self.global_strategy {
|
||||
AllReduceStrategy::Centralized => "global_centralized",
|
||||
AllReduceStrategy::Tree(n) => &format!("global_tree_{n}"),
|
||||
AllReduceStrategy::Ring => "global_ring",
|
||||
};
|
||||
|
||||
write!(f, "{op_str}_{local_strategy_str}_{global_strategy_str}")
|
||||
}
|
||||
}
|
||||
|
||||
/// Framed TCP connection for sending tests and receiving results
|
||||
type TestChannel = tokio_serde::Framed<
|
||||
tokio_util::codec::Framed<tokio::net::TcpStream, LengthDelimitedCodec>,
|
||||
NodeTestResult,
|
||||
NodeTest,
|
||||
MessagePack<NodeTestResult, NodeTest>,
|
||||
>;
|
||||
|
||||
/// Handle for each node process
|
||||
struct NodeProcessHandle {
|
||||
process: Child,
|
||||
channel: TestChannel,
|
||||
}
|
||||
|
||||
/// Main function to run the multi-node all-reduce test.
|
||||
/// Launches a orchestrator and multiple nodes based on the provided topology.
|
||||
#[tokio::main(flavor = "multi_thread", worker_threads = 10)]
|
||||
async fn main() {
|
||||
let all_reduce_tests = vec![
|
||||
AllReduceTest {
|
||||
shape: [4, 64, 512],
|
||||
op: ReduceOperation::Mean,
|
||||
local_strategy: AllReduceStrategy::Tree(2),
|
||||
global_strategy: AllReduceStrategy::Tree(2),
|
||||
},
|
||||
AllReduceTest {
|
||||
shape: [4, 64, 512],
|
||||
op: ReduceOperation::Mean,
|
||||
local_strategy: AllReduceStrategy::Tree(2),
|
||||
global_strategy: AllReduceStrategy::Ring,
|
||||
},
|
||||
AllReduceTest {
|
||||
shape: [4, 64, 512],
|
||||
op: ReduceOperation::Mean,
|
||||
local_strategy: AllReduceStrategy::Centralized,
|
||||
global_strategy: AllReduceStrategy::Centralized,
|
||||
},
|
||||
];
|
||||
|
||||
let test_files_dir = "target/test_files";
|
||||
fs::create_dir_all(test_files_dir).expect("Couldn't create test_files directory");
|
||||
|
||||
let topology: Vec<usize> = vec![4; 4];
|
||||
|
||||
let mut orchestrator = launch_orchestrator(test_files_dir);
|
||||
|
||||
let launcher_endpoint = "127.0.0.1:4000";
|
||||
|
||||
// Build and run node processes
|
||||
let mut all_tests_durations = vec![];
|
||||
if let Ok(mut nodes) = launch_nodes(&topology, launcher_endpoint).await {
|
||||
// Run one test
|
||||
for test in all_reduce_tests.clone() {
|
||||
let test_name = test.to_string();
|
||||
|
||||
let time =
|
||||
test_all_reduce_centralized_no_collective::<NdArray>(&topology, test.clone());
|
||||
println!(
|
||||
"{test_name}: Benchmark (no collective, centralized, single-threaded): {} secs",
|
||||
time.as_secs_f32()
|
||||
);
|
||||
|
||||
match test_all_reduce(&topology, test, &mut nodes).await {
|
||||
Err(node_idx) => {
|
||||
println!("{test_name}: Node with index {node_idx} failed!");
|
||||
// Kill other node processes
|
||||
for mut node in nodes.drain(..) {
|
||||
node.process.kill().await.unwrap();
|
||||
node.process.wait().await.unwrap();
|
||||
}
|
||||
break;
|
||||
}
|
||||
Ok(durations) => {
|
||||
all_tests_durations.append(&mut durations.clone());
|
||||
let avg = durations.iter().map(|dur| dur.as_secs_f32()).sum::<f32>()
|
||||
/ durations.len() as f32;
|
||||
println!("{test_name}: Success in {avg} secs");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !all_tests_durations.is_empty() {
|
||||
let avg = all_tests_durations
|
||||
.iter()
|
||||
.map(|dur| dur.as_secs_f32())
|
||||
.sum::<f32>()
|
||||
/ all_tests_durations.len() as f32;
|
||||
println!("Average for all tests: {avg} secs");
|
||||
}
|
||||
|
||||
// Shutdown orchestrator
|
||||
orchestrator.kill().await.unwrap();
|
||||
orchestrator.wait().await.unwrap();
|
||||
}
|
||||
|
||||
/// Launch a global orchestrator with an output file in the given directory.
|
||||
/// Necessary for global collective operations
|
||||
///
|
||||
/// Server listens on localhost port 3000
|
||||
fn launch_orchestrator(test_files_dir: &str) -> Child {
|
||||
let out_path = format!("{test_files_dir}/orchestrator_out.txt");
|
||||
let out = File::create(out_path).expect("Could't create orchestrator output file");
|
||||
|
||||
Command::new("cargo")
|
||||
.args(["run", "--bin", "global", "--", "3000"])
|
||||
.stdout(out.try_clone().unwrap())
|
||||
.stderr(out)
|
||||
.spawn()
|
||||
.expect("failed to launch orchestrator")
|
||||
}
|
||||
|
||||
/// Launch nodes for a all_reduce test
|
||||
/// Each node will connect to the global orchestrator and run an all-reduce operation.
|
||||
/// The topology is a vector where each element represents the number of devices in that node.
|
||||
async fn launch_nodes(
|
||||
topology: &[usize],
|
||||
launcher_endpoint: &str,
|
||||
) -> Result<Vec<NodeProcessHandle>, ()> {
|
||||
println!(
|
||||
"Launching {} nodes with topology: {:?}",
|
||||
topology.len(),
|
||||
topology
|
||||
);
|
||||
|
||||
// Listen for node connections
|
||||
let listener = TcpListener::bind(launcher_endpoint).await.unwrap();
|
||||
println!("Server listening on {launcher_endpoint}");
|
||||
|
||||
let mut nodes = vec![];
|
||||
|
||||
for node_idx in 0..topology.len() {
|
||||
// Create log file
|
||||
let output_filename = format!("target/test_files/node_{}_log.txt", node_idx + 1);
|
||||
let out = File::create(output_filename).expect("Could't open node log file");
|
||||
|
||||
// Start a process for each node. Pass on our feature flags
|
||||
let node_process: Child = Command::new("cargo")
|
||||
.args([
|
||||
"run",
|
||||
"--release",
|
||||
"--features",
|
||||
#[cfg(feature = "ndarray")]
|
||||
"ndarray",
|
||||
"--bin",
|
||||
"node",
|
||||
"--",
|
||||
launcher_endpoint,
|
||||
&node_idx.to_string(),
|
||||
])
|
||||
.stdout(out.try_clone().unwrap())
|
||||
.stderr(out)
|
||||
.spawn()
|
||||
.expect("node failed");
|
||||
|
||||
// Wait for child to connect for io
|
||||
let (socket, _peer_addr) = listener.accept().await.unwrap();
|
||||
let length_delimited = tokio_util::codec::Framed::new(socket, LengthDelimitedCodec::new());
|
||||
let channel: TestChannel = tokio_serde::Framed::new(
|
||||
length_delimited,
|
||||
MessagePack::<NodeTestResult, NodeTest>::default(),
|
||||
);
|
||||
|
||||
nodes.push(NodeProcessHandle {
|
||||
process: node_process,
|
||||
channel,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(nodes)
|
||||
}
|
||||
|
||||
async fn test_all_reduce(
|
||||
topology: &[usize],
|
||||
test: AllReduceTest,
|
||||
nodes: &mut [NodeProcessHandle],
|
||||
) -> Result<Vec<Duration>, usize> {
|
||||
dispatch_all_reduce_test(topology, test, nodes).await;
|
||||
|
||||
let mut all_durations = vec![];
|
||||
for (idx, handle) in nodes.iter_mut().enumerate() {
|
||||
match handle.channel.next().await {
|
||||
Some(Ok(mut result)) => {
|
||||
if !result.success {
|
||||
return Err(idx);
|
||||
}
|
||||
all_durations.append(&mut result.durations);
|
||||
}
|
||||
_ => {
|
||||
return Err(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(all_durations)
|
||||
}
|
||||
|
||||
async fn dispatch_all_reduce_test(
|
||||
topology: &[usize],
|
||||
test: AllReduceTest,
|
||||
nodes: &mut [NodeProcessHandle],
|
||||
) {
|
||||
let total_device_count: usize = topology.iter().sum();
|
||||
let (mut all_inputs, expected) =
|
||||
generate_random_input(test.shape, test.op, total_device_count, 42);
|
||||
|
||||
// URL for the global orchestrator on port 3000
|
||||
let global_url = "ws://localhost:3000";
|
||||
let global_address = Address::from_str(global_url).unwrap();
|
||||
|
||||
for (node_idx, &device_count) in topology.iter().enumerate() {
|
||||
// Construct URL for node
|
||||
// Ports 3001... are for each node
|
||||
let data_service_port = node_idx as u16 + 3001;
|
||||
let node_url = format!("ws://localhost:{data_service_port}");
|
||||
let node_address = Address::from_str(&node_url).unwrap();
|
||||
|
||||
// take input tensors for each device
|
||||
let inputs = all_inputs[0..device_count].to_vec();
|
||||
all_inputs = all_inputs[device_count..].to_vec();
|
||||
|
||||
let test = NodeTest {
|
||||
device_count,
|
||||
node_id: node_idx.into(),
|
||||
node_count: topology.len() as u32,
|
||||
global_address: global_address.clone(),
|
||||
node_address,
|
||||
data_service_port,
|
||||
all_reduce_op: test.op,
|
||||
global_strategy: test.global_strategy,
|
||||
local_strategy: test.local_strategy,
|
||||
inputs,
|
||||
expected: expected.clone(),
|
||||
};
|
||||
let handle = &mut nodes[node_idx];
|
||||
|
||||
handle.channel.send(test).await.unwrap();
|
||||
}
|
||||
|
||||
assert!(
|
||||
all_inputs.is_empty(),
|
||||
"Not all inputs have been sent to tests"
|
||||
);
|
||||
}
|
||||
|
||||
/// Run the test sequentially with no collective operations to get the optimal single-threaded speed
|
||||
fn test_all_reduce_centralized_no_collective<B: Backend>(
|
||||
topology: &[usize],
|
||||
test: AllReduceTest,
|
||||
) -> Duration {
|
||||
let total_device_count: usize = topology.iter().sum();
|
||||
let (all_inputs, _expected) =
|
||||
generate_random_input(test.shape, test.op, total_device_count, 42);
|
||||
|
||||
let mut all_inputs = all_inputs
|
||||
.into_iter()
|
||||
.map(|data| Tensor::<B, 3>::from_data(data, &B::Device::default()));
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
// Sequential test
|
||||
let mut result = all_inputs.next().unwrap();
|
||||
for other in all_inputs {
|
||||
result = result.add(other);
|
||||
}
|
||||
if test.op == ReduceOperation::Mean {
|
||||
result.div_scalar(total_device_count as u32);
|
||||
}
|
||||
|
||||
start.elapsed()
|
||||
}
|
||||
|
||||
/// Generates random input tensors and expected output based on the provided shape and reduce kind.
|
||||
fn generate_random_input(
|
||||
shape: [usize; 3],
|
||||
reduce_kind: ReduceOperation,
|
||||
input_count: usize,
|
||||
seed: u64,
|
||||
) -> (Vec<TensorData>, TensorData) {
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
|
||||
// A random tensor for each device
|
||||
let input: Vec<TensorData> = (0..input_count)
|
||||
.map(|_| {
|
||||
TensorData::random::<f32, _, _>(shape, burn::tensor::Distribution::Default, &mut rng)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sum up the inputs
|
||||
let device = <NdArray as Backend>::Device::default();
|
||||
let mut expected_tensor = Tensor::<NdArray, TENSOR_RANK>::zeros(shape, &device);
|
||||
for item in input.iter().take(input_count) {
|
||||
let input_tensor = Tensor::<NdArray, TENSOR_RANK>::from_data(item.clone(), &device);
|
||||
expected_tensor = expected_tensor.add(input_tensor);
|
||||
}
|
||||
|
||||
if reduce_kind == ReduceOperation::Mean {
|
||||
expected_tensor = expected_tensor.div_scalar(input_count as u32);
|
||||
}
|
||||
|
||||
// All-Reduce results should have this value
|
||||
let expected = expected_tensor.to_data();
|
||||
|
||||
(input, expected)
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
pub mod shared;
|
||||
@@ -0,0 +1,43 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use burn::tensor::TensorData;
|
||||
use burn_collective::{AllReduceStrategy, NodeId, ReduceOperation};
|
||||
use burn_communication::Address;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Ranks of inputs and outputs for all testing
|
||||
pub const TENSOR_RANK: usize = 3;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NodeTest {
|
||||
/// How many threads to start on this node
|
||||
pub device_count: usize,
|
||||
/// ID for this node
|
||||
pub node_id: NodeId,
|
||||
/// How many nodes in the cluster
|
||||
pub node_count: u32,
|
||||
/// Global server address
|
||||
pub global_address: Address,
|
||||
/// Node address
|
||||
pub node_address: Address,
|
||||
/// Node's data service port, for initializing the p2p tensor data service
|
||||
pub data_service_port: u16,
|
||||
/// What kind of all-reduce
|
||||
pub all_reduce_op: ReduceOperation,
|
||||
/// Node's data service port, for initializing the p2p tensor data service
|
||||
pub global_strategy: AllReduceStrategy,
|
||||
/// What kind of aggregation
|
||||
pub local_strategy: AllReduceStrategy,
|
||||
|
||||
/// Input data for test: all tensors are D=3
|
||||
pub inputs: Vec<TensorData>,
|
||||
/// Expected output for test
|
||||
pub expected: TensorData,
|
||||
}
|
||||
|
||||
/// Result sent back from each node for each test
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NodeTestResult {
|
||||
pub success: bool,
|
||||
pub durations: Vec<Duration>,
|
||||
}
|
||||
Reference in New Issue
Block a user