feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,71 @@
[package]
authors = [
"nathanielsimard <nathaniel.simard.42@gmail.com>",
"wingertge <wingertge@gmail.com>",
]
categories = ["science"]
description = "Vision processing operations for burn tensors"
documentation = "https://docs.rs/burn-vision"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "gpu"]
license.workspace = true
name = "burn-vision"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-vision"
version.workspace = true
[lints]
workspace = true
[features]
default = ["ndarray", "cubecl-backend", "fusion", "std"]
std = ["aligned-vec/std"]
tracing = [
"burn-cubecl?/tracing",
"burn-fusion?/tracing",
"burn-ir/tracing",
"burn-ndarray?/tracing",
"burn-tch?/tracing",
"burn-tensor/tracing",
"cubecl/tracing",
]
cubecl-backend = ["cubecl", "burn-cubecl"]
fusion = ["burn-fusion", "burn-cuda/fusion", "burn-wgpu/fusion"]
ndarray = ["burn-ndarray"]
tch = ["burn-tch"]
# Test features
test-cpu = []
test-cuda = ["cubecl-backend", ]
test-wgpu = ["cubecl-backend", ]
test-vulkan = ["burn-wgpu/vulkan", "test-wgpu"]
test-metal = ["burn-wgpu/metal", "test-wgpu"]
[dependencies]
aligned-vec = { version = "0.6", default-features = false }
bon = { workspace = true }
burn-cubecl = { path = "../burn-cubecl", version = "=0.21.0-pre.2", optional = true }
burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true }
burn-ir = { path = "../burn-ir", version = "=0.21.0-pre.2" }
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2", optional = true }
burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true }
burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2" }
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "=0.21.0-pre.2", optional = true }
bytemuck = { workspace = true }
cubecl = { workspace = true, optional = true }
derive-new = { workspace = true }
half = { workspace = true }
image = { version = "0.25" }
macerator = { workspace = true }
ndarray = { workspace = true }
num-traits = { workspace = true }
paste = { workspace = true }
serde = { workspace = true }
[dev-dependencies]
burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", default-features = false }
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" }
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", default-features = false }
cubecl = { workspace = true }

View File

@@ -0,0 +1,42 @@
pub trait MinMax {
fn min(self, other: Self) -> Self;
fn max(self, other: Self) -> Self;
}
macro_rules! impl_minmax {
($ty: ty) => {
impl MinMax for $ty {
fn min(self, other: Self) -> Self {
Ord::min(self, other)
}
fn max(self, other: Self) -> Self {
Ord::max(self, other)
}
}
};
($($ty: ty),*) => {
$(impl_minmax!($ty);)*
}
}
impl_minmax!(u8, i8, u16, i16, u32, i32, u64, i64);
impl MinMax for f32 {
fn min(self, other: Self) -> Self {
self.min(other)
}
fn max(self, other: Self) -> Self {
self.max(other)
}
}
impl MinMax for f64 {
fn min(self, other: Self) -> Self {
self.min(other)
}
fn max(self, other: Self) -> Self {
self.max(other)
}
}

View File

@@ -0,0 +1,279 @@
use std::{cmp::Ordering, marker::PhantomData};
use alloc::vec::Vec;
use burn_tensor::{
Bool, DType, Element, ElementConversion, ElementOrdered, Int, Shape, Tensor, TensorData,
backend::Backend,
ops::{BoolTensor, IntTensor},
};
use ndarray::Array2;
use crate::{ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity};
mod spaghetti;
mod spaghetti_4c;
/// Dispatches connected components based on `B::IntElem::dtype()`, binding a concrete
/// integer type to enable generic instantiations without extra trait bounds (after removing
/// `ElementComparison` from `Element`).
macro_rules! dispatch_int_dtype {
(|$ty:ident| $body:expr) => {
match B::IntElem::dtype() {
DType::I64 => {
type $ty = i64;
$body
}
DType::I32 => {
type $ty = i32;
$body
}
DType::I16 => {
type $ty = i16;
$body
}
DType::I8 => {
type $ty = i8;
$body
}
DType::U64 => {
type $ty = u64;
$body
}
DType::U32 => {
type $ty = u32;
$body
}
DType::U16 => {
type $ty = u16;
$body
}
DType::U8 => {
type $ty = u8;
$body
}
_ => unreachable!("Unsupported dtype"),
}
};
}
pub fn connected_components<B: Backend>(
img: BoolTensor<B>,
connectivity: Connectivity,
) -> IntTensor<B> {
dispatch_int_dtype!(|I| run::<B, I, NoOp<_>>(img, connectivity, NoOp::default).0)
}
pub fn connected_components_with_stats<B: Backend>(
img: BoolTensor<B>,
connectivity: Connectivity,
_options: ConnectedStatsOptions,
) -> (IntTensor<B>, ConnectedStatsPrimitive<B>) {
let device = B::bool_device(&img);
dispatch_int_dtype!(|I| {
let (labels, stats) =
run::<B, I, ConnectedStatsOp<I>>(img, connectivity, ConnectedStatsOp::default);
let stats = finalize_stats(&device, stats);
(labels, stats)
})
}
fn run<B: Backend, I: ElementOrdered, Stats: StatsOp<Label = I>>(
img: BoolTensor<B>,
connectivity: Connectivity,
stats: impl Fn() -> Stats,
) -> (IntTensor<B>, Stats) {
let device = B::bool_device(&img);
let img = Tensor::<B, 2, Bool>::from_primitive(img);
let [height, width] = img.shape().dims();
let img = img.into_data();
let img = img.into_vec::<B::BoolElem>().unwrap();
let mut stats = stats();
let out = match connectivity {
Connectivity::Four => {
spaghetti_4c::process::<B::BoolElem, UnionFind<_>>(img, height, width, &mut stats)
}
Connectivity::Eight => {
// SAFETY: This is validated by `TensorData`
let img = unsafe { Array2::from_shape_vec_unchecked((height, width), img) };
spaghetti::process::<B::BoolElem, UnionFind<_>>(img, &mut stats)
}
};
let (data, _) = out.into_raw_vec_and_offset();
let data = TensorData::new(data, Shape::new([height, width]));
let labels = Tensor::<B, 2, Int>::from_data(data, &device).into_primitive();
(labels, stats)
}
pub trait Solver {
type Label: ElementOrdered;
fn init(max_labels: usize) -> Self;
/// Hack to get around mutable borrow limitations on methods
fn merge(label_1: Self::Label, label_2: Self::Label, solver: &mut Self) -> Self::Label;
fn new_label(&mut self) -> Self::Label;
fn flatten(&mut self) -> Self::Label;
fn get_label(&self, i_label: Self::Label) -> Self::Label;
}
pub(crate) struct UnionFind<I: Element> {
labels: Vec<I>,
}
impl<I: ElementOrdered> Solver for UnionFind<I> {
type Label = I;
fn init(max_labels: usize) -> Self {
let mut labels = Vec::with_capacity(max_labels);
labels.push(0.elem());
Self { labels }
}
fn merge(mut label_1: I, mut label_2: I, solver: &mut Self) -> I {
use Ordering::Less;
while matches!(solver.labels[label_1.to_usize()].cmp(&label_1), Less) {
label_1 = solver.labels[label_1.to_usize()];
}
while matches!(solver.labels[label_2.to_usize()].cmp(&label_2), Less) {
label_2 = solver.labels[label_2.to_usize()];
}
if matches!(label_1.cmp(&label_2), Less) {
solver.labels[label_2.to_usize()] = label_1;
label_1
} else {
solver.labels[label_1.to_usize()] = label_2;
label_2
}
}
fn new_label(&mut self) -> I {
let len = I::from_elem(self.labels.len());
self.labels.push(len);
len
}
fn flatten(&mut self) -> I {
let mut k = 1;
for i in 1..self.labels.len() {
if matches!(self.labels[i].cmp(&I::from_elem(i)), Ordering::Less) {
self.labels[i] = self.labels[self.labels[i].to_usize()];
} else {
self.labels[i] = k.elem();
k += 1;
}
}
k.elem()
}
fn get_label(&self, i_label: I) -> I {
self.labels[i_label.to_usize()]
}
}
pub trait StatsOp {
type Label;
fn init(&mut self, num_labels: usize);
fn update(&mut self, row: usize, column: usize, label: Self::Label);
fn finish(&mut self);
}
#[derive(Default)]
struct NoOp<I: Element> {
_i: PhantomData<I>,
}
impl<I: Element> StatsOp for NoOp<I> {
type Label = I; // placeholder still required
fn init(&mut self, _num_labels: usize) {}
fn update(&mut self, _row: usize, _column: usize, _label: Self::Label) {}
fn finish(&mut self) {}
}
#[derive(Default, Debug)]
struct ConnectedStatsOp<I: Element> {
pub area: Vec<I>,
pub left: Vec<I>,
pub top: Vec<I>,
pub right: Vec<I>,
pub bottom: Vec<I>,
}
impl<I: Element> StatsOp for ConnectedStatsOp<I> {
type Label = I;
fn init(&mut self, num_labels: usize) {
self.area = vec![0.elem(); num_labels];
self.left = vec![I::MAX; num_labels];
self.top = vec![I::MAX; num_labels];
self.right = vec![0.elem(); num_labels];
self.bottom = vec![0.elem(); num_labels];
}
fn update(&mut self, row: usize, column: usize, label: I) {
let l = label.to_usize();
unsafe {
*self.area.get_unchecked_mut(l) =
I::from_elem((*self.area.get_unchecked(l)).to_usize() + 1);
*self.left.get_unchecked_mut(l) =
I::from_elem((*self.left.get_unchecked(l)).to_usize().min(column));
*self.top.get_unchecked_mut(l) =
I::from_elem((*self.top.get_unchecked(l)).to_usize().min(row));
*self.right.get_unchecked_mut(l) =
I::from_elem((*self.right.get_unchecked(l)).to_usize().max(column));
*self.bottom.get_unchecked_mut(l) =
I::from_elem((*self.bottom.get_unchecked(l)).to_usize().max(row));
}
}
fn finish(&mut self) {
// Background shouldn't have stats
self.area[0] = 0.elem();
self.left[0] = 0.elem();
self.right[0] = 0.elem();
self.top[0] = 0.elem();
self.bottom[0] = 0.elem();
}
}
fn finalize_stats<B: Backend, I: Element>(
device: &B::Device,
stats: ConnectedStatsOp<I>,
) -> ConnectedStatsPrimitive<B> {
let labels = stats.area.len();
let into_prim = |data: Vec<I>| {
let data = TensorData::new(data, Shape::new([labels]));
Tensor::<B, 1, Int>::from_data(data, device).into_primitive()
};
let max_label = {
let data = TensorData::new(vec![I::from_elem(labels - 1)], Shape::new([1]));
Tensor::<B, 1, Int>::from_data(data, device).into_primitive()
};
ConnectedStatsPrimitive {
area: into_prim(stats.area),
left: into_prim(stats.left),
top: into_prim(stats.top),
right: into_prim(stats.right),
bottom: into_prim(stats.bottom),
max_label,
}
}
pub fn max_labels(h: usize, w: usize, conn: Connectivity) -> usize {
match conn {
Connectivity::Four => (h * w).div_ceil(2) + 1,
Connectivity::Eight => h.div_ceil(2) * w.div_ceil(2) + 1,
}
}

View File

@@ -0,0 +1,223 @@
no_analyze!{{
use firstLabels::*;let mut label = entry;
while let Some(next) = (|label| -> Option<firstLabels> { match label {
NODE_72=> {
if (*img_row00.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = solver.new_label();
return Some(fl_tree_1);
}
else {
*img_labels_row00.add(c as usize) = solver.new_label();
return Some(fl_tree_2);
}
}
NODE_73=> {
if (*img_row00.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
return Some(fl_tree_1);
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
return Some(fl_tree_2);
}
}
NODE_74=> {
if (*img_row00.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = solver.new_label();
return Some(fl_tree_1);
}
else {
if (*img_row01.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = solver.new_label();
return Some(fl_tree_1);
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
return Some(fl_tree_0);
}
}
}
fl_tree_0 => {
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(fl_break_0_0); } else { return Some(fl_break_1_0); } }
if (*img_row00.add((c) as usize)).to_bool() {
return Some(NODE_72);
}
else {
if (*img_row01.add((c) as usize)).to_bool() {
return Some(NODE_72);
}
else {
return Some(NODE_74);
}
}
}
fl_tree_1 => {
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(fl_break_0_1); } else { return Some(fl_break_1_1); } }
if (*img_row00.add((c) as usize)).to_bool() {
return Some(NODE_73);
}
else {
if (*img_row01.add((c) as usize)).to_bool() {
return Some(NODE_73);
}
else {
return Some(NODE_74);
}
}
}
fl_tree_2 => {
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(fl_break_0_2); } else { return Some(fl_break_1_2); } }
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row01.add((c - 1) as usize)).to_bool() {
return Some(NODE_73);
}
else {
return Some(NODE_72);
}
}
else {
if (*img_row01.add((c) as usize)).to_bool() {
if (*img_row00.add((c + 1) as usize)).to_bool() {
if (*img_row01.add((c - 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
return Some(fl_tree_1);
}
else {
*img_labels_row00.add(c as usize) = solver.new_label();
return Some(fl_tree_1);
}
}
else {
if (*img_row01.add((c - 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
return Some(fl_tree_2);
}
else {
*img_labels_row00.add(c as usize) = solver.new_label();
return Some(fl_tree_2);
}
}
}
else {
return Some(NODE_74);
}
}
}
NODE_75=> {
if (*img_row01.add((c - 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
}
else {
*img_labels_row00.add(c as usize) = solver.new_label();
}
}
fl_break_0_0 => {
if (*img_row00.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = solver.new_label();
}
else {
if (*img_row01.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = solver.new_label();
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
}
}
return None;}
fl_break_0_1 => {
if (*img_row00.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
}
else {
if (*img_row01.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
}
}
return None;}
fl_break_0_2 => {
if (*img_row00.add((c) as usize)).to_bool() {
return Some(NODE_75);
}
else {
if (*img_row01.add((c) as usize)).to_bool() {
return Some(NODE_75);
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
}
}
return None;}
NODE_76=> {
if (*img_row00.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = solver.new_label();
}
else {
if (*img_row01.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = solver.new_label();
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
}
}
}
NODE_77=> {
if (*img_row01.add((c - 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
}
else {
*img_labels_row00.add(c as usize) = solver.new_label();
}
}
fl_break_1_0 => {
if (*img_row00.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = solver.new_label();
}
else {
if (*img_row01.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = solver.new_label();
}
else {
return Some(NODE_76);
}
}
return None;}
fl_break_1_1 => {
if (*img_row00.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
}
else {
if (*img_row01.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
}
else {
return Some(NODE_76);
}
}
return None;}
fl_break_1_2 => {
if (*img_row00.add((c) as usize)).to_bool() {
return Some(NODE_77);
}
else {
if (*img_row01.add((c) as usize)).to_bool() {
if (*img_row00.add((c + 1) as usize)).to_bool() {
return Some(NODE_77);
}
else {
return Some(NODE_77);
}
}
else {
return Some(NODE_76);
}
}
return None;}
fl_ => {},
}; None})(label)
{
label = next;
}
}}

View File

@@ -0,0 +1,191 @@
/// Workaround for rust-analyzer bug that causes invalid errors on the `include!`.
macro_rules! no_analyze {
($tokens:tt) => {
$tokens
};
}
pub(crate) use no_analyze;
#[allow(non_snake_case, non_camel_case_types, unused)]
pub enum centerLabels {
NODE_1,
NODE_2,
NODE_3,
NODE_4,
NODE_5,
NODE_6,
NODE_7,
NODE_8,
NODE_9,
NODE_10,
NODE_11,
NODE_12,
NODE_13,
NODE_14,
NODE_15,
NODE_16,
NODE_17,
NODE_18,
NODE_19,
NODE_20,
NODE_21,
NODE_22,
NODE_23,
NODE_24,
NODE_25,
NODE_26,
NODE_27,
NODE_28,
NODE_29,
NODE_30,
NODE_31,
NODE_32,
NODE_33,
NODE_34,
NODE_35,
NODE_36,
NODE_37,
NODE_38,
NODE_39,
NODE_40,
NODE_41,
NODE_42,
NODE_43,
NODE_44,
NODE_45,
NODE_46,
NODE_47,
NODE_48,
NODE_49,
NODE_50,
NODE_51,
NODE_52,
NODE_53,
NODE_54,
NODE_55,
NODE_56,
NODE_57,
NODE_58,
NODE_59,
NODE_60,
NODE_61,
NODE_62,
NODE_63,
NODE_64,
NODE_65,
NODE_66,
NODE_67,
NODE_68,
NODE_69,
NODE_70,
NODE_71,
cl_tree_0,
cl_tree_1,
cl_tree_2,
cl_tree_3,
cl_tree_4,
cl_tree_5,
cl_tree_6,
cl_tree_7,
cl_tree_8,
cl_tree_9,
cl_tree_10,
cl_tree_11,
cl_tree_12,
cl_break_0_0,
cl_break_0_1,
cl_break_0_2,
cl_break_0_3,
cl_break_0_4,
cl_break_0_5,
cl_break_0_6,
cl_break_0_7,
cl_break_0_8,
cl_break_1_0,
cl_break_1_1,
cl_break_1_2,
cl_break_1_3,
cl_break_1_4,
cl_break_1_5,
cl_break_1_6,
cl_break_1_7,
cl_break_1_8,
cl_break_1_9,
cl_break_1_10,
cl_break_1_11,
cl_break_1_12,
}
#[allow(non_snake_case, non_camel_case_types, unused)]
pub enum firstLabels {
NODE_72,
NODE_73,
NODE_74,
NODE_75,
NODE_76,
NODE_77,
fl_tree_0,
fl_tree_1,
fl_tree_2,
fl_break_0_0,
fl_break_0_1,
fl_break_0_2,
fl_break_1_0,
fl_break_1_1,
fl_break_1_2,
fl_,
}
#[allow(non_snake_case, non_camel_case_types, unused)]
pub enum lastLabels {
NODE_78,
NODE_79,
NODE_80,
NODE_81,
NODE_82,
NODE_83,
NODE_84,
NODE_85,
NODE_86,
NODE_87,
NODE_88,
NODE_89,
NODE_90,
NODE_91,
NODE_92,
ll_tree_0,
ll_tree_1,
ll_tree_2,
ll_tree_3,
ll_tree_4,
ll_tree_5,
ll_tree_6,
ll_tree_7,
ll_break_0_0,
ll_break_0_1,
ll_break_0_2,
ll_break_0_3,
ll_break_1_0,
ll_break_1_1,
ll_break_1_2,
ll_break_1_3,
ll_break_1_4,
ll_break_1_5,
ll_break_1_6,
ll_break_1_7,
ll_,
}
#[allow(non_snake_case, non_camel_case_types, unused)]
pub enum singleLabels {
NODE_93,
NODE_94,
sl_tree_0,
sl_tree_1,
sl_break_0_0,
sl_break_0_1,
sl_break_1_0,
sl_break_1_1,
sl_,
}

View File

@@ -0,0 +1,787 @@
no_analyze!{{
use lastLabels::*;let mut label = entry;
while let Some(next) = (|label| -> Option<lastLabels> { match label {
NODE_78=> {
if (*img_row12.add((c - 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);
return Some(ll_tree_4);
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c - 2) as usize), solver);
return Some(ll_tree_4);
}
}
NODE_79=> {
if (*img_row12.add((c - 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_6);
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver);
return Some(ll_tree_6);
}
}
NODE_80=> {
if (*img_row12.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_6);
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver);
return Some(ll_tree_6);
}
}
NODE_81=> {
if (*img_row11.add((c + 2) as usize)).to_bool() {
if (*img_row11.add((c) as usize)).to_bool() {
return Some(NODE_82);
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);
return Some(ll_tree_4);
}
}
else {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_3);
}
else {
*img_labels_row00.add(c as usize) = solver.new_label();
return Some(ll_tree_2);
}
}
}
NODE_83=> {
if (*img_row00.add((c + 1) as usize)).to_bool() {
if (*img_row11.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_5);
}
else {
if (*img_row11.add((c + 2) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);
return Some(ll_tree_4);
}
else {
*img_labels_row00.add(c as usize) = solver.new_label();
return Some(ll_tree_2);
}
}
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
return Some(ll_tree_1);
}
}
NODE_84=> {
if (*img_row00.add((c + 1) as usize)).to_bool() {
if (*img_row11.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_5);
}
else {
return Some(NODE_81);
}
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
return Some(ll_tree_1);
}
}
NODE_82=> {
if (*img_row12.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);
return Some(ll_tree_4);
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c) as usize), solver);
return Some(ll_tree_4);
}
}
NODE_85=> {
if (*img_row12.add((c + 1) as usize)).to_bool() {
if (*img_row12.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);
return Some(ll_tree_4);
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c - 2) as usize), solver);
return Some(ll_tree_4);
}
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c - 2) as usize), solver);
return Some(ll_tree_4);
}
}
NODE_86=> {
if (*img_row11.add((c + 1) as usize)).to_bool() {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_6);
}
else {
if (*img_row12.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_6);
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
return Some(ll_tree_6);
}
}
}
else {
if (*img_row00.add((c + 1) as usize)).to_bool() {
if (*img_row11.add((c + 2) as usize)).to_bool() {
if (*img_row12.add((c + 1) as usize)).to_bool() {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);
return Some(ll_tree_4);
}
else {
if (*img_row12.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);
return Some(ll_tree_4);
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);
return Some(ll_tree_4);
}
}
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);
return Some(ll_tree_4);
}
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
return Some(ll_tree_7);
}
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
return Some(ll_tree_0);
}
}
}
ll_tree_0 => {
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_0); } else { return Some(ll_break_1_0); } }
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row11.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_6);
}
else {
if (*img_row00.add((c + 1) as usize)).to_bool() {
return Some(NODE_81);
}
else {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_0);
}
else {
*img_labels_row00.add(c as usize) = solver.new_label();
return Some(ll_tree_0);
}
}
}
}
else {
return Some(NODE_84);
}
}
ll_tree_1 => {
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_1); } else { return Some(ll_break_1_1); } }
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row11.add((c + 1) as usize)).to_bool() {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_6);
}
else {
if (*img_row11.add((c - 1) as usize)).to_bool() {
return Some(NODE_80);
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_6);
}
}
}
else {
if (*img_row00.add((c + 1) as usize)).to_bool() {
if (*img_row11.add((c + 2) as usize)).to_bool() {
if (*img_row11.add((c) as usize)).to_bool() {
return Some(NODE_82);
}
else {
if (*img_row11.add((c - 1) as usize)).to_bool() {
return Some(NODE_85);
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);
return Some(ll_tree_4);
}
}
}
else {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_3);
}
else {
if (*img_row11.add((c - 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);
return Some(ll_tree_2);
}
else {
*img_labels_row00.add(c as usize) = solver.new_label();
return Some(ll_tree_2);
}
}
}
}
else {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_0);
}
else {
if (*img_row11.add((c - 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);
return Some(ll_tree_0);
}
else {
*img_labels_row00.add(c as usize) = solver.new_label();
return Some(ll_tree_0);
}
}
}
}
}
else {
return Some(NODE_84);
}
}
ll_tree_2 => {
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_2); } }
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row11.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
return Some(ll_tree_6);
}
else {
if (*img_row00.add((c + 1) as usize)).to_bool() {
if (*img_row11.add((c + 2) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);
return Some(ll_tree_4);
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
return Some(ll_tree_7);
}
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
return Some(ll_tree_0);
}
}
}
else {
return Some(NODE_83);
}
}
ll_tree_3 => {
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_3); } }
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row11.add((c + 1) as usize)).to_bool() {
if (*img_row12.add((c) as usize)).to_bool() {
return Some(NODE_79);
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
return Some(ll_tree_6);
}
}
else {
if (*img_row00.add((c + 1) as usize)).to_bool() {
if (*img_row11.add((c + 2) as usize)).to_bool() {
if (*img_row12.add((c + 1) as usize)).to_bool() {
if (*img_row12.add((c) as usize)).to_bool() {
return Some(NODE_78);
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);
return Some(ll_tree_4);
}
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);
return Some(ll_tree_4);
}
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
return Some(ll_tree_7);
}
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
return Some(ll_tree_0);
}
}
}
else {
return Some(NODE_83);
}
}
ll_tree_4 => {
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_4); } }
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row11.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_6);
}
else {
if (*img_row00.add((c + 1) as usize)).to_bool() {
if (*img_row11.add((c + 2) as usize)).to_bool() {
if (*img_row12.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);
return Some(ll_tree_4);
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);
return Some(ll_tree_4);
}
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
return Some(ll_tree_7);
}
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
return Some(ll_tree_0);
}
}
}
else {
if (*img_row00.add((c + 1) as usize)).to_bool() {
if (*img_row11.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_5);
}
else {
if (*img_row11.add((c + 2) as usize)).to_bool() {
return Some(NODE_82);
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_3);
}
}
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
return Some(ll_tree_1);
}
}
}
ll_tree_5 => {
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_5); } }
if (*img_row00.add((c) as usize)).to_bool() {
return Some(NODE_86);
}
else {
return Some(NODE_84);
}
}
ll_tree_6 => {
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_3); } else { return Some(ll_break_1_6); } }
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row00.add((c - 1) as usize)).to_bool() {
return Some(NODE_86);
}
else {
if (*img_row11.add((c + 1) as usize)).to_bool() {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_6);
}
else {
return Some(NODE_80);
}
}
else {
if (*img_row00.add((c + 1) as usize)).to_bool() {
if (*img_row11.add((c + 2) as usize)).to_bool() {
if (*img_row11.add((c) as usize)).to_bool() {
return Some(NODE_82);
}
else {
return Some(NODE_85);
}
}
else {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_3);
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);
return Some(ll_tree_2);
}
}
}
else {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
return Some(ll_tree_0);
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);
return Some(ll_tree_0);
}
}
}
}
}
else {
return Some(NODE_84);
}
}
ll_tree_7 => {
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_7); } }
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row11.add((c + 1) as usize)).to_bool() {
if (*img_row12.add((c) as usize)).to_bool() {
if (*img_row11.add((c - 2) as usize)).to_bool() {
return Some(NODE_79);
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
return Some(ll_tree_6);
}
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
return Some(ll_tree_6);
}
}
else {
if (*img_row00.add((c + 1) as usize)).to_bool() {
if (*img_row11.add((c + 2) as usize)).to_bool() {
if (*img_row12.add((c + 1) as usize)).to_bool() {
if (*img_row12.add((c) as usize)).to_bool() {
if (*img_row11.add((c - 2) as usize)).to_bool() {
return Some(NODE_78);
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);
return Some(ll_tree_4);
}
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);
return Some(ll_tree_4);
}
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);
return Some(ll_tree_4);
}
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
return Some(ll_tree_7);
}
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
return Some(ll_tree_0);
}
}
}
else {
return Some(NODE_83);
}
}
ll_break_0_0 => {
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
}
else {
*img_labels_row00.add(c as usize) = solver.new_label();
}
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
}
return None;}
ll_break_0_1 => {
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
}
else {
if (*img_row11.add((c - 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);
}
else {
*img_labels_row00.add(c as usize) = solver.new_label();
}
}
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
}
return None;}
ll_break_0_2 => {
if (*img_row00.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
}
return None;}
ll_break_0_3 => {
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row00.add((c - 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
}
else {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);
}
}
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
}
return None;}
NODE_87=> {
if (*img_row00.add((c + 1) as usize)).to_bool() {
return Some(NODE_88);
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
}
}
NODE_88=> {
if (*img_row11.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
}
else {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
}
else {
*img_labels_row00.add(c as usize) = solver.new_label();
}
}
}
NODE_89=> {
if (*img_row12.add((c - 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver);
}
}
NODE_90=> {
if (*img_row11.add((c + 1) as usize)).to_bool() {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
}
else {
if (*img_row12.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
}
}
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
}
}
NODE_91=> {
if (*img_row00.add((c + 1) as usize)).to_bool() {
if (*img_row11.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
}
else {
*img_labels_row00.add(c as usize) = solver.new_label();
}
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
}
}
NODE_92=> {
if (*img_row12.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver);
}
}
ll_break_1_0 => {
if (*img_row00.add((c) as usize)).to_bool() {
return Some(NODE_88);
}
else {
return Some(NODE_87);
}
return None;}
ll_break_1_1 => {
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row11.add((c + 1) as usize)).to_bool() {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
}
else {
if (*img_row11.add((c - 1) as usize)).to_bool() {
return Some(NODE_92);
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
}
}
}
else {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
}
else {
if (*img_row11.add((c - 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);
}
else {
*img_labels_row00.add(c as usize) = solver.new_label();
}
}
}
}
else {
return Some(NODE_87);
}
return None;}
ll_break_1_2 => {
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row11.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
}
}
else {
return Some(NODE_91);
}
return None;}
ll_break_1_3 => {
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row11.add((c + 1) as usize)).to_bool() {
if (*img_row12.add((c) as usize)).to_bool() {
return Some(NODE_89);
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
}
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
}
}
else {
return Some(NODE_91);
}
return None;}
ll_break_1_4 => {
if (*img_row00.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
}
else {
if (*img_row00.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
}
}
return None;}
ll_break_1_5 => {
if (*img_row00.add((c) as usize)).to_bool() {
return Some(NODE_90);
}
else {
return Some(NODE_87);
}
return None;}
ll_break_1_6 => {
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row00.add((c - 1) as usize)).to_bool() {
return Some(NODE_90);
}
else {
if (*img_row11.add((c + 1) as usize)).to_bool() {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
}
else {
return Some(NODE_92);
}
}
else {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);
}
}
}
}
else {
return Some(NODE_87);
}
return None;}
ll_break_1_7 => {
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row11.add((c + 1) as usize)).to_bool() {
if (*img_row12.add((c) as usize)).to_bool() {
if (*img_row11.add((c - 2) as usize)).to_bool() {
return Some(NODE_89);
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
}
}
else {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
}
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
}
}
else {
return Some(NODE_91);
}
return None;}
ll_ => {},
}; None})(label)
{
label = next;
}
}}

