// Re-export use super::FloatElemType; // Default #[cfg(feature = "ndarray")] pub type TestBackend = burn_ndarray::NdArray; #[cfg(feature = "tch")] pub type TestBackend = burn_tch::LibTorch; #[cfg(feature = "cuda")] pub type TestBackend = burn_cuda::Cuda; #[cfg(feature = "rocm")] pub type TestBackend = burn_rocm::Rocm; #[cfg(feature = "wgpu")] pub type TestBackend = burn_wgpu::Wgpu; #[cfg(feature = "cpu")] pub type TestBackend = burn_cpu::Cpu; #[cfg(feature = "router")] pub type TestBackend = burn_router::BackendRouter< burn_router::DirectByteChannel<(burn_ndarray::NdArray, burn_wgpu::Wgpu)>, >; /// Collection of types used across tests #[allow(unused)] pub mod prelude { pub use burn_autodiff::Autodiff; pub use burn_tensor::Tensor; use super::*; pub type TestTensor = Tensor; pub type TestTensorInt = Tensor; pub type TestTensorBool = Tensor; pub type FloatElem = burn_tensor::ops::FloatElem; pub type IntElem = burn_tensor::ops::IntElem; pub type TestAutodiffBackend = Autodiff; pub type TestAutodiffTensor = Tensor; } #[allow(unused)] pub use prelude::*;