- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
43 lines
1.2 KiB
Rust
43 lines
1.2 KiB
Rust
use super::*;
|
|
use burn_tensor::Tolerance;
|
|
use burn_tensor::{Distribution, Tensor, backend::Backend};
|
|
|
|
#[test]
|
|
fn cat_should_match_reference_backend_dim0() {
|
|
test_same_as_reference([6, 256], 2, 0);
|
|
}
|
|
|
|
#[test]
|
|
fn cat_should_match_reference_backend_dim1() {
|
|
test_same_as_reference([6, 256], 2, 1);
|
|
}
|
|
|
|
#[test]
|
|
fn cat_should_support_uneven_launch() {
|
|
test_same_as_reference([1, 137], 2, 0);
|
|
}
|
|
|
|
fn test_same_as_reference(shape: [usize; 2], num_tensors: usize, dim: usize) {
|
|
let device = Default::default();
|
|
TestBackend::seed(&device, 0);
|
|
|
|
let tensors = (0..num_tensors)
|
|
.map(|_| {
|
|
Tensor::<TestBackend, 2>::random(shape, Distribution::Default, &Default::default())
|
|
})
|
|
.collect::<Vec<_>>();
|
|
let tensors_ref = tensors
|
|
.iter()
|
|
.map(|tensor| {
|
|
Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data(), &Default::default())
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
let tensor = Tensor::<TestBackend, 2>::cat(tensors, dim);
|
|
let tensor_ref = Tensor::<ReferenceBackend, 2>::cat(tensors_ref, dim);
|
|
|
|
tensor
|
|
.into_data()
|
|
.assert_approx_eq::<FloatElem>(&tensor_ref.into_data(), Tolerance::default());
|
|
}
|