View File

@@ -0,0 +1,91 @@
no_analyze!{{
use singleLabels::*;let mut label = entry;
while let Some(next) = (|label| -> Option<singleLabels> { match label {
NODE_93=> {
if (*img_row00.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = solver.new_label();
return Some(sl_tree_1);
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
return Some(sl_tree_0);
}
}
sl_tree_0 => {
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(sl_break_0_0); } else { return Some(sl_break_1_0); } }
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row00.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = solver.new_label();
return Some(sl_tree_1);
}
else {
*img_labels_row00.add(c as usize) = solver.new_label();
return Some(sl_tree_0);
}
}
else {
return Some(NODE_93);
}
}
sl_tree_1 => {
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(sl_break_0_1); } else { return Some(sl_break_1_1); } }
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row00.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
return Some(sl_tree_1);
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
return Some(sl_tree_0);
}
}
else {
return Some(NODE_93);
}
}
sl_break_0_0 => {
if (*img_row00.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = solver.new_label();
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
}
return None;}
sl_break_0_1 => {
if (*img_row00.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
}
return None;}
NODE_94=> {
if (*img_row00.add((c + 1) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = solver.new_label();
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
}
}
sl_break_1_0 => {
if (*img_row00.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = solver.new_label();
}
else {
return Some(NODE_94);
}
return None;}
sl_break_1_1 => {
if (*img_row00.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
}
else {
return Some(NODE_94);
}
return None;}
sl_ => {},
}; None})(label)
{
label = next;
}
}}

View File

@@ -0,0 +1,273 @@
//! Spaghetti algorithm for connected component labeling
//! F. Bolelli, S. Allegretti, L. Baraldi, and C. Grana,
//! "Spaghetti Labeling: Directed Acyclic Graphs for Block-Based Bonnected Components Labeling,"
//! IEEE Transactions on Image Processing, vol. 29, no. 1, pp. 1999-2012, 2019.
//!
//! Decision forests are generated using a modified [GRAPHGEN](https://github.com/wingertge/GRAPHGEN)
//! as described in
//!
//! F. Bolelli, S. Allegretti, C. Grana.
//! "One DAG to Rule Them All."
//! IEEE Transactions on Pattern Analysis and Machine Intelligence, 2021
#![allow(
unreachable_code,
clippy::collapsible_else_if,
clippy::if_same_then_else
)]
use std::cmp::Ordering;
use burn_tensor::{Element, ElementComparison, ElementConversion, cast::ToElement};
use ndarray::{Array2, Axis, s};
#[allow(non_snake_case)]
mod Spaghetti_forest_labels;
pub(crate) use Spaghetti_forest_labels::*;
use crate::Connectivity;
use super::{Solver, StatsOp, max_labels};
pub fn process<B: Element, LabelsSolver: Solver>(
img_arr: Array2<B>,
stats: &mut impl StatsOp<Label = LabelsSolver::Label>,
) -> Array2<LabelsSolver::Label> {
let (h, w) = img_arr.dim();
let mut img_labels_arr = Array2::<LabelsSolver::Label>::default(img_arr.raw_dim());
let img = img_arr.as_ptr();
let e_rows = h as u32 & 0xfffffffe;
let o_rows = h % 2 == 1;
let e_cols = w as u32 & 0xfffffffe;
let o_cols = w % 2 == 1;
let img_labels = img_labels_arr.as_mut_ptr();
let mut solver = LabelsSolver::init(max_labels(h, w, Connectivity::Eight));
let solver = &mut solver;
let w = w as i32;
// SAFETY:
// Generated code includes mathematically proven bounds checks, so raw pointers are a safe speed
// boost.
unsafe {
if h == 1 {
// Single line
let r = 0;
//Pointers:
// Row pointers for the input image
let img_row00 = img.add(r * w as usize);
// Row pointers for the output image
let img_labels_row00 = img_labels.add(r * w as usize);
let mut c = -2i32;
let entry = singleLabels::sl_tree_0;
include!("Spaghetti_single_line_forest_code.rs");
} else {
// More than one line
// First couple of lines
{
let r = 0;
//Pointers:
// Row pointers for the input image
let img_row00 = img.add(r * w as usize);
let img_row01 = img.add((r + 1) * w as usize);
// Row pointers for the output image
let img_labels_row00 = img_labels.add(r * w as usize);
let mut c = -2i32;
let entry = firstLabels::fl_tree_0;
include!("Spaghetti_first_line_forest_code.rs");
}
// Every other line but the last one if image has an odd number of rows
for r in (2..e_rows as usize).step_by(2) {
//Pointers:
// Row pointers for the input image
let img_row00 = img.add(r * w as usize);
let img_row12 = img.add((r - 2) * w as usize);
let img_row11 = img.add((r - 1) * w as usize);
let img_row01 = img.add((r + 1) * w as usize);
// Row pointers for the output image
let img_labels_row00 = img_labels.add(r * w as usize);
let img_labels_row12 = img_labels.add((r - 2) * w as usize);
let mut c = -2;
let entry = centerLabels::cl_tree_0;
include!("Spaghetti_center_line_forest_code.rs");
}
if o_rows {
let r = h - 1;
//Pointers:
// Row pointers for the input image
let img_row00 = img.add(r * w as usize);
let img_row12 = img.add((r - 2) * w as usize);
let img_row11 = img.add((r - 1) * w as usize);
// Row pointers for the output image
let img_labels_row00 = img_labels.add(r * w as usize);
let img_labels_row12 = img_labels.add((r - 2) * w as usize);
let mut c = -2;
let entry = lastLabels::ll_tree_0;
include!("Spaghetti_last_line_forest_code.rs");
}
}
}
let n_labels = solver.flatten();
stats.init(n_labels.to_usize());
let img = img_arr;
let mut img_labels = img_labels_arr;
for r in (0..e_rows as usize).step_by(2) {
//Pointers:
// Row pointers for the input image
let img_row00 = img.index_axis(Axis(0), r);
let img_row01 = img.index_axis(Axis(0), r + 1);
// Row pointers for the output image
let (mut img_labels_row00, mut img_labels_row01) =
img_labels.multi_slice_mut((s![r, ..], s![r + 1, ..]));
for c in (0..e_cols as usize).step_by(2) {
let mut i_label = img_labels_row00[c];
if matches!(i_label.cmp(&0.elem()), Ordering::Greater) {
i_label = solver.get_label(i_label);
if img_row00[c].to_u8() > 0 {
img_labels_row00[c] = i_label;
stats.update(r, c, i_label);
} else {
img_labels_row00[c] = 0.elem();
stats.update(r, c, 0.elem());
}
if img_row00[c + 1].to_u8() > 0 {
img_labels_row00[c + 1] = i_label;
stats.update(r, c + 1, i_label);
} else {
img_labels_row00[c + 1] = 0.elem();
stats.update(r, c + 1, 0.elem());
}
if img_row01[c].to_u8() > 0 {
img_labels_row01[c] = i_label;
stats.update(r + 1, c, i_label);
} else {
img_labels_row01[c] = 0.elem();
stats.update(r + 1, c, 0.elem());
}
if img_row01[c + 1].to_u8() > 0 {
img_labels_row01[c + 1] = i_label;
stats.update(r + 1, c + 1, i_label);
} else {
img_labels_row01[c + 1] = 0.elem();
stats.update(r + 1, c + 1, 0.elem());
}
} else {
img_labels_row00[c] = 0.elem();
stats.update(r, c, 0.elem());
img_labels_row00[c + 1] = 0.elem();
stats.update(r, c + 1, 0.elem());
img_labels_row01[c] = 0.elem();
stats.update(r + 1, c, 0.elem());
img_labels_row01[c + 1] = 0.elem();
stats.update(r + 1, c + 1, 0.elem());
}
}
if o_cols {
let c = e_cols as usize;
let mut i_label = img_labels_row00[c];
if matches!(i_label.cmp(&0.elem()), Ordering::Greater) {
i_label = solver.get_label(i_label);
if img_row00[c].to_u8() > 0 {
img_labels_row00[c] = i_label;
stats.update(r, c, i_label);
} else {
img_labels_row00[c] = 0.elem();
stats.update(r, c, 0.elem());
}
if img_row01[c].to_u8() > 0 {
img_labels_row01[c] = i_label;
stats.update(r + 1, c, i_label);
} else {
img_labels_row01[c] = 0.elem();
stats.update(r + 1, c, 0.elem());
}
} else {
img_labels_row00[c] = 0.elem();
stats.update(r, c, 0.elem());
img_labels_row01[c] = 0.elem();
stats.update(r + 1, c, 0.elem());
}
}
}
if o_rows {
let r = e_rows as usize;
// Row pointers for the input image
let img_row00 = img.index_axis(Axis(0), r);
// Row pointers for the output image
let mut img_labels_row00 = img_labels.slice_mut(s![r, ..]);
for c in (0..e_cols as usize).step_by(2) {
let mut i_label = img_labels_row00[c];
if matches!(i_label.cmp(&0.elem()), Ordering::Greater) {
i_label = solver.get_label(i_label);
if img_row00[c].to_u8() > 0 {
img_labels_row00[c] = i_label;
stats.update(r, c, i_label);
} else {
img_labels_row00[c] = 0.elem();
stats.update(r, c, 0.elem());
}
if img_row00[c + 1].to_u8() > 0 {
img_labels_row00[c + 1] = i_label;
stats.update(r, c + 1, i_label);
} else {
img_labels_row00[c + 1] = 0.elem();
stats.update(r, c + 1, 0.elem());
}
} else {
img_labels_row00[c] = 0.elem();
stats.update(r, c, 0.elem());
img_labels_row00[c + 1] = 0.elem();
stats.update(r, c + 1, 0.elem());
}
}
if o_cols {
let c = e_cols as usize;
let mut i_label = img_labels_row00[c];
if matches!(i_label.cmp(&0.elem()), Ordering::Greater) {
i_label = solver.get_label(i_label);
if img_row00[c].to_u8() > 0 {
img_labels_row00[c] = i_label;
stats.update(r, c, i_label);
} else {
img_labels_row00[c] = 0.elem();
stats.update(r, c, 0.elem());
}
} else {
img_labels_row00[c] = 0.elem();
stats.update(r, c, i_label);
}
}
}
stats.finish();
img_labels
}

View File

@@ -0,0 +1,42 @@
no_analyze!{{
use centerLabels::*;let mut label = entry;
while let Some(next) = (|label| -> Option<centerLabels> { match label {
cl_tree_0 => {
if ({c+=1; c} >= w) { return None; }
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row11.add((c) as usize);
return Some(cl_tree_1);
}
else {
*img_labels_row00.add(c as usize) = solver.new_label();
return Some(cl_tree_1);
}
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
return Some(cl_tree_0);
}
}
cl_tree_1 => {
if ({c+=1; c} >= w) { return None; }
if (*img_row00.add((c) as usize)).to_bool() {
if (*img_row11.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 1) as usize), *img_labels_row11.add((c) as usize), solver);
return Some(cl_tree_1);
}
else {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 1) as usize);
return Some(cl_tree_1);
}
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
return Some(cl_tree_0);
}
}
}; None})(label)
{
label = next;
}
}}

View File

@@ -0,0 +1,31 @@
no_analyze!{{
use firstLabels::*;let mut label = entry;
while let Some(next) = (|label| -> Option<firstLabels> { match label {
fl_tree_0 => {
if ({c+=1; c} >= w) { return None; }
if (*img_row00.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = solver.new_label();
return Some(fl_tree_1);
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
return Some(fl_tree_0);
}
}
fl_tree_1 => {
if ({c+=1; c} >= w) { return None; }
if (*img_row00.add((c) as usize)).to_bool() {
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 1) as usize);
return Some(fl_tree_1);
}
else {
*img_labels_row00.add(c as usize) = 0.elem();
return Some(fl_tree_0);
}
}
fl_ => {},
}; None})(label)
{
label = next;
}
}}

View File

@@ -0,0 +1,21 @@
/// Workaround for rust-analyzer bug that causes invalid errors on the `include!`.
macro_rules! no_analyze {
($tokens:tt) => {
$tokens
};
}
pub(crate) use no_analyze;
#[allow(non_snake_case, non_camel_case_types, unused)]
pub enum centerLabels {
cl_tree_0,
cl_tree_1,
}
#[allow(non_snake_case, non_camel_case_types, unused)]
pub enum firstLabels {
fl_tree_0,
fl_tree_1,
fl_,
}

View File

@@ -0,0 +1,102 @@
//! Spaghetti algorithm for connected component labeling, modified for 4-connectivity using the
//! 4-connected Rosenfeld mask.
//! F. Bolelli, S. Allegretti, L. Baraldi, and C. Grana,
//! "Spaghetti Labeling: Directed Acyclic Graphs for Block-Based Bonnected Components Labeling,"
//! IEEE Transactions on Image Processing, vol. 29, no. 1, pp. 1999-2012, 2019.
//!
//! Decision forests are generated using a modified [GRAPHGEN](https://github.com/wingertge/GRAPHGEN)
//! as described in
//!
//! F. Bolelli, S. Allegretti, C. Grana.
//! "One DAG to Rule Them All."
//! IEEE Transactions on Pattern Analysis and Machine Intelligence, 2021
#![allow(unreachable_code)]
use burn_tensor::{Element, ElementConversion, cast::ToElement};
use ndarray::Array2;
use crate::Connectivity;
use super::{Solver, StatsOp, max_labels};
#[allow(non_snake_case)]
mod Spaghetti4C_forest_labels;
pub(crate) use Spaghetti4C_forest_labels::*;
pub fn process<B: Element, LabelsSolver: Solver>(
img: Vec<B>,
h: usize,
w: usize,
stats: &mut impl StatsOp<Label = LabelsSolver::Label>,
) -> Array2<LabelsSolver::Label> {
let img = img.as_ptr();
let mut img_labels: Vec<LabelsSolver::Label> = vec![0.elem(); h * w];
// A quick and dirty upper bound for the maximum number of labels.
// Following formula comes from the fact that a 2x2 block in 4-connectivity case
// can never have more than 2 new labels and 1 label for background.
// Worst case image example pattern:
// 1 0 1 0 1...
// 0 1 0 1 0...
// 1 0 1 0 1...
// ............
let max_labels = max_labels(h, w, Connectivity::Four);
let mut solver = LabelsSolver::init(max_labels);
let solver = &mut solver;
let w = w as i32;
// SAFETY:
// This code is generated from constraints and includes manual bounds checks, so unchecked pointer
// indexes are always safe.
unsafe {
// First row
{
let r = 0;
//Pointers:
// Row pointers for the input image
let img_row00 = img.add(r * w as usize);
// Row pointers for the output image
let img_labels_row00 = img_labels.as_mut_ptr().add(r * w as usize);
let mut c = -1i32;
let entry = firstLabels::fl_tree_0;
include!("Spaghetti4C_first_line_forest_code.rs");
}
for r in 1..h {
//Pointers:
// Row pointers for the input image
let img_row00 = img.add(r * w as usize);
let img_row11 = img.add((r - 1) * w as usize);
// Row pointers for the output image
let img_labels_row00 = img_labels.as_mut_ptr().add(r * w as usize);
let img_labels_row11 = img_labels.as_mut_ptr().add((r - 1) * w as usize);
let mut c = -1i32;
let entry = centerLabels::cl_tree_0;
include!("Spaghetti4C_center_line_forest_code.rs");
}
}
let n_labels = solver.flatten();
stats.init(n_labels.to_usize());
// SAFETY: This is always valid
let mut img_labels = unsafe { Array2::from_shape_vec_unchecked((h, w as usize), img_labels) };
img_labels.indexed_iter_mut().for_each(|((r, c), label)| {
*label = solver.get_label(*label);
stats.update(r, c, *label);
});
stats.finish();
img_labels
}

View File

@@ -0,0 +1,10 @@
mod base;
mod connected_components;
mod morphology;
mod nms;
mod ops;
pub use base::*;
pub use connected_components::*;
pub use morphology::*;
pub use nms::*;

View File

