Files
RustyUI/crates/stable-diffusion-burn/burn-crates/burn-collective/multinode-tests/src/shared.rs
Ben_Kosytorz 3a67c0979c feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
2026-03-05 19:39:14 +01:00

44 lines
1.3 KiB
Rust

use std::time::Duration;
use burn::tensor::TensorData;
use burn_collective::{AllReduceStrategy, NodeId, ReduceOperation};
use burn_communication::Address;
use serde::{Deserialize, Serialize};
/// Ranks of inputs and outputs for all testing
pub const TENSOR_RANK: usize = 3;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeTest {
/// How many threads to start on this node
pub device_count: usize,
/// ID for this node
pub node_id: NodeId,
/// How many nodes in the cluster
pub node_count: u32,
/// Global server address
pub global_address: Address,
/// Node address
pub node_address: Address,
/// Node's data service port, for initializing the p2p tensor data service
pub data_service_port: u16,
/// What kind of all-reduce
pub all_reduce_op: ReduceOperation,
/// Node's data service port, for initializing the p2p tensor data service
pub global_strategy: AllReduceStrategy,
/// What kind of aggregation
pub local_strategy: AllReduceStrategy,
/// Input data for test: all tensors are D=3
pub inputs: Vec<TensorData>,
/// Expected output for test
pub expected: TensorData,
}
/// Result sent back from each node for each test
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeTestResult {
pub success: bool,
pub durations: Vec<Duration>,
}