feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,82 @@
use std::path::PathBuf;
use burn_tensor::{Shape, Tensor, TensorData, backend::Backend};
use image::{DynamicImage, ImageBuffer, Luma, Rgb};
use burn_tensor::{Bool, Int};
#[cfg(all(
any(feature = "test-cpu", feature = "ndarray"),
not(any(feature = "test-wgpu", feature = "test-cuda"))
))]
pub type TestBackend = burn_ndarray::NdArray<f32, i32>;
#[cfg(all(test, feature = "test-wgpu"))]
pub type TestBackend = burn_wgpu::Wgpu;
#[cfg(all(test, feature = "test-cuda"))]
pub type TestBackend = burn_cuda::Cuda;
#[allow(unused)]
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
pub type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, Int>;
#[allow(unused)]
pub type TestTensorBool<const D: usize> = burn_tensor::Tensor<TestBackend, D, Bool>;
#[allow(unused)]
pub type IntType = <TestBackend as burn_tensor::backend::Backend>::IntElem;
#[allow(missing_docs)]
#[macro_export]
macro_rules! as_type {
($ty:ident: [$($elem:tt),*]) => {
[$($crate::as_type![$ty: $elem]),*]
};
($ty:ident: [$($elem:tt,)*]) => {
[$($crate::as_type![$ty: $elem]),*]
};
($ty:ident: $elem:expr) => {
{
use cubecl::prelude::*;
$ty::new($elem)
}
};
}
#[allow(unused)]
pub fn test_image<B: Backend>(name: &str, device: &B::Device, luma: bool) -> Tensor<B, 3> {
let file = PathBuf::from("tests/images").join(name);
let image = image::open(file).unwrap();
if luma {
let image = image.to_luma32f();
let h = image.height() as usize;
let w = image.width() as usize;
let data = TensorData::new(image.into_vec(), Shape::new([h, w, 1]));
Tensor::from_data(data, device)
} else {
let image = image.to_rgb32f();
let h = image.height() as usize;
let w = image.width() as usize;
let data = TensorData::new(image.into_vec(), Shape::new([h, w, 3]));
Tensor::from_data(data, device)
}
}
#[allow(unused)]
pub fn save_test_image<B: Backend>(name: &str, tensor: Tensor<B, 3>, luma: bool) {
let file = PathBuf::from("tests/images").join(name);
let [h, w, _] = tensor.shape().dims();
let data = tensor
.into_data()
.convert::<f32>()
.into_vec::<f32>()
.unwrap();
if luma {
let image = ImageBuffer::<Luma<f32>, _>::from_raw(w as u32, h as u32, data).unwrap();
DynamicImage::from(image).to_luma8().save(file).unwrap();
} else {
let image = ImageBuffer::<Rgb<f32>, _>::from_raw(w as u32, h as u32, data).unwrap();
DynamicImage::from(image).to_rgb8().save(file).unwrap();
}
}

View File