@@ -0,0 +1,720 @@
use core::slice;
use std::{marker::PhantomData, ptr::null};
use burn_tensor::Element;
use macerator::{
Scalar, Simd, VOrd, Vector, vload, vload_low, vload_unaligned, vstore, vstore_low,
vstore_unaligned,
};
use crate::{Point, Size, backends::cpu::MinMax};
pub trait MorphOperator<T> {
fn apply(a: T, b: T) -> T;
}
pub trait VecMorphOperator<T: Scalar> {
fn apply<S: Simd>(a: Vector<S, T>, b: Vector<S, T>) -> Vector<S, T>;
}
pub struct MinOp;
pub struct MaxOp;
impl<T: MinMax> MorphOperator<T> for MinOp {
fn apply(a: T, b: T) -> T {
MinMax::min(a, b)
}
}
impl<T: VOrd> VecMorphOperator<T> for MinOp {
fn apply<S: Simd>(a: Vector<S, T>, b: Vector<S, T>) -> Vector<S, T> {
T::vmin(a, b)
}
}
impl<T: MinMax> MorphOperator<T> for MaxOp {
fn apply(a: T, b: T) -> T {
MinMax::max(a, b)
}
}
impl<T: VOrd> VecMorphOperator<T> for MaxOp {
fn apply<S: Simd>(a: Vector<S, T>, b: Vector<S, T>) -> Vector<S, T> {
T::vmax(a, b)
}
}
pub struct MorphRowFilter<T: Scalar, S: MorphOperator<T>, Vec: VecRow<T>> {
pub ksize: usize,
pub anchor: usize,
vec: Vec,
_t: PhantomData<T>,
_scalar: PhantomData<S>,
}
impl<T: Scalar, SOp: MorphOperator<T>, Vec: VecRow<T>> MorphRowFilter<T, SOp, Vec> {
pub fn new(ksize: usize, anchor: usize) -> Self {
let vec = Vec::new(ksize, anchor);
Self {
ksize,
anchor,
vec,
_t: PhantomData,
_scalar: PhantomData,
}
}
pub fn apply<S: Simd>(&self, src: &[T], dst: &mut [T], width: usize, ch: usize) {
let k_size = self.ksize * ch;
if k_size == ch {
let width = width * ch;
dst[..width].copy_from_slice(&src[..width]);
return;
}
let i0 = self.vec.apply::<S>(src, dst, width, ch);
let width = width * ch;
for k in 0..ch {
let mut last_i = i0;
for i in (i0..width.saturating_sub(ch * 2)).step_by(ch * 2) {
let mut m = src[k + i + ch];
let mut last_j = ch * 2;
for j in (ch * 2..k_size).step_by(ch) {
m = SOp::apply(m, src[k + i + j]);
last_j = j + ch;
}
dst[k + i] = SOp::apply(m, src[k + i]);
dst[k + i + ch] = SOp::apply(m, src[k + i + last_j]);
last_i = i + ch * 2;
}
for i in (last_i..width).step_by(ch) {
let mut m = src[k + i];
for j in (ch..k_size).step_by(ch) {
m = SOp::apply(m, src[k + i + j]);
}
dst[k + i] = m;
}
}
}
}
pub struct MorphRowVec<T: Scalar, Op: VecMorphOperator<T>> {
k_size: usize,
_t: PhantomData<T>,
_op: PhantomData<Op>,
}
pub trait VecRow<T: Scalar> {
fn new(ksize: usize, anchor: usize) -> Self;
fn apply<S: Simd>(&self, src: &[T], dst: &mut [T], width: usize, channels: usize) -> usize;
}
impl<T: Scalar, Op: VecMorphOperator<T>> VecRow<T> for MorphRowVec<T, Op> {
fn apply<S: Simd>(&self, src: &[T], dst: &mut [T], width: usize, ch: usize) -> usize {
let src = src.as_ptr();
let dst = dst.as_mut_ptr();
let k_size = self.k_size * ch;
let width = (width * ch) as isize;
let lanes = T::lanes::<S>();
// Safety: everything here is unsafe. Test thoroughly.
unsafe {
let mut x = 0;
while x as isize <= width - 4 * lanes as isize {
let mut s0 = vload(src.add(x));
let mut s1 = vload(src.add(x + lanes));
let mut s2 = vload(src.add(x + 2 * lanes));
let mut s3 = vload(src.add(x + 3 * lanes));
for k in (ch..k_size).step_by(ch) {
let x = x + k;
s0 = Op::apply::<S>(s0, vload_unaligned(src.add(x)));
s1 = Op::apply::<S>(s1, vload_unaligned(src.add(x + lanes)));
s2 = Op::apply::<S>(s2, vload_unaligned(src.add(x + 2 * lanes)));
s3 = Op::apply::<S>(s3, vload_unaligned(src.add(x + 3 * lanes)));
}
vstore(dst.add(x), s0);
vstore(dst.add(x + lanes), s1);
vstore(dst.add(x + 2 * lanes), s2);
vstore(dst.add(x + 3 * lanes), s3);
x += 4 * lanes;
}
if x as isize <= width - 2 * lanes as isize {
let mut s0 = vload(src.add(x));
let mut s1 = vload(src.add(x + lanes));
for k in (ch..k_size).step_by(ch) {
s0 = Op::apply::<S>(s0, vload_unaligned(src.add(x + k)));
s1 = Op::apply::<S>(s1, vload_unaligned(src.add(x + k + lanes)));
}
vstore(dst.add(x), s0);
vstore(dst.add(x + lanes), s1);
x += 2 * lanes;
}
if x as isize <= width - lanes as isize {
let mut s = vload(src.add(x));
for k in (ch..k_size).step_by(ch) {
s = Op::apply::<S>(s, vload_unaligned(src.add(x + k)));
}
vstore(dst.add(x), s);
x += lanes;
}
if x as isize <= width - lanes as isize / 2 {
let mut s = vload_low(src.add(x));
for k in (ch..k_size).step_by(ch) {
s = Op::apply::<S>(s, vload_low(src.add(x + k)));
}
vstore_low(dst.add(x), s);
x += lanes / 2;
}
x - x % ch
}
}
fn new(k_size: usize, _anchor: usize) -> Self {
Self {
k_size,
_t: PhantomData,
_op: PhantomData,
}
}
}
pub trait VecColumn<T: Scalar> {
fn new(ksize: usize, anchor: usize) -> Self;
fn apply<S: Simd>(
&self,
src: &[*const T],
dst: &mut [T],
dst_step: usize,
height: usize,
width: usize,
) -> usize;
}
pub struct MorphColumnVec<T: Scalar, Op: VecMorphOperator<T>> {
k_size: usize,
_t: PhantomData<T>,
_op: PhantomData<Op>,
}
impl<T: VOrd, Op: VecMorphOperator<T>> VecColumn<T> for MorphColumnVec<T, Op> {
fn new(k_size: usize, _anchor: usize) -> Self {
Self {
k_size,
_t: PhantomData,
_op: PhantomData,
}
}
fn apply<S: Simd>(
&self,
src: &[*const T],
dst: &mut [T],
dst_step: usize,
mut count: usize,
width: usize,
) -> usize {
let ksize = self.k_size;
let width = width as isize;
let mut dst = dst.as_mut_ptr();
let lanes = T::lanes::<S>();
let mut y = 0;
let mut x = 0;
// Safety: everything here is unsafe. Test thoroughly.
unsafe {
while count > 1 && ksize > 1 {
x = 0;
while x as isize <= width - 4 * lanes as isize {
let sptr = src[y + 1].add(x);
let mut s0 = vload(sptr);
let mut s1 = vload(sptr.add(lanes));
let mut s2 = vload(sptr.add(2 * lanes));
let mut s3 = vload(sptr.add(3 * lanes));
for k in 2..ksize {
let sptr = src[y + k].add(x);
s0 = Op::apply::<S>(s0, vload(sptr));
s1 = Op::apply::<S>(s1, vload(sptr.add(lanes)));
s2 = Op::apply::<S>(s2, vload(sptr.add(2 * lanes)));
s3 = Op::apply::<S>(s3, vload(sptr.add(3 * lanes)));
}
// Row 1
{
let sptr = src[y].add(x);
let s0 = Op::apply(s0, vload(sptr));
let s1 = Op::apply(s1, vload(sptr.add(lanes)));
let s2 = Op::apply(s2, vload(sptr.add(2 * lanes)));
let s3 = Op::apply(s3, vload(sptr.add(3 * lanes)));
vstore_unaligned(dst.add(x), s0);
vstore_unaligned(dst.add(x + lanes), s1);
vstore_unaligned(dst.add(x + 2 * lanes), s2);
vstore_unaligned(dst.add(x + 3 * lanes), s3);
}
// Row 2
{
let sptr = src[y + ksize].add(x);
let s0 = Op::apply(s0, vload(sptr));
let s1 = Op::apply(s1, vload(sptr.add(lanes)));
let s2 = Op::apply(s2, vload(sptr.add(2 * lanes)));
let s3 = Op::apply(s3, vload(sptr.add(3 * lanes)));
vstore_unaligned(dst.add(dst_step + x), s0);
vstore_unaligned(dst.add(dst_step + x + lanes), s1);
vstore_unaligned(dst.add(dst_step + x + 2 * lanes), s2);
vstore_unaligned(dst.add(dst_step + x + 3 * lanes), s3);
}
x += 4 * lanes;
}
if x as isize <= width - 2 * lanes as isize {
let sptr = src[y + 1].add(x);
let mut s0 = vload(sptr);
let mut s1 = vload(sptr.add(lanes));
for k in 2..ksize {
let sptr = src[y + k].add(x);
s0 = Op::apply::<S>(s0, vload(sptr));
s1 = Op::apply::<S>(s1, vload(sptr.add(lanes)));
}
// Row 1
{
let sptr = src[y].add(x);
let s0 = Op::apply(s0, vload(sptr));
let s1 = Op::apply(s1, vload(sptr.add(lanes)));
vstore_unaligned(dst.add(x), s0);
vstore_unaligned(dst.add(x + lanes), s1);
}
// Row 2
{
let sptr = src[y + ksize].add(x);
let s0 = Op::apply(s0, vload(sptr));
let s1 = Op::apply(s1, vload(sptr.add(lanes)));
vstore_unaligned(dst.add(dst_step + x), s0);
vstore_unaligned(dst.add(dst_step + x + lanes), s1);
}
x += 2 * lanes;
}
if x as isize <= width - lanes as isize {
let mut s0 = vload(src[y + 1].add(x));
for k in 2..ksize {
s0 = Op::apply::<S>(s0, vload(src[y + k].add(x)));
}
// Row 1
{
let sptr = src[y].add(x);
vstore_unaligned(dst.add(x), Op::apply(s0, vload(sptr)));
}
// Row 2
{
let sptr = src[y + ksize].add(x);
let s0 = Op::apply(s0, vload(sptr));
vstore_unaligned(dst.add(dst_step + x), s0);
}
x += lanes;
}
if x as isize <= width - lanes as isize / 2 {
let mut s0 = vload_low(src[y + 1].add(x));
for k in 2..ksize {
s0 = Op::apply::<S>(s0, vload_low(src[y + k].add(x)));
}
// Row 1
{
let sptr = src[y].add(x);
let s0 = Op::apply(s0, vload_low(sptr));
vstore_low(dst.add(x), s0);
}
// Row 2
{
let sptr = src[y + ksize].add(x);
let s0 = Op::apply(s0, vload_low(sptr));
vstore_low(dst.add(dst_step + x), s0);
}
x += lanes / 2;
}
count -= 2;
dst = dst.add(dst_step * 2);
y += 2;
}
while count > 0 {
x = 0;
while x as isize <= width - 4 * lanes as isize {
let sptr = src[y].add(x);
let mut s0 = vload(sptr);
let mut s1 = vload(sptr.add(lanes));
let mut s2 = vload(sptr.add(2 * lanes));
let mut s3 = vload(sptr.add(3 * lanes));
for k in 1..ksize {
let sptr = src[y + k].add(x);
s0 = Op::apply::<S>(s0, vload(sptr));
s1 = Op::apply::<S>(s1, vload(sptr.add(lanes)));
s2 = Op::apply::<S>(s2, vload(sptr.add(2 * lanes)));
s3 = Op::apply::<S>(s3, vload(sptr.add(3 * lanes)));
}
vstore_unaligned(dst.add(x), s0);
vstore_unaligned(dst.add(x + lanes), s1);
vstore_unaligned(dst.add(x + 2 * lanes), s2);
vstore_unaligned(dst.add(x + 3 * lanes), s3);
x += 4 * lanes;
}
if x as isize <= width - 2 * lanes as isize {
let sptr = src[y].add(x);
let mut s0 = vload(sptr);
let mut s1 = vload(sptr.add(lanes));
for k in 1..ksize {
let sptr = src[y + k].add(x);
s0 = Op::apply::<S>(s0, vload(sptr));
s1 = Op::apply::<S>(s1, vload(sptr.add(lanes)));
}
vstore_unaligned(dst.add(x), s0);
vstore_unaligned(dst.add(x + lanes), s1);
x += 2 * lanes;
}
if x as isize <= width - lanes as isize {
let mut s0 = vload(src[y].add(x));
for k in 1..ksize {
s0 = Op::apply::<S>(s0, vload(src[y + k].add(x)));
}
vstore_unaligned(dst.add(x), s0);
x += lanes;
}
if x as isize <= width - lanes as isize / 2 {
let mut s0 = vload_low(src[y].add(x));
for k in 1..ksize {
s0 = Op::apply::<S>(s0, vload_low(src[y + k].add(x)));
}
vstore_low(dst.add(x), s0);
x += lanes / 2;
}
count -= 1;
dst = dst.add(dst_step);
y += 1;
}
}
x
}
}
pub struct MorphColumnFilter<T: Scalar, Op: MorphOperator<T>, VecOp: VecColumn<T>> {
pub ksize: usize,
pub anchor: usize,
vec: VecOp,
_t: PhantomData<T>,
_op: PhantomData<Op>,
}
impl<T: Scalar, Op: MorphOperator<T>, VecOp: VecColumn<T>> MorphColumnFilter<T, Op, VecOp> {
pub fn new(ksize: usize, anchor: usize) -> Self {
let vec = VecOp::new(ksize, anchor);
Self {
ksize,
anchor,
vec,
_t: PhantomData,
_op: PhantomData,
}
}
pub fn apply<S: Simd>(
&self,
src: &[*const T],
dst: &mut [T],
dst_step: usize,
mut count: usize,
width: usize,
) {
let ksize = self.ksize;
let x0 = self.vec.apply::<S>(src, dst, dst_step, count, width);
let width = width as isize;
let mut d = 0;
let mut x;
let mut y = 0;
let slice = |row: *const T| unsafe { slice::from_raw_parts(row, width as usize) };
while ksize > 1 && count > 1 {
x = x0;
while x as isize <= width - 4 {
let row = slice(src[y + 1]);
let mut s0 = row[x];
let mut s1 = row[x + 1];
let mut s2 = row[x + 2];
let mut s3 = row[x + 3];
for k in 2..ksize {
let row = slice(src[y + k]);
s0 = Op::apply(s0, row[x]);
s1 = Op::apply(s1, row[x + 1]);
s2 = Op::apply(s2, row[x + 2]);
s3 = Op::apply(s3, row[x + 3]);
}
let row = slice(src[y]);
dst[d + x] = Op::apply(s0, row[x]);
dst[d + x + 1] = Op::apply(s1, row[x + 1]);
dst[d + x + 2] = Op::apply(s2, row[x + 2]);
dst[d + x + 3] = Op::apply(s3, row[x + 3]);
let row = slice(src[y + ksize]);
dst[d + dst_step + x] = Op::apply(s0, row[x]);
dst[d + dst_step + x + 1] = Op::apply(s1, row[x + 1]);
dst[d + dst_step + x + 2] = Op::apply(s2, row[x + 2]);
dst[d + dst_step + x + 3] = Op::apply(s3, row[x + 3]);
x += 4;
}
while (x as isize) < width {
let mut s0 = slice(src[y + 1])[x];
for k in 2..ksize {
s0 = Op::apply(s0, slice(src[y + k])[x]);
}
dst[d + x] = Op::apply(s0, slice(src[y])[x]);
dst[d + dst_step + x] = Op::apply(s0, slice(src[y + ksize])[x]);
x += 1;
}
count -= 2;
d += 2 * dst_step;
y += 2;
}
while count > 0 {
x = x0;
while x as isize <= width - 4 {
let row = slice(src[y]);
let mut s0 = row[x];
let mut s1 = row[x + 1];
let mut s2 = row[x + 2];
let mut s3 = row[x + 3];
for k in 1..ksize {
let row = slice(src[y + k]);
s0 = Op::apply(s0, row[x]);
s1 = Op::apply(s1, row[x + 1]);
s2 = Op::apply(s2, row[x + 2]);
s3 = Op::apply(s3, row[x + 3]);
}
dst[d + x] = s0;
dst[d + x + 1] = s1;
dst[d + x + 2] = s2;
dst[d + x + 3] = s3;
x += 4;
}
while (x as isize) < width {
let mut s0 = slice(src[y])[x];
for k in 1..ksize {
s0 = Op::apply(s0, slice(src[y + k])[x]);
}
dst[d + x] = s0;
x += 1;
}
count -= 1;
d += dst_step;
y += 1;
}
}
}
pub trait VecFilter<T: Scalar> {
fn apply<S: Simd>(src: &[*const T], nz: usize, dst: &mut [T], width: usize) -> usize;
}
pub struct MorphVec<T: Scalar, Op: VecMorphOperator<T>>(PhantomData<(T, Op)>);
impl<T: Scalar, Op: VecMorphOperator<T>> VecFilter<T> for MorphVec<T, Op> {
fn apply<S: Simd>(src: &[*const T], nz: usize, dst: &mut [T], width: usize) -> usize {
let dst = dst.as_mut_ptr();
let mut i = 0;
let lanes = T::lanes::<S>();
let width = width as isize;
// Safety: everything here is unsafe. Test thoroughly.
unsafe {
while i as isize <= width - 4 * lanes as isize {
let sptr = src[0].add(i);
let mut s0 = vload_unaligned(sptr);
let mut s1 = vload_unaligned(sptr.add(lanes));
let mut s2 = vload_unaligned(sptr.add(2 * lanes));
let mut s3 = vload_unaligned(sptr.add(3 * lanes));
for sptr in src[1..nz].iter().map(|sptr| sptr.add(i)) {
s0 = Op::apply::<S>(s0, vload_unaligned(sptr));
s1 = Op::apply::<S>(s1, vload_unaligned(sptr.add(lanes)));
s2 = Op::apply::<S>(s2, vload_unaligned(sptr.add(2 * lanes)));
s3 = Op::apply::<S>(s3, vload_unaligned(sptr.add(3 * lanes)));
}
vstore_unaligned(dst.add(i), s0);
vstore_unaligned(dst.add(i + lanes), s1);
vstore_unaligned(dst.add(i + 2 * lanes), s2);
vstore_unaligned(dst.add(i + 3 * lanes), s3);
i += 4 * lanes;
}
if i as isize <= width - 2 * lanes as isize {
let sptr = src[0].add(i);
let mut s0 = vload_unaligned(sptr);
let mut s1 = vload_unaligned(sptr.add(lanes));
for sptr in src[1..nz].iter().map(|sptr| sptr.add(i)) {
s0 = Op::apply::<S>(s0, vload_unaligned(sptr));
s1 = Op::apply::<S>(s1, vload_unaligned(sptr.add(lanes)));
}
vstore_unaligned(dst.add(i), s0);
vstore_unaligned(dst.add(i + lanes), s1);
i += 2 * lanes;
}
if i as isize <= width - lanes as isize {
let mut s0 = vload_unaligned(src[0].add(i));
for sptr in src[1..nz].iter().map(|sptr| sptr.add(i)) {
s0 = Op::apply::<S>(s0, vload_unaligned(sptr));
}
vstore_unaligned(dst.add(i), s0);
i += lanes;
}
if i as isize <= width - lanes as isize / 2 {
let mut s = vload_low(src[0].add(i));
for sptr in src[1..nz].iter().map(|sptr| sptr.add(i)) {
s = Op::apply::<S>(s, vload_low(sptr));
}
vstore_low(dst.add(i), s);
i += lanes / 2;
}
}
i
}
}
pub struct MorphFilter<T: Scalar, Op: MorphOperator<T>, VecOp: VecFilter<T>> {
pub ksize: Size,
pub anchor: Point,
coords: Vec<Point>,
ptrs: Vec<*const T>,
_op: PhantomData<(Op, VecOp)>,
}
impl<T: Scalar, Op: MorphOperator<T>, VecOp: VecFilter<T>> MorphFilter<T, Op, VecOp> {
pub fn new<B: Element>(kernel: &[B], ksize: Size, anchor: Point) -> Self {
let coords = process_2d_kernel(kernel, ksize);
let ptrs = vec![null(); coords.len()];
Self {
ksize,
anchor,
coords,
ptrs,
_op: PhantomData,
}
}
#[allow(clippy::too_many_arguments)]
pub fn apply<S: Simd>(
&mut self,
src: &[*const T],
dst: &mut [T],
dst_step: usize,
mut count: usize,
width: usize,
ch: usize,
) {
let nz = self.coords.len();
let width = (width * ch) as isize;
let pt = &self.coords;
let kp = &mut self.ptrs;
let mut dst_off = 0;
let mut src_off = 0;
let slice = |ptr: *const T| unsafe { slice::from_raw_parts(ptr, width as usize) };
unsafe {
while count > 0 {
for k in 0..nz {
kp[k] = src[src_off + pt[k].y].add(pt[k].x * ch);
}
let mut i = VecOp::apply::<S>(kp, nz, &mut dst[dst_off..], width as usize);
while i as isize <= width - 4 {
let sptr = slice(kp[0].add(i));
let mut s0 = sptr[0];
let mut s1 = sptr[1];
let mut s2 = sptr[2];
let mut s3 = sptr[3];
for sptr in kp[1..nz].iter().map(|sptr| slice(sptr.add(i))) {
s0 = Op::apply(s0, sptr[0]);
s1 = Op::apply(s1, sptr[1]);
s2 = Op::apply(s2, sptr[2]);
s3 = Op::apply(s3, sptr[3]);
}
dst[dst_off + i] = s0;
dst[dst_off + i + 1] = s1;
dst[dst_off + i + 2] = s2;
dst[dst_off + i + 3] = s3;
i += 4;
}
for i in i..width as usize {
let mut s0 = *kp[0].add(i);
for v in kp[1..nz].iter().map(|sptr| *sptr.add(i)) {
s0 = Op::apply(s0, v);
}
dst[dst_off + i] = s0;
}
count -= 1;
dst_off += dst_step;
src_off += 1;
}
}
}
}
fn process_2d_kernel<B: Element>(kernel: &[B], ksize: Size) -> Vec<Point> {
let Size { width, height } = ksize;
let mut nz = kernel.iter().filter(|it| it.to_bool()).count();
if nz == 0 {
nz = 1;
}
let mut coords = vec![Point::new(0, 0); nz];
let mut k = 0;
for y in 0..height {
let krow = &kernel[y * width..];
for (x, _) in krow[..width].iter().enumerate().filter(|it| it.1.to_bool()) {
coords[k] = Point::new(x, y);
k += 1;
}
}
coords
}

View File

@@ -0,0 +1,415 @@
use std::{fmt::Debug, ptr::null_mut};
use burn_tensor::Shape;
use bytemuck::{Zeroable, cast_slice, cast_slice_mut};
use macerator::{Simd, VOrd, Vector};
use crate::{BorderType, Point, Size};
use super::filter::{
MorphColumnFilter, MorphColumnVec, MorphFilter, MorphOperator, MorphRowFilter, MorphRowVec,
MorphVec, VecMorphOperator,
};
pub type RowFilter<T, Op> = MorphRowFilter<T, Op, MorphRowVec<T, Op>>;
pub type ColFilter<T, Op> = MorphColumnFilter<T, Op, MorphColumnVec<T, Op>>;
pub type Filter2D<T, Op> = MorphFilter<T, Op, MorphVec<T, Op>>;
pub enum Filter<T: VOrd, Op: MorphOperator<T> + VecMorphOperator<T>> {
Separable {
row_filter: RowFilter<T, Op>,
col_filter: ColFilter<T, Op>,
},
Fallback(Filter2D<T, Op>),
}
pub struct FilterEngine<S: Simd, T: VOrd, Op: MorphOperator<T> + VecMorphOperator<T>> {
/// Vector aligned ring buffer to serve as intermediate, since image isn't always aligned
ring_buf: Vec<Vector<S, T>>,
/// Vector aligned row buffer to serve as intermediate, since image isn't always aligned
src_row: Vec<Vector<S, T>>,
const_border_value: Vec<T>,
const_border_row: Vec<Vector<S, T>>,
border_table: Vec<usize>,
/// Pointers to each row offset in the ring buffer
rows: Vec<*const T>,
filter: Filter<T, Op>,
ksize: Size,
anchor: Point,
dx1: usize,
dx2: usize,
row_count: usize,
dst_y: usize,
start_y: usize,
start_y_0: usize,
end_y: usize,
max_width: usize,
buf_step: usize,
width: usize,
height: usize,
border_type: BorderType,
}
impl<S: Simd, T: VOrd, Op: MorphOperator<T> + VecMorphOperator<T>> FilterEngine<S, T, Op> {
fn resize_ring_buf(&mut self, size: usize) {
let actual = size.div_ceil(T::lanes::<S>());
self.ring_buf.resize(actual, Zeroable::zeroed());
}
fn resize_src_row(&mut self, size: usize) {
let actual = size.div_ceil(T::lanes::<S>());
self.src_row.resize(actual, Zeroable::zeroed());
}
fn is_separable(&self) -> bool {
matches!(self.filter, Filter::Separable { .. })
}
}
impl<S: Simd, T: VOrd + Debug, Op: MorphOperator<T> + VecMorphOperator<T>> FilterEngine<S, T, Op> {
pub fn new(
filter: Filter<T, Op>,
border_type: BorderType,
border_value: &[T],
ch: usize,
) -> Self {
let (ksize, anchor) = match &filter {
Filter::Separable {
row_filter,
col_filter,
} => {
let ksize = Size::new(row_filter.ksize, col_filter.ksize);
let anchor = Point::new(row_filter.anchor, col_filter.anchor);
(ksize, anchor)
}
Filter::Fallback(f) => (f.ksize, f.anchor),
};
let mut border_table = Vec::new();
let border_length = (ksize.width - 1).max(1);
let mut const_border_value = Vec::new();
if matches!(border_type, BorderType::Constant) {
const_border_value = vec![Zeroable::zeroed(); border_length * ch];
for elem in cast_slice_mut::<_, T>(&mut const_border_value).chunks_exact_mut(ch) {
elem.copy_from_slice(border_value);
}
} else {
border_table = vec![0; border_length * ch];
}
Self {
ring_buf: Default::default(),
src_row: Default::default(),
rows: Default::default(),
border_type,
const_border_row: Default::default(),
const_border_value,
border_table,
ksize,
anchor,
filter,
max_width: 0,
buf_step: 0,
dx1: 0,
dx2: 0,
row_count: 0,
dst_y: 0,
start_y: 0,
start_y_0: 0,
end_y: 0,
width: 0,
height: 0,
}
}
pub fn apply(&mut self, tensor: &mut [T], src_shape: Shape) {
let [_, w, ch] = src_shape.dims();
let src_step = w * ch;
self.start(src_shape);
let y = self.start_y;
self.proceed(
&mut tensor[y * src_step..],
src_step,
self.end_y - self.start_y,
ch,
);
}
pub fn start(&mut self, shape: Shape) -> usize {
let [height, width, ch] = shape.dims();
let max_buf_rows = (self.ksize.height + 3)
.max(self.anchor.y)
.max((self.ksize.height - self.anchor.y - 1) * 2 + 1);
let k_offs = if !self.is_separable() {
self.ksize.width - 1
} else {
0
};
let is_sep = self.is_separable();
if self.max_width < width || max_buf_rows != self.rows.len() {
self.rows.resize(max_buf_rows, null_mut());
self.max_width = self.max_width.max(width);
self.resize_src_row((self.max_width + self.ksize.width - 1) * ch);
if matches!(self.border_type, BorderType::Constant) {
self.const_border_row.resize(
((self.max_width + self.ksize.width - 1) * ch).div_ceil(T::lanes::<S>()),
Zeroable::zeroed(),
);
let mut n = self.const_border_value.len();
let n1 = (self.max_width + self.ksize.width - 1) * ch;
let const_val = &self.const_border_value;
let dst = cast_slice_mut(&mut self.const_border_row);
let t_dst = if is_sep {
cast_slice_mut::<_, T>(&mut self.src_row)
} else {
alias_slice_mut(dst)
};
for i in (0..n1).step_by(n) {
n = n.min(n1 - i);
t_dst[i..i + n].copy_from_slice(&const_val[..n]);
}
if let Filter::Separable { row_filter, .. } = &self.filter {
row_filter.apply::<S>(cast_slice(&self.src_row), dst, self.max_width, ch);
}
}
let max_buf_step =
(self.max_width + k_offs).next_multiple_of(align_of::<Vector<S, T>>()) * ch;
self.resize_ring_buf(max_buf_step * self.rows.len());
}
let const_val = &self.const_border_value;
self.buf_step = (width + k_offs).next_multiple_of(align_of::<Vector<S, T>>()) * ch;
self.dx1 = self.anchor.x;
self.dx2 = self.ksize.width - self.anchor.x - 1;
if self.dx1 > 0 || self.dx2 > 0 {
if matches!(self.border_type, BorderType::Constant) {
let nr = if self.is_separable() {
1
} else {
self.rows.len()
};
for i in 0..nr {
let dst = if self.is_separable() {
cast_slice_mut::<_, T>(&mut self.src_row)
} else {
&mut cast_slice_mut::<_, T>(&mut self.ring_buf)[self.buf_step * i..]
};
memcpy(dst, const_val, self.dx1 * ch);
let right = (width + self.ksize.width - 1 - self.dx2) * ch;
memcpy(&mut dst[right..], const_val, self.dx2 * ch);
}
} else {
for i in 0..self.dx1 as isize {
let p0 = border_interpolate(i - self.dx1 as isize, width, self.border_type);
let p0 = p0 as usize * ch;
for j in 0..ch {
self.border_table[i as usize * ch + j] = p0 + j;
}
}
for i in 0..self.dx2 {
let p0 = border_interpolate((width + i) as isize, width, self.border_type)
as usize
* ch;
for j in 0..ch {
self.border_table[(i + self.dx1) * ch + j] = p0 + j;
}
}
}
}
self.row_count = 0;
self.dst_y = 0;
self.start_y = 0;
self.start_y_0 = 0;
self.end_y = height;
self.width = width;
self.height = height;
self.start_y
}
#[allow(clippy::too_many_arguments)]
pub fn proceed(
&mut self,
src: &mut [T],
src_step: usize,
mut count: usize,
ch: usize,
) -> usize {
let buf_rows = self.rows.len();
let kheight = self.ksize.height;
let kwidth = self.ksize.width;
let ay = self.anchor.y as isize;
let dx1 = self.dx1;
let dx2 = self.dx2;
let width1 = self.width + kwidth - 1;
let btab = &self.border_table;
let make_border = (dx1 > 0 || dx2 > 0) && !matches!(self.border_type, BorderType::Constant);
let is_sep = self.is_separable();
count = count.min(self.remaining_input_rows());
let mut dst_off = 0;
let mut src_off = 0;
let mut dy = 0;
let mut i;
let brows = &mut self.rows;
let src_row = cast_slice_mut::<_, T>(&mut self.src_row);
let ring_buf = cast_slice_mut::<_, T>(&mut self.ring_buf);
loop {
let dcount = buf_rows as isize - ay - self.start_y as isize - self.row_count as isize;
let mut dcount = if dcount > 0 {
dcount as usize
} else {
buf_rows + 1 - kheight
};
dcount = dcount.min(count);
count -= dcount;
while dcount > 0 {
let bi = (self.start_y - self.start_y_0 + self.row_count) % buf_rows;
let brow = &mut ring_buf[bi * self.buf_step..];
let row = if is_sep {
&mut src_row[..]
} else {
alias_slice_mut(brow)
};
if self.row_count + 1 > buf_rows {
self.row_count -= 1;
self.start_y += 1;
}
self.row_count += 1;
memcpy(
&mut row[dx1 * ch..],
&src[src_off..],
(width1 - dx2 - dx1) * ch,
);
if make_border {
for i in 0..dx1 * ch {
row[i] = src[src_off + btab[i]];
}
for i in 0..dx2 * ch {
row[i + (width1 - dx2) * ch] = src[src_off + btab[i + dx1 * ch]];
}
}
if let Filter::Separable { row_filter, .. } = &self.filter {
row_filter.apply::<S>(row, brow, self.width, ch);
}
dcount -= 1;
src_off += src_step;
}
let max_i = buf_rows.min(self.height - (self.dst_y + dy) + (kheight - 1));
i = 0;
while i < max_i {
let src_y = border_interpolate(
(self.dst_y + dy + i) as isize - ay,
self.height,
self.border_type,
);
if src_y < 0 {
brows[i] = self.const_border_row.as_ptr() as _;
} else {
if src_y as usize >= self.start_y + self.row_count {
break;
}
let bi = (src_y as usize - self.start_y_0) % buf_rows;
brows[i] = unsafe { ring_buf.as_ptr().add(bi * self.buf_step) };
}
i += 1;
}
if i < kheight {
break;
}
i -= kheight - 1;
match &mut self.filter {
Filter::Separable { col_filter, .. } => {
col_filter.apply::<S>(brows, &mut src[dst_off..], src_step, i, self.width * ch)
}
Filter::Fallback(filter) => {
filter.apply::<S>(brows, &mut src[dst_off..], src_step, i, self.width, ch)
}
}
dst_off += src_step * i;
dy += i;
}
self.dst_y += dy;
dy
}
fn remaining_input_rows(&self) -> usize {
self.end_y - self.start_y - self.row_count
}
}
#[track_caller]
fn memcpy<T: Copy>(to: &mut [T], from: &[T], len: usize) {
to[..len].copy_from_slice(&from[..len]);
}
/// Unsafely alias slice. Needed for the conditional slice targets that depend on the filter. The
/// same slice shouldn't be used multiple times at once
fn alias_slice_mut<'b, T>(slice: &mut [T]) -> &'b mut [T] {
let ptr = slice.as_mut_ptr();
let len = slice.len();
unsafe { core::slice::from_raw_parts_mut(ptr, len) }
}
fn border_interpolate(mut p: isize, len: usize, btype: BorderType) -> isize {
let len = len as isize;
if p < len && p >= 0 {
return p;
}
match btype {
BorderType::Constant => -1,
BorderType::Replicate if p < 0 => 0,
BorderType::Replicate => len - 1,
BorderType::Reflect | BorderType::Reflect101 => {
let delta = matches!(btype, BorderType::Reflect101) as isize;
if len == 1 {
return 0;
}
loop {
if p < 0 {
p = -p - 1 + delta;
} else {
p = len - 1 - (p - len) - delta;
}
if p < len && p >= 0 {
break;
}
}
p
}
BorderType::Wrap => {
if p < 0 {
p -= ((p - len + 1) / len) * len;
}
if p >= len {
p %= len;
}
p
}
}
}

View File

@@ -0,0 +1,300 @@
use std::fmt::Debug;
use burn_tensor::{
BasicOps, Bool, DType, Element, Shape, Tensor, TensorData, backend::Backend, cast::ToElement,
ops::BoolTensor,
};
use filter::{MaxOp, MinOp, MorphOperator, VecMorphOperator};
use filter_engine::{ColFilter, Filter, Filter2D, FilterEngine, RowFilter};
use macerator::{Simd, VOrd};
use crate::{BorderType, MorphOptions, Point, Size};
use super::MinMax;
mod filter;
mod filter_engine;
/// A morphology operation.
/// TODO: Implement composite ops
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum MorphOp {
Erode,
Dilate,
}
pub enum MorphKernel<B: Element> {
Rect {
size: Size,
anchor: Point,
},
Other {
kernel: Vec<B>,
size: Size,
anchor: Point,
},
}
pub fn morph<B: Backend, K: BasicOps<B>>(
input: Tensor<B, 3, K>,
kernel: BoolTensor<B>,
op: MorphOp,
opts: MorphOptions<B, K>,
) -> Tensor<B, 3, K> {
let device = input.device();
let kernel = Tensor::<B, 2, Bool>::new(kernel);
let kshape = kernel.shape().dims();
let [kh, kw] = kshape;
let kernel = kernel.into_data().into_vec::<B::BoolElem>().unwrap();
let is_rect = kernel.iter().all(|it| it.to_bool());
let anchor = opts.anchor.unwrap_or(Point::new(kw / 2, kh / 2));
let iter = opts.iterations;
let btype = opts.border_type;
let bvalue = opts.border_value.map(|it| it.into_data());
let size = Size::new(kw, kh);
let kernel = if is_rect {
MorphKernel::Rect { size, anchor }
} else {
MorphKernel::Other {
kernel,
size,
anchor,
}
};
let shape = input.shape();
let data = input.into_data();
match data.dtype {
DType::F64 => {
morph_typed::<B, K, f64>(data, shape, kernel, op, iter, btype, bvalue, &device)
}
DType::F32 | DType::Flex32 => {
morph_typed::<B, K, f32>(data, shape, kernel, op, iter, btype, bvalue, &device)
}
DType::F16 | DType::BF16 => morph_typed::<B, K, f32>(
data.convert::<f32>(),
shape,
kernel,
op,
iter,
btype,
bvalue,
&device,
),
DType::I64 => {
morph_typed::<B, K, i64>(data, shape, kernel, op, iter, btype, bvalue, &device)
}
DType::I32 => {
morph_typed::<B, K, i32>(data, shape, kernel, op, iter, btype, bvalue, &device)
}
DType::I16 => {
morph_typed::<B, K, i16>(data, shape, kernel, op, iter, btype, bvalue, &device)
}
DType::I8 => morph_typed::<B, K, i8>(data, shape, kernel, op, iter, btype, bvalue, &device),
DType::U64 => {
morph_typed::<B, K, u64>(data, shape, kernel, op, iter, btype, bvalue, &device)
}
DType::U32 => {
morph_typed::<B, K, u32>(data, shape, kernel, op, iter, btype, bvalue, &device)
}
DType::U16 => {
morph_typed::<B, K, u16>(data, shape, kernel, op, iter, btype, bvalue, &device)
}
DType::U8 => morph_typed::<B, K, u8>(data, shape, kernel, op, iter, btype, bvalue, &device),
DType::Bool => morph_bool::<B, K>(data, shape, kernel, op, iter, btype, bvalue, &device),
DType::QFloat(_) => unimplemented!(),
}
}
#[allow(clippy::too_many_arguments)]
fn morph_typed<B: Backend, K: BasicOps<B>, T: VOrd + MinMax + Element>(
mut input: TensorData,
shape: Shape,
kernel: MorphKernel<B::BoolElem>,
op: MorphOp,
iter: usize,
btype: BorderType,
bvalue: Option<TensorData>,
device: &B::Device,
) -> Tensor<B, 3, K> {
let data = input.as_mut_slice::<T>().unwrap();
let bvalue = border_value(btype, bvalue, op, &shape);
run_morph(data, shape, kernel, op, iter, btype, &bvalue);
Tensor::from_data(input, device)
}
#[allow(clippy::too_many_arguments)]
fn morph_bool<B: Backend, K: BasicOps<B>>(
mut input: TensorData,
shape: Shape,
kernel: MorphKernel<B::BoolElem>,
op: MorphOp,
iter: usize,
btype: BorderType,
bvalue: Option<TensorData>,
device: &B::Device,
) -> Tensor<B, 3, K> {
let data = input.as_mut_slice::<bool>().unwrap();
// SAFETY: Morph can't produce invalid boolean values
let data = unsafe { core::mem::transmute::<&mut [bool], &mut [u8]>(data) };
let bvalue = border_value(btype, bvalue, op, &shape);
run_morph(data, shape.clone(), kernel, op, iter, btype, &bvalue);
Tensor::from_data(input, device)
}
fn border_value<T: Element>(
btype: BorderType,
bvalue: Option<TensorData>,
op: MorphOp,
shape: &Shape,
) -> Vec<T> {
let [_, _, ch] = shape.dims();
match (btype, bvalue) {
(BorderType::Constant, Some(value)) => value.convert::<T>().into_vec().unwrap(),
(BorderType::Constant, None) => match op {
MorphOp::Erode => vec![T::MAX; ch],
MorphOp::Dilate => vec![T::MIN; ch],
},
_ => vec![],
}
}
fn run_morph<T: VOrd + MinMax + Element, B: Element>(
input: &mut [T],
shape: Shape,
kernel: MorphKernel<B>,
op: MorphOp,
iter: usize,
btype: BorderType,
bvalue: &[T],
) {
match op {
MorphOp::Erode => {
let filter = filter::<T, MinOp, B>(kernel);
dispatch_morph(input, shape, filter, btype, bvalue, iter);
}
MorphOp::Dilate => {
let filter = filter::<T, MaxOp, B>(kernel);
dispatch_morph(input, shape, filter, btype, bvalue, iter);
}
};
}
fn filter<T: VOrd + MinMax, Op: MorphOperator<T> + VecMorphOperator<T>, B: Element>(
kernel: MorphKernel<B>,
) -> Filter<T, Op> {
match kernel {
MorphKernel::Rect { size, anchor } => {
let row_filter = RowFilter::new(size.width, anchor.x);
let col_filter = ColFilter::new(size.height, anchor.y);
Filter::Separable {
row_filter,
col_filter,
}
}
MorphKernel::Other {
kernel,
size,
anchor,
} => {
let filter = Filter2D::new(&kernel, size, anchor);
Filter::Fallback(filter)
}
}
}
#[inline(always)]
#[allow(clippy::too_many_arguments)]
#[macerator::with_simd]
fn dispatch_morph<
'a,
S: Simd,
T: VOrd + MinMax + Debug,
Op: MorphOperator<T> + VecMorphOperator<T>,
>(
buffer: &'a mut [T],
buffer_shape: Shape,
filter: filter_engine::Filter<T, Op>,
border_type: BorderType,
border_value: &'a [T],
iterations: usize,
) where
'a: 'a,
{
let [_, _, ch] = buffer_shape.dims();
let mut engine = FilterEngine::<S, _, _>::new(filter, border_type, border_value, ch);
engine.apply(buffer, buffer_shape.clone());
for _ in 1..iterations {
engine.apply(buffer, buffer_shape.clone());
}
}
/// Shape of the structuring element
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum KernelShape {
/// Rectangular kernel
Rect,
/// Cross shaped kernel
Cross,
/// Ellipse shaped kernel
Ellipse,
}
/// Create a structuring element tensor for use with morphology ops
pub fn create_structuring_element<B: Backend>(
shape: KernelShape,
ksize: Size,
anchor: Option<Point>,
device: &B::Device,
) -> Tensor<B, 2, Bool> {
fn create_kernel(shape: KernelShape, ksize: Size, anchor: Option<Point>) -> Vec<bool> {
let anchor = anchor.unwrap_or(Point::new(ksize.width / 2, ksize.height / 2));
let mut r = 0;
let mut c = 0;
let mut inv_r2 = 0.0;
if (ksize.width == 1 && ksize.height == 1) || shape == KernelShape::Rect {
return vec![true; ksize.height * ksize.width];
}
if shape == KernelShape::Ellipse {
r = ksize.height / 2;
c = ksize.width / 2;
inv_r2 = if r > 0 { 1.0 / (r * r) as f64 } else { 0.0 }
}
let mut elem = vec![false; ksize.height * ksize.width];
for i in 0..ksize.height {
let mut j1 = 0;
let mut j2 = 0;
if shape == KernelShape::Cross && i == anchor.y {
j2 = ksize.width;
} else if shape == KernelShape::Cross {
j1 = anchor.x;
j2 = j1 + 1;
} else {
let dy = i as isize - r as isize;
if dy.abs() <= r as isize {
let dx = (c as f64 * ((r * r - (dy * dy) as usize) as f64 * inv_r2).sqrt())
.round() as isize;
j1 = (c as isize - dx).max(0) as usize;
j2 = (c + dx as usize + 1).min(ksize.width);
}
}
for j in j1..j2 {
elem[i * ksize.width + j] = true;
}
}
elem
}
let elem = create_kernel(shape, ksize, anchor);
let data = TensorData::new(elem, [ksize.height, ksize.width]);
Tensor::from_data(data, device)
}

View File