@@ -0,0 +1,144 @@
#![allow(clippy::single_range_in_vec_init)]
use std::collections::HashMap;
use burn_tensor::TensorData;
use burn_vision::{ConnectedComponents, ConnectedStatsOptions, Connectivity};
mod common;
use common::*;
fn space_invader() -> [[IntType; 14]; 9] {
as_type!(IntType: [
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0],
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1],
[1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1],
[1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
])
}
#[test]
fn should_support_8_connectivity() {
let tensor = TestTensorBool::<2>::from(space_invader());
let output = tensor.connected_components(Connectivity::Eight);
let expected = space_invader(); // All pixels are in the same group for 8-connected
let expected = TestTensorInt::<2>::from(expected);
normalize_labels(output.into_data()).assert_eq(&expected.into_data(), false);
}
#[test]
fn should_support_8_connectivity_with_stats() {
let tensor = TestTensorBool::<2>::from(space_invader());
let (output, stats) =
tensor.connected_components_with_stats(Connectivity::Eight, ConnectedStatsOptions::all());
let expected = space_invader(); // All pixels are in the same group for 8-connected
let expected = TestTensorInt::<2>::from(expected);
let (area, left, top, right, bottom) = (
stats.area.slice([1..2]).into_data(),
stats.left.slice([1..2]).into_data(),
stats.top.slice([1..2]).into_data(),
stats.right.slice([1..2]).into_data(),
stats.bottom.slice([1..2]).into_data(),
);
output.into_data().assert_eq(&expected.into_data(), false);
area.assert_eq(&TensorData::from([58]), false);
left.assert_eq(&TensorData::from([0]), false);
top.assert_eq(&TensorData::from([0]), false);
right.assert_eq(&TensorData::from([13]), false);
bottom.assert_eq(&TensorData::from([8]), false);
stats
.max_label
.into_data()
.assert_eq(&TensorData::from([1]), false);
}
#[test]
fn should_support_4_connectivity() {
let tensor = TestTensorBool::<2>::from(space_invader());
let output = tensor.connected_components(Connectivity::Four);
let expected = as_type!(IntType: [
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0],
[0, 0, 0, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0],
[0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0],
[0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3, 0, 0],
[0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0],
[4, 0, 0, 3, 3, 0, 0, 0, 0, 3, 3, 0, 0, 5],
[4, 4, 0, 0, 3, 3, 3, 3, 3, 3, 0, 0, 5, 5],
[4, 4, 0, 3, 3, 3, 0, 0, 3, 3, 3, 0, 5, 5],
[0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0],
]);
let expected = TestTensorInt::<2>::from(expected);
normalize_labels(output.into_data()).assert_eq(&expected.into_data(), false);
}
#[test]
fn should_support_4_connectivity_with_stats() {
let tensor = TestTensorBool::<2>::from(space_invader());
let (output, stats) =
tensor.connected_components_with_stats(Connectivity::Four, ConnectedStatsOptions::all());
let expected = as_type!(IntType: [
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0],
[0, 0, 0, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0],
[0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0],
[0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3, 0, 0],
[0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0],
[4, 0, 0, 3, 3, 0, 0, 0, 0, 3, 3, 0, 0, 5],
[4, 4, 0, 0, 3, 3, 3, 3, 3, 3, 0, 0, 5, 5],
[4, 4, 0, 3, 3, 3, 0, 0, 3, 3, 3, 0, 5, 5],
[0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0],
]);
let expected = TestTensorInt::<2>::from(expected);
// Slice off background and limit to compacted labels
let (area, left, top, right, bottom) = (
stats.area.slice([1..6]).into_data(),
stats.left.slice([1..6]).into_data(),
stats.top.slice([1..6]).into_data(),
stats.right.slice([1..6]).into_data(),
stats.bottom.slice([1..6]).into_data(),
);
output.into_data().assert_eq(&expected.into_data(), false);
area.assert_eq(&TensorData::from([1, 1, 46, 5, 5]), false);
left.assert_eq(&TensorData::from([3, 10, 1, 0, 12]), false);
top.assert_eq(&TensorData::from([0, 0, 1, 5, 5]), false);
right.assert_eq(&TensorData::from([3, 10, 12, 1, 13]), false);
bottom.assert_eq(&TensorData::from([0, 0, 8, 7, 7]), false);
stats
.max_label
.into_data()
.assert_eq(&TensorData::from([5]), false);
}
/// Normalize labels to sequential since actual labels aren't required to be contiguous and
/// different algorithms can return different numbers even if correct
fn normalize_labels(mut labels: TensorData) -> TensorData {
let mut next_label = 0;
let mut mappings = HashMap::<i32, i32>::default();
let data = labels.as_mut_slice::<i32>().unwrap();
for label in data {
if *label != 0 {
let relabel = mappings.entry(*label).or_insert_with(|| {
next_label += 1;
next_label
});
*label = *relabel;
}
}
labels
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 422 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 396 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 370 B

View File

@@ -0,0 +1,638 @@
use burn_tensor::{Tolerance, ops::FloatElem};
use burn_vision::{
BorderType, KernelShape, MorphOptions, Morphology, Point, Size, create_structuring_element,
};
type FT = FloatElem<TestBackend>;
mod common;
use common::*;
#[test]
fn should_support_dilate_luma() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(5, 5),
None,
&Default::default(),
);
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Dilate_1_5x5_Rect.png",
&Default::default(),
true,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_luma_cross() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(5, 5),
None,
&Default::default(),
);
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Dilate_1_5x5_Cross.png",
&Default::default(),
true,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_luma_ellipse() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Ellipse,
Size::new(5, 5),
None,
&Default::default(),
);
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Dilate_1_5x5_Ellipse.png",
&Default::default(),
true,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_luma_non_square_rect() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(3, 5),
None,
&Default::default(),
);
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Dilate_1_3x5_Rect.png",
&Default::default(),
true,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_luma_non_square_cross() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(3, 5),
None,
&Default::default(),
);
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Dilate_1_3x5_Cross.png",
&Default::default(),
true,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_rect() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(3, 5),
None,
&Default::default(),
);
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Dilate_2_3x5_Rect.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_cross() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(3, 5),
None,
&Default::default(),
);
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Dilate_2_3x5_Cross.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_border_reflect_rect() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(7, 7),
None,
&Default::default(),
);
let output = tensor.dilate(
kernel,
MorphOptions::builder()
.border_type(BorderType::Reflect)
.build(),
);
let expected = test_image(
"morphology/Dilate_2_7x7_Rect_BORDER_REFLECT.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_border_reflect_cross() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(7, 7),
None,
&Default::default(),
);
let output = tensor.dilate(
kernel,
MorphOptions::builder()
.border_type(BorderType::Reflect)
.build(),
);
let expected = test_image(
"morphology/Dilate_2_7x7_Cross_BORDER_REFLECT.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_border_reflect101_rect() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(7, 7),
None,
&Default::default(),
);
let output = tensor.dilate(
kernel,
MorphOptions::builder()
.border_type(BorderType::Reflect101)
.build(),
);
let expected = test_image(
"morphology/Dilate_2_7x7_Rect_BORDER_REFLECT101.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_border_reflect101_cross() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(7, 7),
None,
&Default::default(),
);
let output = tensor.dilate(
kernel,
MorphOptions::builder()
.border_type(BorderType::Reflect101)
.build(),
);
let expected = test_image(
"morphology/Dilate_2_7x7_Cross_BORDER_REFLECT101.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_border_replicate_rect() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(7, 7),
None,
&Default::default(),
);
let output = tensor.dilate(
kernel,
MorphOptions::builder()
.border_type(BorderType::Replicate)
.build(),
);
let expected = test_image(
"morphology/Dilate_2_7x7_Rect_BORDER_REPLICATE.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_border_replicate_cross() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(7, 7),
None,
&Default::default(),
);
let output = tensor.dilate(
kernel,
MorphOptions::builder()
.border_type(BorderType::Replicate)
.build(),
);
let expected = test_image(
"morphology/Dilate_2_7x7_Cross_BORDER_REPLICATE.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_anchor_rect() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(5, 7),
Some(Point::new(1, 2)),
&Default::default(),
);
let output = tensor.dilate(
kernel,
MorphOptions::builder().anchor(Point::new(2, 1)).build(),
);
let expected = test_image(
"morphology/Dilate_2_5x7_Rect_ANCHOR.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_anchor_cross() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(5, 7),
Some(Point::new(1, 2)),
&Default::default(),
);
// With default border, bottom left pixel is undefined with this particular kernel and anchor
// Use replicate instead for comparability
let output = tensor.dilate(
kernel,
MorphOptions::builder()
.anchor(Point::new(2, 1))
.border_type(BorderType::Replicate)
.build(),
);
let expected = test_image(
"morphology/Dilate_2_5x7_Cross_ANCHOR_BORDER_REPLICATE.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_boolean_rect() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true).greater_elem(0);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(5, 5),
None,
&Default::default(),
);
// With default border, bottom left pixel is undefined with this particular kernel and anchor
// Use replicate instead for comparability
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Dilate_1_5x5_Rect.png",
&Default::default(),
true,
)
.greater_elem(0);
let expected = TestTensorBool::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_boolean_cross() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true).greater_elem(0);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(5, 5),
None,
&Default::default(),
);
// With default border, bottom left pixel is undefined with this particular kernel and anchor
// Use replicate instead for comparability
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Dilate_1_5x5_Cross.png",
&Default::default(),
true,
)
.greater_elem(0);
let expected = TestTensorBool::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_int_rect() {
let tensor = (test_image("morphology/Base_1.png", &Default::default(), true) * 255.0).int();
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(5, 5),
None,
&Default::default(),
);
// With default border, bottom left pixel is undefined with this particular kernel and anchor
// Use replicate instead for comparability
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = (test_image(
"morphology/Dilate_1_5x5_Rect.png",
&Default::default(),
true,
) * 255.0)
.int();
let expected = TestTensorInt::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_int_cross() {
let tensor = (test_image("morphology/Base_1.png", &Default::default(), true) * 255.0).int();
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(5, 5),
None,
&Default::default(),
);
// With default border, bottom left pixel is undefined with this particular kernel and anchor
// Use replicate instead for comparability
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = (test_image(
"morphology/Dilate_1_5x5_Cross.png",
&Default::default(),
true,
) * 255.0)
.int();
let expected = TestTensorInt::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_erode_luma() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = TestTensorBool::<2>::from([
[true, true, true, true, true],
[true, true, true, true, true],
[true, true, true, true, true],
[true, true, true, true, true],
[true, true, true, true, true],
]);
let output = tensor.erode(kernel, MorphOptions::default());
let expected = test_image("morphology/Erode_1_5x5_Rect.png", &Default::default(), true);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_erode_luma_cross() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(5, 5),
None,
&Default::default(),
);
let output = tensor.erode(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Erode_1_5x5_Cross.png",
&Default::default(),
true,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_erode_luma_ellipse() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Ellipse,
Size::new(5, 5),
None,
&Default::default(),
);
let output = tensor.erode(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Erode_1_5x5_Ellipse.png",
&Default::default(),
true,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn create_structuring_element_should_match_manual_rect() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(5, 5),
None,
&Default::default(),
);
let kernel_manual = TestTensorBool::<2>::from([
[true, true, true, true, true],
[true, true, true, true, true],
[true, true, true, true, true],
[true, true, true, true, true],
[true, true, true, true, true],
]);
let output = tensor.clone().dilate(kernel, MorphOptions::default());
let output_manual = tensor.dilate(kernel_manual, MorphOptions::default());
output
.into_data()
.assert_eq(&output_manual.into_data(), false);
}
#[test]
fn create_structuring_element_should_match_manual_cross() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(5, 5),
None,
&Default::default(),
);
let kernel_manual = TestTensorBool::<2>::from([
[false, false, true, false, false],
[false, false, true, false, false],
[true, true, true, true, true],
[false, false, true, false, false],
[false, false, true, false, false],
]);
let output = tensor.clone().dilate(kernel, MorphOptions::default());
let output_manual = tensor.dilate(kernel_manual, MorphOptions::default());
output
.into_data()
.assert_eq(&output_manual.into_data(), false);
}
#[test]
fn create_structuring_element_should_match_manual_ellipse() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Ellipse,
Size::new(5, 5),
None,
&Default::default(),
);
let kernel_manual = TestTensorBool::<2>::from([
[false, false, true, false, false],
[true, true, true, true, true],
[true, true, true, true, true],
[true, true, true, true, true],
[false, false, true, false, false],
]);
let output = tensor.clone().dilate(kernel, MorphOptions::default());
let output_manual = tensor.dilate(kernel_manual, MorphOptions::default());
output
.into_data()
.assert_eq(&output_manual.into_data(), false);
}