@@ -0,0 +1,212 @@
use crate::NmsOptions;
use aligned_vec::{AVec, ConstAlign};
use alloc::vec::Vec;
use burn_tensor::{Int, Shape, Tensor, TensorData, backend::Backend};
use macerator::{Scalar, Simd, Vector, vload};
/// Perform NMS on CPU using SIMD acceleration.
///
/// This implementation:
/// 1. Sorts boxes by score (descending)
/// 2. Iteratively selects the highest-scoring non-suppressed box
/// 3. Suppresses all boxes with IoU > threshold using SIMD
pub fn nms<B: Backend>(
boxes: Tensor<B, 2>,
scores: Tensor<B, 1>,
options: NmsOptions,
) -> Tensor<B, 1, Int> {
let device = boxes.device();
let [n_boxes, _] = boxes.shape().dims();
if n_boxes == 0 {
return Tensor::<B, 1, Int>::empty([0], &device);
}
// Get raw data
let boxes_data = boxes.to_data();
let boxes_vec: Vec<f32> = boxes_data.to_vec().unwrap();
let scores_data = scores.to_data();
let scores_vec: Vec<f32> = scores_data.to_vec().unwrap();
let keep = nms_vec(boxes_vec, scores_vec, options);
let n_kept = keep.len();
let indices_data = TensorData::new(keep, Shape::new([n_kept]));
Tensor::<B, 1, Int>::from_data(indices_data, &device)
}
/// Perform NMS on CPU using SIMD acceleration.
fn nms_vec(boxes_vec: Vec<f32>, scores_vec: Vec<f32>, options: NmsOptions) -> Vec<i32> {
let n_boxes = scores_vec.len();
if n_boxes == 0 {
return vec![];
}
// Filter by score threshold first
let mut filtered_indices = Vec::with_capacity(n_boxes);
for (i, &score) in scores_vec.iter().enumerate() {
if score >= options.score_threshold {
filtered_indices.push(i); // original index
}
}
let n_filtered = filtered_indices.len();
if n_filtered == 0 {
return vec![];
}
// Sort by score descending
filtered_indices.sort_by(|&a, &b| scores_vec[b].total_cmp(&scores_vec[a]));
const ALIGN: usize = 64;
const FLOATS_PER_ALIGN: usize = ALIGN / size_of::<f32>(); // 16
let stride = n_filtered.div_ceil(FLOATS_PER_ALIGN) * FLOATS_PER_ALIGN;
let mut buf: AVec<f32, ConstAlign<64>> = AVec::with_capacity(ALIGN, stride * 5);
buf.resize(stride * 5, 0.0);
let (x1s, rest) = buf.split_at_mut(stride);
let (y1s, rest) = rest.split_at_mut(stride);
let (x2s, rest) = rest.split_at_mut(stride);
let (y2s, areas) = rest.split_at_mut(stride);
// Convert filtered boxes to SoA format
for (j, &orig_idx) in filtered_indices.iter().enumerate() {
let x1 = boxes_vec[orig_idx * 4];
let y1 = boxes_vec[orig_idx * 4 + 1];
let x2 = boxes_vec[orig_idx * 4 + 2];
let y2 = boxes_vec[orig_idx * 4 + 3];
x1s[j] = x1;
y1s[j] = y1;
x2s[j] = x2;
y2s[j] = y2;
areas[j] = (x2 - x1) * (y2 - y1);
}
// Apply NMS with SIMD dispatch
let mut suppressed = vec![false; stride];
let mut keep = Vec::new();
for i in 0..n_filtered {
if suppressed[i] {
continue;
}
// Optimization to reduce inner loop comparisons
suppressed[i] = true;
keep.push(filtered_indices[i] as i32); // original index
if options.max_output_boxes > 0 && keep.len() >= options.max_output_boxes {
break;
}
// Suppress overlapping boxes using SIMD
suppress_overlapping(
x1s[i],
y1s[i],
x2s[i],
y2s[i],
areas[i],
x1s,
y1s,
x2s,
y2s,
areas,
&mut suppressed,
stride,
options.iou_threshold,
);
}
keep
}
/// SIMD-accelerated suppression of overlapping boxes.
#[allow(clippy::too_many_arguments)]
#[inline(always)]
#[macerator::with_simd]
fn suppress_overlapping<'a, S: Simd>(
ref_x1: f32,
ref_y1: f32,
ref_x2: f32,
ref_y2: f32,
ref_area: f32,
x1s: &'a [f32],
y1s: &'a [f32],
x2s: &'a [f32],
y2s: &'a [f32],
areas: &'a [f32],
suppressed: &'a mut [bool],
n_boxes: usize, // stride, always multiple of lanes
threshold: f32,
) where
'a: 'a,
{
let lanes = f32::lanes::<S>();
// Splat reference values
let ref_x1_v: Vector<S, f32> = ref_x1.splat();
let ref_y1_v: Vector<S, f32> = ref_y1.splat();
let ref_x2_v: Vector<S, f32> = ref_x2.splat();
let ref_y2_v: Vector<S, f32> = ref_y2.splat();
let ref_area_v: Vector<S, f32> = ref_area.splat();
let thresh_v: Vector<S, f32> = threshold.splat();
let zero_v: Vector<S, f32> = 0.0f32.splat();
let mut i = 0;
let mut mask_buf = core::mem::MaybeUninit::<[bool; 16]>::uninit();
// Process lanes boxes at a time with SIMD
while i + lanes <= n_boxes {
// Skip if all boxes in this chunk are already suppressed
let all_suppressed = unsafe {
match lanes {
4 => *(suppressed.as_ptr().add(i) as *const u32) == 0x01010101,
8 => *(suppressed.as_ptr().add(i) as *const u64) == 0x0101010101010101,
16 => {
*(suppressed.as_ptr().add(i) as *const u128)
== 0x01010101010101010101010101010101
}
_ => unreachable!(),
}
};
if !all_suppressed {
let x1_v: Vector<S, f32> = unsafe { vload(x1s.as_ptr().add(i)) };
let y1_v: Vector<S, f32> = unsafe { vload(y1s.as_ptr().add(i)) };
let x2_v: Vector<S, f32> = unsafe { vload(x2s.as_ptr().add(i)) };
let y2_v: Vector<S, f32> = unsafe { vload(y2s.as_ptr().add(i)) };
let area_v: Vector<S, f32> = unsafe { vload(areas.as_ptr().add(i)) };
// Compute intersection coordinates
let xx1 = ref_x1_v.max(x1_v);
let yy1 = ref_y1_v.max(y1_v);
let xx2 = ref_x2_v.min(x2_v);
let yy2 = ref_y2_v.min(y2_v);
// Compute intersection area (clamp to 0 for non-overlapping)
let w = (xx2 - xx1).max(zero_v);
let h = (yy2 - yy1).max(zero_v);
let inter = w * h;
// Compute IoU
let union = ref_area_v + area_v - inter;
let iou = inter / union;
// Get suppression mask (IoU > threshold)
let suppress_mask = iou.gt(thresh_v);
// Extract mask to bool array and apply to suppressed
// SAFETY: mask_store_as_bool writes exactly `lanes` bools, we only read 0..lanes
unsafe { f32::mask_store_as_bool::<S>(mask_buf.as_mut_ptr().cast(), suppress_mask) };
let mask_buf = unsafe { mask_buf.assume_init() };
for k in 0..lanes {
if mask_buf[k] {
suppressed[i + k] = true;
}
}
}
i += lanes;
}
}

View File

@@ -0,0 +1,54 @@
#[cfg(feature = "ndarray")]
mod ndarray {
use crate::{BoolVisionOps, FloatVisionOps, IntVisionOps, QVisionOps, VisionBackend};
use burn_ndarray::{
FloatNdArrayElement, IntNdArrayElement, NdArray, NdArrayTensor, QuantElement, SharedArray,
};
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolVisionOps
for NdArray<E, I, Q>
where
NdArrayTensor: From<SharedArray<E>>,
NdArrayTensor: From<SharedArray<I>>,
{
}
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntVisionOps
for NdArray<E, I, Q>
where
NdArrayTensor: From<SharedArray<E>>,
NdArrayTensor: From<SharedArray<I>>,
{
}
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatVisionOps
for NdArray<E, I, Q>
where
NdArrayTensor: From<SharedArray<E>>,
NdArrayTensor: From<SharedArray<I>>,
{
}
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QVisionOps for NdArray<E, I, Q>
where
NdArrayTensor: From<SharedArray<E>>,
NdArrayTensor: From<SharedArray<I>>,
{
}
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> VisionBackend
for NdArray<E, I, Q>
where
NdArrayTensor: From<SharedArray<E>>,
NdArrayTensor: From<SharedArray<I>>,
{
}
}
#[cfg(feature = "tch")]
mod tch {
use crate::{BoolVisionOps, FloatVisionOps, IntVisionOps, QVisionOps, VisionBackend};
use burn_tch::{LibTorch, TchElement};
impl<E: TchElement, Q: burn_tch::QuantElement> BoolVisionOps for LibTorch<E, Q> {}
impl<E: TchElement, Q: burn_tch::QuantElement> IntVisionOps for LibTorch<E, Q> {}
impl<E: TchElement, Q: burn_tch::QuantElement> FloatVisionOps for LibTorch<E, Q> {}
impl<E: TchElement, Q: burn_tch::QuantElement> QVisionOps for LibTorch<E, Q> {}
impl<E: TchElement, Q: burn_tch::QuantElement> VisionBackend for LibTorch<E, Q> {}
}

View File

@@ -0,0 +1,624 @@
//! Hardware Accelerated 4-connected, adapted from
//! A. Hennequin, L. Lacassagne, L. Cabaret, Q. Meunier,
//! "A new Direct Connected Component Labeling and Analysis Algorithms for GPUs",
//! DASIP, 2018
use crate::{
ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity,
backends::cube::connected_components::stats_from_opts,
};
use burn_cubecl::{
BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement, kernel,
ops::{into_data_sync, numeric::zeros_client},
tensor::CubeTensor,
};
use burn_tensor::{Shape, TensorMetadata, cast::ToElement, ops::IntTensorOps};
use cubecl::{features::Plane, prelude::*};
use super::prefix_sum::prefix_sum;
const BLOCK_H: usize = 4;
#[cube]
fn merge<I: Int>(labels: &Tensor<Atomic<I>>, label_1: u32, label_2: u32) {
let mut label_1 = label_1 as usize;
let mut label_2 = label_2 as usize;
while label_1 != label_2 && (label_1 != usize::cast_from(labels[label_1].load()) - 1) {
label_1 = usize::cast_from(labels[label_1].load()) - 1;
}
while label_1 != label_2 && (label_2 != usize::cast_from(labels[label_2].load()) - 1) {
label_2 = usize::cast_from(labels[label_2].load()) - 1;
}
while label_1 != label_2 {
#[allow(clippy::manual_swap)]
if label_1 < label_2 {
let tmp = label_1;
label_1 = label_2;
label_2 = tmp;
}
let label_3 = usize::cast_from(labels[label_1].fetch_min(I::cast_from(label_2 + 1))) - 1;
if label_1 == label_3 {
label_1 = label_2;
} else {
label_1 = label_3;
}
}
}
#[cube]
fn start_distance(pixels: u32, tx: u32) -> u32 {
(!(pixels << (32 - tx))).leading_zeros()
}
#[cube]
fn end_distance(pixels: u32, tx: u32) -> u32 {
(!(pixels >> (tx + 1))).find_first_set()
}
#[cube]
#[allow(unconditional_panic, reason = "clippy thinks PLANE_DIM is always 2")]
fn ballot_dyn(y: u32, pred: bool) -> u32 {
let index = y % (PLANE_DIM / 32);
plane_ballot(pred)[index as usize]
}
#[cube(launch_unchecked)]
fn strip_labeling<I: Int, BT: CubePrimitive>(
img: &Tensor<BT>,
labels: &Tensor<Atomic<I>>,
#[comptime] connectivity: Connectivity,
) {
let mut shared_pixels = SharedMemory::<u32>::new(BLOCK_H);
let y = ABSOLUTE_POS_Y;
let rows = labels.shape(0) as u32;
let cols = labels.shape(1) as u32;
if y >= rows {
terminate!();
}
let img_stride = img.stride(0) as u32;
let labels_stride = labels.stride(0) as u32;
let img_line_base = y * img_stride + UNIT_POS_X;
let labels_line_base = y * labels_stride + UNIT_POS_X;
let mut distance_y = 0u32;
let mut distance_y_1 = 0;
for i in range_stepped(0, img.shape(1) as u32, PLANE_DIM) {
let x = UNIT_POS_X + i;
if x < cols {
let mut mask = 0xffffffffu32;
let involved_cols = cols - i;
if involved_cols < 32 {
mask >>= 32 - involved_cols;
}
let img_index = img_line_base + i;
let labels_index = labels_line_base + i;
let p_y = bool::cast_from(img[img_index as usize]);
let pixels_y = ballot_dyn(UNIT_POS_Y, p_y) & mask;
let mut s_dist_y = start_distance(pixels_y, UNIT_POS_X);
if p_y && s_dist_y == 0 {
labels[labels_index as usize].store(I::cast_from(
labels_index - select(UNIT_POS_X == 0, distance_y, 0) + 1,
));
}
// Only needed pre-Volta, but we can't check that at present
sync_cube();
if UNIT_POS_X == 0 {
shared_pixels[UNIT_POS_Y as usize] = pixels_y;
}
sync_cube();
// Requires if and not select, because `select` may execute the then branch even if the
// condition is false (on non-CUDA backends), which can lead to OOB reads.
let pixels_y_1 = if UNIT_POS_Y > 0 {
shared_pixels[(UNIT_POS_Y - 1) as usize]
} else {
0u32.runtime()
};
let p_y_1 = (pixels_y_1 >> UNIT_POS_X) & 1 != 0;
let mut s_dist_y_1 = start_distance(pixels_y_1, UNIT_POS_X);
if UNIT_POS_X == 0 {
s_dist_y = distance_y;
s_dist_y_1 = distance_y_1;
}
match connectivity {
Connectivity::Four => {
if p_y && p_y_1 && (s_dist_y == 0 || s_dist_y_1 == 0) {
let label_1 = labels_index - s_dist_y;
let label_2 = labels_index - s_dist_y_1 - labels_stride;
merge(labels, label_1, label_2);
}
}
Connectivity::Eight => {
let pixels_y_shifted = (pixels_y << 1) | (distance_y > 0) as u32;
let pixels_y_1_shifted = (pixels_y_1 << 1) | (distance_y_1 > 0) as u32;
if p_y && p_y_1 && (s_dist_y == 0 || s_dist_y_1 == 0) {
let label_1 = labels_index - s_dist_y;
let label_2 = labels_index - s_dist_y_1 - labels_stride;
merge(labels, label_1, label_2);
} else if p_y && s_dist_y == 0 && (pixels_y_1_shifted >> UNIT_POS_X) & 1 != 0 {
let s_dist_y_1_prev = select(
UNIT_POS_X == 0,
distance_y_1 - 1,
start_distance(pixels_y_1, UNIT_POS_X - 1),
);
let label_1 = labels_index;
let label_2 = labels_index - labels_stride - 1 - s_dist_y_1_prev;
merge(labels, label_1, label_2);
} else if p_y_1 && s_dist_y_1 == 0 && (pixels_y_shifted >> UNIT_POS_X) & 1 != 0
{
let s_dist_y_prev = select(
UNIT_POS_X == 0,
distance_y - 1,
start_distance(pixels_y, UNIT_POS_X - 1),
);
let label_1 = labels_index - 1 - s_dist_y_prev;
let label_2 = labels_index - labels_stride;
merge(labels, label_1, label_2);
}
}
}
if p_y && p_y_1 && (s_dist_y == 0 || s_dist_y_1 == 0) {
let label_1 = labels_index - s_dist_y;
let label_2 = labels_index - s_dist_y_1 - labels_stride;
merge(labels, label_1, label_2);
}
let mut d = start_distance(pixels_y_1, 32);
distance_y_1 = d + select(d == 32, distance_y_1, 0);
d = start_distance(pixels_y, 32);
distance_y = d + select(d == 32, distance_y, 0);
}
}
}
#[cube(launch_unchecked)]
fn strip_merge<I: Int, BT: CubePrimitive>(
img: &Tensor<BT>,
labels: &Tensor<Atomic<I>>,
#[comptime] connectivity: Connectivity,
) {
let plane_start_x = CUBE_POS_X * (CUBE_DIM_X * CUBE_DIM_Z - PLANE_DIM) + UNIT_POS_Z * PLANE_DIM;
let y = (CUBE_POS_Y + 1) * BLOCK_H as u32;
let x = plane_start_x + UNIT_POS_X;
let img_step = img.stride(0) as u32;
let labels_step = labels.stride(0) as u32;
let cols = img.shape(1) as u32;
if y < labels.shape(0) as u32 && x < labels.shape(1) as u32 {
let mut mask = 0xffffffffu32;
if cols - plane_start_x < 32 {
mask >>= 32 - (cols - plane_start_x);
}
let img_index = y * img_step + x;
let labels_index = y * labels_step + x;
let img_index_up = img_index - img_step;
let labels_index_up = labels_index - labels_step;
let p = bool::cast_from(img[img_index as usize]);
let p_up = bool::cast_from(img[img_index_up as usize]);
let pixels = ballot_dyn(UNIT_POS_Z, p) & mask;
let pixels_up = ballot_dyn(UNIT_POS_Z, p_up) & mask;
match connectivity {
Connectivity::Four => {
if p && p_up {
let s_dist = start_distance(pixels, UNIT_POS_X);
let s_dist_up = start_distance(pixels_up, UNIT_POS_X);
if s_dist == 0 || s_dist_up == 0 {
merge(labels, labels_index - s_dist, labels_index_up - s_dist_up);
}
}
}
Connectivity::Eight => {
let mut last_dist_vec = SharedMemory::<u32>::new(32usize);
let mut last_dist_up_vec = SharedMemory::<u32>::new(32usize);
let s_dist = start_distance(pixels, UNIT_POS_X);
let s_dist_up = start_distance(pixels_up, UNIT_POS_X);
if UNIT_POS_PLANE == PLANE_DIM - 1 {
last_dist_vec[UNIT_POS_Z as usize] = start_distance(pixels, 32);
last_dist_up_vec[UNIT_POS_Z as usize] = start_distance(pixels_up, 32);
}
sync_cube();
if CUBE_POS_X == 0 || UNIT_POS_Z > 0 {
let last_dist = if UNIT_POS_Z > 0 {
last_dist_vec[(UNIT_POS_Z - 1) as usize]
} else {
0u32.runtime()
};
let last_dist_up = if UNIT_POS_Z > 0 {
last_dist_up_vec[(UNIT_POS_Z - 1) as usize]
} else {
0u32.runtime()
};
let p_prev =
select(UNIT_POS_X > 0, (pixels >> (UNIT_POS_X - 1)) & 1, last_dist) != 0;
let p_up_prev = select(
UNIT_POS_X > 0,
(pixels_up >> (UNIT_POS_X - 1)) & 1,
last_dist_up,
) != 0;
if p && p_up {
let s_dist = start_distance(pixels, UNIT_POS_X);
let s_dist_up = start_distance(pixels_up, UNIT_POS_X);
if s_dist == 0 || s_dist_up == 0 {
merge(labels, labels_index - s_dist, labels_index_up - s_dist_up);
}
} else if p && p_up_prev && s_dist == 0 {
let s_dist_up_prev = select(
UNIT_POS_X == 0,
last_dist_up - 1,
start_distance(pixels_up, UNIT_POS_X - 1),
);
merge(labels, labels_index, labels_index_up - 1 - s_dist_up_prev);
} else if p_prev && p_up && s_dist_up == 0 {
let s_dist_prev = select(
UNIT_POS_X == 0,
last_dist - 1,
start_distance(pixels, UNIT_POS_X - 1),
);
merge(labels, labels_index - 1 - s_dist_prev, labels_index_up);
}
}
}
}
}
}
#[cube(launch_unchecked)]
fn relabeling<I: Int, BT: CubePrimitive>(img: &Tensor<BT>, labels: &mut Tensor<I>) {
let plane_start_x = CUBE_POS_X * CUBE_DIM_X;
let y = ABSOLUTE_POS_Y;
let x = plane_start_x + UNIT_POS_X;
let cols = labels.shape(1) as u32;
let rows = labels.shape(0) as u32;
let img_step = img.stride(0) as u32;
let labels_step = labels.stride(0) as u32;
if x < cols && y < rows {
let mut mask = 0xffffffffu32;
if cols - plane_start_x < 32 {
mask >>= 32 - (cols - plane_start_x);
}
let img_index = y * img_step + x;
let labels_index = y * labels_step + x;
let p = bool::cast_from(img[img_index as usize]);
let pixels = ballot_dyn(UNIT_POS_Y, p) & mask;
let s_dist = start_distance(pixels, UNIT_POS_X);
let mut label = 0u32;
if p && s_dist == 0 {
label = u32::cast_from(labels[labels_index as usize]) - 1;
while label != u32::cast_from(labels[label as usize]) - 1 {
label = u32::cast_from(labels[label as usize]) - 1;
}
}
label = plane_shuffle(label, UNIT_POS_X - s_dist);
if p {
labels[labels_index as usize] = I::cast_from(label + 1);
}
}
}
#[cube(launch_unchecked)]
fn analysis<I: Int, BT: CubePrimitive>(
img: &Tensor<BT>,
labels: &mut Tensor<I>,
area: &mut Tensor<Atomic<I>>,
top: &mut Tensor<Atomic<I>>,
left: &mut Tensor<Atomic<I>>,
right: &mut Tensor<Atomic<I>>,
bottom: &mut Tensor<Atomic<I>>,
max_label: &mut Tensor<Atomic<I>>,
#[comptime] opts: ConnectedStatsOptions,
) {
let y = ABSOLUTE_POS_Y;
let x = ABSOLUTE_POS_X;
let cols = labels.shape(1) as u32;
let rows = labels.shape(0) as u32;
let img_step = img.stride(0) as u32;
let labels_step = labels.stride(0) as u32;
if x < cols && y < rows {
let mut mask = 0xffffffffu32;
if cols - CUBE_POS_X * CUBE_DIM_X < 32 {
mask >>= 32 - (cols - CUBE_POS_X * CUBE_DIM_X);
}
let img_index = y * img_step + x;
let labels_index = y * labels_step + x;
let p = bool::cast_from(img[img_index as usize]);
let pixels = ballot_dyn(UNIT_POS_Y, p) & mask;
let s_dist = start_distance(pixels, UNIT_POS_X);
let count = end_distance(pixels, UNIT_POS_X);
let max_x = x + count - 1;
let mut label = 0u32;
if p && s_dist == 0 {
label = u32::cast_from(labels[labels_index as usize]) - 1;
while label != u32::cast_from(labels[label as usize]) - 1 {
label = u32::cast_from(labels[label as usize]) - 1;
}
label += 1;
area[label as usize].fetch_add(I::cast_from(count));
if opts.bounds_enabled {
left[label as usize].fetch_min(I::cast_from(x));
top[label as usize].fetch_min(I::cast_from(y));
right[label as usize].fetch_max(I::cast_from(max_x));
bottom[label as usize].fetch_max(I::cast_from(y));
}
if comptime!(opts.max_label_enabled || opts.compact_labels) {
max_label[0].fetch_max(I::cast_from(label));
}
}
label = plane_shuffle(label, UNIT_POS_X - s_dist);
if p {
labels[labels_index as usize] = I::cast_from(label);
}
}
}
#[cube(launch_unchecked)]
fn compact_labels<I: Int>(
labels: &mut Tensor<I>,
remap: &Tensor<I>,
max_label: &Tensor<Atomic<I>>,
) {
let x = ABSOLUTE_POS_X;
let y = ABSOLUTE_POS_Y;
let labels_pos = y * labels.stride(0) as u32 + x;
if labels_pos as usize >= labels.len() {
terminate!();
}
let label = u32::cast_from(labels[labels_pos as usize]);
if label != 0 {
let new_label = remap[label as usize];
labels[labels_pos as usize] = new_label;
max_label[0].fetch_max(new_label);
}
}
#[cube(launch_unchecked)]
fn compact_stats<I: Int>(
area: &Tensor<I>,
area_new: &mut Tensor<I>,
top: &Tensor<I>,
top_new: &mut Tensor<I>,
left: &Tensor<I>,
left_new: &mut Tensor<I>,
right: &Tensor<I>,
right_new: &mut Tensor<I>,
bottom: &Tensor<I>,
bottom_new: &mut Tensor<I>,
remap: &Tensor<I>,
) {
let label = ABSOLUTE_POS_X;
if label as usize >= remap.len() {
terminate!();
}
let area = area[label as usize];
if area == I::new(0) {
terminate!();
}
let new_label = u32::cast_from(remap[label as usize]);
area_new[new_label as usize] = area;
// This should be gated but there's a problem with the Eq bound only being implemented for tuples
// up to 12 elems, so I can't pass the opts. It's not unsafe, but potentially unnecessary work.
top_new[new_label as usize] = top[label as usize];
left_new[new_label as usize] = left[label as usize];
right_new[new_label as usize] = right[label as usize];
bottom_new[new_label as usize] = bottom[label as usize];
}
#[allow(clippy::type_complexity)]
pub fn hardware_accelerated<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement>(
img: CubeTensor<R>,
stats_opt: ConnectedStatsOptions,
connectivity: Connectivity,
) -> Result<
(
CubeTensor<R>,
ConnectedStatsPrimitive<CubeBackend<R, F, I, BT>>,
),
String,
> {
let client = img.client.clone();
let device = img.device.clone();
if !client.properties().features.plane.contains(Plane::Ops) {
return Err("Requires plane instructions".into());
}
let props = &client.properties().hardware;
if props.plane_size_min == 32 && props.plane_size_max == 32 {
return Err("Requires plane size of at least 32".into());
}
// Somehow the kernel doesn't work on AMD and Apple Silicon.
//
// The check invalidates those, but probably not for the right reason.
if props.plane_size_max != 32 {
return Err("Requires plane size of at least 32".into());
}
let [rows, cols] = img.meta.shape().dims();
let labels = zeros_client::<R>(client.clone(), device.clone(), img.shape(), I::dtype());
// Assume 32 wide warp. Currently, larger warps are handled by just exiting everything past 32.
// This isn't ideal but we require CUBE_DIM_X == warp_size, and we can't query the actual warp
// size at compile time. `REQUIRE_FULL_SUBGROUPS` or subgroup size controls are not supported
// in wgpu.
let warp_size = 32;
let cube_dim = CubeDim::new_2d(warp_size, BLOCK_H as u32);
let cube_count = CubeCount::new_2d(1, (rows as u32).div_ceil(cube_dim.y));
unsafe {
strip_labeling::launch_unchecked::<I, BT, R>(
&client,
cube_count,
cube_dim,
img.as_tensor_arg(1),
labels.as_tensor_arg(1),
connectivity,
)
.expect("Kernel to never fail");
};
let horizontal_warps = Ord::min((cols as u32).div_ceil(warp_size), 32);
let cube_dim_merge = CubeDim::new_3d(warp_size, 1, horizontal_warps);
let cube_count = CubeCount::new_2d(
Ord::max((cols as u32 + warp_size * 30 - 1) / (warp_size * 31), 1),
(rows as u32 - 1) / BLOCK_H as u32,
);
unsafe {
strip_merge::launch_unchecked::<I, BT, R>(
&client,
cube_count,
cube_dim_merge,
img.as_tensor_arg(1),
labels.as_tensor_arg(1),
connectivity,
)
.expect("Kernel to never fail");
};
let cube_count = CubeCount::new_2d(
(cols as u32).div_ceil(cube_dim.x),
(rows as u32).div_ceil(cube_dim.y),
);
let mut stats = stats_from_opts(labels.clone(), stats_opt);
if stats_opt == ConnectedStatsOptions::none() {
unsafe {
relabeling::launch_unchecked::<I, BT, R>(
&client,
cube_count,
cube_dim,
img.as_tensor_arg(1),
labels.as_tensor_arg(1),
)
.expect("Kernel to never fail");
};
} else {
unsafe {
analysis::launch_unchecked::<I, BT, R>(
&client,
cube_count,
cube_dim,
img.as_tensor_arg(1),
labels.as_tensor_arg(1),
stats.area.as_tensor_arg(1),
stats.top.as_tensor_arg(1),
stats.left.as_tensor_arg(1),
stats.right.as_tensor_arg(1),
stats.bottom.as_tensor_arg(1),
stats.max_label.as_tensor_arg(1),
stats_opt,
)
.expect("Kernel to never fail");
};
if stats_opt.compact_labels {
let max_label = CubeBackend::<R, F, I, BT>::int_max(stats.max_label);
let max_label = into_data_sync::<R>(max_label);
let max_label = ToElement::to_usize(&max_label.as_slice::<I>().unwrap()[0]);
let sliced = kernel::slice::<R>(
stats.area.clone(),
#[allow(clippy::single_range_in_vec_init)]
&[0..(max_label + 1).next_multiple_of(4)],
);
let relabel = prefix_sum::<R, I>(sliced);
let cube_dim = CubeDim::new_2d(32, 8);
let cube_count = CubeCount::new_2d(
(cols as u32).div_ceil(cube_dim.x),
(rows as u32).div_ceil(cube_dim.y),
);
stats.max_label =
zeros_client::<R>(client.clone(), device.clone(), Shape::new([1]), I::dtype());
unsafe {
compact_labels::launch_unchecked::<I, R>(
&client,
cube_count,
cube_dim,
labels.as_tensor_arg(1),
relabel.as_tensor_arg(1),
stats.max_label.as_tensor_arg(1),
)
.expect("Kernel to never fail");
};
let cube_dim = CubeDim::new_1d(256);
let cube_count = CubeCount::new_1d((rows * cols).div_ceil(256) as u32);
unsafe {
compact_stats::launch_unchecked::<I, R>(
&client,
cube_count,
cube_dim,
stats.area.copy().as_tensor_arg(1),
stats.area.as_tensor_arg(1),
stats.top.copy().as_tensor_arg(1),
stats.top.as_tensor_arg(1),
stats.left.copy().as_tensor_arg(1),
stats.left.as_tensor_arg(1),
stats.right.copy().as_tensor_arg(1),
stats.right.as_tensor_arg(1),
stats.bottom.copy().as_tensor_arg(1),
stats.bottom.as_tensor_arg(1),
relabel.as_tensor_arg(1),
)
.expect("Kernel to never fail");
};
}
}
Ok((labels, stats))
}

View File

@@ -0,0 +1,63 @@
mod hardware_accelerated;
/// Should eventually make this a full op, but the kernel is too specialized on ints and plane ops
/// to really use it in a general case. Needs more work to use as a normal tensor method.
mod prefix_sum;
use burn_cubecl::{
BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement,
ops::numeric::{full_client, zeros_client},
tensor::CubeTensor,
};
use burn_tensor::Shape;
pub use hardware_accelerated::*;
use crate::{ConnectedStatsOptions, ConnectedStatsPrimitive};
pub(crate) fn stats_from_opts<R, F, I, BT>(
l: CubeTensor<R>,
opts: ConnectedStatsOptions,
) -> ConnectedStatsPrimitive<CubeBackend<R, F, I, BT>>
where
R: CubeRuntime,
F: FloatElement,
I: IntElement,
BT: BoolElement,
{
let [height, width] = l.meta.shape().dims();
let shape = Shape::new([height * width]);
let zeros = || {
zeros_client::<R>(
l.client.clone(),
l.device.clone(),
shape.clone(),
I::dtype(),
)
};
let max = I::max_value();
let max = || full_client::<R, I>(l.client.clone(), shape.clone(), l.device.clone(), max);
let dummy = || {
CubeTensor::new_contiguous(
l.client.clone(),
l.device.clone(),
shape.clone(),
l.handle.clone(),
l.dtype,
)
};
ConnectedStatsPrimitive {
area: (opts != ConnectedStatsOptions::none())
.then(zeros)
.unwrap_or_else(dummy),
left: opts.bounds_enabled.then(max).unwrap_or_else(dummy),
top: opts.bounds_enabled.then(max).unwrap_or_else(dummy),
right: opts.bounds_enabled.then(zeros).unwrap_or_else(dummy),
bottom: opts.bounds_enabled.then(zeros).unwrap_or_else(dummy),
max_label: zeros_client::<R>(
l.client.clone(),
l.device.clone(),
Shape::new([1]),
I::dtype(),
),
}
}

View File

@@ -0,0 +1,262 @@
use burn_tensor::{Shape, TensorMetadata};
use cubecl::prelude::*;
use burn_cubecl::{
CubeRuntime, IntElement,
ops::{
numeric::{empty_device, zeros_client},
reshape,
},
tensor::CubeTensor,
};
const CUBE_SIZE: usize = 256;
const MIN_SUBGROUP_SIZE: usize = 4;
const MAX_REDUCE_SIZE: usize = CUBE_SIZE / MIN_SUBGROUP_SIZE;
const PART_SIZE: usize = 4096;
#[cube(launch_unchecked)]
fn prefix_sum_kernel<I: Int>(
scan_in: &Tensor<Line<I>>,
scan_out: &mut Tensor<Line<I>>,
scan_bump: &Tensor<Atomic<I>>,
reduction: &Tensor<Atomic<I>>,
cube_count_x: usize,
) {
let mut broadcast = SharedMemory::<I>::new(1usize);
let mut reduce = SharedMemory::<I>::new(MAX_REDUCE_SIZE);
let batch = CUBE_POS_Z as usize;
let line_spt = comptime!(PART_SIZE / CUBE_SIZE / scan_in.line_size());
let nums_per_cube = CUBE_SIZE * line_spt;
let v_last = comptime!(scan_in.line_size() - 1);
//acquire partition index
if UNIT_POS_X == 0 {
broadcast[0] = scan_bump[batch].fetch_add(I::new(1));
}
sync_cube();
let part_id = usize::cast_from(broadcast[0]);
let plane_id = UNIT_POS_X / PLANE_DIM;
let dev_offs = part_id * nums_per_cube;
let plane_offs = UNIT_POS_X as usize * line_spt;
// Exit if full plane is out of bounds
if dev_offs + plane_offs >= scan_in.shape(1) {
terminate!();
}
let zero = I::new(0);
let flag_reduction = I::new(1);
let flag_inclusive = I::new(2);
let flag_mask = I::new(3);
let red_offs = batch * reduction.stride(0);
let scan_offs = batch * scan_in.stride(0);
let mut t_scan = Array::<Line<I>>::lined(line_spt, scan_in.line_size());
{
let mut i = dev_offs + plane_offs + UNIT_POS_PLANE as usize;
if part_id < cube_count_x - 1 {
for k in 0..line_spt {
// Manually fuse not_equal and cast
let mut scan = Line::cast_from(scan_in[i + scan_offs].not_equal(Line::new(zero)));
#[unroll]
for v in 1..scan_in.line_size() {
let prev = scan[v - 1];
scan[v] += prev;
}
t_scan[k] = scan;
i += PLANE_DIM as usize;
}
}
if part_id == cube_count_x - 1 {
for k in 0..line_spt {
if i < scan_in.shape(1) {
// Manually fuse not_equal and cast
let mut scan =
Line::cast_from(scan_in[i + scan_offs].not_equal(Line::new(zero)));
#[unroll]
for v in 1..scan_in.line_size() {
let prev = scan[v - 1];
scan[v] += prev;
}
t_scan[k] = scan;
}
i += PLANE_DIM as usize;
}
}
let mut prev = zero;
let plane_mask = PLANE_DIM - 1;
let circular_shift = (UNIT_POS_PLANE + plane_mask) & plane_mask;
for k in 0..line_spt {
let t = plane_shuffle(plane_inclusive_sum(t_scan[k][v_last]), circular_shift);
t_scan[k] += Line::cast_from(select(UNIT_POS_PLANE != 0, t, zero) + prev);
prev += plane_broadcast(t, 0u32);
}
if UNIT_POS_PLANE == 0 {
reduce[plane_id as usize] = prev;
}
}
sync_cube();
//Non-divergent subgroup agnostic inclusive scan across subgroup reductions
let lane_log = count_trailing_zeros(PLANE_DIM);
let spine_size = CUBE_DIM >> lane_log;
{
let mut offset_0 = 0;
let mut offset_1 = 0;
let aligned_size =
1 << ((count_trailing_zeros(spine_size) + lane_log + 1) / lane_log * lane_log);
let mut j = PLANE_DIM;
while j <= aligned_size {
let i_0 = ((UNIT_POS_X + offset_0) << offset_1) - offset_0;
let pred_0 = i_0 < spine_size;
let t_0 = plane_inclusive_sum(select(pred_0, reduce[i_0 as usize], zero));
if pred_0 {
reduce[i_0 as usize] = t_0;
}
sync_cube();
if j != PLANE_DIM {
let rshift = j >> lane_log;
let i_1 = UNIT_POS_X + rshift;
if (i_1 & (j - 1)) >= rshift {
let pred_1 = i_1 < spine_size;
let t_1 = select(
pred_1,
reduce[(((i_1 >> offset_1) << offset_1) - 1) as usize],
zero,
);
if pred_1 && ((i_1 + 1) & (rshift - 1)) != 0 {
reduce[i_1 as usize] += t_1;
}
}
} else {
offset_0 += 1;
}
offset_1 += lane_log;
j <<= lane_log;
}
}
sync_cube();
//Device broadcast
if UNIT_POS_X == 0 {
reduction[part_id + red_offs].store(
(reduce[(spine_size - 1) as usize] << I::new(2))
| select(part_id != 0, flag_reduction, flag_inclusive),
)
}
//Lookback, single thread
if part_id != 0 {
if UNIT_POS_X == 0 {
let mut lookback_id = part_id - 1;
let mut prev_reduction = zero;
loop {
let flag_payload = reduction[lookback_id + red_offs].load();
if (flag_payload & flag_mask) == flag_inclusive {
prev_reduction += flag_payload >> I::new(2);
reduction[part_id + red_offs].store(
((prev_reduction + reduce[(spine_size - 1) as usize]) << I::new(2))
| flag_inclusive,
);
broadcast[0] = prev_reduction;
break;
}
if (flag_payload & flag_mask) == flag_reduction {
prev_reduction += flag_payload >> I::new(2);
lookback_id -= 1;
}
}
}
sync_cube();
}
{
let prev = if plane_id != 0 {
reduce[(plane_id - 1) as usize]
} else {
zero
};
let prev = Line::cast_from(broadcast[0] + prev);
let s_offset = UNIT_POS_PLANE + plane_id * PLANE_DIM * line_spt as u32;
let dev_offset = part_id * nums_per_cube;
let mut i = s_offset as usize + dev_offset;
if part_id < cube_count_x - 1 {
for k in 0..line_spt {
scan_out[i + scan_offs] = t_scan[k] + prev;
i += PLANE_DIM as usize;
}
}
if part_id == cube_count_x - 1 {
for k in 0..line_spt {
if i < scan_out.shape(1) {
scan_out[i + scan_offs] = t_scan[k] + prev;
}
i += PLANE_DIM as usize;
}
}
}
}
#[cube]
fn count_trailing_zeros(num: u32) -> u32 {
u32::find_first_set(num) - 1
}
/// Compute the prefix sum of a tensor
pub fn prefix_sum<R: CubeRuntime, I: IntElement>(input: CubeTensor<R>) -> CubeTensor<R> {
let client = input.client.clone();
let device = input.device.clone();
let num_elems = input.meta.num_elements();
let numbers = *input.meta.shape().last().unwrap();
let batches = num_elems / numbers;
let input = reshape(input, Shape::new([batches, numbers]));
let out = empty_device::<R, I>(client.clone(), device.clone(), input.shape());
let cubes = numbers.div_ceil(PART_SIZE);
let cube_dim = CubeDim::new_1d(CUBE_SIZE as u32);
let cube_count = CubeCount::new_3d(cubes as u32, 1, batches as u32);
let bump = zeros_client::<R>(
client.clone(),
device.clone(),
Shape::new([batches]),
I::dtype(),
);
let reduction = zeros_client::<R>(
client.clone(),
device.clone(),
Shape::new([batches, cubes]),
I::dtype(),
);
unsafe {
prefix_sum_kernel::launch_unchecked::<I, R>(
&input.client,
cube_count,
cube_dim,
input.as_tensor_arg(4),
out.as_tensor_arg(4),
bump.as_tensor_arg(1),
reduction.as_tensor_arg(1),
ScalarArg::new(cubes),
)
.expect("Kernel to never fail");
};
out
}

View File

@@ -0,0 +1,2 @@
mod connected_components;
mod ops;

View File

@@ -0,0 +1,211 @@
use crate::{
BoolVisionOps, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, FloatVisionOps,
IntVisionOps, QVisionOps, VisionBackend, backends::cpu,
};
use burn_cubecl::{BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement};
use burn_tensor::{
Element,
ops::{BoolTensor, IntTensor},
};
use super::connected_components::hardware_accelerated;
impl<R, F, I, BT> BoolVisionOps for CubeBackend<R, F, I, BT>
where
R: CubeRuntime,
F: FloatElement,
I: IntElement,
BT: BoolElement,
{
fn connected_components(img: BoolTensor<Self>, connectivity: Connectivity) -> IntTensor<Self> {
hardware_accelerated::<R, F, I, BT>(
img.clone(),
ConnectedStatsOptions::none(),
connectivity,
)
.map(|it| it.0)
.unwrap_or_else(|_| cpu::connected_components::<Self>(img, connectivity))
}
fn connected_components_with_stats(
img: BoolTensor<Self>,
connectivity: Connectivity,
opts: ConnectedStatsOptions,
) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
hardware_accelerated::<R, F, I, BT>(img.clone(), opts, connectivity).unwrap_or_else(|_| {
cpu::connected_components_with_stats::<Self>(img, connectivity, opts)
})
}
}
impl<R, F, I, BT> IntVisionOps for CubeBackend<R, F, I, BT>
where
R: CubeRuntime,
F: FloatElement,
I: IntElement,
BT: BoolElement,
{
}
impl<R, F, I, BT> FloatVisionOps for CubeBackend<R, F, I, BT>
where
R: CubeRuntime,
F: FloatElement,
I: IntElement,
BT: BoolElement,
{
}
impl<R, F, I, BT> QVisionOps for CubeBackend<R, F, I, BT>
where
R: CubeRuntime,
F: FloatElement,
I: IntElement,
BT: BoolElement,
{
}
impl<R, F, I, BT> VisionBackend for CubeBackend<R, F, I, BT>
where
R: CubeRuntime,
F: FloatElement,
I: IntElement,
BT: BoolElement,
{
}
#[cfg(feature = "fusion")]
mod fusion {
use super::*;
use burn_fusion::{
Fusion, FusionBackend, FusionRuntime,
stream::{Operation, OperationStreams},
};
use burn_ir::{CustomOpIr, HandleContainer, OperationIr, OperationOutput, TensorIr};
use burn_tensor::Shape;
impl<B: FusionBackend + BoolVisionOps> BoolVisionOps for Fusion<B> {
fn connected_components(img: BoolTensor<Self>, conn: Connectivity) -> IntTensor<Self> {
let height = img.shape[0];
let width = img.shape[1];
let client = img.client.clone();
#[derive(derive_new::new, Clone, Debug)]
struct ConnComp<B> {
desc: CustomOpIr,
conn: Connectivity,
_b: core::marker::PhantomData<B>,
}
impl<B1: FusionBackend + BoolVisionOps> Operation<B1::FusionRuntime> for ConnComp<B1> {
fn execute(
&self,
handles: &mut HandleContainer<
<B1::FusionRuntime as FusionRuntime>::FusionHandle,
>,
) {
let ([img], [labels]) = self.desc.as_fixed();
let input = handles.get_bool_tensor::<B1>(img);
let output = B1::connected_components(input, self.conn);
handles.register_int_tensor::<B1>(&labels.id, output);
}
}
let streams = OperationStreams::with_inputs([&img]);
let out = TensorIr::uninit(
client.create_empty_handle(),
Shape::new([height, width]),
B::IntElem::dtype(),
);
let desc = CustomOpIr::new("connected_components", &[img.into_ir()], &[out]);
client
.register(
streams,
OperationIr::Custom(desc.clone()),
ConnComp::<B>::new(desc, conn),
)
.output()
}
fn connected_components_with_stats(
img: BoolTensor<Self>,
conn: Connectivity,
opts: ConnectedStatsOptions,
) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
let height = img.shape[0];
let width = img.shape[1];
let client = img.client.clone();
#[derive(derive_new::new, Clone, Debug)]
struct ConnCompStats<B> {
desc: CustomOpIr,
conn: Connectivity,
opts: ConnectedStatsOptions,
_b: core::marker::PhantomData<B>,
}
impl<B1: FusionBackend + BoolVisionOps> Operation<B1::FusionRuntime> for ConnCompStats<B1> {
fn execute(
&self,
handles: &mut HandleContainer<
<B1::FusionRuntime as FusionRuntime>::FusionHandle,
>,
) {
let ([img], [labels, area, left, top, right, bottom, max_label]) =
self.desc.as_fixed();
let input = handles.get_bool_tensor::<B1>(img);
let (output, stats) =
B1::connected_components_with_stats(input, self.conn, self.opts);
handles.register_int_tensor::<B1>(&labels.id, output);
handles.register_int_tensor::<B1>(&area.id, stats.area);
handles.register_int_tensor::<B1>(&left.id, stats.left);
handles.register_int_tensor::<B1>(&top.id, stats.top);
handles.register_int_tensor::<B1>(&right.id, stats.right);
handles.register_int_tensor::<B1>(&bottom.id, stats.bottom);
handles.register_int_tensor::<B1>(&max_label.id, stats.max_label);
}
}
let dtype = B::IntElem::dtype();
let shape = Shape::new([height, width]);
let shape_flat = shape.clone().flatten();
let streams = OperationStreams::with_inputs([&img]);
let out = TensorIr::uninit(client.create_empty_handle(), shape.clone(), dtype);
let area = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);
let left = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);
let top = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);
let right = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);
let bottom = TensorIr::uninit(client.create_empty_handle(), shape_flat, dtype);
let max_label = TensorIr::uninit(client.create_empty_handle(), [1].into(), dtype);
let desc = CustomOpIr::new(
"connected_components",
&[img.into_ir()],
&[out, area, left, top, right, bottom, max_label],
);
let [out, area, left, top, right, bottom, max_label] = client
.register(
streams,
OperationIr::Custom(desc.clone()),
ConnCompStats::<B>::new(desc, conn, opts),
)
.try_into()
.unwrap();
let stats = ConnectedStatsPrimitive {
area,
left,
top,
right,
bottom,
max_label,
};
(out, stats)
}
}
impl<B: FusionBackend + IntVisionOps> IntVisionOps for Fusion<B> {}
impl<B: FusionBackend + FloatVisionOps> FloatVisionOps for Fusion<B> {}
impl<B: FusionBackend + QVisionOps> QVisionOps for Fusion<B> {}
impl<B: FusionBackend + VisionBackend> VisionBackend for Fusion<B> {}
}

View File

@@ -0,0 +1,5 @@
pub(crate) mod cpu;
#[cfg(feature = "cubecl-backend")]
mod cube;
pub use cpu::{KernelShape, create_structuring_element};

View File

@@ -0,0 +1,19 @@
use derive_new::new;
/// 2D size used for vision ops.
#[derive(new, Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct Size {
/// Width of the element
pub width: usize,
/// Height of the element
pub height: usize,
}
/// 2D Point used for vision ops. Coordinates start at the top left.
#[derive(new, Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct Point {
/// X (horizontal) coordinate
pub x: usize,
/// Y (vertical) coordinate
pub y: usize,
}

View File

@@ -0,0 +1,31 @@
//! Vision ops for burn, with GPU acceleration where possible.
//!
//! # Operations
//! Operation names are based on `opencv` wherever applicable.
//!
//! Currently implemented are:
//! - `connected_components`
//! - `connected_components_with_stats`
//! - `nms` (Non-Maximum Suppression)
//!
#![warn(missing_docs)]
extern crate alloc;
/// Backend implementations for JIT and CPU
pub mod backends;
mod base;
mod ops;
mod tensor;
mod transform;
pub use base::*;
pub use ops::*;
pub use tensor::*;
pub use transform::*;
/// Module for vision/image utilities
pub mod utils;
pub use backends::{KernelShape, create_structuring_element};

View File