View File

@@ -0,0 +1,92 @@
use burn_vision::{Nms, NmsOptions};
mod common;
use common::*;
#[test]
fn should_suppress_non_maximum() {
let boxes = TestTensor::<2>::from([
[0, 0, 100, 100],
[0, 1, 100, 100],
[0, 101, 200, 200],
[0, 100, 200, 200],
[0, 170, 300, 300],
]);
let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);
let options = NmsOptions {
iou_threshold: 0.5,
score_threshold: 0.0,
max_output_boxes: 0,
};
let output = boxes.nms(scores, options);
let expected = TestTensorInt::<1>::from([4, 2, 1]);
output.into_data().assert_eq(&expected.into_data(), true);
}
#[test]
fn should_apply_score_threshold() {
let boxes = TestTensor::<2>::from([
[0, 0, 100, 100],
[0, 1, 100, 100],
[0, 101, 200, 200],
[0, 100, 200, 200],
[0, 170, 300, 300],
]);
let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);
let options = NmsOptions {
iou_threshold: 0.5,
score_threshold: 0.3,
max_output_boxes: 0,
};
let output = boxes.nms(scores, options);
let expected = TestTensorInt::<1>::from([4, 2]);
output.into_data().assert_eq(&expected.into_data(), true);
}
#[test]
fn should_apply_iou_threshold() {
let boxes = TestTensor::<2>::from([
[0, 0, 100, 100],
[0, 1, 100, 100],
[0, 101, 200, 200],
[0, 100, 200, 200],
[0, 170, 300, 300],
]);
let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);
let options = NmsOptions {
iou_threshold: 0.1,
score_threshold: 0.0,
max_output_boxes: 0,
};
let output = boxes.nms(scores, options);
let expected = TestTensorInt::<1>::from([4, 1]);
output.into_data().assert_eq(&expected.into_data(), true);
}
#[test]
fn should_apply_max_output_boxes() {
let boxes = TestTensor::<2>::from([
[0, 0, 100, 100],
[0, 1, 100, 100],
[0, 101, 200, 200],
[0, 100, 200, 200],
[0, 170, 300, 300],
]);
let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);
let options = NmsOptions {
iou_threshold: 0.5,
score_threshold: 0.0,
max_output_boxes: 1,
};
let output = boxes.nms(scores, options);
let expected = TestTensorInt::<1>::from([4]);
output.into_data().assert_eq(&expected.into_data(), true);
}