@@ -0,0 +1,340 @@
use crate::{
Point,
backends::cpu::{self, MorphOp, morph},
};
use bon::Builder;
use burn_tensor::{
Bool, Float, Int, Tensor, TensorKind, TensorPrimitive,
backend::Backend,
ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
};
/// Connected components connectivity
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum Connectivity {
/// Four-connected (only connected in cardinal directions)
Four,
/// Eight-connected (connected if any of the surrounding 8 pixels are in the foreground)
Eight,
}
/// Which stats should be enabled for `connected_components_with_stats`.
/// Currently only used by the GPU implementation to save on atomic operations for unneeded stats.
///
/// Disabled stats are aliased to the labels tensor
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ConnectedStatsOptions {
/// Whether to enable bounding boxes
pub bounds_enabled: bool,
/// Whether to enable the max label
pub max_label_enabled: bool,
/// Whether labels must be contiguous starting at 1
pub compact_labels: bool,
}
/// Options for morphology ops
#[derive(Clone, Debug, Builder)]
pub struct MorphOptions<B: Backend, K: TensorKind<B>> {
/// Anchor position within the kernel. Defaults to the center.
pub anchor: Option<Point>,
/// Number of iterations to apply
#[builder(default = 1)]
pub iterations: usize,
/// Border type. Default: constant based on operation
#[builder(default)]
pub border_type: BorderType,
/// Value of each channel for constant border type
pub border_value: Option<Tensor<B, 1, K>>,
}
impl<B: Backend, K: TensorKind<B>> Default for MorphOptions<B, K> {
fn default() -> Self {
Self {
anchor: Default::default(),
iterations: 1,
border_type: Default::default(),
border_value: Default::default(),
}
}
}
/// Morphology border type
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
pub enum BorderType {
/// Constant border with per-channel value. If no value is provided, the value is picked based
/// on the morph op.
#[default]
Constant,
/// Replicate first/last element
Replicate,
/// Reflect start/end elements
Reflect,
/// Reflect start/end elements, ignoring the first/last element
Reflect101,
/// Not supported for erode/dilate
Wrap,
}
/// Stats collected by the connected components analysis
///
/// Disabled analyses may be aliased to labels
#[derive(Clone, Debug)]
pub struct ConnectedStats<B: Backend> {
/// Total area of each component
pub area: Tensor<B, 1, Int>,
/// Topmost y coordinate in the component
pub top: Tensor<B, 1, Int>,
/// Leftmost x coordinate in the component
pub left: Tensor<B, 1, Int>,
/// Rightmost x coordinate in the component
pub right: Tensor<B, 1, Int>,
/// Bottommost y coordinate in the component
pub bottom: Tensor<B, 1, Int>,
/// Scalar tensor of the max label
pub max_label: Tensor<B, 1, Int>,
}
/// Primitive version of [`ConnectedStats`], to be returned by the backend
pub struct ConnectedStatsPrimitive<B: Backend> {
/// Total area of each component
pub area: IntTensor<B>,
/// Leftmost x coordinate in the component
pub left: IntTensor<B>,
/// Topmost y coordinate in the component
pub top: IntTensor<B>,
/// Rightmost x coordinate in the component
pub right: IntTensor<B>,
/// Bottommost y coordinate in the component
pub bottom: IntTensor<B>,
/// Scalar tensor of the max label
pub max_label: IntTensor<B>,
}
impl<B: Backend> From<ConnectedStatsPrimitive<B>> for ConnectedStats<B> {
fn from(value: ConnectedStatsPrimitive<B>) -> Self {
ConnectedStats {
area: Tensor::from_primitive(value.area),
top: Tensor::from_primitive(value.top),
left: Tensor::from_primitive(value.left),
right: Tensor::from_primitive(value.right),
bottom: Tensor::from_primitive(value.bottom),
max_label: Tensor::from_primitive(value.max_label),
}
}
}
impl<B: Backend> ConnectedStats<B> {
/// Convert a connected stats into the corresponding primitive
pub fn into_primitive(self) -> ConnectedStatsPrimitive<B> {
ConnectedStatsPrimitive {
area: self.area.into_primitive(),
top: self.top.into_primitive(),
left: self.left.into_primitive(),
right: self.right.into_primitive(),
bottom: self.bottom.into_primitive(),
max_label: self.max_label.into_primitive(),
}
}
}
impl Default for ConnectedStatsOptions {
fn default() -> Self {
Self::all()
}
}
impl ConnectedStatsOptions {
/// Don't collect any stats
pub fn none() -> Self {
Self {
bounds_enabled: false,
max_label_enabled: false,
compact_labels: false,
}
}
/// Collect all stats
pub fn all() -> Self {
Self {
bounds_enabled: true,
max_label_enabled: true,
compact_labels: true,
}
}
}
/// Non-Maximum Suppression options.
#[derive(Clone, Copy, Debug)]
pub struct NmsOptions {
/// IoU threshold for suppression (default: 0.5).
/// Boxes with IoU > threshold with a higher-scoring box are suppressed.
pub iou_threshold: f32,
/// Score threshold to filter boxes before NMS (default: 0.0, i.e., no filtering).
/// Boxes with score < score_threshold are discarded.
pub score_threshold: f32,
/// Maximum number of boxes to keep (0 = unlimited).
pub max_output_boxes: usize,
}
impl Default for NmsOptions {
fn default() -> Self {
Self {
iou_threshold: 0.5,
score_threshold: 0.0,
max_output_boxes: 0,
}
}
}
/// Vision capable backend, implemented by each backend
pub trait VisionBackend:
BoolVisionOps + IntVisionOps + FloatVisionOps + QVisionOps + Backend
{
}
/// Vision ops on bool tensors
pub trait BoolVisionOps: Backend {
/// Computes the connected components labeled image of boolean image with 4 or 8 way
/// connectivity - returns a tensor of the component label of each pixel.
///
/// `img`- The boolean image tensor in the format [batches, height, width]
fn connected_components(img: BoolTensor<Self>, connectivity: Connectivity) -> IntTensor<Self> {
cpu::connected_components::<Self>(img, connectivity)
}
/// Computes the connected components labeled image of boolean image with 4 or 8 way
/// connectivity and collects statistics on each component - returns a tensor of the component
/// label of each pixel, along with stats collected for each component.
///
/// `img`- The boolean image tensor in the format [batches, height, width]
fn connected_components_with_stats(
img: BoolTensor<Self>,
connectivity: Connectivity,
opts: ConnectedStatsOptions,
) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
cpu::connected_components_with_stats(img, connectivity, opts)
}
/// Erodes an input tensor with the specified kernel.
fn bool_erode(
input: BoolTensor<Self>,
kernel: BoolTensor<Self>,
opts: MorphOptions<Self, Bool>,
) -> BoolTensor<Self> {
let input = Tensor::<Self, 3, Bool>::from_primitive(input);
morph(input, kernel, MorphOp::Erode, opts).into_primitive()
}
/// Dilates an input tensor with the specified kernel.
fn bool_dilate(
input: BoolTensor<Self>,
kernel: BoolTensor<Self>,
opts: MorphOptions<Self, Bool>,
) -> BoolTensor<Self> {
let input = Tensor::<Self, 3, Bool>::from_primitive(input);
morph(input, kernel, MorphOp::Dilate, opts).into_primitive()
}
}
/// Vision ops on int tensors
pub trait IntVisionOps: Backend {
/// Erodes an input tensor with the specified kernel.
fn int_erode(
input: IntTensor<Self>,
kernel: BoolTensor<Self>,
opts: MorphOptions<Self, Int>,
) -> IntTensor<Self> {
let input = Tensor::<Self, 3, Int>::from_primitive(input);
morph(input, kernel, MorphOp::Erode, opts).into_primitive()
}
/// Dilates an input tensor with the specified kernel.
fn int_dilate(
input: IntTensor<Self>,
kernel: BoolTensor<Self>,
opts: MorphOptions<Self, Int>,
) -> IntTensor<Self> {
let input = Tensor::<Self, 3, Int>::from_primitive(input);
morph(input, kernel, MorphOp::Dilate, opts).into_primitive()
}
}
/// Vision ops on float tensors
pub trait FloatVisionOps: Backend {
/// Erodes an input tensor with the specified kernel.
fn float_erode(
input: FloatTensor<Self>,
kernel: BoolTensor<Self>,
opts: MorphOptions<Self, Float>,
) -> FloatTensor<Self> {
let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::Float(input));
morph(input, kernel, MorphOp::Erode, opts)
.into_primitive()
.tensor()
}
/// Dilates an input tensor with the specified kernel.
fn float_dilate(
input: FloatTensor<Self>,
kernel: BoolTensor<Self>,
opts: MorphOptions<Self, Float>,
) -> FloatTensor<Self> {
let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::Float(input));
morph(input, kernel, MorphOp::Dilate, opts)
.into_primitive()
.tensor()
}
/// Perform Non-Maximum Suppression on bounding boxes.
///
/// Returns indices of kept boxes after suppressing overlapping detections.
/// Boxes are processed in descending score order; a box suppresses all
/// lower-scoring boxes with IoU > threshold.
///
/// # Arguments
/// * `boxes` - Bounding boxes as \[N, 4\] tensor in (x1, y1, x2, y2) format
/// * `scores` - Confidence scores as \[N\] tensor
/// * `options` - NMS options (IoU threshold, score threshold, max boxes)
///
/// # Returns
/// Indices of kept boxes as \[M\] tensor where M <= N
fn nms(
boxes: FloatTensor<Self>,
scores: FloatTensor<Self>,
options: NmsOptions,
) -> IntTensor<Self> {
let boxes = Tensor::<Self, 2>::from_primitive(TensorPrimitive::Float(boxes));
let scores = Tensor::<Self, 1>::from_primitive(TensorPrimitive::Float(scores));
cpu::nms::<Self>(boxes, scores, options).into_primitive()
}
}
/// Vision ops on quantized float tensors
pub trait QVisionOps: Backend {
/// Erodes an input tensor with the specified kernel.
fn q_erode(
input: QuantizedTensor<Self>,
kernel: BoolTensor<Self>,
opts: MorphOptions<Self, Float>,
) -> QuantizedTensor<Self> {
let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::QFloat(input));
match morph(input, kernel, MorphOp::Erode, opts).into_primitive() {
TensorPrimitive::QFloat(tensor) => tensor,
_ => unreachable!(),
}
}
/// Dilates an input tensor with the specified kernel.
fn q_dilate(
input: QuantizedTensor<Self>,
kernel: BoolTensor<Self>,
opts: MorphOptions<Self, Float>,
) -> QuantizedTensor<Self> {
let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::QFloat(input));
match morph(input, kernel, MorphOp::Dilate, opts).into_primitive() {
TensorPrimitive::QFloat(tensor) => tensor,
_ => unreachable!(),
}
}
}

View File

@@ -0,0 +1,3 @@
mod base;
pub use base::*;

View File

@@ -0,0 +1,186 @@
use burn_tensor::{
BasicOps, Bool, Float, Int, Tensor, TensorKind, TensorPrimitive, backend::Backend,
ops::BoolTensor,
};
use crate::{
BoolVisionOps, ConnectedStats, ConnectedStatsOptions, Connectivity, MorphOptions, NmsOptions,
VisionBackend,
};
/// Connected components tensor extensions
pub trait ConnectedComponents<B: Backend> {
/// Computes the connected components labeled image of boolean image with 4 or 8 way
/// connectivity - returns a tensor of the component label of each pixel.
///
/// `img`- The boolean image tensor in the format [batches, height, width]
fn connected_components(self, connectivity: Connectivity) -> Tensor<B, 2, Int>;
/// Computes the connected components labeled image of boolean image with 4 or 8 way
/// connectivity and collects statistics on each component - returns a tensor of the component
/// label of each pixel, along with stats collected for each component.
///
/// `img`- The boolean image tensor in the format [batches, height, width]
fn connected_components_with_stats(
self,
connectivity: Connectivity,
options: ConnectedStatsOptions,
) -> (Tensor<B, 2, Int>, ConnectedStats<B>);
}
/// Morphology tensor operations
pub trait Morphology<B: Backend, K: TensorKind<B>> {
/// Erodes this tensor using the specified kernel.
/// Assumes NHWC layout.
fn erode(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self;
/// Dilates this tensor using the specified kernel.
/// Assumes NHWC layout.
fn dilate(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self;
}
/// Morphology tensor operations
pub trait MorphologyKind<B: Backend>: BasicOps<B> {
/// Erodes this tensor using the specified kernel
fn erode(
tensor: Self::Primitive,
kernel: BoolTensor<B>,
opts: MorphOptions<B, Self>,
) -> Self::Primitive;
/// Dilates this tensor using the specified kernel
fn dilate(
tensor: Self::Primitive,
kernel: BoolTensor<B>,
opts: MorphOptions<B, Self>,
) -> Self::Primitive;
}
/// Non-maximum suppression tensor operations
pub trait Nms<B: Backend> {
/// Perform Non-Maximum Suppression on this tensor of bounding boxes.
///
/// Returns indices of kept boxes after suppressing overlapping detections.
/// Boxes are processed in descending score order; a box suppresses all
/// lower-scoring boxes with IoU > threshold.
///
/// # Arguments
/// * `self` - Bounding boxes as \[N, 4\] tensor in (x1, y1, x2, y2) format
/// * `scores` - Confidence scores as \[N\] tensor
/// * `options` - NMS options (IoU threshold, score threshold, max boxes)
///
/// # Returns
/// Indices of kept boxes as \[M\] tensor where M <= N
fn nms(self, scores: Tensor<B, 1, Float>, opts: NmsOptions) -> Tensor<B, 1, Int>;
}
impl<B: BoolVisionOps> ConnectedComponents<B> for Tensor<B, 2, Bool> {
fn connected_components(self, connectivity: Connectivity) -> Tensor<B, 2, Int> {
Tensor::from_primitive(B::connected_components(self.into_primitive(), connectivity))
}
fn connected_components_with_stats(
self,
connectivity: Connectivity,
options: ConnectedStatsOptions,
) -> (Tensor<B, 2, Int>, ConnectedStats<B>) {
let (labels, stats) =
B::connected_components_with_stats(self.into_primitive(), connectivity, options);
(Tensor::from_primitive(labels), stats.into())
}
}
impl<B: VisionBackend, K: MorphologyKind<B>> Morphology<B, K> for Tensor<B, 3, K> {
fn erode(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self {
Tensor::new(K::erode(
self.into_primitive(),
kernel.into_primitive(),
opts,
))
}
fn dilate(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self {
Tensor::new(K::dilate(
self.into_primitive(),
kernel.into_primitive(),
opts,
))
}
}
impl<B: VisionBackend> MorphologyKind<B> for Float {
fn erode(
tensor: Self::Primitive,
kernel: BoolTensor<B>,
opts: MorphOptions<B, Self>,
) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_erode(tensor, kernel, opts))
}
TensorPrimitive::QFloat(tensor) => {
TensorPrimitive::QFloat(B::q_erode(tensor, kernel, opts))
}
}
}
fn dilate(
tensor: Self::Primitive,
kernel: BoolTensor<B>,
opts: MorphOptions<B, Self>,
) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_dilate(tensor, kernel, opts))
}
TensorPrimitive::QFloat(tensor) => {
TensorPrimitive::QFloat(B::q_dilate(tensor, kernel, opts))
}
}
}
}
impl<B: VisionBackend> MorphologyKind<B> for Int {
fn erode(
tensor: Self::Primitive,
kernel: BoolTensor<B>,
opts: MorphOptions<B, Self>,
) -> Self::Primitive {
B::int_erode(tensor, kernel, opts)
}
fn dilate(
tensor: Self::Primitive,
kernel: BoolTensor<B>,
opts: MorphOptions<B, Self>,
) -> Self::Primitive {
B::int_dilate(tensor, kernel, opts)
}
}
impl<B: VisionBackend> MorphologyKind<B> for Bool {
fn erode(
tensor: Self::Primitive,
kernel: BoolTensor<B>,
opts: MorphOptions<B, Self>,
) -> Self::Primitive {
B::bool_erode(tensor, kernel, opts)
}
fn dilate(
tensor: Self::Primitive,
kernel: BoolTensor<B>,
opts: MorphOptions<B, Self>,
) -> Self::Primitive {
B::bool_dilate(tensor, kernel, opts)
}
}
impl<B: VisionBackend> Nms<B> for Tensor<B, 2> {
fn nms(self, scores: Tensor<B, 1>, options: NmsOptions) -> Tensor<B, 1, Int> {
match (self.into_primitive(), scores.into_primitive()) {
(TensorPrimitive::Float(boxes), TensorPrimitive::Float(scores)) => {
Tensor::<B, 1, Int>::from_primitive(B::nms(boxes, scores, options))
}
_ => todo!("Quantized inputs are not yet supported"),
}
}
}

View File

@@ -0,0 +1,27 @@
use std::path::PathBuf;
use burn_tensor::{Shape, Tensor, TensorData, backend::Backend};
use image::{DynamicImage, ImageBuffer, Luma, Rgb};
mod connected_components;
mod morphology;
#[macro_export]
macro_rules! testgen_all {
() => {
use burn_tensor::{Bool, Float, Int};
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
pub type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, Int>;
pub type TestTensorBool<const D: usize> = burn_tensor::Tensor<TestBackend, D, Bool>;
pub mod vision {
pub use super::*;
pub type IntType = <TestBackend as burn_tensor::backend::Backend>::IntElem;
burn_vision::testgen_connected_components!();
burn_vision::testgen_morphology!();
}
};
}

View File

@@ -0,0 +1,3 @@
mod transform2d;
pub use transform2d::*;

View File

@@ -0,0 +1,229 @@
use burn_tensor::{
Tensor,
backend::Backend,
grid::affine_grid_2d,
ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode},
};
/// 2D point transformation
///
/// Useful for resampling: rotating, scaling, translating, etc image tensors
pub struct Transform2D {
// 2x3 transformation matrix, to be used with column vectors:
// T(x) = Ax
transform: [[f32; 3]; 2],
}
impl Transform2D {
/// Transforms an image
///
/// * `img` - Images tensor with shape (batch_size, channels, height, width)
///
/// # Returns
///
/// A tensor with the same as the input
pub fn transform<B: Backend>(self, img: Tensor<B, 4>) -> Tensor<B, 4> {
let [batch_size, channels, height, width] = img.shape().dims();
let transform = Tensor::<B, 2>::from(self.transform);
let transform = transform.reshape([1, 2, 3]).expand([batch_size, 2, 3]);
let grid = affine_grid_2d(transform, [batch_size, channels, height, width]);
let options = GridSampleOptions::new(InterpolateMode::Bilinear)
.with_padding_mode(GridSamplePaddingMode::Border)
.with_align_corners(true);
img.grid_sample_2d(grid, options)
}
/// Makes a 2d transformation composed of other transformations
pub fn composed<I: IntoIterator<Item = Self>>(transforms: I) -> Self {
let mut result = Self::identity();
for t in transforms.into_iter() {
result = result.mul(t);
}
result
}
/// Multiply two affine transforms represented as 2x3 matrices
fn mul(self, other: Transform2D) -> Transform2D {
let mut result = [[0.0f32; 3]; 2];
// Row 0
result[0][0] = self.transform[0][0] * other.transform[0][0]
+ self.transform[0][1] * other.transform[1][0];
result[0][1] = self.transform[0][0] * other.transform[0][1]
+ self.transform[0][1] * other.transform[1][1];
result[0][2] = self.transform[0][0] * other.transform[0][2]
+ self.transform[0][1] * other.transform[1][2]
+ self.transform[0][2];
// Row 1
result[1][0] = self.transform[1][0] * other.transform[0][0]
+ self.transform[1][1] * other.transform[1][0];
result[1][1] = self.transform[1][0] * other.transform[0][1]
+ self.transform[1][1] * other.transform[1][1];
result[1][2] = self.transform[1][0] * other.transform[0][2]
+ self.transform[1][1] * other.transform[1][2]
+ self.transform[1][2];
Transform2D { transform: result }
}
/// Makes an identity transform (x = Ax)
pub fn identity() -> Self {
Self {
transform: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
}
}
/// Makes a [`Transform2D`] for rotating a tensor
///
/// * `theta` - In radians, the rotation
/// * `cx` - Center of rotation, x
/// * `cy` - Center of rotation, y
pub fn rotation(theta: f32, cx: f32, cy: f32) -> Self {
let cos_theta = theta.cos();
let sin_theta = theta.sin();
let transform = [
[cos_theta, -sin_theta, cx - cos_theta * cx + sin_theta * cy],
[sin_theta, cos_theta, cy - sin_theta * cx - cos_theta * cy],
];
Self { transform }
}
/// Makes a [`Transform2D`] for scaling an image tensor
///
/// * `sx` - Scale factor in the x direction
/// * `sy` - Scale factor in the y direction
/// * `cx` - Center of scaling, x
/// * `cy` - Center of scaling, y
pub fn scale(sx: f32, sy: f32, cx: f32, cy: f32) -> Self {
let transform = [[sx, 0.0, cx - sx * cx], [0.0, sy, cy - sy * cy]];
Self { transform }
}
/// Makes a [`Transform2D`] for translating an image tensor
///
/// * `tx` - Translation in the x direction
/// * `ty` - Translation in the y direction
pub fn translation(tx: f32, ty: f32) -> Self {
let transform = [[1.0, 0.0, tx], [0.0, 1.0, ty]];
Self { transform }
}
/// Applies a general shear transformation around the image center,
/// combining both X and Y shear.
///
/// # Arguments
/// * `shx` - Shear factor along the X-axis.
/// * `shy` - Shear factor along the Y-axis.
/// * `cx`, `cy` - Coordinates of the image center.
///
/// # Returns
/// * `Self` with a combined shear transform matrix.
pub fn shear(shx: f32, shy: f32, cx: f32, cy: f32) -> Self {
let transform = [[1.0, shx, -shx * cy], [shy, 1.0, -shy * cx]];
Self { transform }
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn_ndarray::NdArray;
use burn_tensor::Tolerance;
type B = NdArray;
#[test]
fn transform_identity_translation() {
let t = Transform2D::translation(0.0, 0.0);
let image_original = Tensor::<B, 4>::from([[[[1., 0.], [0., 2.]]]]);
let image_transformed = t.transform(image_original.clone());
image_original
.to_data()
.assert_approx_eq(&image_transformed.to_data(), Tolerance::<f32>::balanced());
}
#[test]
fn transform_translation() {
let t = Transform2D::translation(1., 1.);
let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
// This result would change if the padding method is different
let image_expected = Tensor::<B, 4>::from([[[[2.5, 3.], [3.5, 4.]]]]);
let image = t.transform(image);
image_expected
.to_data()
.assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
}
#[test]
fn transform_rotation_90_degrees() {
let t = Transform2D::rotation(std::f32::consts::FRAC_PI_2, 0.0, 0.0);
let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
let image_expected = Tensor::<B, 4>::from([[[[2., 4.], [1., 3.]]]]);
let image = t.transform(image);
image_expected
.to_data()
.assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
}
#[test]
fn transform_rotation_around_corner() {
let cx = 1.;
let cy = -1.;
let t = Transform2D::rotation(std::f32::consts::FRAC_PI_2, cx, cy);
let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
// This result would change if the padding method is different
let image_expected = Tensor::<B, 4>::from([[[[2., 2.], [1., 1.]]]]);
let image = t.transform(image);
image_expected
.to_data()
.assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
}
#[test]
fn transform_scale() {
let cx = 0.0;
let cy = 0.0;
let t = Transform2D::scale(0.5, 0.5, cx, cy);
let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
let image_expected = Tensor::<B, 4>::from([[[[1.75, 2.25], [2.75, 3.25]]]]);
let image = t.transform(image);
image_expected
.to_data()
.assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
}
#[test]
fn transform_scale_around_corner() {
let cx = 1.;
let cy = -1.;
let t = Transform2D::scale(0.5, 0.5, cx, cy);
let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
let image_expected = Tensor::<B, 4>::from([[[[1.5, 2.], [2.5, 3.]]]]);
let image = t.transform(image);
image_expected
.to_data()
.assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
}
#[test]
fn transform_combined() {
let t1 = Transform2D::translation(0.2, -0.5);
let t2 = Transform2D::rotation(std::f32::consts::FRAC_PI_3, 0., 0.);
let t = Transform2D::composed([t1, t2]);
let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
// This result would change if the padding method is different
let image_expected =
Tensor::<B, 4>::from([[[[1.7830127, 2.8660254], [1.1339746, 3.2830124]]]]);
let image = t.transform(image);
image_expected
.to_data()
.assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
}
}

View File

@@ -0,0 +1,3 @@
mod save;
pub use save::*;

View File

@@ -0,0 +1,191 @@
//! Utilities for saving tensors as images
use burn_tensor::{ElementConversion, Tensor, backend::Backend};
use image::{Rgb, RgbImage};
use std::fs;
use std::path::Path;
/// How to save a tensor as an image
pub struct TensorDisplayOptions {
/// How should the dimensions be interpreted
pub dim_order: ImageDimOrder,
/// What colors should be used
pub color_opts: ColorDisplayOpts,
/// How to handle batches
pub batch_opts: Option<BatchDisplayOpts>,
/// Output image width
pub width_out: usize,
/// Output image height
pub height_out: usize,
}
/// How to interpret dimensions for image tensors
pub enum ImageDimOrder {
/// dims: (height, width)
Hw,
/// dims: (channels, height, width)
Chw,
/// dims: (height, width, channels)
Hwc,
/// dims: (batch_size, height, width)
Nhw,
/// dims: (batch_size, channels, height, width)
Nchw,
/// dims: (batch_size, height, width, channels)
Nhwc,
}
/// How to translate tensor values to colors
pub enum ColorDisplayOpts {
/// The values in each channel are respectively assigned to an RGB channel
Rgb,
/// The channel value is mapped between two colors
Monochrome {
/// Color assigned to the minimum value
min: [f32; 3],
/// Color assigned to the maximum value
max: [f32; 3],
},
}
/// How to handle multi-batch tensors
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum BatchDisplayOpts {
/// Each item is placed consecutively in the image
Tiled,
/// Each item is aggregated
Aggregated,
}
/// Save a tensor of a batch of images as an image
///
/// * `tensor` - Image batch with shape (N, height, width)
/// * `opts` - Options for how to draw the tensor
/// * `path` - The file path to use
pub fn save_tensor_as_image<B: Backend, const D: usize, P: AsRef<std::ffi::OsStr>>(
tensor: Tensor<B, D>,
opts: TensorDisplayOptions,
path: P,
) -> Result<(), Box<dyn std::error::Error>> {
// Output file
let path = Path::new(&path);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let tensor = normalize(tensor);
// convert to (N,C,H,W) format
let tensor: Tensor<B, 4> = match opts.dim_order {
ImageDimOrder::Hw => {
let [h, w] = tensor.shape().dims();
tensor.reshape([1, 1, h, w])
}
ImageDimOrder::Chw => {
let [c, h, w] = tensor.shape().dims();
tensor.reshape([1, c, h, w])
}
ImageDimOrder::Hwc => {
let [h, w, c] = tensor.shape().dims();
tensor.swap_dims(0, 2).swap_dims(1, 2).reshape([1, c, h, w])
}
ImageDimOrder::Nhw => {
let [n, h, w] = tensor.shape().dims();
tensor.reshape([n, 1, h, w])
}
ImageDimOrder::Nchw => tensor.reshape([0, 0, 0, 0]),
ImageDimOrder::Nhwc => tensor.swap_dims(1, 3).swap_dims(2, 3).reshape([0, 0, 0, 0]),
};
let data = tensor.to_data();
let shape = data.shape.clone();
let (batch, channels, src_height, src_width) = (shape[0], shape[1], shape[2], shape[3]);
let mut img = if let Some(batch_opts) = &opts.batch_opts
&& BatchDisplayOpts::Tiled == *batch_opts
{
RgbImage::new(opts.width_out as u32, (opts.height_out * batch) as u32)
} else {
RgbImage::new(opts.width_out as u32, opts.height_out as u32)
};
let data_vec = data.to_vec::<f32>().unwrap();
let mut channel_vals = vec![0 as f32; channels]; // value for each channel in a given pixel
for n in 0..batch {
for x in 0..opts.width_out {
for y in 0..opts.height_out {
let i = ((x as f32) / (opts.width_out as f32) * (src_width as f32))
.floor()
.clamp(0.0, src_width as f32) as usize;
let j = ((y as f32) / (opts.height_out as f32) * (src_height as f32))
.floor()
.clamp(0.0, src_height as f32) as usize;
for c in 0..channels {
channel_vals[c] =
data_vec[i + (j + (n * channels + c) * src_height) * src_width];
}
let (x, y) = if let Some(batch_opts) = opts.batch_opts
&& BatchDisplayOpts::Tiled == batch_opts
{
let batch_x = 0;
let batch_y = n as u32 * opts.height_out as u32;
(x as u32 + batch_x, y as u32 + batch_y)
} else {
(x as u32, y as u32)
};
let mut pixel = [0 as f32; 3];
match opts.color_opts {
ColorDisplayOpts::Rgb => match channels {
1 => {
pixel[0] = channel_vals[0];
pixel[1] = 0.0;
pixel[2] = 0.0;
}
2 => {
pixel[0] = channel_vals[0];
pixel[1] = channel_vals[1];
pixel[2] = 0.0;
}
3 => {
pixel[0] = channel_vals[0];
pixel[1] = channel_vals[1];
pixel[2] = channel_vals[2];
}
_ => unimplemented!("More than 3 channels not supported ({channels})"),
},
ColorDisplayOpts::Monochrome { min, max } => {
let val: f32 = channel_vals.iter().sum();
pixel[0] = min[0] * (1.0 - val) + max[0] * val;
pixel[1] = min[1] * (1.0 - val) + max[1] * val;
pixel[2] = min[2] * (1.0 - val) + max[2] * val;
}
}
let pixel = [
(pixel[0] * 255.0) as u8,
(pixel[1] * 255.0) as u8,
(pixel[2] * 255.0) as u8,
];
img.put_pixel(x, y, Rgb(pixel));
}
}
}
img.save(path)?;
Ok(())
}
/// Normalize values in 2D tensor from 0 to 1
fn normalize<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, D> {
let min = tensor.clone().min().into_scalar().elem::<f32>();
let max = tensor.clone().max().into_scalar().elem::<f32>();
let range = if max - min == 0.0 { 1.0 } else { max - min };
tensor
.sub_scalar(min.elem::<f32>())
.div_scalar(range.elem::<f32>())
}

View File

@@ -0,0 +1,82 @@
use std::path::PathBuf;
use burn_tensor::{Shape, Tensor, TensorData, backend::Backend};
use image::{DynamicImage, ImageBuffer, Luma, Rgb};
use burn_tensor::{Bool, Int};
#[cfg(all(
any(feature = "test-cpu", feature = "ndarray"),
not(any(feature = "test-wgpu", feature = "test-cuda"))
))]
pub type TestBackend = burn_ndarray::NdArray<f32, i32>;
#[cfg(all(test, feature = "test-wgpu"))]
pub type TestBackend = burn_wgpu::Wgpu;
#[cfg(all(test, feature = "test-cuda"))]
pub type TestBackend = burn_cuda::Cuda;
#[allow(unused)]
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
pub type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, Int>;
#[allow(unused)]
pub type TestTensorBool<const D: usize> = burn_tensor::Tensor<TestBackend, D, Bool>;
#[allow(unused)]
pub type IntType = <TestBackend as burn_tensor::backend::Backend>::IntElem;
#[allow(missing_docs)]
#[macro_export]
macro_rules! as_type {
($ty:ident: [$($elem:tt),*]) => {
[$($crate::as_type![$ty: $elem]),*]
};
($ty:ident: [$($elem:tt,)*]) => {
[$($crate::as_type![$ty: $elem]),*]
};
($ty:ident: $elem:expr) => {
{
use cubecl::prelude::*;
$ty::new($elem)
}
};
}
#[allow(unused)]
pub fn test_image<B: Backend>(name: &str, device: &B::Device, luma: bool) -> Tensor<B, 3> {
let file = PathBuf::from("tests/images").join(name);
let image = image::open(file).unwrap();
if luma {
let image = image.to_luma32f();
let h = image.height() as usize;
let w = image.width() as usize;
let data = TensorData::new(image.into_vec(), Shape::new([h, w, 1]));
Tensor::from_data(data, device)
} else {
let image = image.to_rgb32f();
let h = image.height() as usize;
let w = image.width() as usize;
let data = TensorData::new(image.into_vec(), Shape::new([h, w, 3]));
Tensor::from_data(data, device)
}
}
#[allow(unused)]
pub fn save_test_image<B: Backend>(name: &str, tensor: Tensor<B, 3>, luma: bool) {
let file = PathBuf::from("tests/images").join(name);
let [h, w, _] = tensor.shape().dims();
let data = tensor
.into_data()
.convert::<f32>()
.into_vec::<f32>()
.unwrap();
if luma {
let image = ImageBuffer::<Luma<f32>, _>::from_raw(w as u32, h as u32, data).unwrap();
DynamicImage::from(image).to_luma8().save(file).unwrap();
} else {
let image = ImageBuffer::<Rgb<f32>, _>::from_raw(w as u32, h as u32, data).unwrap();
DynamicImage::from(image).to_rgb8().save(file).unwrap();
}
}

View File

@@ -0,0 +1,144 @@
#![allow(clippy::single_range_in_vec_init)]
use std::collections::HashMap;
use burn_tensor::TensorData;
use burn_vision::{ConnectedComponents, ConnectedStatsOptions, Connectivity};
mod common;
use common::*;
fn space_invader() -> [[IntType; 14]; 9] {
as_type!(IntType: [
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0],
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1],
[1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1],
[1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
])
}
#[test]
fn should_support_8_connectivity() {
let tensor = TestTensorBool::<2>::from(space_invader());
let output = tensor.connected_components(Connectivity::Eight);
let expected = space_invader(); // All pixels are in the same group for 8-connected
let expected = TestTensorInt::<2>::from(expected);
normalize_labels(output.into_data()).assert_eq(&expected.into_data(), false);
}
#[test]
fn should_support_8_connectivity_with_stats() {
let tensor = TestTensorBool::<2>::from(space_invader());
let (output, stats) =
tensor.connected_components_with_stats(Connectivity::Eight, ConnectedStatsOptions::all());
let expected = space_invader(); // All pixels are in the same group for 8-connected
let expected = TestTensorInt::<2>::from(expected);
let (area, left, top, right, bottom) = (
stats.area.slice([1..2]).into_data(),
stats.left.slice([1..2]).into_data(),
stats.top.slice([1..2]).into_data(),
stats.right.slice([1..2]).into_data(),
stats.bottom.slice([1..2]).into_data(),
);
output.into_data().assert_eq(&expected.into_data(), false);
area.assert_eq(&TensorData::from([58]), false);
left.assert_eq(&TensorData::from([0]), false);
top.assert_eq(&TensorData::from([0]), false);
right.assert_eq(&TensorData::from([13]), false);
bottom.assert_eq(&TensorData::from([8]), false);
stats
.max_label
.into_data()
.assert_eq(&TensorData::from([1]), false);
}
#[test]
fn should_support_4_connectivity() {
let tensor = TestTensorBool::<2>::from(space_invader());
let output = tensor.connected_components(Connectivity::Four);
let expected = as_type!(IntType: [
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0],
[0, 0, 0, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0],
[0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0],
[0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3, 0, 0],
[0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0],
[4, 0, 0, 3, 3, 0, 0, 0, 0, 3, 3, 0, 0, 5],
[4, 4, 0, 0, 3, 3, 3, 3, 3, 3, 0, 0, 5, 5],
[4, 4, 0, 3, 3, 3, 0, 0, 3, 3, 3, 0, 5, 5],
[0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0],
]);
let expected = TestTensorInt::<2>::from(expected);
normalize_labels(output.into_data()).assert_eq(&expected.into_data(), false);
}
#[test]
fn should_support_4_connectivity_with_stats() {
let tensor = TestTensorBool::<2>::from(space_invader());
let (output, stats) =
tensor.connected_components_with_stats(Connectivity::Four, ConnectedStatsOptions::all());
let expected = as_type!(IntType: [
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0],
[0, 0, 0, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0],
[0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0],
[0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3, 0, 0],
[0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0],
[4, 0, 0, 3, 3, 0, 0, 0, 0, 3, 3, 0, 0, 5],
[4, 4, 0, 0, 3, 3, 3, 3, 3, 3, 0, 0, 5, 5],
[4, 4, 0, 3, 3, 3, 0, 0, 3, 3, 3, 0, 5, 5],
[0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0],
]);
let expected = TestTensorInt::<2>::from(expected);
// Slice off background and limit to compacted labels
let (area, left, top, right, bottom) = (
stats.area.slice([1..6]).into_data(),
stats.left.slice([1..6]).into_data(),
stats.top.slice([1..6]).into_data(),
stats.right.slice([1..6]).into_data(),
stats.bottom.slice([1..6]).into_data(),
);
output.into_data().assert_eq(&expected.into_data(), false);
area.assert_eq(&TensorData::from([1, 1, 46, 5, 5]), false);
left.assert_eq(&TensorData::from([3, 10, 1, 0, 12]), false);
top.assert_eq(&TensorData::from([0, 0, 1, 5, 5]), false);
right.assert_eq(&TensorData::from([3, 10, 12, 1, 13]), false);
bottom.assert_eq(&TensorData::from([0, 0, 8, 7, 7]), false);
stats
.max_label
.into_data()
.assert_eq(&TensorData::from([5]), false);
}
/// Normalize labels to sequential since actual labels aren't required to be contiguous and
/// different algorithms can return different numbers even if correct
fn normalize_labels(mut labels: TensorData) -> TensorData {
let mut next_label = 0;
let mut mappings = HashMap::<i32, i32>::default();
let data = labels.as_mut_slice::<i32>().unwrap();
for label in data {
if *label != 0 {
let relabel = mappings.entry(*label).or_insert_with(|| {
next_label += 1;
next_label
});
*label = *relabel;
}
}
labels
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 422 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 396 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 370 B

View File

@@ -0,0 +1,638 @@
use burn_tensor::{Tolerance, ops::FloatElem};
use burn_vision::{
BorderType, KernelShape, MorphOptions, Morphology, Point, Size, create_structuring_element,
};
type FT = FloatElem<TestBackend>;
mod common;
use common::*;
#[test]
fn should_support_dilate_luma() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(5, 5),
None,
&Default::default(),
);
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Dilate_1_5x5_Rect.png",
&Default::default(),
true,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_luma_cross() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(5, 5),
None,
&Default::default(),
);
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Dilate_1_5x5_Cross.png",
&Default::default(),
true,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_luma_ellipse() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Ellipse,
Size::new(5, 5),
None,
&Default::default(),
);
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Dilate_1_5x5_Ellipse.png",
&Default::default(),
true,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_luma_non_square_rect() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(3, 5),
None,
&Default::default(),
);
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Dilate_1_3x5_Rect.png",
&Default::default(),
true,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_luma_non_square_cross() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(3, 5),
None,
&Default::default(),
);
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Dilate_1_3x5_Cross.png",
&Default::default(),
true,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_rect() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(3, 5),
None,
&Default::default(),
);
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Dilate_2_3x5_Rect.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_cross() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(3, 5),
None,
&Default::default(),
);
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Dilate_2_3x5_Cross.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_border_reflect_rect() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(7, 7),
None,
&Default::default(),
);
let output = tensor.dilate(
kernel,
MorphOptions::builder()
.border_type(BorderType::Reflect)
.build(),
);
let expected = test_image(
"morphology/Dilate_2_7x7_Rect_BORDER_REFLECT.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_border_reflect_cross() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(7, 7),
None,
&Default::default(),
);
let output = tensor.dilate(
kernel,
MorphOptions::builder()
.border_type(BorderType::Reflect)
.build(),
);
let expected = test_image(
"morphology/Dilate_2_7x7_Cross_BORDER_REFLECT.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_border_reflect101_rect() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(7, 7),
None,
&Default::default(),
);
let output = tensor.dilate(
kernel,
MorphOptions::builder()
.border_type(BorderType::Reflect101)
.build(),
);
let expected = test_image(
"morphology/Dilate_2_7x7_Rect_BORDER_REFLECT101.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_border_reflect101_cross() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(7, 7),
None,
&Default::default(),
);
let output = tensor.dilate(
kernel,
MorphOptions::builder()
.border_type(BorderType::Reflect101)
.build(),
);
let expected = test_image(
"morphology/Dilate_2_7x7_Cross_BORDER_REFLECT101.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_border_replicate_rect() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(7, 7),
None,
&Default::default(),
);
let output = tensor.dilate(
kernel,
MorphOptions::builder()
.border_type(BorderType::Replicate)
.build(),
);
let expected = test_image(
"morphology/Dilate_2_7x7_Rect_BORDER_REPLICATE.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_border_replicate_cross() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(7, 7),
None,
&Default::default(),
);
let output = tensor.dilate(
kernel,
MorphOptions::builder()
.border_type(BorderType::Replicate)
.build(),
);
let expected = test_image(
"morphology/Dilate_2_7x7_Cross_BORDER_REPLICATE.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_anchor_rect() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(5, 7),
Some(Point::new(1, 2)),
&Default::default(),
);
let output = tensor.dilate(
kernel,
MorphOptions::builder().anchor(Point::new(2, 1)).build(),
);
let expected = test_image(
"morphology/Dilate_2_5x7_Rect_ANCHOR.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_rgb_anchor_cross() {
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(5, 7),
Some(Point::new(1, 2)),
&Default::default(),
);
// With default border, bottom left pixel is undefined with this particular kernel and anchor
// Use replicate instead for comparability
let output = tensor.dilate(
kernel,
MorphOptions::builder()
.anchor(Point::new(2, 1))
.border_type(BorderType::Replicate)
.build(),
);
let expected = test_image(
"morphology/Dilate_2_5x7_Cross_ANCHOR_BORDER_REPLICATE.png",
&Default::default(),
false,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_boolean_rect() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true).greater_elem(0);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(5, 5),
None,
&Default::default(),
);
// With default border, bottom left pixel is undefined with this particular kernel and anchor
// Use replicate instead for comparability
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Dilate_1_5x5_Rect.png",
&Default::default(),
true,
)
.greater_elem(0);
let expected = TestTensorBool::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_boolean_cross() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true).greater_elem(0);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(5, 5),
None,
&Default::default(),
);
// With default border, bottom left pixel is undefined with this particular kernel and anchor
// Use replicate instead for comparability
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Dilate_1_5x5_Cross.png",
&Default::default(),
true,
)
.greater_elem(0);
let expected = TestTensorBool::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_int_rect() {
let tensor = (test_image("morphology/Base_1.png", &Default::default(), true) * 255.0).int();
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(5, 5),
None,
&Default::default(),
);
// With default border, bottom left pixel is undefined with this particular kernel and anchor
// Use replicate instead for comparability
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = (test_image(
"morphology/Dilate_1_5x5_Rect.png",
&Default::default(),
true,
) * 255.0)
.int();
let expected = TestTensorInt::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_dilate_int_cross() {
let tensor = (test_image("morphology/Base_1.png", &Default::default(), true) * 255.0).int();
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(5, 5),
None,
&Default::default(),
);
// With default border, bottom left pixel is undefined with this particular kernel and anchor
// Use replicate instead for comparability
let output = tensor.dilate(kernel, MorphOptions::default());
let expected = (test_image(
"morphology/Dilate_1_5x5_Cross.png",
&Default::default(),
true,
) * 255.0)
.int();
let expected = TestTensorInt::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_erode_luma() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = TestTensorBool::<2>::from([
[true, true, true, true, true],
[true, true, true, true, true],
[true, true, true, true, true],
[true, true, true, true, true],
[true, true, true, true, true],
]);
let output = tensor.erode(kernel, MorphOptions::default());
let expected = test_image("morphology/Erode_1_5x5_Rect.png", &Default::default(), true);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_erode_luma_cross() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(5, 5),
None,
&Default::default(),
);
let output = tensor.erode(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Erode_1_5x5_Cross.png",
&Default::default(),
true,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn should_support_erode_luma_ellipse() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Ellipse,
Size::new(5, 5),
None,
&Default::default(),
);
let output = tensor.erode(kernel, MorphOptions::default());
let expected = test_image(
"morphology/Erode_1_5x5_Ellipse.png",
&Default::default(),
true,
);
let expected = TestTensor::<3>::from(expected);
output
.into_data()
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
}
#[test]
fn create_structuring_element_should_match_manual_rect() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Rect,
Size::new(5, 5),
None,
&Default::default(),
);
let kernel_manual = TestTensorBool::<2>::from([
[true, true, true, true, true],
[true, true, true, true, true],
[true, true, true, true, true],
[true, true, true, true, true],
[true, true, true, true, true],
]);
let output = tensor.clone().dilate(kernel, MorphOptions::default());
let output_manual = tensor.dilate(kernel_manual, MorphOptions::default());
output
.into_data()
.assert_eq(&output_manual.into_data(), false);
}
#[test]
fn create_structuring_element_should_match_manual_cross() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Cross,
Size::new(5, 5),
None,
&Default::default(),
);
let kernel_manual = TestTensorBool::<2>::from([
[false, false, true, false, false],
[false, false, true, false, false],
[true, true, true, true, true],
[false, false, true, false, false],
[false, false, true, false, false],
]);
let output = tensor.clone().dilate(kernel, MorphOptions::default());
let output_manual = tensor.dilate(kernel_manual, MorphOptions::default());
output
.into_data()
.assert_eq(&output_manual.into_data(), false);
}
#[test]
fn create_structuring_element_should_match_manual_ellipse() {
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
let kernel = create_structuring_element::<TestBackend>(
KernelShape::Ellipse,
Size::new(5, 5),
None,
&Default::default(),
);
let kernel_manual = TestTensorBool::<2>::from([
[false, false, true, false, false],
[true, true, true, true, true],
[true, true, true, true, true],
[true, true, true, true, true],
[false, false, true, false, false],
]);
let output = tensor.clone().dilate(kernel, MorphOptions::default());
let output_manual = tensor.dilate(kernel_manual, MorphOptions::default());
output
.into_data()
.assert_eq(&output_manual.into_data(), false);
}

View File

@@ -0,0 +1,92 @@
use burn_vision::{Nms, NmsOptions};
mod common;
use common::*;
#[test]
fn should_suppress_non_maximum() {
let boxes = TestTensor::<2>::from([
[0, 0, 100, 100],
[0, 1, 100, 100],
[0, 101, 200, 200],
[0, 100, 200, 200],
[0, 170, 300, 300],
]);
let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);
let options = NmsOptions {
iou_threshold: 0.5,
score_threshold: 0.0,
max_output_boxes: 0,
};
let output = boxes.nms(scores, options);
let expected = TestTensorInt::<1>::from([4, 2, 1]);
output.into_data().assert_eq(&expected.into_data(), true);
}
#[test]
fn should_apply_score_threshold() {
let boxes = TestTensor::<2>::from([
[0, 0, 100, 100],
[0, 1, 100, 100],
[0, 101, 200, 200],
[0, 100, 200, 200],
[0, 170, 300, 300],
]);
let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);
let options = NmsOptions {
iou_threshold: 0.5,
score_threshold: 0.3,
max_output_boxes: 0,
};
let output = boxes.nms(scores, options);
let expected = TestTensorInt::<1>::from([4, 2]);
output.into_data().assert_eq(&expected.into_data(), true);
}
#[test]
fn should_apply_iou_threshold() {
let boxes = TestTensor::<2>::from([
[0, 0, 100, 100],
[0, 1, 100, 100],
[0, 101, 200, 200],
[0, 100, 200, 200],
[0, 170, 300, 300],
]);
let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);
let options = NmsOptions {
iou_threshold: 0.1,
score_threshold: 0.0,
max_output_boxes: 0,
};
let output = boxes.nms(scores, options);
let expected = TestTensorInt::<1>::from([4, 1]);
output.into_data().assert_eq(&expected.into_data(), true);
}
#[test]
fn should_apply_max_output_boxes() {
let boxes = TestTensor::<2>::from([
[0, 0, 100, 100],
[0, 1, 100, 100],
[0, 101, 200, 200],
[0, 100, 200, 200],
[0, 170, 300, 300],
]);
let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);
let options = NmsOptions {
iou_threshold: 0.5,
score_threshold: 0.0,
max_output_boxes: 1,
};
let output = boxes.nms(scores, options);
let expected = TestTensorInt::<1>::from([4]);
output.into_data().assert_eq(&expected.into_data(), true);
}