feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
@@ -0,0 +1,71 @@
|
||||
[package]
|
||||
authors = [
|
||||
"nathanielsimard <nathaniel.simard.42@gmail.com>",
|
||||
"wingertge <wingertge@gmail.com>",
|
||||
]
|
||||
categories = ["science"]
|
||||
description = "Vision processing operations for burn tensors"
|
||||
documentation = "https://docs.rs/burn-vision"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "gpu"]
|
||||
license.workspace = true
|
||||
name = "burn-vision"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-vision"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
|
||||
[features]
|
||||
default = ["ndarray", "cubecl-backend", "fusion", "std"]
|
||||
std = ["aligned-vec/std"]
|
||||
tracing = [
|
||||
"burn-cubecl?/tracing",
|
||||
"burn-fusion?/tracing",
|
||||
"burn-ir/tracing",
|
||||
"burn-ndarray?/tracing",
|
||||
"burn-tch?/tracing",
|
||||
"burn-tensor/tracing",
|
||||
"cubecl/tracing",
|
||||
]
|
||||
|
||||
cubecl-backend = ["cubecl", "burn-cubecl"]
|
||||
fusion = ["burn-fusion", "burn-cuda/fusion", "burn-wgpu/fusion"]
|
||||
ndarray = ["burn-ndarray"]
|
||||
tch = ["burn-tch"]
|
||||
|
||||
# Test features
|
||||
test-cpu = []
|
||||
test-cuda = ["cubecl-backend", ]
|
||||
test-wgpu = ["cubecl-backend", ]
|
||||
test-vulkan = ["burn-wgpu/vulkan", "test-wgpu"]
|
||||
test-metal = ["burn-wgpu/metal", "test-wgpu"]
|
||||
|
||||
[dependencies]
|
||||
aligned-vec = { version = "0.6", default-features = false }
|
||||
bon = { workspace = true }
|
||||
burn-cubecl = { path = "../burn-cubecl", version = "=0.21.0-pre.2", optional = true }
|
||||
burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true }
|
||||
burn-ir = { path = "../burn-ir", version = "=0.21.0-pre.2" }
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2", optional = true }
|
||||
burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true }
|
||||
burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2" }
|
||||
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "=0.21.0-pre.2", optional = true }
|
||||
bytemuck = { workspace = true }
|
||||
cubecl = { workspace = true, optional = true }
|
||||
derive-new = { workspace = true }
|
||||
half = { workspace = true }
|
||||
image = { version = "0.25" }
|
||||
macerator = { workspace = true }
|
||||
ndarray = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
paste = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", default-features = false }
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" }
|
||||
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", default-features = false }
|
||||
cubecl = { workspace = true }
|
||||
@@ -0,0 +1,42 @@
|
||||
pub trait MinMax {
|
||||
fn min(self, other: Self) -> Self;
|
||||
fn max(self, other: Self) -> Self;
|
||||
}
|
||||
|
||||
macro_rules! impl_minmax {
|
||||
($ty: ty) => {
|
||||
impl MinMax for $ty {
|
||||
fn min(self, other: Self) -> Self {
|
||||
Ord::min(self, other)
|
||||
}
|
||||
fn max(self, other: Self) -> Self {
|
||||
Ord::max(self, other)
|
||||
}
|
||||
}
|
||||
};
|
||||
($($ty: ty),*) => {
|
||||
$(impl_minmax!($ty);)*
|
||||
}
|
||||
}
|
||||
|
||||
impl_minmax!(u8, i8, u16, i16, u32, i32, u64, i64);
|
||||
|
||||
impl MinMax for f32 {
|
||||
fn min(self, other: Self) -> Self {
|
||||
self.min(other)
|
||||
}
|
||||
|
||||
fn max(self, other: Self) -> Self {
|
||||
self.max(other)
|
||||
}
|
||||
}
|
||||
|
||||
impl MinMax for f64 {
|
||||
fn min(self, other: Self) -> Self {
|
||||
self.min(other)
|
||||
}
|
||||
|
||||
fn max(self, other: Self) -> Self {
|
||||
self.max(other)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,279 @@
|
||||
use std::{cmp::Ordering, marker::PhantomData};
|
||||
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::{
|
||||
Bool, DType, Element, ElementConversion, ElementOrdered, Int, Shape, Tensor, TensorData,
|
||||
backend::Backend,
|
||||
ops::{BoolTensor, IntTensor},
|
||||
};
|
||||
use ndarray::Array2;
|
||||
|
||||
use crate::{ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity};
|
||||
|
||||
mod spaghetti;
|
||||
mod spaghetti_4c;
|
||||
|
||||
/// Dispatches connected components based on `B::IntElem::dtype()`, binding a concrete
|
||||
/// integer type to enable generic instantiations without extra trait bounds (after removing
|
||||
/// `ElementComparison` from `Element`).
|
||||
macro_rules! dispatch_int_dtype {
|
||||
(|$ty:ident| $body:expr) => {
|
||||
match B::IntElem::dtype() {
|
||||
DType::I64 => {
|
||||
type $ty = i64;
|
||||
$body
|
||||
}
|
||||
DType::I32 => {
|
||||
type $ty = i32;
|
||||
$body
|
||||
}
|
||||
DType::I16 => {
|
||||
type $ty = i16;
|
||||
$body
|
||||
}
|
||||
DType::I8 => {
|
||||
type $ty = i8;
|
||||
$body
|
||||
}
|
||||
DType::U64 => {
|
||||
type $ty = u64;
|
||||
$body
|
||||
}
|
||||
DType::U32 => {
|
||||
type $ty = u32;
|
||||
$body
|
||||
}
|
||||
DType::U16 => {
|
||||
type $ty = u16;
|
||||
$body
|
||||
}
|
||||
DType::U8 => {
|
||||
type $ty = u8;
|
||||
$body
|
||||
}
|
||||
_ => unreachable!("Unsupported dtype"),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn connected_components<B: Backend>(
|
||||
img: BoolTensor<B>,
|
||||
connectivity: Connectivity,
|
||||
) -> IntTensor<B> {
|
||||
dispatch_int_dtype!(|I| run::<B, I, NoOp<_>>(img, connectivity, NoOp::default).0)
|
||||
}
|
||||
|
||||
pub fn connected_components_with_stats<B: Backend>(
|
||||
img: BoolTensor<B>,
|
||||
connectivity: Connectivity,
|
||||
_options: ConnectedStatsOptions,
|
||||
) -> (IntTensor<B>, ConnectedStatsPrimitive<B>) {
|
||||
let device = B::bool_device(&img);
|
||||
|
||||
dispatch_int_dtype!(|I| {
|
||||
let (labels, stats) =
|
||||
run::<B, I, ConnectedStatsOp<I>>(img, connectivity, ConnectedStatsOp::default);
|
||||
let stats = finalize_stats(&device, stats);
|
||||
(labels, stats)
|
||||
})
|
||||
}
|
||||
|
||||
fn run<B: Backend, I: ElementOrdered, Stats: StatsOp<Label = I>>(
|
||||
img: BoolTensor<B>,
|
||||
connectivity: Connectivity,
|
||||
stats: impl Fn() -> Stats,
|
||||
) -> (IntTensor<B>, Stats) {
|
||||
let device = B::bool_device(&img);
|
||||
let img = Tensor::<B, 2, Bool>::from_primitive(img);
|
||||
let [height, width] = img.shape().dims();
|
||||
let img = img.into_data();
|
||||
let img = img.into_vec::<B::BoolElem>().unwrap();
|
||||
|
||||
let mut stats = stats();
|
||||
|
||||
let out = match connectivity {
|
||||
Connectivity::Four => {
|
||||
spaghetti_4c::process::<B::BoolElem, UnionFind<_>>(img, height, width, &mut stats)
|
||||
}
|
||||
Connectivity::Eight => {
|
||||
// SAFETY: This is validated by `TensorData`
|
||||
let img = unsafe { Array2::from_shape_vec_unchecked((height, width), img) };
|
||||
spaghetti::process::<B::BoolElem, UnionFind<_>>(img, &mut stats)
|
||||
}
|
||||
};
|
||||
|
||||
let (data, _) = out.into_raw_vec_and_offset();
|
||||
let data = TensorData::new(data, Shape::new([height, width]));
|
||||
let labels = Tensor::<B, 2, Int>::from_data(data, &device).into_primitive();
|
||||
(labels, stats)
|
||||
}
|
||||
|
||||
pub trait Solver {
|
||||
type Label: ElementOrdered;
|
||||
|
||||
fn init(max_labels: usize) -> Self;
|
||||
/// Hack to get around mutable borrow limitations on methods
|
||||
fn merge(label_1: Self::Label, label_2: Self::Label, solver: &mut Self) -> Self::Label;
|
||||
fn new_label(&mut self) -> Self::Label;
|
||||
fn flatten(&mut self) -> Self::Label;
|
||||
fn get_label(&self, i_label: Self::Label) -> Self::Label;
|
||||
}
|
||||
|
||||
pub(crate) struct UnionFind<I: Element> {
|
||||
labels: Vec<I>,
|
||||
}
|
||||
|
||||
impl<I: ElementOrdered> Solver for UnionFind<I> {
|
||||
type Label = I;
|
||||
|
||||
fn init(max_labels: usize) -> Self {
|
||||
let mut labels = Vec::with_capacity(max_labels);
|
||||
labels.push(0.elem());
|
||||
Self { labels }
|
||||
}
|
||||
|
||||
fn merge(mut label_1: I, mut label_2: I, solver: &mut Self) -> I {
|
||||
use Ordering::Less;
|
||||
|
||||
while matches!(solver.labels[label_1.to_usize()].cmp(&label_1), Less) {
|
||||
label_1 = solver.labels[label_1.to_usize()];
|
||||
}
|
||||
|
||||
while matches!(solver.labels[label_2.to_usize()].cmp(&label_2), Less) {
|
||||
label_2 = solver.labels[label_2.to_usize()];
|
||||
}
|
||||
|
||||
if matches!(label_1.cmp(&label_2), Less) {
|
||||
solver.labels[label_2.to_usize()] = label_1;
|
||||
label_1
|
||||
} else {
|
||||
solver.labels[label_1.to_usize()] = label_2;
|
||||
label_2
|
||||
}
|
||||
}
|
||||
|
||||
fn new_label(&mut self) -> I {
|
||||
let len = I::from_elem(self.labels.len());
|
||||
self.labels.push(len);
|
||||
len
|
||||
}
|
||||
|
||||
fn flatten(&mut self) -> I {
|
||||
let mut k = 1;
|
||||
for i in 1..self.labels.len() {
|
||||
if matches!(self.labels[i].cmp(&I::from_elem(i)), Ordering::Less) {
|
||||
self.labels[i] = self.labels[self.labels[i].to_usize()];
|
||||
} else {
|
||||
self.labels[i] = k.elem();
|
||||
k += 1;
|
||||
}
|
||||
}
|
||||
k.elem()
|
||||
}
|
||||
|
||||
fn get_label(&self, i_label: I) -> I {
|
||||
self.labels[i_label.to_usize()]
|
||||
}
|
||||
}
|
||||
|
||||
pub trait StatsOp {
|
||||
type Label;
|
||||
|
||||
fn init(&mut self, num_labels: usize);
|
||||
fn update(&mut self, row: usize, column: usize, label: Self::Label);
|
||||
fn finish(&mut self);
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct NoOp<I: Element> {
|
||||
_i: PhantomData<I>,
|
||||
}
|
||||
|
||||
impl<I: Element> StatsOp for NoOp<I> {
|
||||
type Label = I; // placeholder still required
|
||||
|
||||
fn init(&mut self, _num_labels: usize) {}
|
||||
|
||||
fn update(&mut self, _row: usize, _column: usize, _label: Self::Label) {}
|
||||
|
||||
fn finish(&mut self) {}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct ConnectedStatsOp<I: Element> {
|
||||
pub area: Vec<I>,
|
||||
pub left: Vec<I>,
|
||||
pub top: Vec<I>,
|
||||
pub right: Vec<I>,
|
||||
pub bottom: Vec<I>,
|
||||
}
|
||||
|
||||
impl<I: Element> StatsOp for ConnectedStatsOp<I> {
|
||||
type Label = I;
|
||||
|
||||
fn init(&mut self, num_labels: usize) {
|
||||
self.area = vec![0.elem(); num_labels];
|
||||
self.left = vec![I::MAX; num_labels];
|
||||
self.top = vec![I::MAX; num_labels];
|
||||
self.right = vec![0.elem(); num_labels];
|
||||
self.bottom = vec![0.elem(); num_labels];
|
||||
}
|
||||
|
||||
fn update(&mut self, row: usize, column: usize, label: I) {
|
||||
let l = label.to_usize();
|
||||
unsafe {
|
||||
*self.area.get_unchecked_mut(l) =
|
||||
I::from_elem((*self.area.get_unchecked(l)).to_usize() + 1);
|
||||
*self.left.get_unchecked_mut(l) =
|
||||
I::from_elem((*self.left.get_unchecked(l)).to_usize().min(column));
|
||||
*self.top.get_unchecked_mut(l) =
|
||||
I::from_elem((*self.top.get_unchecked(l)).to_usize().min(row));
|
||||
*self.right.get_unchecked_mut(l) =
|
||||
I::from_elem((*self.right.get_unchecked(l)).to_usize().max(column));
|
||||
*self.bottom.get_unchecked_mut(l) =
|
||||
I::from_elem((*self.bottom.get_unchecked(l)).to_usize().max(row));
|
||||
}
|
||||
}
|
||||
|
||||
fn finish(&mut self) {
|
||||
// Background shouldn't have stats
|
||||
self.area[0] = 0.elem();
|
||||
self.left[0] = 0.elem();
|
||||
self.right[0] = 0.elem();
|
||||
self.top[0] = 0.elem();
|
||||
self.bottom[0] = 0.elem();
|
||||
}
|
||||
}
|
||||
|
||||
fn finalize_stats<B: Backend, I: Element>(
|
||||
device: &B::Device,
|
||||
stats: ConnectedStatsOp<I>,
|
||||
) -> ConnectedStatsPrimitive<B> {
|
||||
let labels = stats.area.len();
|
||||
|
||||
let into_prim = |data: Vec<I>| {
|
||||
let data = TensorData::new(data, Shape::new([labels]));
|
||||
Tensor::<B, 1, Int>::from_data(data, device).into_primitive()
|
||||
};
|
||||
|
||||
let max_label = {
|
||||
let data = TensorData::new(vec![I::from_elem(labels - 1)], Shape::new([1]));
|
||||
Tensor::<B, 1, Int>::from_data(data, device).into_primitive()
|
||||
};
|
||||
|
||||
ConnectedStatsPrimitive {
|
||||
area: into_prim(stats.area),
|
||||
left: into_prim(stats.left),
|
||||
top: into_prim(stats.top),
|
||||
right: into_prim(stats.right),
|
||||
bottom: into_prim(stats.bottom),
|
||||
max_label,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_labels(h: usize, w: usize, conn: Connectivity) -> usize {
|
||||
match conn {
|
||||
Connectivity::Four => (h * w).div_ceil(2) + 1,
|
||||
Connectivity::Eight => h.div_ceil(2) * w.div_ceil(2) + 1,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
no_analyze!{{
|
||||
use firstLabels::*;let mut label = entry;
|
||||
while let Some(next) = (|label| -> Option<firstLabels> { match label {
|
||||
NODE_72=> {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
return Some(fl_tree_1);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
return Some(fl_tree_2);
|
||||
}
|
||||
}
|
||||
NODE_73=> {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
return Some(fl_tree_1);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
return Some(fl_tree_2);
|
||||
}
|
||||
}
|
||||
NODE_74=> {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
return Some(fl_tree_1);
|
||||
}
|
||||
else {
|
||||
if (*img_row01.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
return Some(fl_tree_1);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
return Some(fl_tree_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
fl_tree_0 => {
|
||||
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(fl_break_0_0); } else { return Some(fl_break_1_0); } }
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
return Some(NODE_72);
|
||||
}
|
||||
else {
|
||||
if (*img_row01.add((c) as usize)).to_bool() {
|
||||
return Some(NODE_72);
|
||||
}
|
||||
else {
|
||||
return Some(NODE_74);
|
||||
}
|
||||
}
|
||||
}
|
||||
fl_tree_1 => {
|
||||
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(fl_break_0_1); } else { return Some(fl_break_1_1); } }
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
return Some(NODE_73);
|
||||
}
|
||||
else {
|
||||
if (*img_row01.add((c) as usize)).to_bool() {
|
||||
return Some(NODE_73);
|
||||
}
|
||||
else {
|
||||
return Some(NODE_74);
|
||||
}
|
||||
}
|
||||
}
|
||||
fl_tree_2 => {
|
||||
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(fl_break_0_2); } else { return Some(fl_break_1_2); } }
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row01.add((c - 1) as usize)).to_bool() {
|
||||
return Some(NODE_73);
|
||||
}
|
||||
else {
|
||||
return Some(NODE_72);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (*img_row01.add((c) as usize)).to_bool() {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row01.add((c - 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
return Some(fl_tree_1);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
return Some(fl_tree_1);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (*img_row01.add((c - 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
return Some(fl_tree_2);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
return Some(fl_tree_2);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Some(NODE_74);
|
||||
}
|
||||
}
|
||||
}
|
||||
NODE_75=> {
|
||||
if (*img_row01.add((c - 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
}
|
||||
}
|
||||
fl_break_0_0 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
}
|
||||
else {
|
||||
if (*img_row01.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
}
|
||||
}
|
||||
return None;}
|
||||
fl_break_0_1 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
}
|
||||
else {
|
||||
if (*img_row01.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
}
|
||||
}
|
||||
return None;}
|
||||
fl_break_0_2 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
return Some(NODE_75);
|
||||
}
|
||||
else {
|
||||
if (*img_row01.add((c) as usize)).to_bool() {
|
||||
return Some(NODE_75);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
}
|
||||
}
|
||||
return None;}
|
||||
NODE_76=> {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
}
|
||||
else {
|
||||
if (*img_row01.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
}
|
||||
}
|
||||
}
|
||||
NODE_77=> {
|
||||
if (*img_row01.add((c - 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
}
|
||||
}
|
||||
fl_break_1_0 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
}
|
||||
else {
|
||||
if (*img_row01.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
}
|
||||
else {
|
||||
return Some(NODE_76);
|
||||
}
|
||||
}
|
||||
return None;}
|
||||
fl_break_1_1 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
}
|
||||
else {
|
||||
if (*img_row01.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
}
|
||||
else {
|
||||
return Some(NODE_76);
|
||||
}
|
||||
}
|
||||
return None;}
|
||||
fl_break_1_2 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
return Some(NODE_77);
|
||||
}
|
||||
else {
|
||||
if (*img_row01.add((c) as usize)).to_bool() {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
return Some(NODE_77);
|
||||
}
|
||||
else {
|
||||
return Some(NODE_77);
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Some(NODE_76);
|
||||
}
|
||||
}
|
||||
return None;}
|
||||
fl_ => {},
|
||||
}; None})(label)
|
||||
{
|
||||
label = next;
|
||||
}
|
||||
}}
|
||||
@@ -0,0 +1,191 @@
|
||||
/// Workaround for rust-analyzer bug that causes invalid errors on the `include!`.
|
||||
macro_rules! no_analyze {
|
||||
($tokens:tt) => {
|
||||
$tokens
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) use no_analyze;
|
||||
|
||||
#[allow(non_snake_case, non_camel_case_types, unused)]
|
||||
pub enum centerLabels {
|
||||
NODE_1,
|
||||
NODE_2,
|
||||
NODE_3,
|
||||
NODE_4,
|
||||
NODE_5,
|
||||
NODE_6,
|
||||
NODE_7,
|
||||
NODE_8,
|
||||
NODE_9,
|
||||
NODE_10,
|
||||
NODE_11,
|
||||
NODE_12,
|
||||
NODE_13,
|
||||
NODE_14,
|
||||
NODE_15,
|
||||
NODE_16,
|
||||
NODE_17,
|
||||
NODE_18,
|
||||
NODE_19,
|
||||
NODE_20,
|
||||
NODE_21,
|
||||
NODE_22,
|
||||
NODE_23,
|
||||
NODE_24,
|
||||
NODE_25,
|
||||
NODE_26,
|
||||
NODE_27,
|
||||
NODE_28,
|
||||
NODE_29,
|
||||
NODE_30,
|
||||
NODE_31,
|
||||
NODE_32,
|
||||
NODE_33,
|
||||
NODE_34,
|
||||
NODE_35,
|
||||
NODE_36,
|
||||
NODE_37,
|
||||
NODE_38,
|
||||
NODE_39,
|
||||
NODE_40,
|
||||
NODE_41,
|
||||
NODE_42,
|
||||
NODE_43,
|
||||
NODE_44,
|
||||
NODE_45,
|
||||
NODE_46,
|
||||
NODE_47,
|
||||
NODE_48,
|
||||
NODE_49,
|
||||
NODE_50,
|
||||
NODE_51,
|
||||
NODE_52,
|
||||
NODE_53,
|
||||
NODE_54,
|
||||
NODE_55,
|
||||
NODE_56,
|
||||
NODE_57,
|
||||
NODE_58,
|
||||
NODE_59,
|
||||
NODE_60,
|
||||
NODE_61,
|
||||
NODE_62,
|
||||
NODE_63,
|
||||
NODE_64,
|
||||
NODE_65,
|
||||
NODE_66,
|
||||
NODE_67,
|
||||
NODE_68,
|
||||
NODE_69,
|
||||
NODE_70,
|
||||
NODE_71,
|
||||
cl_tree_0,
|
||||
cl_tree_1,
|
||||
cl_tree_2,
|
||||
cl_tree_3,
|
||||
cl_tree_4,
|
||||
cl_tree_5,
|
||||
cl_tree_6,
|
||||
cl_tree_7,
|
||||
cl_tree_8,
|
||||
cl_tree_9,
|
||||
cl_tree_10,
|
||||
cl_tree_11,
|
||||
cl_tree_12,
|
||||
cl_break_0_0,
|
||||
cl_break_0_1,
|
||||
cl_break_0_2,
|
||||
cl_break_0_3,
|
||||
cl_break_0_4,
|
||||
cl_break_0_5,
|
||||
cl_break_0_6,
|
||||
cl_break_0_7,
|
||||
cl_break_0_8,
|
||||
cl_break_1_0,
|
||||
cl_break_1_1,
|
||||
cl_break_1_2,
|
||||
cl_break_1_3,
|
||||
cl_break_1_4,
|
||||
cl_break_1_5,
|
||||
cl_break_1_6,
|
||||
cl_break_1_7,
|
||||
cl_break_1_8,
|
||||
cl_break_1_9,
|
||||
cl_break_1_10,
|
||||
cl_break_1_11,
|
||||
cl_break_1_12,
|
||||
}
|
||||
|
||||
#[allow(non_snake_case, non_camel_case_types, unused)]
|
||||
pub enum firstLabels {
|
||||
NODE_72,
|
||||
NODE_73,
|
||||
NODE_74,
|
||||
NODE_75,
|
||||
NODE_76,
|
||||
NODE_77,
|
||||
fl_tree_0,
|
||||
fl_tree_1,
|
||||
fl_tree_2,
|
||||
fl_break_0_0,
|
||||
fl_break_0_1,
|
||||
fl_break_0_2,
|
||||
fl_break_1_0,
|
||||
fl_break_1_1,
|
||||
fl_break_1_2,
|
||||
fl_,
|
||||
}
|
||||
|
||||
#[allow(non_snake_case, non_camel_case_types, unused)]
|
||||
pub enum lastLabels {
|
||||
NODE_78,
|
||||
NODE_79,
|
||||
NODE_80,
|
||||
NODE_81,
|
||||
NODE_82,
|
||||
NODE_83,
|
||||
NODE_84,
|
||||
NODE_85,
|
||||
NODE_86,
|
||||
NODE_87,
|
||||
NODE_88,
|
||||
NODE_89,
|
||||
NODE_90,
|
||||
NODE_91,
|
||||
NODE_92,
|
||||
ll_tree_0,
|
||||
ll_tree_1,
|
||||
ll_tree_2,
|
||||
ll_tree_3,
|
||||
ll_tree_4,
|
||||
ll_tree_5,
|
||||
ll_tree_6,
|
||||
ll_tree_7,
|
||||
ll_break_0_0,
|
||||
ll_break_0_1,
|
||||
ll_break_0_2,
|
||||
ll_break_0_3,
|
||||
ll_break_1_0,
|
||||
ll_break_1_1,
|
||||
ll_break_1_2,
|
||||
ll_break_1_3,
|
||||
ll_break_1_4,
|
||||
ll_break_1_5,
|
||||
ll_break_1_6,
|
||||
ll_break_1_7,
|
||||
ll_,
|
||||
}
|
||||
|
||||
#[allow(non_snake_case, non_camel_case_types, unused)]
|
||||
pub enum singleLabels {
|
||||
NODE_93,
|
||||
NODE_94,
|
||||
sl_tree_0,
|
||||
sl_tree_1,
|
||||
sl_break_0_0,
|
||||
sl_break_0_1,
|
||||
sl_break_1_0,
|
||||
sl_break_1_1,
|
||||
sl_,
|
||||
}
|
||||
@@ -0,0 +1,787 @@
|
||||
no_analyze!{{
|
||||
use lastLabels::*;let mut label = entry;
|
||||
while let Some(next) = (|label| -> Option<lastLabels> { match label {
|
||||
NODE_78=> {
|
||||
if (*img_row12.add((c - 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c - 2) as usize), solver);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
}
|
||||
NODE_79=> {
|
||||
if (*img_row12.add((c - 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_6);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver);
|
||||
return Some(ll_tree_6);
|
||||
}
|
||||
}
|
||||
NODE_80=> {
|
||||
if (*img_row12.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_6);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver);
|
||||
return Some(ll_tree_6);
|
||||
}
|
||||
}
|
||||
NODE_81=> {
|
||||
if (*img_row11.add((c + 2) as usize)).to_bool() {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
return Some(NODE_82);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_3);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
return Some(ll_tree_2);
|
||||
}
|
||||
}
|
||||
}
|
||||
NODE_83=> {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_5);
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c + 2) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
return Some(ll_tree_2);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
return Some(ll_tree_1);
|
||||
}
|
||||
}
|
||||
NODE_84=> {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_5);
|
||||
}
|
||||
else {
|
||||
return Some(NODE_81);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
return Some(ll_tree_1);
|
||||
}
|
||||
}
|
||||
NODE_82=> {
|
||||
if (*img_row12.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c) as usize), solver);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
}
|
||||
NODE_85=> {
|
||||
if (*img_row12.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row12.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c - 2) as usize), solver);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c - 2) as usize), solver);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
}
|
||||
NODE_86=> {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_6);
|
||||
}
|
||||
else {
|
||||
if (*img_row12.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_6);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
|
||||
return Some(ll_tree_6);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 2) as usize)).to_bool() {
|
||||
if (*img_row12.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
else {
|
||||
if (*img_row12.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
return Some(ll_tree_7);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
return Some(ll_tree_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
ll_tree_0 => {
|
||||
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_0); } else { return Some(ll_break_1_0); } }
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_6);
|
||||
}
|
||||
else {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
return Some(NODE_81);
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_0);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
return Some(ll_tree_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Some(NODE_84);
|
||||
}
|
||||
}
|
||||
ll_tree_1 => {
|
||||
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_1); } else { return Some(ll_break_1_1); } }
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_6);
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c - 1) as usize)).to_bool() {
|
||||
return Some(NODE_80);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_6);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 2) as usize)).to_bool() {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
return Some(NODE_82);
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c - 1) as usize)).to_bool() {
|
||||
return Some(NODE_85);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_3);
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c - 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);
|
||||
return Some(ll_tree_2);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
return Some(ll_tree_2);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_0);
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c - 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);
|
||||
return Some(ll_tree_0);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
return Some(ll_tree_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Some(NODE_84);
|
||||
}
|
||||
}
|
||||
ll_tree_2 => {
|
||||
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_2); } }
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
|
||||
return Some(ll_tree_6);
|
||||
}
|
||||
else {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 2) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
return Some(ll_tree_7);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
return Some(ll_tree_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Some(NODE_83);
|
||||
}
|
||||
}
|
||||
ll_tree_3 => {
|
||||
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_3); } }
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row12.add((c) as usize)).to_bool() {
|
||||
return Some(NODE_79);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
|
||||
return Some(ll_tree_6);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 2) as usize)).to_bool() {
|
||||
if (*img_row12.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row12.add((c) as usize)).to_bool() {
|
||||
return Some(NODE_78);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
return Some(ll_tree_7);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
return Some(ll_tree_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Some(NODE_83);
|
||||
}
|
||||
}
|
||||
ll_tree_4 => {
|
||||
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_4); } }
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_6);
|
||||
}
|
||||
else {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 2) as usize)).to_bool() {
|
||||
if (*img_row12.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
return Some(ll_tree_7);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
return Some(ll_tree_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_5);
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c + 2) as usize)).to_bool() {
|
||||
return Some(NODE_82);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_3);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
return Some(ll_tree_1);
|
||||
}
|
||||
}
|
||||
}
|
||||
ll_tree_5 => {
|
||||
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_5); } }
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
return Some(NODE_86);
|
||||
}
|
||||
else {
|
||||
return Some(NODE_84);
|
||||
}
|
||||
}
|
||||
ll_tree_6 => {
|
||||
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_3); } else { return Some(ll_break_1_6); } }
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row00.add((c - 1) as usize)).to_bool() {
|
||||
return Some(NODE_86);
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_6);
|
||||
}
|
||||
else {
|
||||
return Some(NODE_80);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 2) as usize)).to_bool() {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
return Some(NODE_82);
|
||||
}
|
||||
else {
|
||||
return Some(NODE_85);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_3);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);
|
||||
return Some(ll_tree_2);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
return Some(ll_tree_0);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);
|
||||
return Some(ll_tree_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Some(NODE_84);
|
||||
}
|
||||
}
|
||||
ll_tree_7 => {
|
||||
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_7); } }
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row12.add((c) as usize)).to_bool() {
|
||||
if (*img_row11.add((c - 2) as usize)).to_bool() {
|
||||
return Some(NODE_79);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
|
||||
return Some(ll_tree_6);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
|
||||
return Some(ll_tree_6);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 2) as usize)).to_bool() {
|
||||
if (*img_row12.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row12.add((c) as usize)).to_bool() {
|
||||
if (*img_row11.add((c - 2) as usize)).to_bool() {
|
||||
return Some(NODE_78);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);
|
||||
return Some(ll_tree_4);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
return Some(ll_tree_7);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
return Some(ll_tree_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Some(NODE_83);
|
||||
}
|
||||
}
|
||||
ll_break_0_0 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
}
|
||||
return None;}
|
||||
ll_break_0_1 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c - 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
}
|
||||
return None;}
|
||||
ll_break_0_2 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
}
|
||||
return None;}
|
||||
ll_break_0_3 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row00.add((c - 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
}
|
||||
return None;}
|
||||
NODE_87=> {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
return Some(NODE_88);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
}
|
||||
}
|
||||
NODE_88=> {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
}
|
||||
}
|
||||
}
|
||||
NODE_89=> {
|
||||
if (*img_row12.add((c - 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver);
|
||||
}
|
||||
}
|
||||
NODE_90=> {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
}
|
||||
else {
|
||||
if (*img_row12.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
}
|
||||
}
|
||||
NODE_91=> {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
}
|
||||
}
|
||||
NODE_92=> {
|
||||
if (*img_row12.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver);
|
||||
}
|
||||
}
|
||||
ll_break_1_0 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
return Some(NODE_88);
|
||||
}
|
||||
else {
|
||||
return Some(NODE_87);
|
||||
}
|
||||
return None;}
|
||||
ll_break_1_1 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c - 1) as usize)).to_bool() {
|
||||
return Some(NODE_92);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c - 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Some(NODE_87);
|
||||
}
|
||||
return None;}
|
||||
ll_break_1_2 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Some(NODE_91);
|
||||
}
|
||||
return None;}
|
||||
ll_break_1_3 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row12.add((c) as usize)).to_bool() {
|
||||
return Some(NODE_89);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Some(NODE_91);
|
||||
}
|
||||
return None;}
|
||||
ll_break_1_4 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
}
|
||||
else {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
}
|
||||
}
|
||||
return None;}
|
||||
ll_break_1_5 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
return Some(NODE_90);
|
||||
}
|
||||
else {
|
||||
return Some(NODE_87);
|
||||
}
|
||||
return None;}
|
||||
ll_break_1_6 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row00.add((c - 1) as usize)).to_bool() {
|
||||
return Some(NODE_90);
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
}
|
||||
else {
|
||||
return Some(NODE_92);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Some(NODE_87);
|
||||
}
|
||||
return None;}
|
||||
ll_break_1_7 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row11.add((c + 1) as usize)).to_bool() {
|
||||
if (*img_row12.add((c) as usize)).to_bool() {
|
||||
if (*img_row11.add((c - 2) as usize)).to_bool() {
|
||||
return Some(NODE_89);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Some(NODE_91);
|
||||
}
|
||||
return None;}
|
||||
ll_ => {},
|
||||
}; None})(label)
|
||||
{
|
||||
label = next;
|
||||
}
|
||||
}}
|
||||
@@ -0,0 +1,91 @@
|
||||
no_analyze!{{
|
||||
use singleLabels::*;let mut label = entry;
|
||||
while let Some(next) = (|label| -> Option<singleLabels> { match label {
|
||||
NODE_93=> {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
return Some(sl_tree_1);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
return Some(sl_tree_0);
|
||||
}
|
||||
}
|
||||
sl_tree_0 => {
|
||||
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(sl_break_0_0); } else { return Some(sl_break_1_0); } }
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
return Some(sl_tree_1);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
return Some(sl_tree_0);
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Some(NODE_93);
|
||||
}
|
||||
}
|
||||
sl_tree_1 => {
|
||||
if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(sl_break_0_1); } else { return Some(sl_break_1_1); } }
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
return Some(sl_tree_1);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
return Some(sl_tree_0);
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Some(NODE_93);
|
||||
}
|
||||
}
|
||||
sl_break_0_0 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
}
|
||||
return None;}
|
||||
sl_break_0_1 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
}
|
||||
return None;}
|
||||
NODE_94=> {
|
||||
if (*img_row00.add((c + 1) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
}
|
||||
}
|
||||
sl_break_1_0 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
}
|
||||
else {
|
||||
return Some(NODE_94);
|
||||
}
|
||||
return None;}
|
||||
sl_break_1_1 => {
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);
|
||||
}
|
||||
else {
|
||||
return Some(NODE_94);
|
||||
}
|
||||
return None;}
|
||||
sl_ => {},
|
||||
}; None})(label)
|
||||
{
|
||||
label = next;
|
||||
}
|
||||
}}
|
||||
@@ -0,0 +1,273 @@
|
||||
//! Spaghetti algorithm for connected component labeling
|
||||
//! F. Bolelli, S. Allegretti, L. Baraldi, and C. Grana,
|
||||
//! "Spaghetti Labeling: Directed Acyclic Graphs for Block-Based Bonnected Components Labeling,"
|
||||
//! IEEE Transactions on Image Processing, vol. 29, no. 1, pp. 1999-2012, 2019.
|
||||
//!
|
||||
//! Decision forests are generated using a modified [GRAPHGEN](https://github.com/wingertge/GRAPHGEN)
|
||||
//! as described in
|
||||
//!
|
||||
//! F. Bolelli, S. Allegretti, C. Grana.
|
||||
//! "One DAG to Rule Them All."
|
||||
//! IEEE Transactions on Pattern Analysis and Machine Intelligence, 2021
|
||||
|
||||
#![allow(
|
||||
unreachable_code,
|
||||
clippy::collapsible_else_if,
|
||||
clippy::if_same_then_else
|
||||
)]
|
||||
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use burn_tensor::{Element, ElementComparison, ElementConversion, cast::ToElement};
|
||||
use ndarray::{Array2, Axis, s};
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
mod Spaghetti_forest_labels;
|
||||
pub(crate) use Spaghetti_forest_labels::*;
|
||||
|
||||
use crate::Connectivity;
|
||||
|
||||
use super::{Solver, StatsOp, max_labels};
|
||||
|
||||
pub fn process<B: Element, LabelsSolver: Solver>(
|
||||
img_arr: Array2<B>,
|
||||
stats: &mut impl StatsOp<Label = LabelsSolver::Label>,
|
||||
) -> Array2<LabelsSolver::Label> {
|
||||
let (h, w) = img_arr.dim();
|
||||
let mut img_labels_arr = Array2::<LabelsSolver::Label>::default(img_arr.raw_dim());
|
||||
|
||||
let img = img_arr.as_ptr();
|
||||
|
||||
let e_rows = h as u32 & 0xfffffffe;
|
||||
let o_rows = h % 2 == 1;
|
||||
let e_cols = w as u32 & 0xfffffffe;
|
||||
let o_cols = w % 2 == 1;
|
||||
|
||||
let img_labels = img_labels_arr.as_mut_ptr();
|
||||
|
||||
let mut solver = LabelsSolver::init(max_labels(h, w, Connectivity::Eight));
|
||||
|
||||
let solver = &mut solver;
|
||||
|
||||
let w = w as i32;
|
||||
|
||||
// SAFETY:
|
||||
// Generated code includes mathematically proven bounds checks, so raw pointers are a safe speed
|
||||
// boost.
|
||||
unsafe {
|
||||
if h == 1 {
|
||||
// Single line
|
||||
let r = 0;
|
||||
//Pointers:
|
||||
// Row pointers for the input image
|
||||
let img_row00 = img.add(r * w as usize);
|
||||
|
||||
// Row pointers for the output image
|
||||
let img_labels_row00 = img_labels.add(r * w as usize);
|
||||
|
||||
let mut c = -2i32;
|
||||
let entry = singleLabels::sl_tree_0;
|
||||
|
||||
include!("Spaghetti_single_line_forest_code.rs");
|
||||
} else {
|
||||
// More than one line
|
||||
|
||||
// First couple of lines
|
||||
{
|
||||
let r = 0;
|
||||
//Pointers:
|
||||
// Row pointers for the input image
|
||||
let img_row00 = img.add(r * w as usize);
|
||||
let img_row01 = img.add((r + 1) * w as usize);
|
||||
|
||||
// Row pointers for the output image
|
||||
let img_labels_row00 = img_labels.add(r * w as usize);
|
||||
|
||||
let mut c = -2i32;
|
||||
let entry = firstLabels::fl_tree_0;
|
||||
|
||||
include!("Spaghetti_first_line_forest_code.rs");
|
||||
}
|
||||
|
||||
// Every other line but the last one if image has an odd number of rows
|
||||
for r in (2..e_rows as usize).step_by(2) {
|
||||
//Pointers:
|
||||
// Row pointers for the input image
|
||||
let img_row00 = img.add(r * w as usize);
|
||||
let img_row12 = img.add((r - 2) * w as usize);
|
||||
let img_row11 = img.add((r - 1) * w as usize);
|
||||
let img_row01 = img.add((r + 1) * w as usize);
|
||||
|
||||
// Row pointers for the output image
|
||||
let img_labels_row00 = img_labels.add(r * w as usize);
|
||||
let img_labels_row12 = img_labels.add((r - 2) * w as usize);
|
||||
|
||||
let mut c = -2;
|
||||
let entry = centerLabels::cl_tree_0;
|
||||
|
||||
include!("Spaghetti_center_line_forest_code.rs");
|
||||
}
|
||||
|
||||
if o_rows {
|
||||
let r = h - 1;
|
||||
//Pointers:
|
||||
// Row pointers for the input image
|
||||
let img_row00 = img.add(r * w as usize);
|
||||
let img_row12 = img.add((r - 2) * w as usize);
|
||||
let img_row11 = img.add((r - 1) * w as usize);
|
||||
|
||||
// Row pointers for the output image
|
||||
let img_labels_row00 = img_labels.add(r * w as usize);
|
||||
let img_labels_row12 = img_labels.add((r - 2) * w as usize);
|
||||
|
||||
let mut c = -2;
|
||||
let entry = lastLabels::ll_tree_0;
|
||||
|
||||
include!("Spaghetti_last_line_forest_code.rs");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let n_labels = solver.flatten();
|
||||
stats.init(n_labels.to_usize());
|
||||
|
||||
let img = img_arr;
|
||||
let mut img_labels = img_labels_arr;
|
||||
|
||||
for r in (0..e_rows as usize).step_by(2) {
|
||||
//Pointers:
|
||||
// Row pointers for the input image
|
||||
let img_row00 = img.index_axis(Axis(0), r);
|
||||
let img_row01 = img.index_axis(Axis(0), r + 1);
|
||||
|
||||
// Row pointers for the output image
|
||||
let (mut img_labels_row00, mut img_labels_row01) =
|
||||
img_labels.multi_slice_mut((s![r, ..], s![r + 1, ..]));
|
||||
|
||||
for c in (0..e_cols as usize).step_by(2) {
|
||||
let mut i_label = img_labels_row00[c];
|
||||
if matches!(i_label.cmp(&0.elem()), Ordering::Greater) {
|
||||
i_label = solver.get_label(i_label);
|
||||
if img_row00[c].to_u8() > 0 {
|
||||
img_labels_row00[c] = i_label;
|
||||
stats.update(r, c, i_label);
|
||||
} else {
|
||||
img_labels_row00[c] = 0.elem();
|
||||
stats.update(r, c, 0.elem());
|
||||
}
|
||||
if img_row00[c + 1].to_u8() > 0 {
|
||||
img_labels_row00[c + 1] = i_label;
|
||||
stats.update(r, c + 1, i_label);
|
||||
} else {
|
||||
img_labels_row00[c + 1] = 0.elem();
|
||||
stats.update(r, c + 1, 0.elem());
|
||||
}
|
||||
if img_row01[c].to_u8() > 0 {
|
||||
img_labels_row01[c] = i_label;
|
||||
stats.update(r + 1, c, i_label);
|
||||
} else {
|
||||
img_labels_row01[c] = 0.elem();
|
||||
stats.update(r + 1, c, 0.elem());
|
||||
}
|
||||
if img_row01[c + 1].to_u8() > 0 {
|
||||
img_labels_row01[c + 1] = i_label;
|
||||
stats.update(r + 1, c + 1, i_label);
|
||||
} else {
|
||||
img_labels_row01[c + 1] = 0.elem();
|
||||
stats.update(r + 1, c + 1, 0.elem());
|
||||
}
|
||||
} else {
|
||||
img_labels_row00[c] = 0.elem();
|
||||
stats.update(r, c, 0.elem());
|
||||
img_labels_row00[c + 1] = 0.elem();
|
||||
stats.update(r, c + 1, 0.elem());
|
||||
img_labels_row01[c] = 0.elem();
|
||||
stats.update(r + 1, c, 0.elem());
|
||||
img_labels_row01[c + 1] = 0.elem();
|
||||
stats.update(r + 1, c + 1, 0.elem());
|
||||
}
|
||||
}
|
||||
if o_cols {
|
||||
let c = e_cols as usize;
|
||||
let mut i_label = img_labels_row00[c];
|
||||
if matches!(i_label.cmp(&0.elem()), Ordering::Greater) {
|
||||
i_label = solver.get_label(i_label);
|
||||
if img_row00[c].to_u8() > 0 {
|
||||
img_labels_row00[c] = i_label;
|
||||
stats.update(r, c, i_label);
|
||||
} else {
|
||||
img_labels_row00[c] = 0.elem();
|
||||
stats.update(r, c, 0.elem());
|
||||
}
|
||||
if img_row01[c].to_u8() > 0 {
|
||||
img_labels_row01[c] = i_label;
|
||||
stats.update(r + 1, c, i_label);
|
||||
} else {
|
||||
img_labels_row01[c] = 0.elem();
|
||||
stats.update(r + 1, c, 0.elem());
|
||||
}
|
||||
} else {
|
||||
img_labels_row00[c] = 0.elem();
|
||||
stats.update(r, c, 0.elem());
|
||||
img_labels_row01[c] = 0.elem();
|
||||
stats.update(r + 1, c, 0.elem());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if o_rows {
|
||||
let r = e_rows as usize;
|
||||
|
||||
// Row pointers for the input image
|
||||
let img_row00 = img.index_axis(Axis(0), r);
|
||||
|
||||
// Row pointers for the output image
|
||||
let mut img_labels_row00 = img_labels.slice_mut(s![r, ..]);
|
||||
|
||||
for c in (0..e_cols as usize).step_by(2) {
|
||||
let mut i_label = img_labels_row00[c];
|
||||
if matches!(i_label.cmp(&0.elem()), Ordering::Greater) {
|
||||
i_label = solver.get_label(i_label);
|
||||
if img_row00[c].to_u8() > 0 {
|
||||
img_labels_row00[c] = i_label;
|
||||
stats.update(r, c, i_label);
|
||||
} else {
|
||||
img_labels_row00[c] = 0.elem();
|
||||
stats.update(r, c, 0.elem());
|
||||
}
|
||||
if img_row00[c + 1].to_u8() > 0 {
|
||||
img_labels_row00[c + 1] = i_label;
|
||||
stats.update(r, c + 1, i_label);
|
||||
} else {
|
||||
img_labels_row00[c + 1] = 0.elem();
|
||||
stats.update(r, c + 1, 0.elem());
|
||||
}
|
||||
} else {
|
||||
img_labels_row00[c] = 0.elem();
|
||||
stats.update(r, c, 0.elem());
|
||||
img_labels_row00[c + 1] = 0.elem();
|
||||
stats.update(r, c + 1, 0.elem());
|
||||
}
|
||||
}
|
||||
if o_cols {
|
||||
let c = e_cols as usize;
|
||||
let mut i_label = img_labels_row00[c];
|
||||
if matches!(i_label.cmp(&0.elem()), Ordering::Greater) {
|
||||
i_label = solver.get_label(i_label);
|
||||
if img_row00[c].to_u8() > 0 {
|
||||
img_labels_row00[c] = i_label;
|
||||
stats.update(r, c, i_label);
|
||||
} else {
|
||||
img_labels_row00[c] = 0.elem();
|
||||
stats.update(r, c, 0.elem());
|
||||
}
|
||||
} else {
|
||||
img_labels_row00[c] = 0.elem();
|
||||
stats.update(r, c, i_label);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stats.finish();
|
||||
img_labels
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
no_analyze!{{
|
||||
use centerLabels::*;let mut label = entry;
|
||||
while let Some(next) = (|label| -> Option<centerLabels> { match label {
|
||||
cl_tree_0 => {
|
||||
if ({c+=1; c} >= w) { return None; }
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row11.add((c) as usize);
|
||||
return Some(cl_tree_1);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
return Some(cl_tree_1);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
return Some(cl_tree_0);
|
||||
}
|
||||
}
|
||||
cl_tree_1 => {
|
||||
if ({c+=1; c} >= w) { return None; }
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
if (*img_row11.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 1) as usize), *img_labels_row11.add((c) as usize), solver);
|
||||
return Some(cl_tree_1);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 1) as usize);
|
||||
return Some(cl_tree_1);
|
||||
}
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
return Some(cl_tree_0);
|
||||
}
|
||||
}
|
||||
}; None})(label)
|
||||
{
|
||||
label = next;
|
||||
}
|
||||
}}
|
||||
@@ -0,0 +1,31 @@
|
||||
no_analyze!{{
|
||||
use firstLabels::*;let mut label = entry;
|
||||
while let Some(next) = (|label| -> Option<firstLabels> { match label {
|
||||
fl_tree_0 => {
|
||||
if ({c+=1; c} >= w) { return None; }
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = solver.new_label();
|
||||
return Some(fl_tree_1);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
return Some(fl_tree_0);
|
||||
}
|
||||
}
|
||||
fl_tree_1 => {
|
||||
if ({c+=1; c} >= w) { return None; }
|
||||
if (*img_row00.add((c) as usize)).to_bool() {
|
||||
*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 1) as usize);
|
||||
return Some(fl_tree_1);
|
||||
}
|
||||
else {
|
||||
*img_labels_row00.add(c as usize) = 0.elem();
|
||||
return Some(fl_tree_0);
|
||||
}
|
||||
}
|
||||
fl_ => {},
|
||||
}; None})(label)
|
||||
{
|
||||
label = next;
|
||||
}
|
||||
}}
|
||||
@@ -0,0 +1,21 @@
|
||||
/// Workaround for rust-analyzer bug that causes invalid errors on the `include!`.
|
||||
macro_rules! no_analyze {
|
||||
($tokens:tt) => {
|
||||
$tokens
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) use no_analyze;
|
||||
|
||||
#[allow(non_snake_case, non_camel_case_types, unused)]
|
||||
pub enum centerLabels {
|
||||
cl_tree_0,
|
||||
cl_tree_1,
|
||||
}
|
||||
|
||||
#[allow(non_snake_case, non_camel_case_types, unused)]
|
||||
pub enum firstLabels {
|
||||
fl_tree_0,
|
||||
fl_tree_1,
|
||||
fl_,
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
//! Spaghetti algorithm for connected component labeling, modified for 4-connectivity using the
|
||||
//! 4-connected Rosenfeld mask.
|
||||
//! F. Bolelli, S. Allegretti, L. Baraldi, and C. Grana,
|
||||
//! "Spaghetti Labeling: Directed Acyclic Graphs for Block-Based Bonnected Components Labeling,"
|
||||
//! IEEE Transactions on Image Processing, vol. 29, no. 1, pp. 1999-2012, 2019.
|
||||
//!
|
||||
//! Decision forests are generated using a modified [GRAPHGEN](https://github.com/wingertge/GRAPHGEN)
|
||||
//! as described in
|
||||
//!
|
||||
//! F. Bolelli, S. Allegretti, C. Grana.
|
||||
//! "One DAG to Rule Them All."
|
||||
//! IEEE Transactions on Pattern Analysis and Machine Intelligence, 2021
|
||||
|
||||
#![allow(unreachable_code)]
|
||||
|
||||
use burn_tensor::{Element, ElementConversion, cast::ToElement};
|
||||
use ndarray::Array2;
|
||||
|
||||
use crate::Connectivity;
|
||||
|
||||
use super::{Solver, StatsOp, max_labels};
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
mod Spaghetti4C_forest_labels;
|
||||
pub(crate) use Spaghetti4C_forest_labels::*;
|
||||
|
||||
pub fn process<B: Element, LabelsSolver: Solver>(
|
||||
img: Vec<B>,
|
||||
h: usize,
|
||||
w: usize,
|
||||
stats: &mut impl StatsOp<Label = LabelsSolver::Label>,
|
||||
) -> Array2<LabelsSolver::Label> {
|
||||
let img = img.as_ptr();
|
||||
|
||||
let mut img_labels: Vec<LabelsSolver::Label> = vec![0.elem(); h * w];
|
||||
|
||||
// A quick and dirty upper bound for the maximum number of labels.
|
||||
// Following formula comes from the fact that a 2x2 block in 4-connectivity case
|
||||
// can never have more than 2 new labels and 1 label for background.
|
||||
// Worst case image example pattern:
|
||||
// 1 0 1 0 1...
|
||||
// 0 1 0 1 0...
|
||||
// 1 0 1 0 1...
|
||||
// ............
|
||||
let max_labels = max_labels(h, w, Connectivity::Four);
|
||||
|
||||
let mut solver = LabelsSolver::init(max_labels);
|
||||
let solver = &mut solver;
|
||||
|
||||
let w = w as i32;
|
||||
// SAFETY:
|
||||
// This code is generated from constraints and includes manual bounds checks, so unchecked pointer
|
||||
// indexes are always safe.
|
||||
unsafe {
|
||||
// First row
|
||||
{
|
||||
let r = 0;
|
||||
//Pointers:
|
||||
// Row pointers for the input image
|
||||
let img_row00 = img.add(r * w as usize);
|
||||
|
||||
// Row pointers for the output image
|
||||
let img_labels_row00 = img_labels.as_mut_ptr().add(r * w as usize);
|
||||
let mut c = -1i32;
|
||||
|
||||
let entry = firstLabels::fl_tree_0;
|
||||
|
||||
include!("Spaghetti4C_first_line_forest_code.rs");
|
||||
}
|
||||
|
||||
for r in 1..h {
|
||||
//Pointers:
|
||||
// Row pointers for the input image
|
||||
let img_row00 = img.add(r * w as usize);
|
||||
let img_row11 = img.add((r - 1) * w as usize);
|
||||
|
||||
// Row pointers for the output image
|
||||
let img_labels_row00 = img_labels.as_mut_ptr().add(r * w as usize);
|
||||
let img_labels_row11 = img_labels.as_mut_ptr().add((r - 1) * w as usize);
|
||||
let mut c = -1i32;
|
||||
|
||||
let entry = centerLabels::cl_tree_0;
|
||||
|
||||
include!("Spaghetti4C_center_line_forest_code.rs");
|
||||
}
|
||||
}
|
||||
|
||||
let n_labels = solver.flatten();
|
||||
stats.init(n_labels.to_usize());
|
||||
|
||||
// SAFETY: This is always valid
|
||||
let mut img_labels = unsafe { Array2::from_shape_vec_unchecked((h, w as usize), img_labels) };
|
||||
|
||||
img_labels.indexed_iter_mut().for_each(|((r, c), label)| {
|
||||
*label = solver.get_label(*label);
|
||||
stats.update(r, c, *label);
|
||||
});
|
||||
|
||||
stats.finish();
|
||||
|
||||
img_labels
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
mod base;
|
||||
mod connected_components;
|
||||
mod morphology;
|
||||
mod nms;
|
||||
mod ops;
|
||||
|
||||
pub use base::*;
|
||||
pub use connected_components::*;
|
||||
pub use morphology::*;
|
||||
pub use nms::*;
|
||||
@@ -0,0 +1,720 @@
|
||||
use core::slice;
|
||||
use std::{marker::PhantomData, ptr::null};
|
||||
|
||||
use burn_tensor::Element;
|
||||
use macerator::{
|
||||
Scalar, Simd, VOrd, Vector, vload, vload_low, vload_unaligned, vstore, vstore_low,
|
||||
vstore_unaligned,
|
||||
};
|
||||
|
||||
use crate::{Point, Size, backends::cpu::MinMax};
|
||||
|
||||
pub trait MorphOperator<T> {
|
||||
fn apply(a: T, b: T) -> T;
|
||||
}
|
||||
|
||||
pub trait VecMorphOperator<T: Scalar> {
|
||||
fn apply<S: Simd>(a: Vector<S, T>, b: Vector<S, T>) -> Vector<S, T>;
|
||||
}
|
||||
|
||||
pub struct MinOp;
|
||||
pub struct MaxOp;
|
||||
|
||||
impl<T: MinMax> MorphOperator<T> for MinOp {
|
||||
fn apply(a: T, b: T) -> T {
|
||||
MinMax::min(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VOrd> VecMorphOperator<T> for MinOp {
|
||||
fn apply<S: Simd>(a: Vector<S, T>, b: Vector<S, T>) -> Vector<S, T> {
|
||||
T::vmin(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: MinMax> MorphOperator<T> for MaxOp {
|
||||
fn apply(a: T, b: T) -> T {
|
||||
MinMax::max(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VOrd> VecMorphOperator<T> for MaxOp {
|
||||
fn apply<S: Simd>(a: Vector<S, T>, b: Vector<S, T>) -> Vector<S, T> {
|
||||
T::vmax(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MorphRowFilter<T: Scalar, S: MorphOperator<T>, Vec: VecRow<T>> {
|
||||
pub ksize: usize,
|
||||
pub anchor: usize,
|
||||
vec: Vec,
|
||||
_t: PhantomData<T>,
|
||||
_scalar: PhantomData<S>,
|
||||
}
|
||||
|
||||
impl<T: Scalar, SOp: MorphOperator<T>, Vec: VecRow<T>> MorphRowFilter<T, SOp, Vec> {
|
||||
pub fn new(ksize: usize, anchor: usize) -> Self {
|
||||
let vec = Vec::new(ksize, anchor);
|
||||
Self {
|
||||
ksize,
|
||||
anchor,
|
||||
vec,
|
||||
_t: PhantomData,
|
||||
_scalar: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn apply<S: Simd>(&self, src: &[T], dst: &mut [T], width: usize, ch: usize) {
|
||||
let k_size = self.ksize * ch;
|
||||
|
||||
if k_size == ch {
|
||||
let width = width * ch;
|
||||
dst[..width].copy_from_slice(&src[..width]);
|
||||
return;
|
||||
}
|
||||
|
||||
let i0 = self.vec.apply::<S>(src, dst, width, ch);
|
||||
let width = width * ch;
|
||||
|
||||
for k in 0..ch {
|
||||
let mut last_i = i0;
|
||||
for i in (i0..width.saturating_sub(ch * 2)).step_by(ch * 2) {
|
||||
let mut m = src[k + i + ch];
|
||||
let mut last_j = ch * 2;
|
||||
for j in (ch * 2..k_size).step_by(ch) {
|
||||
m = SOp::apply(m, src[k + i + j]);
|
||||
last_j = j + ch;
|
||||
}
|
||||
dst[k + i] = SOp::apply(m, src[k + i]);
|
||||
dst[k + i + ch] = SOp::apply(m, src[k + i + last_j]);
|
||||
last_i = i + ch * 2;
|
||||
}
|
||||
|
||||
for i in (last_i..width).step_by(ch) {
|
||||
let mut m = src[k + i];
|
||||
for j in (ch..k_size).step_by(ch) {
|
||||
m = SOp::apply(m, src[k + i + j]);
|
||||
}
|
||||
dst[k + i] = m;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MorphRowVec<T: Scalar, Op: VecMorphOperator<T>> {
|
||||
k_size: usize,
|
||||
_t: PhantomData<T>,
|
||||
_op: PhantomData<Op>,
|
||||
}
|
||||
|
||||
pub trait VecRow<T: Scalar> {
|
||||
fn new(ksize: usize, anchor: usize) -> Self;
|
||||
fn apply<S: Simd>(&self, src: &[T], dst: &mut [T], width: usize, channels: usize) -> usize;
|
||||
}
|
||||
|
||||
impl<T: Scalar, Op: VecMorphOperator<T>> VecRow<T> for MorphRowVec<T, Op> {
|
||||
fn apply<S: Simd>(&self, src: &[T], dst: &mut [T], width: usize, ch: usize) -> usize {
|
||||
let src = src.as_ptr();
|
||||
let dst = dst.as_mut_ptr();
|
||||
let k_size = self.k_size * ch;
|
||||
let width = (width * ch) as isize;
|
||||
let lanes = T::lanes::<S>();
|
||||
|
||||
// Safety: everything here is unsafe. Test thoroughly.
|
||||
unsafe {
|
||||
let mut x = 0;
|
||||
while x as isize <= width - 4 * lanes as isize {
|
||||
let mut s0 = vload(src.add(x));
|
||||
let mut s1 = vload(src.add(x + lanes));
|
||||
let mut s2 = vload(src.add(x + 2 * lanes));
|
||||
let mut s3 = vload(src.add(x + 3 * lanes));
|
||||
for k in (ch..k_size).step_by(ch) {
|
||||
let x = x + k;
|
||||
s0 = Op::apply::<S>(s0, vload_unaligned(src.add(x)));
|
||||
s1 = Op::apply::<S>(s1, vload_unaligned(src.add(x + lanes)));
|
||||
s2 = Op::apply::<S>(s2, vload_unaligned(src.add(x + 2 * lanes)));
|
||||
s3 = Op::apply::<S>(s3, vload_unaligned(src.add(x + 3 * lanes)));
|
||||
}
|
||||
vstore(dst.add(x), s0);
|
||||
vstore(dst.add(x + lanes), s1);
|
||||
vstore(dst.add(x + 2 * lanes), s2);
|
||||
vstore(dst.add(x + 3 * lanes), s3);
|
||||
x += 4 * lanes;
|
||||
}
|
||||
if x as isize <= width - 2 * lanes as isize {
|
||||
let mut s0 = vload(src.add(x));
|
||||
let mut s1 = vload(src.add(x + lanes));
|
||||
for k in (ch..k_size).step_by(ch) {
|
||||
s0 = Op::apply::<S>(s0, vload_unaligned(src.add(x + k)));
|
||||
s1 = Op::apply::<S>(s1, vload_unaligned(src.add(x + k + lanes)));
|
||||
}
|
||||
vstore(dst.add(x), s0);
|
||||
vstore(dst.add(x + lanes), s1);
|
||||
x += 2 * lanes;
|
||||
}
|
||||
if x as isize <= width - lanes as isize {
|
||||
let mut s = vload(src.add(x));
|
||||
for k in (ch..k_size).step_by(ch) {
|
||||
s = Op::apply::<S>(s, vload_unaligned(src.add(x + k)));
|
||||
}
|
||||
vstore(dst.add(x), s);
|
||||
x += lanes;
|
||||
}
|
||||
if x as isize <= width - lanes as isize / 2 {
|
||||
let mut s = vload_low(src.add(x));
|
||||
for k in (ch..k_size).step_by(ch) {
|
||||
s = Op::apply::<S>(s, vload_low(src.add(x + k)));
|
||||
}
|
||||
vstore_low(dst.add(x), s);
|
||||
x += lanes / 2;
|
||||
}
|
||||
x - x % ch
|
||||
}
|
||||
}
|
||||
|
||||
fn new(k_size: usize, _anchor: usize) -> Self {
|
||||
Self {
|
||||
k_size,
|
||||
_t: PhantomData,
|
||||
_op: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecColumn<T: Scalar> {
|
||||
fn new(ksize: usize, anchor: usize) -> Self;
|
||||
fn apply<S: Simd>(
|
||||
&self,
|
||||
|
||||
src: &[*const T],
|
||||
dst: &mut [T],
|
||||
dst_step: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
) -> usize;
|
||||
}
|
||||
|
||||
pub struct MorphColumnVec<T: Scalar, Op: VecMorphOperator<T>> {
|
||||
k_size: usize,
|
||||
_t: PhantomData<T>,
|
||||
_op: PhantomData<Op>,
|
||||
}
|
||||
|
||||
impl<T: VOrd, Op: VecMorphOperator<T>> VecColumn<T> for MorphColumnVec<T, Op> {
|
||||
fn new(k_size: usize, _anchor: usize) -> Self {
|
||||
Self {
|
||||
k_size,
|
||||
_t: PhantomData,
|
||||
_op: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn apply<S: Simd>(
|
||||
&self,
|
||||
|
||||
src: &[*const T],
|
||||
dst: &mut [T],
|
||||
dst_step: usize,
|
||||
mut count: usize,
|
||||
width: usize,
|
||||
) -> usize {
|
||||
let ksize = self.k_size;
|
||||
let width = width as isize;
|
||||
let mut dst = dst.as_mut_ptr();
|
||||
let lanes = T::lanes::<S>();
|
||||
let mut y = 0;
|
||||
let mut x = 0;
|
||||
|
||||
// Safety: everything here is unsafe. Test thoroughly.
|
||||
unsafe {
|
||||
while count > 1 && ksize > 1 {
|
||||
x = 0;
|
||||
while x as isize <= width - 4 * lanes as isize {
|
||||
let sptr = src[y + 1].add(x);
|
||||
let mut s0 = vload(sptr);
|
||||
let mut s1 = vload(sptr.add(lanes));
|
||||
let mut s2 = vload(sptr.add(2 * lanes));
|
||||
let mut s3 = vload(sptr.add(3 * lanes));
|
||||
|
||||
for k in 2..ksize {
|
||||
let sptr = src[y + k].add(x);
|
||||
s0 = Op::apply::<S>(s0, vload(sptr));
|
||||
s1 = Op::apply::<S>(s1, vload(sptr.add(lanes)));
|
||||
s2 = Op::apply::<S>(s2, vload(sptr.add(2 * lanes)));
|
||||
s3 = Op::apply::<S>(s3, vload(sptr.add(3 * lanes)));
|
||||
}
|
||||
|
||||
// Row 1
|
||||
{
|
||||
let sptr = src[y].add(x);
|
||||
let s0 = Op::apply(s0, vload(sptr));
|
||||
let s1 = Op::apply(s1, vload(sptr.add(lanes)));
|
||||
let s2 = Op::apply(s2, vload(sptr.add(2 * lanes)));
|
||||
let s3 = Op::apply(s3, vload(sptr.add(3 * lanes)));
|
||||
vstore_unaligned(dst.add(x), s0);
|
||||
vstore_unaligned(dst.add(x + lanes), s1);
|
||||
vstore_unaligned(dst.add(x + 2 * lanes), s2);
|
||||
vstore_unaligned(dst.add(x + 3 * lanes), s3);
|
||||
}
|
||||
|
||||
// Row 2
|
||||
{
|
||||
let sptr = src[y + ksize].add(x);
|
||||
let s0 = Op::apply(s0, vload(sptr));
|
||||
let s1 = Op::apply(s1, vload(sptr.add(lanes)));
|
||||
let s2 = Op::apply(s2, vload(sptr.add(2 * lanes)));
|
||||
let s3 = Op::apply(s3, vload(sptr.add(3 * lanes)));
|
||||
vstore_unaligned(dst.add(dst_step + x), s0);
|
||||
vstore_unaligned(dst.add(dst_step + x + lanes), s1);
|
||||
vstore_unaligned(dst.add(dst_step + x + 2 * lanes), s2);
|
||||
vstore_unaligned(dst.add(dst_step + x + 3 * lanes), s3);
|
||||
}
|
||||
x += 4 * lanes;
|
||||
}
|
||||
if x as isize <= width - 2 * lanes as isize {
|
||||
let sptr = src[y + 1].add(x);
|
||||
let mut s0 = vload(sptr);
|
||||
let mut s1 = vload(sptr.add(lanes));
|
||||
|
||||
for k in 2..ksize {
|
||||
let sptr = src[y + k].add(x);
|
||||
s0 = Op::apply::<S>(s0, vload(sptr));
|
||||
s1 = Op::apply::<S>(s1, vload(sptr.add(lanes)));
|
||||
}
|
||||
|
||||
// Row 1
|
||||
{
|
||||
let sptr = src[y].add(x);
|
||||
let s0 = Op::apply(s0, vload(sptr));
|
||||
let s1 = Op::apply(s1, vload(sptr.add(lanes)));
|
||||
vstore_unaligned(dst.add(x), s0);
|
||||
vstore_unaligned(dst.add(x + lanes), s1);
|
||||
}
|
||||
|
||||
// Row 2
|
||||
{
|
||||
let sptr = src[y + ksize].add(x);
|
||||
let s0 = Op::apply(s0, vload(sptr));
|
||||
let s1 = Op::apply(s1, vload(sptr.add(lanes)));
|
||||
vstore_unaligned(dst.add(dst_step + x), s0);
|
||||
vstore_unaligned(dst.add(dst_step + x + lanes), s1);
|
||||
}
|
||||
x += 2 * lanes;
|
||||
}
|
||||
if x as isize <= width - lanes as isize {
|
||||
let mut s0 = vload(src[y + 1].add(x));
|
||||
for k in 2..ksize {
|
||||
s0 = Op::apply::<S>(s0, vload(src[y + k].add(x)));
|
||||
}
|
||||
// Row 1
|
||||
{
|
||||
let sptr = src[y].add(x);
|
||||
vstore_unaligned(dst.add(x), Op::apply(s0, vload(sptr)));
|
||||
}
|
||||
|
||||
// Row 2
|
||||
{
|
||||
let sptr = src[y + ksize].add(x);
|
||||
let s0 = Op::apply(s0, vload(sptr));
|
||||
vstore_unaligned(dst.add(dst_step + x), s0);
|
||||
}
|
||||
x += lanes;
|
||||
}
|
||||
if x as isize <= width - lanes as isize / 2 {
|
||||
let mut s0 = vload_low(src[y + 1].add(x));
|
||||
for k in 2..ksize {
|
||||
s0 = Op::apply::<S>(s0, vload_low(src[y + k].add(x)));
|
||||
}
|
||||
// Row 1
|
||||
{
|
||||
let sptr = src[y].add(x);
|
||||
let s0 = Op::apply(s0, vload_low(sptr));
|
||||
vstore_low(dst.add(x), s0);
|
||||
}
|
||||
|
||||
// Row 2
|
||||
{
|
||||
let sptr = src[y + ksize].add(x);
|
||||
let s0 = Op::apply(s0, vload_low(sptr));
|
||||
vstore_low(dst.add(dst_step + x), s0);
|
||||
}
|
||||
x += lanes / 2;
|
||||
}
|
||||
|
||||
count -= 2;
|
||||
dst = dst.add(dst_step * 2);
|
||||
y += 2;
|
||||
}
|
||||
|
||||
while count > 0 {
|
||||
x = 0;
|
||||
while x as isize <= width - 4 * lanes as isize {
|
||||
let sptr = src[y].add(x);
|
||||
let mut s0 = vload(sptr);
|
||||
let mut s1 = vload(sptr.add(lanes));
|
||||
let mut s2 = vload(sptr.add(2 * lanes));
|
||||
let mut s3 = vload(sptr.add(3 * lanes));
|
||||
|
||||
for k in 1..ksize {
|
||||
let sptr = src[y + k].add(x);
|
||||
s0 = Op::apply::<S>(s0, vload(sptr));
|
||||
s1 = Op::apply::<S>(s1, vload(sptr.add(lanes)));
|
||||
s2 = Op::apply::<S>(s2, vload(sptr.add(2 * lanes)));
|
||||
s3 = Op::apply::<S>(s3, vload(sptr.add(3 * lanes)));
|
||||
}
|
||||
|
||||
vstore_unaligned(dst.add(x), s0);
|
||||
vstore_unaligned(dst.add(x + lanes), s1);
|
||||
vstore_unaligned(dst.add(x + 2 * lanes), s2);
|
||||
vstore_unaligned(dst.add(x + 3 * lanes), s3);
|
||||
|
||||
x += 4 * lanes;
|
||||
}
|
||||
if x as isize <= width - 2 * lanes as isize {
|
||||
let sptr = src[y].add(x);
|
||||
let mut s0 = vload(sptr);
|
||||
let mut s1 = vload(sptr.add(lanes));
|
||||
|
||||
for k in 1..ksize {
|
||||
let sptr = src[y + k].add(x);
|
||||
s0 = Op::apply::<S>(s0, vload(sptr));
|
||||
s1 = Op::apply::<S>(s1, vload(sptr.add(lanes)));
|
||||
}
|
||||
|
||||
vstore_unaligned(dst.add(x), s0);
|
||||
vstore_unaligned(dst.add(x + lanes), s1);
|
||||
x += 2 * lanes;
|
||||
}
|
||||
if x as isize <= width - lanes as isize {
|
||||
let mut s0 = vload(src[y].add(x));
|
||||
|
||||
for k in 1..ksize {
|
||||
s0 = Op::apply::<S>(s0, vload(src[y + k].add(x)));
|
||||
}
|
||||
|
||||
vstore_unaligned(dst.add(x), s0);
|
||||
x += lanes;
|
||||
}
|
||||
if x as isize <= width - lanes as isize / 2 {
|
||||
let mut s0 = vload_low(src[y].add(x));
|
||||
|
||||
for k in 1..ksize {
|
||||
s0 = Op::apply::<S>(s0, vload_low(src[y + k].add(x)));
|
||||
}
|
||||
|
||||
vstore_low(dst.add(x), s0);
|
||||
x += lanes / 2;
|
||||
}
|
||||
|
||||
count -= 1;
|
||||
dst = dst.add(dst_step);
|
||||
y += 1;
|
||||
}
|
||||
}
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MorphColumnFilter<T: Scalar, Op: MorphOperator<T>, VecOp: VecColumn<T>> {
|
||||
pub ksize: usize,
|
||||
pub anchor: usize,
|
||||
vec: VecOp,
|
||||
_t: PhantomData<T>,
|
||||
_op: PhantomData<Op>,
|
||||
}
|
||||
|
||||
impl<T: Scalar, Op: MorphOperator<T>, VecOp: VecColumn<T>> MorphColumnFilter<T, Op, VecOp> {
|
||||
pub fn new(ksize: usize, anchor: usize) -> Self {
|
||||
let vec = VecOp::new(ksize, anchor);
|
||||
Self {
|
||||
ksize,
|
||||
anchor,
|
||||
vec,
|
||||
_t: PhantomData,
|
||||
_op: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn apply<S: Simd>(
|
||||
&self,
|
||||
|
||||
src: &[*const T],
|
||||
dst: &mut [T],
|
||||
dst_step: usize,
|
||||
mut count: usize,
|
||||
width: usize,
|
||||
) {
|
||||
let ksize = self.ksize;
|
||||
let x0 = self.vec.apply::<S>(src, dst, dst_step, count, width);
|
||||
let width = width as isize;
|
||||
|
||||
let mut d = 0;
|
||||
let mut x;
|
||||
let mut y = 0;
|
||||
|
||||
let slice = |row: *const T| unsafe { slice::from_raw_parts(row, width as usize) };
|
||||
|
||||
while ksize > 1 && count > 1 {
|
||||
x = x0;
|
||||
|
||||
while x as isize <= width - 4 {
|
||||
let row = slice(src[y + 1]);
|
||||
let mut s0 = row[x];
|
||||
let mut s1 = row[x + 1];
|
||||
let mut s2 = row[x + 2];
|
||||
let mut s3 = row[x + 3];
|
||||
|
||||
for k in 2..ksize {
|
||||
let row = slice(src[y + k]);
|
||||
s0 = Op::apply(s0, row[x]);
|
||||
s1 = Op::apply(s1, row[x + 1]);
|
||||
s2 = Op::apply(s2, row[x + 2]);
|
||||
s3 = Op::apply(s3, row[x + 3]);
|
||||
}
|
||||
|
||||
let row = slice(src[y]);
|
||||
dst[d + x] = Op::apply(s0, row[x]);
|
||||
dst[d + x + 1] = Op::apply(s1, row[x + 1]);
|
||||
dst[d + x + 2] = Op::apply(s2, row[x + 2]);
|
||||
dst[d + x + 3] = Op::apply(s3, row[x + 3]);
|
||||
|
||||
let row = slice(src[y + ksize]);
|
||||
dst[d + dst_step + x] = Op::apply(s0, row[x]);
|
||||
dst[d + dst_step + x + 1] = Op::apply(s1, row[x + 1]);
|
||||
dst[d + dst_step + x + 2] = Op::apply(s2, row[x + 2]);
|
||||
dst[d + dst_step + x + 3] = Op::apply(s3, row[x + 3]);
|
||||
|
||||
x += 4;
|
||||
}
|
||||
while (x as isize) < width {
|
||||
let mut s0 = slice(src[y + 1])[x];
|
||||
for k in 2..ksize {
|
||||
s0 = Op::apply(s0, slice(src[y + k])[x]);
|
||||
}
|
||||
dst[d + x] = Op::apply(s0, slice(src[y])[x]);
|
||||
dst[d + dst_step + x] = Op::apply(s0, slice(src[y + ksize])[x]);
|
||||
|
||||
x += 1;
|
||||
}
|
||||
|
||||
count -= 2;
|
||||
d += 2 * dst_step;
|
||||
y += 2;
|
||||
}
|
||||
|
||||
while count > 0 {
|
||||
x = x0;
|
||||
|
||||
while x as isize <= width - 4 {
|
||||
let row = slice(src[y]);
|
||||
let mut s0 = row[x];
|
||||
let mut s1 = row[x + 1];
|
||||
let mut s2 = row[x + 2];
|
||||
let mut s3 = row[x + 3];
|
||||
|
||||
for k in 1..ksize {
|
||||
let row = slice(src[y + k]);
|
||||
s0 = Op::apply(s0, row[x]);
|
||||
s1 = Op::apply(s1, row[x + 1]);
|
||||
s2 = Op::apply(s2, row[x + 2]);
|
||||
s3 = Op::apply(s3, row[x + 3]);
|
||||
}
|
||||
|
||||
dst[d + x] = s0;
|
||||
dst[d + x + 1] = s1;
|
||||
dst[d + x + 2] = s2;
|
||||
dst[d + x + 3] = s3;
|
||||
|
||||
x += 4;
|
||||
}
|
||||
while (x as isize) < width {
|
||||
let mut s0 = slice(src[y])[x];
|
||||
for k in 1..ksize {
|
||||
s0 = Op::apply(s0, slice(src[y + k])[x]);
|
||||
}
|
||||
|
||||
dst[d + x] = s0;
|
||||
|
||||
x += 1;
|
||||
}
|
||||
|
||||
count -= 1;
|
||||
d += dst_step;
|
||||
y += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecFilter<T: Scalar> {
|
||||
fn apply<S: Simd>(src: &[*const T], nz: usize, dst: &mut [T], width: usize) -> usize;
|
||||
}
|
||||
|
||||
pub struct MorphVec<T: Scalar, Op: VecMorphOperator<T>>(PhantomData<(T, Op)>);
|
||||
|
||||
impl<T: Scalar, Op: VecMorphOperator<T>> VecFilter<T> for MorphVec<T, Op> {
|
||||
fn apply<S: Simd>(src: &[*const T], nz: usize, dst: &mut [T], width: usize) -> usize {
|
||||
let dst = dst.as_mut_ptr();
|
||||
let mut i = 0;
|
||||
let lanes = T::lanes::<S>();
|
||||
let width = width as isize;
|
||||
|
||||
// Safety: everything here is unsafe. Test thoroughly.
|
||||
unsafe {
|
||||
while i as isize <= width - 4 * lanes as isize {
|
||||
let sptr = src[0].add(i);
|
||||
let mut s0 = vload_unaligned(sptr);
|
||||
let mut s1 = vload_unaligned(sptr.add(lanes));
|
||||
let mut s2 = vload_unaligned(sptr.add(2 * lanes));
|
||||
let mut s3 = vload_unaligned(sptr.add(3 * lanes));
|
||||
for sptr in src[1..nz].iter().map(|sptr| sptr.add(i)) {
|
||||
s0 = Op::apply::<S>(s0, vload_unaligned(sptr));
|
||||
s1 = Op::apply::<S>(s1, vload_unaligned(sptr.add(lanes)));
|
||||
s2 = Op::apply::<S>(s2, vload_unaligned(sptr.add(2 * lanes)));
|
||||
s3 = Op::apply::<S>(s3, vload_unaligned(sptr.add(3 * lanes)));
|
||||
}
|
||||
vstore_unaligned(dst.add(i), s0);
|
||||
vstore_unaligned(dst.add(i + lanes), s1);
|
||||
vstore_unaligned(dst.add(i + 2 * lanes), s2);
|
||||
vstore_unaligned(dst.add(i + 3 * lanes), s3);
|
||||
i += 4 * lanes;
|
||||
}
|
||||
if i as isize <= width - 2 * lanes as isize {
|
||||
let sptr = src[0].add(i);
|
||||
let mut s0 = vload_unaligned(sptr);
|
||||
let mut s1 = vload_unaligned(sptr.add(lanes));
|
||||
for sptr in src[1..nz].iter().map(|sptr| sptr.add(i)) {
|
||||
s0 = Op::apply::<S>(s0, vload_unaligned(sptr));
|
||||
s1 = Op::apply::<S>(s1, vload_unaligned(sptr.add(lanes)));
|
||||
}
|
||||
vstore_unaligned(dst.add(i), s0);
|
||||
vstore_unaligned(dst.add(i + lanes), s1);
|
||||
i += 2 * lanes;
|
||||
}
|
||||
if i as isize <= width - lanes as isize {
|
||||
let mut s0 = vload_unaligned(src[0].add(i));
|
||||
for sptr in src[1..nz].iter().map(|sptr| sptr.add(i)) {
|
||||
s0 = Op::apply::<S>(s0, vload_unaligned(sptr));
|
||||
}
|
||||
vstore_unaligned(dst.add(i), s0);
|
||||
i += lanes;
|
||||
}
|
||||
if i as isize <= width - lanes as isize / 2 {
|
||||
let mut s = vload_low(src[0].add(i));
|
||||
for sptr in src[1..nz].iter().map(|sptr| sptr.add(i)) {
|
||||
s = Op::apply::<S>(s, vload_low(sptr));
|
||||
}
|
||||
vstore_low(dst.add(i), s);
|
||||
i += lanes / 2;
|
||||
}
|
||||
}
|
||||
i
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MorphFilter<T: Scalar, Op: MorphOperator<T>, VecOp: VecFilter<T>> {
|
||||
pub ksize: Size,
|
||||
pub anchor: Point,
|
||||
coords: Vec<Point>,
|
||||
ptrs: Vec<*const T>,
|
||||
_op: PhantomData<(Op, VecOp)>,
|
||||
}
|
||||
|
||||
impl<T: Scalar, Op: MorphOperator<T>, VecOp: VecFilter<T>> MorphFilter<T, Op, VecOp> {
|
||||
pub fn new<B: Element>(kernel: &[B], ksize: Size, anchor: Point) -> Self {
|
||||
let coords = process_2d_kernel(kernel, ksize);
|
||||
let ptrs = vec![null(); coords.len()];
|
||||
|
||||
Self {
|
||||
ksize,
|
||||
anchor,
|
||||
coords,
|
||||
ptrs,
|
||||
_op: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn apply<S: Simd>(
|
||||
&mut self,
|
||||
|
||||
src: &[*const T],
|
||||
dst: &mut [T],
|
||||
dst_step: usize,
|
||||
mut count: usize,
|
||||
width: usize,
|
||||
ch: usize,
|
||||
) {
|
||||
let nz = self.coords.len();
|
||||
let width = (width * ch) as isize;
|
||||
let pt = &self.coords;
|
||||
let kp = &mut self.ptrs;
|
||||
|
||||
let mut dst_off = 0;
|
||||
let mut src_off = 0;
|
||||
|
||||
let slice = |ptr: *const T| unsafe { slice::from_raw_parts(ptr, width as usize) };
|
||||
|
||||
unsafe {
|
||||
while count > 0 {
|
||||
for k in 0..nz {
|
||||
kp[k] = src[src_off + pt[k].y].add(pt[k].x * ch);
|
||||
}
|
||||
|
||||
let mut i = VecOp::apply::<S>(kp, nz, &mut dst[dst_off..], width as usize);
|
||||
while i as isize <= width - 4 {
|
||||
let sptr = slice(kp[0].add(i));
|
||||
let mut s0 = sptr[0];
|
||||
let mut s1 = sptr[1];
|
||||
let mut s2 = sptr[2];
|
||||
let mut s3 = sptr[3];
|
||||
|
||||
for sptr in kp[1..nz].iter().map(|sptr| slice(sptr.add(i))) {
|
||||
s0 = Op::apply(s0, sptr[0]);
|
||||
s1 = Op::apply(s1, sptr[1]);
|
||||
s2 = Op::apply(s2, sptr[2]);
|
||||
s3 = Op::apply(s3, sptr[3]);
|
||||
}
|
||||
|
||||
dst[dst_off + i] = s0;
|
||||
dst[dst_off + i + 1] = s1;
|
||||
dst[dst_off + i + 2] = s2;
|
||||
dst[dst_off + i + 3] = s3;
|
||||
i += 4;
|
||||
}
|
||||
for i in i..width as usize {
|
||||
let mut s0 = *kp[0].add(i);
|
||||
for v in kp[1..nz].iter().map(|sptr| *sptr.add(i)) {
|
||||
s0 = Op::apply(s0, v);
|
||||
}
|
||||
dst[dst_off + i] = s0;
|
||||
}
|
||||
|
||||
count -= 1;
|
||||
dst_off += dst_step;
|
||||
src_off += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn process_2d_kernel<B: Element>(kernel: &[B], ksize: Size) -> Vec<Point> {
|
||||
let Size { width, height } = ksize;
|
||||
|
||||
let mut nz = kernel.iter().filter(|it| it.to_bool()).count();
|
||||
if nz == 0 {
|
||||
nz = 1;
|
||||
}
|
||||
|
||||
let mut coords = vec![Point::new(0, 0); nz];
|
||||
let mut k = 0;
|
||||
|
||||
for y in 0..height {
|
||||
let krow = &kernel[y * width..];
|
||||
for (x, _) in krow[..width].iter().enumerate().filter(|it| it.1.to_bool()) {
|
||||
coords[k] = Point::new(x, y);
|
||||
k += 1;
|
||||
}
|
||||
}
|
||||
|
||||
coords
|
||||
}
|
||||
@@ -0,0 +1,415 @@
|
||||
use std::{fmt::Debug, ptr::null_mut};
|
||||
|
||||
use burn_tensor::Shape;
|
||||
use bytemuck::{Zeroable, cast_slice, cast_slice_mut};
|
||||
use macerator::{Simd, VOrd, Vector};
|
||||
|
||||
use crate::{BorderType, Point, Size};
|
||||
|
||||
use super::filter::{
|
||||
MorphColumnFilter, MorphColumnVec, MorphFilter, MorphOperator, MorphRowFilter, MorphRowVec,
|
||||
MorphVec, VecMorphOperator,
|
||||
};
|
||||
|
||||
pub type RowFilter<T, Op> = MorphRowFilter<T, Op, MorphRowVec<T, Op>>;
|
||||
pub type ColFilter<T, Op> = MorphColumnFilter<T, Op, MorphColumnVec<T, Op>>;
|
||||
pub type Filter2D<T, Op> = MorphFilter<T, Op, MorphVec<T, Op>>;
|
||||
|
||||
pub enum Filter<T: VOrd, Op: MorphOperator<T> + VecMorphOperator<T>> {
|
||||
Separable {
|
||||
row_filter: RowFilter<T, Op>,
|
||||
col_filter: ColFilter<T, Op>,
|
||||
},
|
||||
Fallback(Filter2D<T, Op>),
|
||||
}
|
||||
|
||||
pub struct FilterEngine<S: Simd, T: VOrd, Op: MorphOperator<T> + VecMorphOperator<T>> {
|
||||
/// Vector aligned ring buffer to serve as intermediate, since image isn't always aligned
|
||||
ring_buf: Vec<Vector<S, T>>,
|
||||
/// Vector aligned row buffer to serve as intermediate, since image isn't always aligned
|
||||
src_row: Vec<Vector<S, T>>,
|
||||
const_border_value: Vec<T>,
|
||||
const_border_row: Vec<Vector<S, T>>,
|
||||
border_table: Vec<usize>,
|
||||
/// Pointers to each row offset in the ring buffer
|
||||
rows: Vec<*const T>,
|
||||
|
||||
filter: Filter<T, Op>,
|
||||
|
||||
ksize: Size,
|
||||
anchor: Point,
|
||||
dx1: usize,
|
||||
dx2: usize,
|
||||
row_count: usize,
|
||||
dst_y: usize,
|
||||
start_y: usize,
|
||||
start_y_0: usize,
|
||||
end_y: usize,
|
||||
|
||||
max_width: usize,
|
||||
buf_step: usize,
|
||||
width: usize,
|
||||
height: usize,
|
||||
border_type: BorderType,
|
||||
}
|
||||
|
||||
impl<S: Simd, T: VOrd, Op: MorphOperator<T> + VecMorphOperator<T>> FilterEngine<S, T, Op> {
|
||||
fn resize_ring_buf(&mut self, size: usize) {
|
||||
let actual = size.div_ceil(T::lanes::<S>());
|
||||
self.ring_buf.resize(actual, Zeroable::zeroed());
|
||||
}
|
||||
fn resize_src_row(&mut self, size: usize) {
|
||||
let actual = size.div_ceil(T::lanes::<S>());
|
||||
self.src_row.resize(actual, Zeroable::zeroed());
|
||||
}
|
||||
fn is_separable(&self) -> bool {
|
||||
matches!(self.filter, Filter::Separable { .. })
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Simd, T: VOrd + Debug, Op: MorphOperator<T> + VecMorphOperator<T>> FilterEngine<S, T, Op> {
|
||||
pub fn new(
|
||||
filter: Filter<T, Op>,
|
||||
border_type: BorderType,
|
||||
border_value: &[T],
|
||||
ch: usize,
|
||||
) -> Self {
|
||||
let (ksize, anchor) = match &filter {
|
||||
Filter::Separable {
|
||||
row_filter,
|
||||
col_filter,
|
||||
} => {
|
||||
let ksize = Size::new(row_filter.ksize, col_filter.ksize);
|
||||
let anchor = Point::new(row_filter.anchor, col_filter.anchor);
|
||||
(ksize, anchor)
|
||||
}
|
||||
Filter::Fallback(f) => (f.ksize, f.anchor),
|
||||
};
|
||||
|
||||
let mut border_table = Vec::new();
|
||||
let border_length = (ksize.width - 1).max(1);
|
||||
let mut const_border_value = Vec::new();
|
||||
if matches!(border_type, BorderType::Constant) {
|
||||
const_border_value = vec![Zeroable::zeroed(); border_length * ch];
|
||||
for elem in cast_slice_mut::<_, T>(&mut const_border_value).chunks_exact_mut(ch) {
|
||||
elem.copy_from_slice(border_value);
|
||||
}
|
||||
} else {
|
||||
border_table = vec![0; border_length * ch];
|
||||
}
|
||||
|
||||
Self {
|
||||
ring_buf: Default::default(),
|
||||
src_row: Default::default(),
|
||||
rows: Default::default(),
|
||||
border_type,
|
||||
const_border_row: Default::default(),
|
||||
const_border_value,
|
||||
border_table,
|
||||
ksize,
|
||||
anchor,
|
||||
filter,
|
||||
max_width: 0,
|
||||
buf_step: 0,
|
||||
dx1: 0,
|
||||
dx2: 0,
|
||||
row_count: 0,
|
||||
dst_y: 0,
|
||||
start_y: 0,
|
||||
start_y_0: 0,
|
||||
end_y: 0,
|
||||
width: 0,
|
||||
height: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn apply(&mut self, tensor: &mut [T], src_shape: Shape) {
|
||||
let [_, w, ch] = src_shape.dims();
|
||||
let src_step = w * ch;
|
||||
self.start(src_shape);
|
||||
let y = self.start_y;
|
||||
self.proceed(
|
||||
&mut tensor[y * src_step..],
|
||||
src_step,
|
||||
self.end_y - self.start_y,
|
||||
ch,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn start(&mut self, shape: Shape) -> usize {
|
||||
let [height, width, ch] = shape.dims();
|
||||
|
||||
let max_buf_rows = (self.ksize.height + 3)
|
||||
.max(self.anchor.y)
|
||||
.max((self.ksize.height - self.anchor.y - 1) * 2 + 1);
|
||||
let k_offs = if !self.is_separable() {
|
||||
self.ksize.width - 1
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let is_sep = self.is_separable();
|
||||
|
||||
if self.max_width < width || max_buf_rows != self.rows.len() {
|
||||
self.rows.resize(max_buf_rows, null_mut());
|
||||
self.max_width = self.max_width.max(width);
|
||||
self.resize_src_row((self.max_width + self.ksize.width - 1) * ch);
|
||||
|
||||
if matches!(self.border_type, BorderType::Constant) {
|
||||
self.const_border_row.resize(
|
||||
((self.max_width + self.ksize.width - 1) * ch).div_ceil(T::lanes::<S>()),
|
||||
Zeroable::zeroed(),
|
||||
);
|
||||
let mut n = self.const_border_value.len();
|
||||
let n1 = (self.max_width + self.ksize.width - 1) * ch;
|
||||
let const_val = &self.const_border_value;
|
||||
let dst = cast_slice_mut(&mut self.const_border_row);
|
||||
let t_dst = if is_sep {
|
||||
cast_slice_mut::<_, T>(&mut self.src_row)
|
||||
} else {
|
||||
alias_slice_mut(dst)
|
||||
};
|
||||
|
||||
for i in (0..n1).step_by(n) {
|
||||
n = n.min(n1 - i);
|
||||
t_dst[i..i + n].copy_from_slice(&const_val[..n]);
|
||||
}
|
||||
|
||||
if let Filter::Separable { row_filter, .. } = &self.filter {
|
||||
row_filter.apply::<S>(cast_slice(&self.src_row), dst, self.max_width, ch);
|
||||
}
|
||||
}
|
||||
|
||||
let max_buf_step =
|
||||
(self.max_width + k_offs).next_multiple_of(align_of::<Vector<S, T>>()) * ch;
|
||||
|
||||
self.resize_ring_buf(max_buf_step * self.rows.len());
|
||||
}
|
||||
|
||||
let const_val = &self.const_border_value;
|
||||
|
||||
self.buf_step = (width + k_offs).next_multiple_of(align_of::<Vector<S, T>>()) * ch;
|
||||
|
||||
self.dx1 = self.anchor.x;
|
||||
self.dx2 = self.ksize.width - self.anchor.x - 1;
|
||||
|
||||
if self.dx1 > 0 || self.dx2 > 0 {
|
||||
if matches!(self.border_type, BorderType::Constant) {
|
||||
let nr = if self.is_separable() {
|
||||
1
|
||||
} else {
|
||||
self.rows.len()
|
||||
};
|
||||
for i in 0..nr {
|
||||
let dst = if self.is_separable() {
|
||||
cast_slice_mut::<_, T>(&mut self.src_row)
|
||||
} else {
|
||||
&mut cast_slice_mut::<_, T>(&mut self.ring_buf)[self.buf_step * i..]
|
||||
};
|
||||
memcpy(dst, const_val, self.dx1 * ch);
|
||||
let right = (width + self.ksize.width - 1 - self.dx2) * ch;
|
||||
memcpy(&mut dst[right..], const_val, self.dx2 * ch);
|
||||
}
|
||||
} else {
|
||||
for i in 0..self.dx1 as isize {
|
||||
let p0 = border_interpolate(i - self.dx1 as isize, width, self.border_type);
|
||||
let p0 = p0 as usize * ch;
|
||||
for j in 0..ch {
|
||||
self.border_table[i as usize * ch + j] = p0 + j;
|
||||
}
|
||||
}
|
||||
for i in 0..self.dx2 {
|
||||
let p0 = border_interpolate((width + i) as isize, width, self.border_type)
|
||||
as usize
|
||||
* ch;
|
||||
for j in 0..ch {
|
||||
self.border_table[(i + self.dx1) * ch + j] = p0 + j;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.row_count = 0;
|
||||
self.dst_y = 0;
|
||||
self.start_y = 0;
|
||||
self.start_y_0 = 0;
|
||||
self.end_y = height;
|
||||
self.width = width;
|
||||
self.height = height;
|
||||
|
||||
self.start_y
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn proceed(
|
||||
&mut self,
|
||||
|
||||
src: &mut [T],
|
||||
src_step: usize,
|
||||
mut count: usize,
|
||||
ch: usize,
|
||||
) -> usize {
|
||||
let buf_rows = self.rows.len();
|
||||
let kheight = self.ksize.height;
|
||||
let kwidth = self.ksize.width;
|
||||
let ay = self.anchor.y as isize;
|
||||
let dx1 = self.dx1;
|
||||
let dx2 = self.dx2;
|
||||
let width1 = self.width + kwidth - 1;
|
||||
let btab = &self.border_table;
|
||||
let make_border = (dx1 > 0 || dx2 > 0) && !matches!(self.border_type, BorderType::Constant);
|
||||
let is_sep = self.is_separable();
|
||||
|
||||
count = count.min(self.remaining_input_rows());
|
||||
let mut dst_off = 0;
|
||||
let mut src_off = 0;
|
||||
let mut dy = 0;
|
||||
let mut i;
|
||||
let brows = &mut self.rows;
|
||||
|
||||
let src_row = cast_slice_mut::<_, T>(&mut self.src_row);
|
||||
let ring_buf = cast_slice_mut::<_, T>(&mut self.ring_buf);
|
||||
|
||||
loop {
|
||||
let dcount = buf_rows as isize - ay - self.start_y as isize - self.row_count as isize;
|
||||
let mut dcount = if dcount > 0 {
|
||||
dcount as usize
|
||||
} else {
|
||||
buf_rows + 1 - kheight
|
||||
};
|
||||
dcount = dcount.min(count);
|
||||
count -= dcount;
|
||||
|
||||
while dcount > 0 {
|
||||
let bi = (self.start_y - self.start_y_0 + self.row_count) % buf_rows;
|
||||
let brow = &mut ring_buf[bi * self.buf_step..];
|
||||
let row = if is_sep {
|
||||
&mut src_row[..]
|
||||
} else {
|
||||
alias_slice_mut(brow)
|
||||
};
|
||||
|
||||
if self.row_count + 1 > buf_rows {
|
||||
self.row_count -= 1;
|
||||
self.start_y += 1;
|
||||
}
|
||||
self.row_count += 1;
|
||||
|
||||
memcpy(
|
||||
&mut row[dx1 * ch..],
|
||||
&src[src_off..],
|
||||
(width1 - dx2 - dx1) * ch,
|
||||
);
|
||||
|
||||
if make_border {
|
||||
for i in 0..dx1 * ch {
|
||||
row[i] = src[src_off + btab[i]];
|
||||
}
|
||||
for i in 0..dx2 * ch {
|
||||
row[i + (width1 - dx2) * ch] = src[src_off + btab[i + dx1 * ch]];
|
||||
}
|
||||
}
|
||||
|
||||
if let Filter::Separable { row_filter, .. } = &self.filter {
|
||||
row_filter.apply::<S>(row, brow, self.width, ch);
|
||||
}
|
||||
|
||||
dcount -= 1;
|
||||
src_off += src_step;
|
||||
}
|
||||
|
||||
let max_i = buf_rows.min(self.height - (self.dst_y + dy) + (kheight - 1));
|
||||
i = 0;
|
||||
while i < max_i {
|
||||
let src_y = border_interpolate(
|
||||
(self.dst_y + dy + i) as isize - ay,
|
||||
self.height,
|
||||
self.border_type,
|
||||
);
|
||||
if src_y < 0 {
|
||||
brows[i] = self.const_border_row.as_ptr() as _;
|
||||
} else {
|
||||
if src_y as usize >= self.start_y + self.row_count {
|
||||
break;
|
||||
}
|
||||
let bi = (src_y as usize - self.start_y_0) % buf_rows;
|
||||
brows[i] = unsafe { ring_buf.as_ptr().add(bi * self.buf_step) };
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
if i < kheight {
|
||||
break;
|
||||
}
|
||||
i -= kheight - 1;
|
||||
match &mut self.filter {
|
||||
Filter::Separable { col_filter, .. } => {
|
||||
col_filter.apply::<S>(brows, &mut src[dst_off..], src_step, i, self.width * ch)
|
||||
}
|
||||
Filter::Fallback(filter) => {
|
||||
filter.apply::<S>(brows, &mut src[dst_off..], src_step, i, self.width, ch)
|
||||
}
|
||||
}
|
||||
|
||||
dst_off += src_step * i;
|
||||
dy += i;
|
||||
}
|
||||
|
||||
self.dst_y += dy;
|
||||
dy
|
||||
}
|
||||
|
||||
fn remaining_input_rows(&self) -> usize {
|
||||
self.end_y - self.start_y - self.row_count
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn memcpy<T: Copy>(to: &mut [T], from: &[T], len: usize) {
|
||||
to[..len].copy_from_slice(&from[..len]);
|
||||
}
|
||||
|
||||
/// Unsafely alias slice. Needed for the conditional slice targets that depend on the filter. The
|
||||
/// same slice shouldn't be used multiple times at once
|
||||
fn alias_slice_mut<'b, T>(slice: &mut [T]) -> &'b mut [T] {
|
||||
let ptr = slice.as_mut_ptr();
|
||||
let len = slice.len();
|
||||
unsafe { core::slice::from_raw_parts_mut(ptr, len) }
|
||||
}
|
||||
|
||||
fn border_interpolate(mut p: isize, len: usize, btype: BorderType) -> isize {
|
||||
let len = len as isize;
|
||||
if p < len && p >= 0 {
|
||||
return p;
|
||||
}
|
||||
match btype {
|
||||
BorderType::Constant => -1,
|
||||
BorderType::Replicate if p < 0 => 0,
|
||||
BorderType::Replicate => len - 1,
|
||||
BorderType::Reflect | BorderType::Reflect101 => {
|
||||
let delta = matches!(btype, BorderType::Reflect101) as isize;
|
||||
if len == 1 {
|
||||
return 0;
|
||||
}
|
||||
loop {
|
||||
if p < 0 {
|
||||
p = -p - 1 + delta;
|
||||
} else {
|
||||
p = len - 1 - (p - len) - delta;
|
||||
}
|
||||
if p < len && p >= 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
p
|
||||
}
|
||||
BorderType::Wrap => {
|
||||
if p < 0 {
|
||||
p -= ((p - len + 1) / len) * len;
|
||||
}
|
||||
if p >= len {
|
||||
p %= len;
|
||||
}
|
||||
p
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,300 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use burn_tensor::{
|
||||
BasicOps, Bool, DType, Element, Shape, Tensor, TensorData, backend::Backend, cast::ToElement,
|
||||
ops::BoolTensor,
|
||||
};
|
||||
use filter::{MaxOp, MinOp, MorphOperator, VecMorphOperator};
|
||||
use filter_engine::{ColFilter, Filter, Filter2D, FilterEngine, RowFilter};
|
||||
use macerator::{Simd, VOrd};
|
||||
|
||||
use crate::{BorderType, MorphOptions, Point, Size};
|
||||
|
||||
use super::MinMax;
|
||||
|
||||
mod filter;
|
||||
mod filter_engine;
|
||||
|
||||
/// A morphology operation.
|
||||
/// TODO: Implement composite ops
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum MorphOp {
|
||||
Erode,
|
||||
Dilate,
|
||||
}
|
||||
|
||||
pub enum MorphKernel<B: Element> {
|
||||
Rect {
|
||||
size: Size,
|
||||
anchor: Point,
|
||||
},
|
||||
Other {
|
||||
kernel: Vec<B>,
|
||||
size: Size,
|
||||
anchor: Point,
|
||||
},
|
||||
}
|
||||
|
||||
pub fn morph<B: Backend, K: BasicOps<B>>(
|
||||
input: Tensor<B, 3, K>,
|
||||
kernel: BoolTensor<B>,
|
||||
op: MorphOp,
|
||||
opts: MorphOptions<B, K>,
|
||||
) -> Tensor<B, 3, K> {
|
||||
let device = input.device();
|
||||
|
||||
let kernel = Tensor::<B, 2, Bool>::new(kernel);
|
||||
let kshape = kernel.shape().dims();
|
||||
let [kh, kw] = kshape;
|
||||
|
||||
let kernel = kernel.into_data().into_vec::<B::BoolElem>().unwrap();
|
||||
let is_rect = kernel.iter().all(|it| it.to_bool());
|
||||
let anchor = opts.anchor.unwrap_or(Point::new(kw / 2, kh / 2));
|
||||
let iter = opts.iterations;
|
||||
let btype = opts.border_type;
|
||||
let bvalue = opts.border_value.map(|it| it.into_data());
|
||||
|
||||
let size = Size::new(kw, kh);
|
||||
let kernel = if is_rect {
|
||||
MorphKernel::Rect { size, anchor }
|
||||
} else {
|
||||
MorphKernel::Other {
|
||||
kernel,
|
||||
size,
|
||||
anchor,
|
||||
}
|
||||
};
|
||||
|
||||
let shape = input.shape();
|
||||
let data = input.into_data();
|
||||
match data.dtype {
|
||||
DType::F64 => {
|
||||
morph_typed::<B, K, f64>(data, shape, kernel, op, iter, btype, bvalue, &device)
|
||||
}
|
||||
DType::F32 | DType::Flex32 => {
|
||||
morph_typed::<B, K, f32>(data, shape, kernel, op, iter, btype, bvalue, &device)
|
||||
}
|
||||
DType::F16 | DType::BF16 => morph_typed::<B, K, f32>(
|
||||
data.convert::<f32>(),
|
||||
shape,
|
||||
kernel,
|
||||
op,
|
||||
iter,
|
||||
btype,
|
||||
bvalue,
|
||||
&device,
|
||||
),
|
||||
DType::I64 => {
|
||||
morph_typed::<B, K, i64>(data, shape, kernel, op, iter, btype, bvalue, &device)
|
||||
}
|
||||
DType::I32 => {
|
||||
morph_typed::<B, K, i32>(data, shape, kernel, op, iter, btype, bvalue, &device)
|
||||
}
|
||||
DType::I16 => {
|
||||
morph_typed::<B, K, i16>(data, shape, kernel, op, iter, btype, bvalue, &device)
|
||||
}
|
||||
DType::I8 => morph_typed::<B, K, i8>(data, shape, kernel, op, iter, btype, bvalue, &device),
|
||||
DType::U64 => {
|
||||
morph_typed::<B, K, u64>(data, shape, kernel, op, iter, btype, bvalue, &device)
|
||||
}
|
||||
DType::U32 => {
|
||||
morph_typed::<B, K, u32>(data, shape, kernel, op, iter, btype, bvalue, &device)
|
||||
}
|
||||
DType::U16 => {
|
||||
morph_typed::<B, K, u16>(data, shape, kernel, op, iter, btype, bvalue, &device)
|
||||
}
|
||||
DType::U8 => morph_typed::<B, K, u8>(data, shape, kernel, op, iter, btype, bvalue, &device),
|
||||
DType::Bool => morph_bool::<B, K>(data, shape, kernel, op, iter, btype, bvalue, &device),
|
||||
DType::QFloat(_) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn morph_typed<B: Backend, K: BasicOps<B>, T: VOrd + MinMax + Element>(
|
||||
mut input: TensorData,
|
||||
shape: Shape,
|
||||
kernel: MorphKernel<B::BoolElem>,
|
||||
op: MorphOp,
|
||||
iter: usize,
|
||||
btype: BorderType,
|
||||
bvalue: Option<TensorData>,
|
||||
device: &B::Device,
|
||||
) -> Tensor<B, 3, K> {
|
||||
let data = input.as_mut_slice::<T>().unwrap();
|
||||
let bvalue = border_value(btype, bvalue, op, &shape);
|
||||
run_morph(data, shape, kernel, op, iter, btype, &bvalue);
|
||||
Tensor::from_data(input, device)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn morph_bool<B: Backend, K: BasicOps<B>>(
|
||||
mut input: TensorData,
|
||||
shape: Shape,
|
||||
kernel: MorphKernel<B::BoolElem>,
|
||||
op: MorphOp,
|
||||
iter: usize,
|
||||
btype: BorderType,
|
||||
bvalue: Option<TensorData>,
|
||||
device: &B::Device,
|
||||
) -> Tensor<B, 3, K> {
|
||||
let data = input.as_mut_slice::<bool>().unwrap();
|
||||
// SAFETY: Morph can't produce invalid boolean values
|
||||
let data = unsafe { core::mem::transmute::<&mut [bool], &mut [u8]>(data) };
|
||||
let bvalue = border_value(btype, bvalue, op, &shape);
|
||||
run_morph(data, shape.clone(), kernel, op, iter, btype, &bvalue);
|
||||
Tensor::from_data(input, device)
|
||||
}
|
||||
|
||||
fn border_value<T: Element>(
|
||||
btype: BorderType,
|
||||
bvalue: Option<TensorData>,
|
||||
op: MorphOp,
|
||||
shape: &Shape,
|
||||
) -> Vec<T> {
|
||||
let [_, _, ch] = shape.dims();
|
||||
match (btype, bvalue) {
|
||||
(BorderType::Constant, Some(value)) => value.convert::<T>().into_vec().unwrap(),
|
||||
(BorderType::Constant, None) => match op {
|
||||
MorphOp::Erode => vec![T::MAX; ch],
|
||||
MorphOp::Dilate => vec![T::MIN; ch],
|
||||
},
|
||||
_ => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn run_morph<T: VOrd + MinMax + Element, B: Element>(
|
||||
input: &mut [T],
|
||||
shape: Shape,
|
||||
kernel: MorphKernel<B>,
|
||||
op: MorphOp,
|
||||
iter: usize,
|
||||
btype: BorderType,
|
||||
bvalue: &[T],
|
||||
) {
|
||||
match op {
|
||||
MorphOp::Erode => {
|
||||
let filter = filter::<T, MinOp, B>(kernel);
|
||||
dispatch_morph(input, shape, filter, btype, bvalue, iter);
|
||||
}
|
||||
MorphOp::Dilate => {
|
||||
let filter = filter::<T, MaxOp, B>(kernel);
|
||||
dispatch_morph(input, shape, filter, btype, bvalue, iter);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn filter<T: VOrd + MinMax, Op: MorphOperator<T> + VecMorphOperator<T>, B: Element>(
|
||||
kernel: MorphKernel<B>,
|
||||
) -> Filter<T, Op> {
|
||||
match kernel {
|
||||
MorphKernel::Rect { size, anchor } => {
|
||||
let row_filter = RowFilter::new(size.width, anchor.x);
|
||||
let col_filter = ColFilter::new(size.height, anchor.y);
|
||||
Filter::Separable {
|
||||
row_filter,
|
||||
col_filter,
|
||||
}
|
||||
}
|
||||
MorphKernel::Other {
|
||||
kernel,
|
||||
size,
|
||||
anchor,
|
||||
} => {
|
||||
let filter = Filter2D::new(&kernel, size, anchor);
|
||||
Filter::Fallback(filter)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[macerator::with_simd]
|
||||
fn dispatch_morph<
|
||||
'a,
|
||||
S: Simd,
|
||||
T: VOrd + MinMax + Debug,
|
||||
Op: MorphOperator<T> + VecMorphOperator<T>,
|
||||
>(
|
||||
buffer: &'a mut [T],
|
||||
buffer_shape: Shape,
|
||||
filter: filter_engine::Filter<T, Op>,
|
||||
border_type: BorderType,
|
||||
border_value: &'a [T],
|
||||
iterations: usize,
|
||||
) where
|
||||
'a: 'a,
|
||||
{
|
||||
let [_, _, ch] = buffer_shape.dims();
|
||||
let mut engine = FilterEngine::<S, _, _>::new(filter, border_type, border_value, ch);
|
||||
engine.apply(buffer, buffer_shape.clone());
|
||||
for _ in 1..iterations {
|
||||
engine.apply(buffer, buffer_shape.clone());
|
||||
}
|
||||
}
|
||||
|
||||
/// Shape of the structuring element
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum KernelShape {
|
||||
/// Rectangular kernel
|
||||
Rect,
|
||||
/// Cross shaped kernel
|
||||
Cross,
|
||||
/// Ellipse shaped kernel
|
||||
Ellipse,
|
||||
}
|
||||
|
||||
/// Create a structuring element tensor for use with morphology ops
|
||||
pub fn create_structuring_element<B: Backend>(
|
||||
shape: KernelShape,
|
||||
ksize: Size,
|
||||
anchor: Option<Point>,
|
||||
device: &B::Device,
|
||||
) -> Tensor<B, 2, Bool> {
|
||||
fn create_kernel(shape: KernelShape, ksize: Size, anchor: Option<Point>) -> Vec<bool> {
|
||||
let anchor = anchor.unwrap_or(Point::new(ksize.width / 2, ksize.height / 2));
|
||||
let mut r = 0;
|
||||
let mut c = 0;
|
||||
let mut inv_r2 = 0.0;
|
||||
|
||||
if (ksize.width == 1 && ksize.height == 1) || shape == KernelShape::Rect {
|
||||
return vec![true; ksize.height * ksize.width];
|
||||
}
|
||||
|
||||
if shape == KernelShape::Ellipse {
|
||||
r = ksize.height / 2;
|
||||
c = ksize.width / 2;
|
||||
inv_r2 = if r > 0 { 1.0 / (r * r) as f64 } else { 0.0 }
|
||||
}
|
||||
|
||||
let mut elem = vec![false; ksize.height * ksize.width];
|
||||
|
||||
for i in 0..ksize.height {
|
||||
let mut j1 = 0;
|
||||
let mut j2 = 0;
|
||||
if shape == KernelShape::Cross && i == anchor.y {
|
||||
j2 = ksize.width;
|
||||
} else if shape == KernelShape::Cross {
|
||||
j1 = anchor.x;
|
||||
j2 = j1 + 1;
|
||||
} else {
|
||||
let dy = i as isize - r as isize;
|
||||
if dy.abs() <= r as isize {
|
||||
let dx = (c as f64 * ((r * r - (dy * dy) as usize) as f64 * inv_r2).sqrt())
|
||||
.round() as isize;
|
||||
j1 = (c as isize - dx).max(0) as usize;
|
||||
j2 = (c + dx as usize + 1).min(ksize.width);
|
||||
}
|
||||
}
|
||||
|
||||
for j in j1..j2 {
|
||||
elem[i * ksize.width + j] = true;
|
||||
}
|
||||
}
|
||||
elem
|
||||
}
|
||||
|
||||
let elem = create_kernel(shape, ksize, anchor);
|
||||
|
||||
let data = TensorData::new(elem, [ksize.height, ksize.width]);
|
||||
Tensor::from_data(data, device)
|
||||
}
|
||||
@@ -0,0 +1,212 @@
|
||||
use crate::NmsOptions;
|
||||
use aligned_vec::{AVec, ConstAlign};
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::{Int, Shape, Tensor, TensorData, backend::Backend};
|
||||
use macerator::{Scalar, Simd, Vector, vload};
|
||||
|
||||
/// Perform NMS on CPU using SIMD acceleration.
|
||||
///
|
||||
/// This implementation:
|
||||
/// 1. Sorts boxes by score (descending)
|
||||
/// 2. Iteratively selects the highest-scoring non-suppressed box
|
||||
/// 3. Suppresses all boxes with IoU > threshold using SIMD
|
||||
pub fn nms<B: Backend>(
|
||||
boxes: Tensor<B, 2>,
|
||||
scores: Tensor<B, 1>,
|
||||
options: NmsOptions,
|
||||
) -> Tensor<B, 1, Int> {
|
||||
let device = boxes.device();
|
||||
let [n_boxes, _] = boxes.shape().dims();
|
||||
if n_boxes == 0 {
|
||||
return Tensor::<B, 1, Int>::empty([0], &device);
|
||||
}
|
||||
|
||||
// Get raw data
|
||||
let boxes_data = boxes.to_data();
|
||||
let boxes_vec: Vec<f32> = boxes_data.to_vec().unwrap();
|
||||
|
||||
let scores_data = scores.to_data();
|
||||
let scores_vec: Vec<f32> = scores_data.to_vec().unwrap();
|
||||
|
||||
let keep = nms_vec(boxes_vec, scores_vec, options);
|
||||
let n_kept = keep.len();
|
||||
let indices_data = TensorData::new(keep, Shape::new([n_kept]));
|
||||
Tensor::<B, 1, Int>::from_data(indices_data, &device)
|
||||
}
|
||||
|
||||
/// Perform NMS on CPU using SIMD acceleration.
|
||||
fn nms_vec(boxes_vec: Vec<f32>, scores_vec: Vec<f32>, options: NmsOptions) -> Vec<i32> {
|
||||
let n_boxes = scores_vec.len();
|
||||
|
||||
if n_boxes == 0 {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
// Filter by score threshold first
|
||||
let mut filtered_indices = Vec::with_capacity(n_boxes);
|
||||
for (i, &score) in scores_vec.iter().enumerate() {
|
||||
if score >= options.score_threshold {
|
||||
filtered_indices.push(i); // original index
|
||||
}
|
||||
}
|
||||
|
||||
let n_filtered = filtered_indices.len();
|
||||
if n_filtered == 0 {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
// Sort by score descending
|
||||
filtered_indices.sort_by(|&a, &b| scores_vec[b].total_cmp(&scores_vec[a]));
|
||||
|
||||
const ALIGN: usize = 64;
|
||||
const FLOATS_PER_ALIGN: usize = ALIGN / size_of::<f32>(); // 16
|
||||
let stride = n_filtered.div_ceil(FLOATS_PER_ALIGN) * FLOATS_PER_ALIGN;
|
||||
let mut buf: AVec<f32, ConstAlign<64>> = AVec::with_capacity(ALIGN, stride * 5);
|
||||
buf.resize(stride * 5, 0.0);
|
||||
|
||||
let (x1s, rest) = buf.split_at_mut(stride);
|
||||
let (y1s, rest) = rest.split_at_mut(stride);
|
||||
let (x2s, rest) = rest.split_at_mut(stride);
|
||||
let (y2s, areas) = rest.split_at_mut(stride);
|
||||
|
||||
// Convert filtered boxes to SoA format
|
||||
for (j, &orig_idx) in filtered_indices.iter().enumerate() {
|
||||
let x1 = boxes_vec[orig_idx * 4];
|
||||
let y1 = boxes_vec[orig_idx * 4 + 1];
|
||||
let x2 = boxes_vec[orig_idx * 4 + 2];
|
||||
let y2 = boxes_vec[orig_idx * 4 + 3];
|
||||
x1s[j] = x1;
|
||||
y1s[j] = y1;
|
||||
x2s[j] = x2;
|
||||
y2s[j] = y2;
|
||||
areas[j] = (x2 - x1) * (y2 - y1);
|
||||
}
|
||||
|
||||
// Apply NMS with SIMD dispatch
|
||||
let mut suppressed = vec![false; stride];
|
||||
let mut keep = Vec::new();
|
||||
|
||||
for i in 0..n_filtered {
|
||||
if suppressed[i] {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Optimization to reduce inner loop comparisons
|
||||
suppressed[i] = true;
|
||||
keep.push(filtered_indices[i] as i32); // original index
|
||||
|
||||
if options.max_output_boxes > 0 && keep.len() >= options.max_output_boxes {
|
||||
break;
|
||||
}
|
||||
|
||||
// Suppress overlapping boxes using SIMD
|
||||
suppress_overlapping(
|
||||
x1s[i],
|
||||
y1s[i],
|
||||
x2s[i],
|
||||
y2s[i],
|
||||
areas[i],
|
||||
x1s,
|
||||
y1s,
|
||||
x2s,
|
||||
y2s,
|
||||
areas,
|
||||
&mut suppressed,
|
||||
stride,
|
||||
options.iou_threshold,
|
||||
);
|
||||
}
|
||||
|
||||
keep
|
||||
}
|
||||
|
||||
/// SIMD-accelerated suppression of overlapping boxes.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[inline(always)]
|
||||
#[macerator::with_simd]
|
||||
fn suppress_overlapping<'a, S: Simd>(
|
||||
ref_x1: f32,
|
||||
ref_y1: f32,
|
||||
ref_x2: f32,
|
||||
ref_y2: f32,
|
||||
ref_area: f32,
|
||||
x1s: &'a [f32],
|
||||
y1s: &'a [f32],
|
||||
x2s: &'a [f32],
|
||||
y2s: &'a [f32],
|
||||
areas: &'a [f32],
|
||||
suppressed: &'a mut [bool],
|
||||
n_boxes: usize, // stride, always multiple of lanes
|
||||
threshold: f32,
|
||||
) where
|
||||
'a: 'a,
|
||||
{
|
||||
let lanes = f32::lanes::<S>();
|
||||
|
||||
// Splat reference values
|
||||
let ref_x1_v: Vector<S, f32> = ref_x1.splat();
|
||||
let ref_y1_v: Vector<S, f32> = ref_y1.splat();
|
||||
let ref_x2_v: Vector<S, f32> = ref_x2.splat();
|
||||
let ref_y2_v: Vector<S, f32> = ref_y2.splat();
|
||||
let ref_area_v: Vector<S, f32> = ref_area.splat();
|
||||
let thresh_v: Vector<S, f32> = threshold.splat();
|
||||
let zero_v: Vector<S, f32> = 0.0f32.splat();
|
||||
|
||||
let mut i = 0;
|
||||
|
||||
let mut mask_buf = core::mem::MaybeUninit::<[bool; 16]>::uninit();
|
||||
// Process lanes boxes at a time with SIMD
|
||||
while i + lanes <= n_boxes {
|
||||
// Skip if all boxes in this chunk are already suppressed
|
||||
let all_suppressed = unsafe {
|
||||
match lanes {
|
||||
4 => *(suppressed.as_ptr().add(i) as *const u32) == 0x01010101,
|
||||
8 => *(suppressed.as_ptr().add(i) as *const u64) == 0x0101010101010101,
|
||||
16 => {
|
||||
*(suppressed.as_ptr().add(i) as *const u128)
|
||||
== 0x01010101010101010101010101010101
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
};
|
||||
|
||||
if !all_suppressed {
|
||||
let x1_v: Vector<S, f32> = unsafe { vload(x1s.as_ptr().add(i)) };
|
||||
let y1_v: Vector<S, f32> = unsafe { vload(y1s.as_ptr().add(i)) };
|
||||
let x2_v: Vector<S, f32> = unsafe { vload(x2s.as_ptr().add(i)) };
|
||||
let y2_v: Vector<S, f32> = unsafe { vload(y2s.as_ptr().add(i)) };
|
||||
let area_v: Vector<S, f32> = unsafe { vload(areas.as_ptr().add(i)) };
|
||||
|
||||
// Compute intersection coordinates
|
||||
let xx1 = ref_x1_v.max(x1_v);
|
||||
let yy1 = ref_y1_v.max(y1_v);
|
||||
let xx2 = ref_x2_v.min(x2_v);
|
||||
let yy2 = ref_y2_v.min(y2_v);
|
||||
|
||||
// Compute intersection area (clamp to 0 for non-overlapping)
|
||||
let w = (xx2 - xx1).max(zero_v);
|
||||
let h = (yy2 - yy1).max(zero_v);
|
||||
let inter = w * h;
|
||||
|
||||
// Compute IoU
|
||||
let union = ref_area_v + area_v - inter;
|
||||
let iou = inter / union;
|
||||
|
||||
// Get suppression mask (IoU > threshold)
|
||||
let suppress_mask = iou.gt(thresh_v);
|
||||
|
||||
// Extract mask to bool array and apply to suppressed
|
||||
// SAFETY: mask_store_as_bool writes exactly `lanes` bools, we only read 0..lanes
|
||||
unsafe { f32::mask_store_as_bool::<S>(mask_buf.as_mut_ptr().cast(), suppress_mask) };
|
||||
let mask_buf = unsafe { mask_buf.assume_init() };
|
||||
|
||||
for k in 0..lanes {
|
||||
if mask_buf[k] {
|
||||
suppressed[i + k] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
i += lanes;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
#[cfg(feature = "ndarray")]
|
||||
mod ndarray {
|
||||
use crate::{BoolVisionOps, FloatVisionOps, IntVisionOps, QVisionOps, VisionBackend};
|
||||
use burn_ndarray::{
|
||||
FloatNdArrayElement, IntNdArrayElement, NdArray, NdArrayTensor, QuantElement, SharedArray,
|
||||
};
|
||||
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolVisionOps
|
||||
for NdArray<E, I, Q>
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<E>>,
|
||||
NdArrayTensor: From<SharedArray<I>>,
|
||||
{
|
||||
}
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntVisionOps
|
||||
for NdArray<E, I, Q>
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<E>>,
|
||||
NdArrayTensor: From<SharedArray<I>>,
|
||||
{
|
||||
}
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatVisionOps
|
||||
for NdArray<E, I, Q>
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<E>>,
|
||||
NdArrayTensor: From<SharedArray<I>>,
|
||||
{
|
||||
}
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QVisionOps for NdArray<E, I, Q>
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<E>>,
|
||||
NdArrayTensor: From<SharedArray<I>>,
|
||||
{
|
||||
}
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> VisionBackend
|
||||
for NdArray<E, I, Q>
|
||||
where
|
||||
NdArrayTensor: From<SharedArray<E>>,
|
||||
NdArrayTensor: From<SharedArray<I>>,
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
mod tch {
|
||||
use crate::{BoolVisionOps, FloatVisionOps, IntVisionOps, QVisionOps, VisionBackend};
|
||||
use burn_tch::{LibTorch, TchElement};
|
||||
|
||||
impl<E: TchElement, Q: burn_tch::QuantElement> BoolVisionOps for LibTorch<E, Q> {}
|
||||
impl<E: TchElement, Q: burn_tch::QuantElement> IntVisionOps for LibTorch<E, Q> {}
|
||||
impl<E: TchElement, Q: burn_tch::QuantElement> FloatVisionOps for LibTorch<E, Q> {}
|
||||
impl<E: TchElement, Q: burn_tch::QuantElement> QVisionOps for LibTorch<E, Q> {}
|
||||
impl<E: TchElement, Q: burn_tch::QuantElement> VisionBackend for LibTorch<E, Q> {}
|
||||
}
|
||||
@@ -0,0 +1,624 @@
|
||||
//! Hardware Accelerated 4-connected, adapted from
|
||||
//! A. Hennequin, L. Lacassagne, L. Cabaret, Q. Meunier,
|
||||
//! "A new Direct Connected Component Labeling and Analysis Algorithms for GPUs",
|
||||
//! DASIP, 2018
|
||||
|
||||
use crate::{
|
||||
ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity,
|
||||
backends::cube::connected_components::stats_from_opts,
|
||||
};
|
||||
use burn_cubecl::{
|
||||
BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement, kernel,
|
||||
ops::{into_data_sync, numeric::zeros_client},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_tensor::{Shape, TensorMetadata, cast::ToElement, ops::IntTensorOps};
|
||||
use cubecl::{features::Plane, prelude::*};
|
||||
|
||||
use super::prefix_sum::prefix_sum;
|
||||
|
||||
const BLOCK_H: usize = 4;
|
||||
|
||||
#[cube]
|
||||
fn merge<I: Int>(labels: &Tensor<Atomic<I>>, label_1: u32, label_2: u32) {
|
||||
let mut label_1 = label_1 as usize;
|
||||
let mut label_2 = label_2 as usize;
|
||||
|
||||
while label_1 != label_2 && (label_1 != usize::cast_from(labels[label_1].load()) - 1) {
|
||||
label_1 = usize::cast_from(labels[label_1].load()) - 1;
|
||||
}
|
||||
while label_1 != label_2 && (label_2 != usize::cast_from(labels[label_2].load()) - 1) {
|
||||
label_2 = usize::cast_from(labels[label_2].load()) - 1;
|
||||
}
|
||||
while label_1 != label_2 {
|
||||
#[allow(clippy::manual_swap)]
|
||||
if label_1 < label_2 {
|
||||
let tmp = label_1;
|
||||
label_1 = label_2;
|
||||
label_2 = tmp;
|
||||
}
|
||||
let label_3 = usize::cast_from(labels[label_1].fetch_min(I::cast_from(label_2 + 1))) - 1;
|
||||
if label_1 == label_3 {
|
||||
label_1 = label_2;
|
||||
} else {
|
||||
label_1 = label_3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn start_distance(pixels: u32, tx: u32) -> u32 {
|
||||
(!(pixels << (32 - tx))).leading_zeros()
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn end_distance(pixels: u32, tx: u32) -> u32 {
|
||||
(!(pixels >> (tx + 1))).find_first_set()
|
||||
}
|
||||
|
||||
#[cube]
|
||||
#[allow(unconditional_panic, reason = "clippy thinks PLANE_DIM is always 2")]
|
||||
fn ballot_dyn(y: u32, pred: bool) -> u32 {
|
||||
let index = y % (PLANE_DIM / 32);
|
||||
plane_ballot(pred)[index as usize]
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked)]
|
||||
fn strip_labeling<I: Int, BT: CubePrimitive>(
|
||||
img: &Tensor<BT>,
|
||||
labels: &Tensor<Atomic<I>>,
|
||||
#[comptime] connectivity: Connectivity,
|
||||
) {
|
||||
let mut shared_pixels = SharedMemory::<u32>::new(BLOCK_H);
|
||||
|
||||
let y = ABSOLUTE_POS_Y;
|
||||
let rows = labels.shape(0) as u32;
|
||||
let cols = labels.shape(1) as u32;
|
||||
|
||||
if y >= rows {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let img_stride = img.stride(0) as u32;
|
||||
let labels_stride = labels.stride(0) as u32;
|
||||
|
||||
let img_line_base = y * img_stride + UNIT_POS_X;
|
||||
let labels_line_base = y * labels_stride + UNIT_POS_X;
|
||||
|
||||
let mut distance_y = 0u32;
|
||||
let mut distance_y_1 = 0;
|
||||
|
||||
for i in range_stepped(0, img.shape(1) as u32, PLANE_DIM) {
|
||||
let x = UNIT_POS_X + i;
|
||||
|
||||
if x < cols {
|
||||
let mut mask = 0xffffffffu32;
|
||||
let involved_cols = cols - i;
|
||||
if involved_cols < 32 {
|
||||
mask >>= 32 - involved_cols;
|
||||
}
|
||||
|
||||
let img_index = img_line_base + i;
|
||||
let labels_index = labels_line_base + i;
|
||||
|
||||
let p_y = bool::cast_from(img[img_index as usize]);
|
||||
|
||||
let pixels_y = ballot_dyn(UNIT_POS_Y, p_y) & mask;
|
||||
let mut s_dist_y = start_distance(pixels_y, UNIT_POS_X);
|
||||
|
||||
if p_y && s_dist_y == 0 {
|
||||
labels[labels_index as usize].store(I::cast_from(
|
||||
labels_index - select(UNIT_POS_X == 0, distance_y, 0) + 1,
|
||||
));
|
||||
}
|
||||
|
||||
// Only needed pre-Volta, but we can't check that at present
|
||||
sync_cube();
|
||||
|
||||
if UNIT_POS_X == 0 {
|
||||
shared_pixels[UNIT_POS_Y as usize] = pixels_y;
|
||||
}
|
||||
|
||||
sync_cube();
|
||||
|
||||
// Requires if and not select, because `select` may execute the then branch even if the
|
||||
// condition is false (on non-CUDA backends), which can lead to OOB reads.
|
||||
let pixels_y_1 = if UNIT_POS_Y > 0 {
|
||||
shared_pixels[(UNIT_POS_Y - 1) as usize]
|
||||
} else {
|
||||
0u32.runtime()
|
||||
};
|
||||
|
||||
let p_y_1 = (pixels_y_1 >> UNIT_POS_X) & 1 != 0;
|
||||
let mut s_dist_y_1 = start_distance(pixels_y_1, UNIT_POS_X);
|
||||
|
||||
if UNIT_POS_X == 0 {
|
||||
s_dist_y = distance_y;
|
||||
s_dist_y_1 = distance_y_1;
|
||||
}
|
||||
|
||||
match connectivity {
|
||||
Connectivity::Four => {
|
||||
if p_y && p_y_1 && (s_dist_y == 0 || s_dist_y_1 == 0) {
|
||||
let label_1 = labels_index - s_dist_y;
|
||||
let label_2 = labels_index - s_dist_y_1 - labels_stride;
|
||||
merge(labels, label_1, label_2);
|
||||
}
|
||||
}
|
||||
Connectivity::Eight => {
|
||||
let pixels_y_shifted = (pixels_y << 1) | (distance_y > 0) as u32;
|
||||
let pixels_y_1_shifted = (pixels_y_1 << 1) | (distance_y_1 > 0) as u32;
|
||||
|
||||
if p_y && p_y_1 && (s_dist_y == 0 || s_dist_y_1 == 0) {
|
||||
let label_1 = labels_index - s_dist_y;
|
||||
let label_2 = labels_index - s_dist_y_1 - labels_stride;
|
||||
merge(labels, label_1, label_2);
|
||||
} else if p_y && s_dist_y == 0 && (pixels_y_1_shifted >> UNIT_POS_X) & 1 != 0 {
|
||||
let s_dist_y_1_prev = select(
|
||||
UNIT_POS_X == 0,
|
||||
distance_y_1 - 1,
|
||||
start_distance(pixels_y_1, UNIT_POS_X - 1),
|
||||
);
|
||||
let label_1 = labels_index;
|
||||
let label_2 = labels_index - labels_stride - 1 - s_dist_y_1_prev;
|
||||
merge(labels, label_1, label_2);
|
||||
} else if p_y_1 && s_dist_y_1 == 0 && (pixels_y_shifted >> UNIT_POS_X) & 1 != 0
|
||||
{
|
||||
let s_dist_y_prev = select(
|
||||
UNIT_POS_X == 0,
|
||||
distance_y - 1,
|
||||
start_distance(pixels_y, UNIT_POS_X - 1),
|
||||
);
|
||||
let label_1 = labels_index - 1 - s_dist_y_prev;
|
||||
let label_2 = labels_index - labels_stride;
|
||||
merge(labels, label_1, label_2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if p_y && p_y_1 && (s_dist_y == 0 || s_dist_y_1 == 0) {
|
||||
let label_1 = labels_index - s_dist_y;
|
||||
let label_2 = labels_index - s_dist_y_1 - labels_stride;
|
||||
merge(labels, label_1, label_2);
|
||||
}
|
||||
|
||||
let mut d = start_distance(pixels_y_1, 32);
|
||||
distance_y_1 = d + select(d == 32, distance_y_1, 0);
|
||||
d = start_distance(pixels_y, 32);
|
||||
distance_y = d + select(d == 32, distance_y, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked)]
|
||||
fn strip_merge<I: Int, BT: CubePrimitive>(
|
||||
img: &Tensor<BT>,
|
||||
labels: &Tensor<Atomic<I>>,
|
||||
#[comptime] connectivity: Connectivity,
|
||||
) {
|
||||
let plane_start_x = CUBE_POS_X * (CUBE_DIM_X * CUBE_DIM_Z - PLANE_DIM) + UNIT_POS_Z * PLANE_DIM;
|
||||
let y = (CUBE_POS_Y + 1) * BLOCK_H as u32;
|
||||
let x = plane_start_x + UNIT_POS_X;
|
||||
|
||||
let img_step = img.stride(0) as u32;
|
||||
let labels_step = labels.stride(0) as u32;
|
||||
let cols = img.shape(1) as u32;
|
||||
|
||||
if y < labels.shape(0) as u32 && x < labels.shape(1) as u32 {
|
||||
let mut mask = 0xffffffffu32;
|
||||
if cols - plane_start_x < 32 {
|
||||
mask >>= 32 - (cols - plane_start_x);
|
||||
}
|
||||
|
||||
let img_index = y * img_step + x;
|
||||
let labels_index = y * labels_step + x;
|
||||
|
||||
let img_index_up = img_index - img_step;
|
||||
let labels_index_up = labels_index - labels_step;
|
||||
|
||||
let p = bool::cast_from(img[img_index as usize]);
|
||||
let p_up = bool::cast_from(img[img_index_up as usize]);
|
||||
|
||||
let pixels = ballot_dyn(UNIT_POS_Z, p) & mask;
|
||||
let pixels_up = ballot_dyn(UNIT_POS_Z, p_up) & mask;
|
||||
|
||||
match connectivity {
|
||||
Connectivity::Four => {
|
||||
if p && p_up {
|
||||
let s_dist = start_distance(pixels, UNIT_POS_X);
|
||||
let s_dist_up = start_distance(pixels_up, UNIT_POS_X);
|
||||
if s_dist == 0 || s_dist_up == 0 {
|
||||
merge(labels, labels_index - s_dist, labels_index_up - s_dist_up);
|
||||
}
|
||||
}
|
||||
}
|
||||
Connectivity::Eight => {
|
||||
let mut last_dist_vec = SharedMemory::<u32>::new(32usize);
|
||||
let mut last_dist_up_vec = SharedMemory::<u32>::new(32usize);
|
||||
|
||||
let s_dist = start_distance(pixels, UNIT_POS_X);
|
||||
let s_dist_up = start_distance(pixels_up, UNIT_POS_X);
|
||||
|
||||
if UNIT_POS_PLANE == PLANE_DIM - 1 {
|
||||
last_dist_vec[UNIT_POS_Z as usize] = start_distance(pixels, 32);
|
||||
last_dist_up_vec[UNIT_POS_Z as usize] = start_distance(pixels_up, 32);
|
||||
}
|
||||
|
||||
sync_cube();
|
||||
|
||||
if CUBE_POS_X == 0 || UNIT_POS_Z > 0 {
|
||||
let last_dist = if UNIT_POS_Z > 0 {
|
||||
last_dist_vec[(UNIT_POS_Z - 1) as usize]
|
||||
} else {
|
||||
0u32.runtime()
|
||||
};
|
||||
let last_dist_up = if UNIT_POS_Z > 0 {
|
||||
last_dist_up_vec[(UNIT_POS_Z - 1) as usize]
|
||||
} else {
|
||||
0u32.runtime()
|
||||
};
|
||||
|
||||
let p_prev =
|
||||
select(UNIT_POS_X > 0, (pixels >> (UNIT_POS_X - 1)) & 1, last_dist) != 0;
|
||||
let p_up_prev = select(
|
||||
UNIT_POS_X > 0,
|
||||
(pixels_up >> (UNIT_POS_X - 1)) & 1,
|
||||
last_dist_up,
|
||||
) != 0;
|
||||
|
||||
if p && p_up {
|
||||
let s_dist = start_distance(pixels, UNIT_POS_X);
|
||||
let s_dist_up = start_distance(pixels_up, UNIT_POS_X);
|
||||
if s_dist == 0 || s_dist_up == 0 {
|
||||
merge(labels, labels_index - s_dist, labels_index_up - s_dist_up);
|
||||
}
|
||||
} else if p && p_up_prev && s_dist == 0 {
|
||||
let s_dist_up_prev = select(
|
||||
UNIT_POS_X == 0,
|
||||
last_dist_up - 1,
|
||||
start_distance(pixels_up, UNIT_POS_X - 1),
|
||||
);
|
||||
merge(labels, labels_index, labels_index_up - 1 - s_dist_up_prev);
|
||||
} else if p_prev && p_up && s_dist_up == 0 {
|
||||
let s_dist_prev = select(
|
||||
UNIT_POS_X == 0,
|
||||
last_dist - 1,
|
||||
start_distance(pixels, UNIT_POS_X - 1),
|
||||
);
|
||||
merge(labels, labels_index - 1 - s_dist_prev, labels_index_up);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked)]
|
||||
fn relabeling<I: Int, BT: CubePrimitive>(img: &Tensor<BT>, labels: &mut Tensor<I>) {
|
||||
let plane_start_x = CUBE_POS_X * CUBE_DIM_X;
|
||||
let y = ABSOLUTE_POS_Y;
|
||||
let x = plane_start_x + UNIT_POS_X;
|
||||
|
||||
let cols = labels.shape(1) as u32;
|
||||
let rows = labels.shape(0) as u32;
|
||||
let img_step = img.stride(0) as u32;
|
||||
let labels_step = labels.stride(0) as u32;
|
||||
|
||||
if x < cols && y < rows {
|
||||
let mut mask = 0xffffffffu32;
|
||||
if cols - plane_start_x < 32 {
|
||||
mask >>= 32 - (cols - plane_start_x);
|
||||
}
|
||||
|
||||
let img_index = y * img_step + x;
|
||||
let labels_index = y * labels_step + x;
|
||||
|
||||
let p = bool::cast_from(img[img_index as usize]);
|
||||
let pixels = ballot_dyn(UNIT_POS_Y, p) & mask;
|
||||
let s_dist = start_distance(pixels, UNIT_POS_X);
|
||||
let mut label = 0u32;
|
||||
|
||||
if p && s_dist == 0 {
|
||||
label = u32::cast_from(labels[labels_index as usize]) - 1;
|
||||
while label != u32::cast_from(labels[label as usize]) - 1 {
|
||||
label = u32::cast_from(labels[label as usize]) - 1;
|
||||
}
|
||||
}
|
||||
|
||||
label = plane_shuffle(label, UNIT_POS_X - s_dist);
|
||||
|
||||
if p {
|
||||
labels[labels_index as usize] = I::cast_from(label + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked)]
|
||||
fn analysis<I: Int, BT: CubePrimitive>(
|
||||
img: &Tensor<BT>,
|
||||
labels: &mut Tensor<I>,
|
||||
area: &mut Tensor<Atomic<I>>,
|
||||
top: &mut Tensor<Atomic<I>>,
|
||||
left: &mut Tensor<Atomic<I>>,
|
||||
right: &mut Tensor<Atomic<I>>,
|
||||
bottom: &mut Tensor<Atomic<I>>,
|
||||
max_label: &mut Tensor<Atomic<I>>,
|
||||
#[comptime] opts: ConnectedStatsOptions,
|
||||
) {
|
||||
let y = ABSOLUTE_POS_Y;
|
||||
let x = ABSOLUTE_POS_X;
|
||||
|
||||
let cols = labels.shape(1) as u32;
|
||||
let rows = labels.shape(0) as u32;
|
||||
let img_step = img.stride(0) as u32;
|
||||
let labels_step = labels.stride(0) as u32;
|
||||
|
||||
if x < cols && y < rows {
|
||||
let mut mask = 0xffffffffu32;
|
||||
if cols - CUBE_POS_X * CUBE_DIM_X < 32 {
|
||||
mask >>= 32 - (cols - CUBE_POS_X * CUBE_DIM_X);
|
||||
}
|
||||
|
||||
let img_index = y * img_step + x;
|
||||
let labels_index = y * labels_step + x;
|
||||
|
||||
let p = bool::cast_from(img[img_index as usize]);
|
||||
let pixels = ballot_dyn(UNIT_POS_Y, p) & mask;
|
||||
let s_dist = start_distance(pixels, UNIT_POS_X);
|
||||
let count = end_distance(pixels, UNIT_POS_X);
|
||||
let max_x = x + count - 1;
|
||||
|
||||
let mut label = 0u32;
|
||||
|
||||
if p && s_dist == 0 {
|
||||
label = u32::cast_from(labels[labels_index as usize]) - 1;
|
||||
while label != u32::cast_from(labels[label as usize]) - 1 {
|
||||
label = u32::cast_from(labels[label as usize]) - 1;
|
||||
}
|
||||
label += 1;
|
||||
|
||||
area[label as usize].fetch_add(I::cast_from(count));
|
||||
|
||||
if opts.bounds_enabled {
|
||||
left[label as usize].fetch_min(I::cast_from(x));
|
||||
top[label as usize].fetch_min(I::cast_from(y));
|
||||
right[label as usize].fetch_max(I::cast_from(max_x));
|
||||
bottom[label as usize].fetch_max(I::cast_from(y));
|
||||
}
|
||||
if comptime!(opts.max_label_enabled || opts.compact_labels) {
|
||||
max_label[0].fetch_max(I::cast_from(label));
|
||||
}
|
||||
}
|
||||
|
||||
label = plane_shuffle(label, UNIT_POS_X - s_dist);
|
||||
|
||||
if p {
|
||||
labels[labels_index as usize] = I::cast_from(label);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked)]
|
||||
fn compact_labels<I: Int>(
|
||||
labels: &mut Tensor<I>,
|
||||
remap: &Tensor<I>,
|
||||
max_label: &Tensor<Atomic<I>>,
|
||||
) {
|
||||
let x = ABSOLUTE_POS_X;
|
||||
let y = ABSOLUTE_POS_Y;
|
||||
|
||||
let labels_pos = y * labels.stride(0) as u32 + x;
|
||||
|
||||
if labels_pos as usize >= labels.len() {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let label = u32::cast_from(labels[labels_pos as usize]);
|
||||
if label != 0 {
|
||||
let new_label = remap[label as usize];
|
||||
labels[labels_pos as usize] = new_label;
|
||||
max_label[0].fetch_max(new_label);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked)]
|
||||
fn compact_stats<I: Int>(
|
||||
area: &Tensor<I>,
|
||||
area_new: &mut Tensor<I>,
|
||||
top: &Tensor<I>,
|
||||
top_new: &mut Tensor<I>,
|
||||
left: &Tensor<I>,
|
||||
left_new: &mut Tensor<I>,
|
||||
right: &Tensor<I>,
|
||||
right_new: &mut Tensor<I>,
|
||||
bottom: &Tensor<I>,
|
||||
bottom_new: &mut Tensor<I>,
|
||||
remap: &Tensor<I>,
|
||||
) {
|
||||
let label = ABSOLUTE_POS_X;
|
||||
if label as usize >= remap.len() {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let area = area[label as usize];
|
||||
if area == I::new(0) {
|
||||
terminate!();
|
||||
}
|
||||
let new_label = u32::cast_from(remap[label as usize]);
|
||||
|
||||
area_new[new_label as usize] = area;
|
||||
// This should be gated but there's a problem with the Eq bound only being implemented for tuples
|
||||
// up to 12 elems, so I can't pass the opts. It's not unsafe, but potentially unnecessary work.
|
||||
top_new[new_label as usize] = top[label as usize];
|
||||
left_new[new_label as usize] = left[label as usize];
|
||||
right_new[new_label as usize] = right[label as usize];
|
||||
bottom_new[new_label as usize] = bottom[label as usize];
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub fn hardware_accelerated<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement>(
|
||||
img: CubeTensor<R>,
|
||||
stats_opt: ConnectedStatsOptions,
|
||||
connectivity: Connectivity,
|
||||
) -> Result<
|
||||
(
|
||||
CubeTensor<R>,
|
||||
ConnectedStatsPrimitive<CubeBackend<R, F, I, BT>>,
|
||||
),
|
||||
String,
|
||||
> {
|
||||
let client = img.client.clone();
|
||||
let device = img.device.clone();
|
||||
|
||||
if !client.properties().features.plane.contains(Plane::Ops) {
|
||||
return Err("Requires plane instructions".into());
|
||||
}
|
||||
|
||||
let props = &client.properties().hardware;
|
||||
|
||||
if props.plane_size_min == 32 && props.plane_size_max == 32 {
|
||||
return Err("Requires plane size of at least 32".into());
|
||||
}
|
||||
|
||||
// Somehow the kernel doesn't work on AMD and Apple Silicon.
|
||||
//
|
||||
// The check invalidates those, but probably not for the right reason.
|
||||
if props.plane_size_max != 32 {
|
||||
return Err("Requires plane size of at least 32".into());
|
||||
}
|
||||
|
||||
let [rows, cols] = img.meta.shape().dims();
|
||||
|
||||
let labels = zeros_client::<R>(client.clone(), device.clone(), img.shape(), I::dtype());
|
||||
|
||||
// Assume 32 wide warp. Currently, larger warps are handled by just exiting everything past 32.
|
||||
// This isn't ideal but we require CUBE_DIM_X == warp_size, and we can't query the actual warp
|
||||
// size at compile time. `REQUIRE_FULL_SUBGROUPS` or subgroup size controls are not supported
|
||||
// in wgpu.
|
||||
let warp_size = 32;
|
||||
let cube_dim = CubeDim::new_2d(warp_size, BLOCK_H as u32);
|
||||
let cube_count = CubeCount::new_2d(1, (rows as u32).div_ceil(cube_dim.y));
|
||||
|
||||
unsafe {
|
||||
strip_labeling::launch_unchecked::<I, BT, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
img.as_tensor_arg(1),
|
||||
labels.as_tensor_arg(1),
|
||||
connectivity,
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
};
|
||||
|
||||
let horizontal_warps = Ord::min((cols as u32).div_ceil(warp_size), 32);
|
||||
let cube_dim_merge = CubeDim::new_3d(warp_size, 1, horizontal_warps);
|
||||
let cube_count = CubeCount::new_2d(
|
||||
Ord::max((cols as u32 + warp_size * 30 - 1) / (warp_size * 31), 1),
|
||||
(rows as u32 - 1) / BLOCK_H as u32,
|
||||
);
|
||||
|
||||
unsafe {
|
||||
strip_merge::launch_unchecked::<I, BT, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim_merge,
|
||||
img.as_tensor_arg(1),
|
||||
labels.as_tensor_arg(1),
|
||||
connectivity,
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
};
|
||||
|
||||
let cube_count = CubeCount::new_2d(
|
||||
(cols as u32).div_ceil(cube_dim.x),
|
||||
(rows as u32).div_ceil(cube_dim.y),
|
||||
);
|
||||
|
||||
let mut stats = stats_from_opts(labels.clone(), stats_opt);
|
||||
|
||||
if stats_opt == ConnectedStatsOptions::none() {
|
||||
unsafe {
|
||||
relabeling::launch_unchecked::<I, BT, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
img.as_tensor_arg(1),
|
||||
labels.as_tensor_arg(1),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
};
|
||||
} else {
|
||||
unsafe {
|
||||
analysis::launch_unchecked::<I, BT, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
img.as_tensor_arg(1),
|
||||
labels.as_tensor_arg(1),
|
||||
stats.area.as_tensor_arg(1),
|
||||
stats.top.as_tensor_arg(1),
|
||||
stats.left.as_tensor_arg(1),
|
||||
stats.right.as_tensor_arg(1),
|
||||
stats.bottom.as_tensor_arg(1),
|
||||
stats.max_label.as_tensor_arg(1),
|
||||
stats_opt,
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
};
|
||||
if stats_opt.compact_labels {
|
||||
let max_label = CubeBackend::<R, F, I, BT>::int_max(stats.max_label);
|
||||
let max_label = into_data_sync::<R>(max_label);
|
||||
let max_label = ToElement::to_usize(&max_label.as_slice::<I>().unwrap()[0]);
|
||||
let sliced = kernel::slice::<R>(
|
||||
stats.area.clone(),
|
||||
#[allow(clippy::single_range_in_vec_init)]
|
||||
&[0..(max_label + 1).next_multiple_of(4)],
|
||||
);
|
||||
let relabel = prefix_sum::<R, I>(sliced);
|
||||
|
||||
let cube_dim = CubeDim::new_2d(32, 8);
|
||||
let cube_count = CubeCount::new_2d(
|
||||
(cols as u32).div_ceil(cube_dim.x),
|
||||
(rows as u32).div_ceil(cube_dim.y),
|
||||
);
|
||||
stats.max_label =
|
||||
zeros_client::<R>(client.clone(), device.clone(), Shape::new([1]), I::dtype());
|
||||
unsafe {
|
||||
compact_labels::launch_unchecked::<I, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
labels.as_tensor_arg(1),
|
||||
relabel.as_tensor_arg(1),
|
||||
stats.max_label.as_tensor_arg(1),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
};
|
||||
|
||||
let cube_dim = CubeDim::new_1d(256);
|
||||
let cube_count = CubeCount::new_1d((rows * cols).div_ceil(256) as u32);
|
||||
unsafe {
|
||||
compact_stats::launch_unchecked::<I, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
stats.area.copy().as_tensor_arg(1),
|
||||
stats.area.as_tensor_arg(1),
|
||||
stats.top.copy().as_tensor_arg(1),
|
||||
stats.top.as_tensor_arg(1),
|
||||
stats.left.copy().as_tensor_arg(1),
|
||||
stats.left.as_tensor_arg(1),
|
||||
stats.right.copy().as_tensor_arg(1),
|
||||
stats.right.as_tensor_arg(1),
|
||||
stats.bottom.copy().as_tensor_arg(1),
|
||||
stats.bottom.as_tensor_arg(1),
|
||||
relabel.as_tensor_arg(1),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
Ok((labels, stats))
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
mod hardware_accelerated;
|
||||
|
||||
/// Should eventually make this a full op, but the kernel is too specialized on ints and plane ops
|
||||
/// to really use it in a general case. Needs more work to use as a normal tensor method.
|
||||
mod prefix_sum;
|
||||
|
||||
use burn_cubecl::{
|
||||
BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement,
|
||||
ops::numeric::{full_client, zeros_client},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
pub use hardware_accelerated::*;
|
||||
|
||||
use crate::{ConnectedStatsOptions, ConnectedStatsPrimitive};
|
||||
|
||||
pub(crate) fn stats_from_opts<R, F, I, BT>(
|
||||
l: CubeTensor<R>,
|
||||
opts: ConnectedStatsOptions,
|
||||
) -> ConnectedStatsPrimitive<CubeBackend<R, F, I, BT>>
|
||||
where
|
||||
R: CubeRuntime,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
BT: BoolElement,
|
||||
{
|
||||
let [height, width] = l.meta.shape().dims();
|
||||
let shape = Shape::new([height * width]);
|
||||
let zeros = || {
|
||||
zeros_client::<R>(
|
||||
l.client.clone(),
|
||||
l.device.clone(),
|
||||
shape.clone(),
|
||||
I::dtype(),
|
||||
)
|
||||
};
|
||||
let max = I::max_value();
|
||||
let max = || full_client::<R, I>(l.client.clone(), shape.clone(), l.device.clone(), max);
|
||||
let dummy = || {
|
||||
CubeTensor::new_contiguous(
|
||||
l.client.clone(),
|
||||
l.device.clone(),
|
||||
shape.clone(),
|
||||
l.handle.clone(),
|
||||
l.dtype,
|
||||
)
|
||||
};
|
||||
ConnectedStatsPrimitive {
|
||||
area: (opts != ConnectedStatsOptions::none())
|
||||
.then(zeros)
|
||||
.unwrap_or_else(dummy),
|
||||
left: opts.bounds_enabled.then(max).unwrap_or_else(dummy),
|
||||
top: opts.bounds_enabled.then(max).unwrap_or_else(dummy),
|
||||
right: opts.bounds_enabled.then(zeros).unwrap_or_else(dummy),
|
||||
bottom: opts.bounds_enabled.then(zeros).unwrap_or_else(dummy),
|
||||
max_label: zeros_client::<R>(
|
||||
l.client.clone(),
|
||||
l.device.clone(),
|
||||
Shape::new([1]),
|
||||
I::dtype(),
|
||||
),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,262 @@
|
||||
use burn_tensor::{Shape, TensorMetadata};
|
||||
use cubecl::prelude::*;
|
||||
|
||||
use burn_cubecl::{
|
||||
CubeRuntime, IntElement,
|
||||
ops::{
|
||||
numeric::{empty_device, zeros_client},
|
||||
reshape,
|
||||
},
|
||||
tensor::CubeTensor,
|
||||
};
|
||||
|
||||
const CUBE_SIZE: usize = 256;
|
||||
const MIN_SUBGROUP_SIZE: usize = 4;
|
||||
const MAX_REDUCE_SIZE: usize = CUBE_SIZE / MIN_SUBGROUP_SIZE;
|
||||
|
||||
const PART_SIZE: usize = 4096;
|
||||
|
||||
#[cube(launch_unchecked)]
|
||||
fn prefix_sum_kernel<I: Int>(
|
||||
scan_in: &Tensor<Line<I>>,
|
||||
scan_out: &mut Tensor<Line<I>>,
|
||||
scan_bump: &Tensor<Atomic<I>>,
|
||||
reduction: &Tensor<Atomic<I>>,
|
||||
cube_count_x: usize,
|
||||
) {
|
||||
let mut broadcast = SharedMemory::<I>::new(1usize);
|
||||
let mut reduce = SharedMemory::<I>::new(MAX_REDUCE_SIZE);
|
||||
let batch = CUBE_POS_Z as usize;
|
||||
let line_spt = comptime!(PART_SIZE / CUBE_SIZE / scan_in.line_size());
|
||||
let nums_per_cube = CUBE_SIZE * line_spt;
|
||||
let v_last = comptime!(scan_in.line_size() - 1);
|
||||
|
||||
//acquire partition index
|
||||
if UNIT_POS_X == 0 {
|
||||
broadcast[0] = scan_bump[batch].fetch_add(I::new(1));
|
||||
}
|
||||
sync_cube();
|
||||
let part_id = usize::cast_from(broadcast[0]);
|
||||
|
||||
let plane_id = UNIT_POS_X / PLANE_DIM;
|
||||
let dev_offs = part_id * nums_per_cube;
|
||||
let plane_offs = UNIT_POS_X as usize * line_spt;
|
||||
|
||||
// Exit if full plane is out of bounds
|
||||
if dev_offs + plane_offs >= scan_in.shape(1) {
|
||||
terminate!();
|
||||
}
|
||||
|
||||
let zero = I::new(0);
|
||||
|
||||
let flag_reduction = I::new(1);
|
||||
let flag_inclusive = I::new(2);
|
||||
let flag_mask = I::new(3);
|
||||
|
||||
let red_offs = batch * reduction.stride(0);
|
||||
let scan_offs = batch * scan_in.stride(0);
|
||||
|
||||
let mut t_scan = Array::<Line<I>>::lined(line_spt, scan_in.line_size());
|
||||
{
|
||||
let mut i = dev_offs + plane_offs + UNIT_POS_PLANE as usize;
|
||||
|
||||
if part_id < cube_count_x - 1 {
|
||||
for k in 0..line_spt {
|
||||
// Manually fuse not_equal and cast
|
||||
let mut scan = Line::cast_from(scan_in[i + scan_offs].not_equal(Line::new(zero)));
|
||||
#[unroll]
|
||||
for v in 1..scan_in.line_size() {
|
||||
let prev = scan[v - 1];
|
||||
scan[v] += prev;
|
||||
}
|
||||
t_scan[k] = scan;
|
||||
i += PLANE_DIM as usize;
|
||||
}
|
||||
}
|
||||
|
||||
if part_id == cube_count_x - 1 {
|
||||
for k in 0..line_spt {
|
||||
if i < scan_in.shape(1) {
|
||||
// Manually fuse not_equal and cast
|
||||
let mut scan =
|
||||
Line::cast_from(scan_in[i + scan_offs].not_equal(Line::new(zero)));
|
||||
#[unroll]
|
||||
for v in 1..scan_in.line_size() {
|
||||
let prev = scan[v - 1];
|
||||
scan[v] += prev;
|
||||
}
|
||||
t_scan[k] = scan;
|
||||
}
|
||||
i += PLANE_DIM as usize;
|
||||
}
|
||||
}
|
||||
|
||||
let mut prev = zero;
|
||||
let plane_mask = PLANE_DIM - 1;
|
||||
let circular_shift = (UNIT_POS_PLANE + plane_mask) & plane_mask;
|
||||
for k in 0..line_spt {
|
||||
let t = plane_shuffle(plane_inclusive_sum(t_scan[k][v_last]), circular_shift);
|
||||
t_scan[k] += Line::cast_from(select(UNIT_POS_PLANE != 0, t, zero) + prev);
|
||||
prev += plane_broadcast(t, 0u32);
|
||||
}
|
||||
|
||||
if UNIT_POS_PLANE == 0 {
|
||||
reduce[plane_id as usize] = prev;
|
||||
}
|
||||
}
|
||||
sync_cube();
|
||||
|
||||
//Non-divergent subgroup agnostic inclusive scan across subgroup reductions
|
||||
let lane_log = count_trailing_zeros(PLANE_DIM);
|
||||
let spine_size = CUBE_DIM >> lane_log;
|
||||
{
|
||||
let mut offset_0 = 0;
|
||||
let mut offset_1 = 0;
|
||||
let aligned_size =
|
||||
1 << ((count_trailing_zeros(spine_size) + lane_log + 1) / lane_log * lane_log);
|
||||
let mut j = PLANE_DIM;
|
||||
while j <= aligned_size {
|
||||
let i_0 = ((UNIT_POS_X + offset_0) << offset_1) - offset_0;
|
||||
let pred_0 = i_0 < spine_size;
|
||||
let t_0 = plane_inclusive_sum(select(pred_0, reduce[i_0 as usize], zero));
|
||||
if pred_0 {
|
||||
reduce[i_0 as usize] = t_0;
|
||||
}
|
||||
sync_cube();
|
||||
|
||||
if j != PLANE_DIM {
|
||||
let rshift = j >> lane_log;
|
||||
let i_1 = UNIT_POS_X + rshift;
|
||||
if (i_1 & (j - 1)) >= rshift {
|
||||
let pred_1 = i_1 < spine_size;
|
||||
let t_1 = select(
|
||||
pred_1,
|
||||
reduce[(((i_1 >> offset_1) << offset_1) - 1) as usize],
|
||||
zero,
|
||||
);
|
||||
if pred_1 && ((i_1 + 1) & (rshift - 1)) != 0 {
|
||||
reduce[i_1 as usize] += t_1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
offset_0 += 1;
|
||||
}
|
||||
offset_1 += lane_log;
|
||||
|
||||
j <<= lane_log;
|
||||
}
|
||||
}
|
||||
sync_cube();
|
||||
|
||||
//Device broadcast
|
||||
if UNIT_POS_X == 0 {
|
||||
reduction[part_id + red_offs].store(
|
||||
(reduce[(spine_size - 1) as usize] << I::new(2))
|
||||
| select(part_id != 0, flag_reduction, flag_inclusive),
|
||||
)
|
||||
}
|
||||
|
||||
//Lookback, single thread
|
||||
if part_id != 0 {
|
||||
if UNIT_POS_X == 0 {
|
||||
let mut lookback_id = part_id - 1;
|
||||
let mut prev_reduction = zero;
|
||||
loop {
|
||||
let flag_payload = reduction[lookback_id + red_offs].load();
|
||||
if (flag_payload & flag_mask) == flag_inclusive {
|
||||
prev_reduction += flag_payload >> I::new(2);
|
||||
reduction[part_id + red_offs].store(
|
||||
((prev_reduction + reduce[(spine_size - 1) as usize]) << I::new(2))
|
||||
| flag_inclusive,
|
||||
);
|
||||
broadcast[0] = prev_reduction;
|
||||
break;
|
||||
}
|
||||
|
||||
if (flag_payload & flag_mask) == flag_reduction {
|
||||
prev_reduction += flag_payload >> I::new(2);
|
||||
lookback_id -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
sync_cube();
|
||||
}
|
||||
|
||||
{
|
||||
let prev = if plane_id != 0 {
|
||||
reduce[(plane_id - 1) as usize]
|
||||
} else {
|
||||
zero
|
||||
};
|
||||
let prev = Line::cast_from(broadcast[0] + prev);
|
||||
let s_offset = UNIT_POS_PLANE + plane_id * PLANE_DIM * line_spt as u32;
|
||||
let dev_offset = part_id * nums_per_cube;
|
||||
let mut i = s_offset as usize + dev_offset;
|
||||
|
||||
if part_id < cube_count_x - 1 {
|
||||
for k in 0..line_spt {
|
||||
scan_out[i + scan_offs] = t_scan[k] + prev;
|
||||
i += PLANE_DIM as usize;
|
||||
}
|
||||
}
|
||||
|
||||
if part_id == cube_count_x - 1 {
|
||||
for k in 0..line_spt {
|
||||
if i < scan_out.shape(1) {
|
||||
scan_out[i + scan_offs] = t_scan[k] + prev;
|
||||
}
|
||||
i += PLANE_DIM as usize;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn count_trailing_zeros(num: u32) -> u32 {
|
||||
u32::find_first_set(num) - 1
|
||||
}
|
||||
|
||||
/// Compute the prefix sum of a tensor
|
||||
pub fn prefix_sum<R: CubeRuntime, I: IntElement>(input: CubeTensor<R>) -> CubeTensor<R> {
|
||||
let client = input.client.clone();
|
||||
let device = input.device.clone();
|
||||
let num_elems = input.meta.num_elements();
|
||||
let numbers = *input.meta.shape().last().unwrap();
|
||||
let batches = num_elems / numbers;
|
||||
|
||||
let input = reshape(input, Shape::new([batches, numbers]));
|
||||
let out = empty_device::<R, I>(client.clone(), device.clone(), input.shape());
|
||||
|
||||
let cubes = numbers.div_ceil(PART_SIZE);
|
||||
let cube_dim = CubeDim::new_1d(CUBE_SIZE as u32);
|
||||
let cube_count = CubeCount::new_3d(cubes as u32, 1, batches as u32);
|
||||
|
||||
let bump = zeros_client::<R>(
|
||||
client.clone(),
|
||||
device.clone(),
|
||||
Shape::new([batches]),
|
||||
I::dtype(),
|
||||
);
|
||||
let reduction = zeros_client::<R>(
|
||||
client.clone(),
|
||||
device.clone(),
|
||||
Shape::new([batches, cubes]),
|
||||
I::dtype(),
|
||||
);
|
||||
|
||||
unsafe {
|
||||
prefix_sum_kernel::launch_unchecked::<I, R>(
|
||||
&input.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
input.as_tensor_arg(4),
|
||||
out.as_tensor_arg(4),
|
||||
bump.as_tensor_arg(1),
|
||||
reduction.as_tensor_arg(1),
|
||||
ScalarArg::new(cubes),
|
||||
)
|
||||
.expect("Kernel to never fail");
|
||||
};
|
||||
|
||||
out
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
mod connected_components;
|
||||
mod ops;
|
||||
@@ -0,0 +1,211 @@
|
||||
use crate::{
|
||||
BoolVisionOps, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, FloatVisionOps,
|
||||
IntVisionOps, QVisionOps, VisionBackend, backends::cpu,
|
||||
};
|
||||
use burn_cubecl::{BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement};
|
||||
|
||||
use burn_tensor::{
|
||||
Element,
|
||||
ops::{BoolTensor, IntTensor},
|
||||
};
|
||||
|
||||
use super::connected_components::hardware_accelerated;
|
||||
|
||||
impl<R, F, I, BT> BoolVisionOps for CubeBackend<R, F, I, BT>
|
||||
where
|
||||
R: CubeRuntime,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
BT: BoolElement,
|
||||
{
|
||||
fn connected_components(img: BoolTensor<Self>, connectivity: Connectivity) -> IntTensor<Self> {
|
||||
hardware_accelerated::<R, F, I, BT>(
|
||||
img.clone(),
|
||||
ConnectedStatsOptions::none(),
|
||||
connectivity,
|
||||
)
|
||||
.map(|it| it.0)
|
||||
.unwrap_or_else(|_| cpu::connected_components::<Self>(img, connectivity))
|
||||
}
|
||||
|
||||
fn connected_components_with_stats(
|
||||
img: BoolTensor<Self>,
|
||||
connectivity: Connectivity,
|
||||
opts: ConnectedStatsOptions,
|
||||
) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
|
||||
hardware_accelerated::<R, F, I, BT>(img.clone(), opts, connectivity).unwrap_or_else(|_| {
|
||||
cpu::connected_components_with_stats::<Self>(img, connectivity, opts)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, F, I, BT> IntVisionOps for CubeBackend<R, F, I, BT>
|
||||
where
|
||||
R: CubeRuntime,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
BT: BoolElement,
|
||||
{
|
||||
}
|
||||
impl<R, F, I, BT> FloatVisionOps for CubeBackend<R, F, I, BT>
|
||||
where
|
||||
R: CubeRuntime,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
BT: BoolElement,
|
||||
{
|
||||
}
|
||||
impl<R, F, I, BT> QVisionOps for CubeBackend<R, F, I, BT>
|
||||
where
|
||||
R: CubeRuntime,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
BT: BoolElement,
|
||||
{
|
||||
}
|
||||
impl<R, F, I, BT> VisionBackend for CubeBackend<R, F, I, BT>
|
||||
where
|
||||
R: CubeRuntime,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
BT: BoolElement,
|
||||
{
|
||||
}
|
||||
|
||||
#[cfg(feature = "fusion")]
|
||||
mod fusion {
|
||||
use super::*;
|
||||
use burn_fusion::{
|
||||
Fusion, FusionBackend, FusionRuntime,
|
||||
stream::{Operation, OperationStreams},
|
||||
};
|
||||
use burn_ir::{CustomOpIr, HandleContainer, OperationIr, OperationOutput, TensorIr};
|
||||
use burn_tensor::Shape;
|
||||
|
||||
impl<B: FusionBackend + BoolVisionOps> BoolVisionOps for Fusion<B> {
|
||||
fn connected_components(img: BoolTensor<Self>, conn: Connectivity) -> IntTensor<Self> {
|
||||
let height = img.shape[0];
|
||||
let width = img.shape[1];
|
||||
let client = img.client.clone();
|
||||
|
||||
#[derive(derive_new::new, Clone, Debug)]
|
||||
struct ConnComp<B> {
|
||||
desc: CustomOpIr,
|
||||
conn: Connectivity,
|
||||
_b: core::marker::PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B1: FusionBackend + BoolVisionOps> Operation<B1::FusionRuntime> for ConnComp<B1> {
|
||||
fn execute(
|
||||
&self,
|
||||
handles: &mut HandleContainer<
|
||||
<B1::FusionRuntime as FusionRuntime>::FusionHandle,
|
||||
>,
|
||||
) {
|
||||
let ([img], [labels]) = self.desc.as_fixed();
|
||||
let input = handles.get_bool_tensor::<B1>(img);
|
||||
let output = B1::connected_components(input, self.conn);
|
||||
|
||||
handles.register_int_tensor::<B1>(&labels.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let streams = OperationStreams::with_inputs([&img]);
|
||||
let out = TensorIr::uninit(
|
||||
client.create_empty_handle(),
|
||||
Shape::new([height, width]),
|
||||
B::IntElem::dtype(),
|
||||
);
|
||||
|
||||
let desc = CustomOpIr::new("connected_components", &[img.into_ir()], &[out]);
|
||||
client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::Custom(desc.clone()),
|
||||
ConnComp::<B>::new(desc, conn),
|
||||
)
|
||||
.output()
|
||||
}
|
||||
|
||||
fn connected_components_with_stats(
|
||||
img: BoolTensor<Self>,
|
||||
conn: Connectivity,
|
||||
opts: ConnectedStatsOptions,
|
||||
) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
|
||||
let height = img.shape[0];
|
||||
let width = img.shape[1];
|
||||
let client = img.client.clone();
|
||||
|
||||
#[derive(derive_new::new, Clone, Debug)]
|
||||
struct ConnCompStats<B> {
|
||||
desc: CustomOpIr,
|
||||
conn: Connectivity,
|
||||
opts: ConnectedStatsOptions,
|
||||
_b: core::marker::PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B1: FusionBackend + BoolVisionOps> Operation<B1::FusionRuntime> for ConnCompStats<B1> {
|
||||
fn execute(
|
||||
&self,
|
||||
handles: &mut HandleContainer<
|
||||
<B1::FusionRuntime as FusionRuntime>::FusionHandle,
|
||||
>,
|
||||
) {
|
||||
let ([img], [labels, area, left, top, right, bottom, max_label]) =
|
||||
self.desc.as_fixed();
|
||||
let input = handles.get_bool_tensor::<B1>(img);
|
||||
let (output, stats) =
|
||||
B1::connected_components_with_stats(input, self.conn, self.opts);
|
||||
|
||||
handles.register_int_tensor::<B1>(&labels.id, output);
|
||||
handles.register_int_tensor::<B1>(&area.id, stats.area);
|
||||
handles.register_int_tensor::<B1>(&left.id, stats.left);
|
||||
handles.register_int_tensor::<B1>(&top.id, stats.top);
|
||||
handles.register_int_tensor::<B1>(&right.id, stats.right);
|
||||
handles.register_int_tensor::<B1>(&bottom.id, stats.bottom);
|
||||
handles.register_int_tensor::<B1>(&max_label.id, stats.max_label);
|
||||
}
|
||||
}
|
||||
|
||||
let dtype = B::IntElem::dtype();
|
||||
let shape = Shape::new([height, width]);
|
||||
let shape_flat = shape.clone().flatten();
|
||||
let streams = OperationStreams::with_inputs([&img]);
|
||||
let out = TensorIr::uninit(client.create_empty_handle(), shape.clone(), dtype);
|
||||
let area = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);
|
||||
let left = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);
|
||||
let top = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);
|
||||
let right = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);
|
||||
let bottom = TensorIr::uninit(client.create_empty_handle(), shape_flat, dtype);
|
||||
let max_label = TensorIr::uninit(client.create_empty_handle(), [1].into(), dtype);
|
||||
|
||||
let desc = CustomOpIr::new(
|
||||
"connected_components",
|
||||
&[img.into_ir()],
|
||||
&[out, area, left, top, right, bottom, max_label],
|
||||
);
|
||||
let [out, area, left, top, right, bottom, max_label] = client
|
||||
.register(
|
||||
streams,
|
||||
OperationIr::Custom(desc.clone()),
|
||||
ConnCompStats::<B>::new(desc, conn, opts),
|
||||
)
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
let stats = ConnectedStatsPrimitive {
|
||||
area,
|
||||
left,
|
||||
top,
|
||||
right,
|
||||
bottom,
|
||||
max_label,
|
||||
};
|
||||
(out, stats)
|
||||
}
|
||||
}
|
||||
impl<B: FusionBackend + IntVisionOps> IntVisionOps for Fusion<B> {}
|
||||
impl<B: FusionBackend + FloatVisionOps> FloatVisionOps for Fusion<B> {}
|
||||
impl<B: FusionBackend + QVisionOps> QVisionOps for Fusion<B> {}
|
||||
impl<B: FusionBackend + VisionBackend> VisionBackend for Fusion<B> {}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
pub(crate) mod cpu;
|
||||
#[cfg(feature = "cubecl-backend")]
|
||||
mod cube;
|
||||
|
||||
pub use cpu::{KernelShape, create_structuring_element};
|
||||
@@ -0,0 +1,19 @@
|
||||
use derive_new::new;
|
||||
|
||||
/// 2D size used for vision ops.
|
||||
#[derive(new, Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct Size {
|
||||
/// Width of the element
|
||||
pub width: usize,
|
||||
/// Height of the element
|
||||
pub height: usize,
|
||||
}
|
||||
|
||||
/// 2D Point used for vision ops. Coordinates start at the top left.
|
||||
#[derive(new, Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct Point {
|
||||
/// X (horizontal) coordinate
|
||||
pub x: usize,
|
||||
/// Y (vertical) coordinate
|
||||
pub y: usize,
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
//! Vision ops for burn, with GPU acceleration where possible.
|
||||
//!
|
||||
//! # Operations
|
||||
//! Operation names are based on `opencv` wherever applicable.
|
||||
//!
|
||||
//! Currently implemented are:
|
||||
//! - `connected_components`
|
||||
//! - `connected_components_with_stats`
|
||||
//! - `nms` (Non-Maximum Suppression)
|
||||
//!
|
||||
|
||||
#![warn(missing_docs)]
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
/// Backend implementations for JIT and CPU
|
||||
pub mod backends;
|
||||
mod base;
|
||||
mod ops;
|
||||
mod tensor;
|
||||
mod transform;
|
||||
|
||||
pub use base::*;
|
||||
pub use ops::*;
|
||||
pub use tensor::*;
|
||||
pub use transform::*;
|
||||
|
||||
/// Module for vision/image utilities
|
||||
pub mod utils;
|
||||
|
||||
pub use backends::{KernelShape, create_structuring_element};
|
||||
@@ -0,0 +1,340 @@
|
||||
use crate::{
|
||||
Point,
|
||||
backends::cpu::{self, MorphOp, morph},
|
||||
};
|
||||
use bon::Builder;
|
||||
use burn_tensor::{
|
||||
Bool, Float, Int, Tensor, TensorKind, TensorPrimitive,
|
||||
backend::Backend,
|
||||
ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
|
||||
};
|
||||
|
||||
/// Connected components connectivity
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum Connectivity {
|
||||
/// Four-connected (only connected in cardinal directions)
|
||||
Four,
|
||||
/// Eight-connected (connected if any of the surrounding 8 pixels are in the foreground)
|
||||
Eight,
|
||||
}
|
||||
|
||||
/// Which stats should be enabled for `connected_components_with_stats`.
|
||||
/// Currently only used by the GPU implementation to save on atomic operations for unneeded stats.
|
||||
///
|
||||
/// Disabled stats are aliased to the labels tensor
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct ConnectedStatsOptions {
|
||||
/// Whether to enable bounding boxes
|
||||
pub bounds_enabled: bool,
|
||||
/// Whether to enable the max label
|
||||
pub max_label_enabled: bool,
|
||||
/// Whether labels must be contiguous starting at 1
|
||||
pub compact_labels: bool,
|
||||
}
|
||||
|
||||
/// Options for morphology ops
|
||||
#[derive(Clone, Debug, Builder)]
|
||||
pub struct MorphOptions<B: Backend, K: TensorKind<B>> {
|
||||
/// Anchor position within the kernel. Defaults to the center.
|
||||
pub anchor: Option<Point>,
|
||||
/// Number of iterations to apply
|
||||
#[builder(default = 1)]
|
||||
pub iterations: usize,
|
||||
/// Border type. Default: constant based on operation
|
||||
#[builder(default)]
|
||||
pub border_type: BorderType,
|
||||
/// Value of each channel for constant border type
|
||||
pub border_value: Option<Tensor<B, 1, K>>,
|
||||
}
|
||||
|
||||
impl<B: Backend, K: TensorKind<B>> Default for MorphOptions<B, K> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
anchor: Default::default(),
|
||||
iterations: 1,
|
||||
border_type: Default::default(),
|
||||
border_value: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Morphology border type
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
|
||||
pub enum BorderType {
|
||||
/// Constant border with per-channel value. If no value is provided, the value is picked based
|
||||
/// on the morph op.
|
||||
#[default]
|
||||
Constant,
|
||||
/// Replicate first/last element
|
||||
Replicate,
|
||||
/// Reflect start/end elements
|
||||
Reflect,
|
||||
/// Reflect start/end elements, ignoring the first/last element
|
||||
Reflect101,
|
||||
/// Not supported for erode/dilate
|
||||
Wrap,
|
||||
}
|
||||
|
||||
/// Stats collected by the connected components analysis
|
||||
///
|
||||
/// Disabled analyses may be aliased to labels
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ConnectedStats<B: Backend> {
|
||||
/// Total area of each component
|
||||
pub area: Tensor<B, 1, Int>,
|
||||
/// Topmost y coordinate in the component
|
||||
pub top: Tensor<B, 1, Int>,
|
||||
/// Leftmost x coordinate in the component
|
||||
pub left: Tensor<B, 1, Int>,
|
||||
/// Rightmost x coordinate in the component
|
||||
pub right: Tensor<B, 1, Int>,
|
||||
/// Bottommost y coordinate in the component
|
||||
pub bottom: Tensor<B, 1, Int>,
|
||||
/// Scalar tensor of the max label
|
||||
pub max_label: Tensor<B, 1, Int>,
|
||||
}
|
||||
|
||||
/// Primitive version of [`ConnectedStats`], to be returned by the backend
|
||||
pub struct ConnectedStatsPrimitive<B: Backend> {
|
||||
/// Total area of each component
|
||||
pub area: IntTensor<B>,
|
||||
/// Leftmost x coordinate in the component
|
||||
pub left: IntTensor<B>,
|
||||
/// Topmost y coordinate in the component
|
||||
pub top: IntTensor<B>,
|
||||
/// Rightmost x coordinate in the component
|
||||
pub right: IntTensor<B>,
|
||||
/// Bottommost y coordinate in the component
|
||||
pub bottom: IntTensor<B>,
|
||||
/// Scalar tensor of the max label
|
||||
pub max_label: IntTensor<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> From<ConnectedStatsPrimitive<B>> for ConnectedStats<B> {
|
||||
fn from(value: ConnectedStatsPrimitive<B>) -> Self {
|
||||
ConnectedStats {
|
||||
area: Tensor::from_primitive(value.area),
|
||||
top: Tensor::from_primitive(value.top),
|
||||
left: Tensor::from_primitive(value.left),
|
||||
right: Tensor::from_primitive(value.right),
|
||||
bottom: Tensor::from_primitive(value.bottom),
|
||||
max_label: Tensor::from_primitive(value.max_label),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ConnectedStats<B> {
|
||||
/// Convert a connected stats into the corresponding primitive
|
||||
pub fn into_primitive(self) -> ConnectedStatsPrimitive<B> {
|
||||
ConnectedStatsPrimitive {
|
||||
area: self.area.into_primitive(),
|
||||
top: self.top.into_primitive(),
|
||||
left: self.left.into_primitive(),
|
||||
right: self.right.into_primitive(),
|
||||
bottom: self.bottom.into_primitive(),
|
||||
max_label: self.max_label.into_primitive(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ConnectedStatsOptions {
|
||||
fn default() -> Self {
|
||||
Self::all()
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectedStatsOptions {
|
||||
/// Don't collect any stats
|
||||
pub fn none() -> Self {
|
||||
Self {
|
||||
bounds_enabled: false,
|
||||
max_label_enabled: false,
|
||||
compact_labels: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Collect all stats
|
||||
pub fn all() -> Self {
|
||||
Self {
|
||||
bounds_enabled: true,
|
||||
max_label_enabled: true,
|
||||
compact_labels: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Non-Maximum Suppression options.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct NmsOptions {
|
||||
/// IoU threshold for suppression (default: 0.5).
|
||||
/// Boxes with IoU > threshold with a higher-scoring box are suppressed.
|
||||
pub iou_threshold: f32,
|
||||
/// Score threshold to filter boxes before NMS (default: 0.0, i.e., no filtering).
|
||||
/// Boxes with score < score_threshold are discarded.
|
||||
pub score_threshold: f32,
|
||||
/// Maximum number of boxes to keep (0 = unlimited).
|
||||
pub max_output_boxes: usize,
|
||||
}
|
||||
|
||||
impl Default for NmsOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
iou_threshold: 0.5,
|
||||
score_threshold: 0.0,
|
||||
max_output_boxes: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Vision capable backend, implemented by each backend
|
||||
pub trait VisionBackend:
|
||||
BoolVisionOps + IntVisionOps + FloatVisionOps + QVisionOps + Backend
|
||||
{
|
||||
}
|
||||
|
||||
/// Vision ops on bool tensors
|
||||
pub trait BoolVisionOps: Backend {
|
||||
/// Computes the connected components labeled image of boolean image with 4 or 8 way
|
||||
/// connectivity - returns a tensor of the component label of each pixel.
|
||||
///
|
||||
/// `img`- The boolean image tensor in the format [batches, height, width]
|
||||
fn connected_components(img: BoolTensor<Self>, connectivity: Connectivity) -> IntTensor<Self> {
|
||||
cpu::connected_components::<Self>(img, connectivity)
|
||||
}
|
||||
|
||||
/// Computes the connected components labeled image of boolean image with 4 or 8 way
|
||||
/// connectivity and collects statistics on each component - returns a tensor of the component
|
||||
/// label of each pixel, along with stats collected for each component.
|
||||
///
|
||||
/// `img`- The boolean image tensor in the format [batches, height, width]
|
||||
fn connected_components_with_stats(
|
||||
img: BoolTensor<Self>,
|
||||
connectivity: Connectivity,
|
||||
opts: ConnectedStatsOptions,
|
||||
) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
|
||||
cpu::connected_components_with_stats(img, connectivity, opts)
|
||||
}
|
||||
|
||||
/// Erodes an input tensor with the specified kernel.
|
||||
fn bool_erode(
|
||||
input: BoolTensor<Self>,
|
||||
kernel: BoolTensor<Self>,
|
||||
opts: MorphOptions<Self, Bool>,
|
||||
) -> BoolTensor<Self> {
|
||||
let input = Tensor::<Self, 3, Bool>::from_primitive(input);
|
||||
morph(input, kernel, MorphOp::Erode, opts).into_primitive()
|
||||
}
|
||||
|
||||
/// Dilates an input tensor with the specified kernel.
|
||||
fn bool_dilate(
|
||||
input: BoolTensor<Self>,
|
||||
kernel: BoolTensor<Self>,
|
||||
opts: MorphOptions<Self, Bool>,
|
||||
) -> BoolTensor<Self> {
|
||||
let input = Tensor::<Self, 3, Bool>::from_primitive(input);
|
||||
morph(input, kernel, MorphOp::Dilate, opts).into_primitive()
|
||||
}
|
||||
}
|
||||
|
||||
/// Vision ops on int tensors
|
||||
pub trait IntVisionOps: Backend {
|
||||
/// Erodes an input tensor with the specified kernel.
|
||||
fn int_erode(
|
||||
input: IntTensor<Self>,
|
||||
kernel: BoolTensor<Self>,
|
||||
opts: MorphOptions<Self, Int>,
|
||||
) -> IntTensor<Self> {
|
||||
let input = Tensor::<Self, 3, Int>::from_primitive(input);
|
||||
morph(input, kernel, MorphOp::Erode, opts).into_primitive()
|
||||
}
|
||||
|
||||
/// Dilates an input tensor with the specified kernel.
|
||||
fn int_dilate(
|
||||
input: IntTensor<Self>,
|
||||
kernel: BoolTensor<Self>,
|
||||
opts: MorphOptions<Self, Int>,
|
||||
) -> IntTensor<Self> {
|
||||
let input = Tensor::<Self, 3, Int>::from_primitive(input);
|
||||
morph(input, kernel, MorphOp::Dilate, opts).into_primitive()
|
||||
}
|
||||
}
|
||||
|
||||
/// Vision ops on float tensors
|
||||
pub trait FloatVisionOps: Backend {
|
||||
/// Erodes an input tensor with the specified kernel.
|
||||
fn float_erode(
|
||||
input: FloatTensor<Self>,
|
||||
kernel: BoolTensor<Self>,
|
||||
opts: MorphOptions<Self, Float>,
|
||||
) -> FloatTensor<Self> {
|
||||
let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::Float(input));
|
||||
|
||||
morph(input, kernel, MorphOp::Erode, opts)
|
||||
.into_primitive()
|
||||
.tensor()
|
||||
}
|
||||
|
||||
/// Dilates an input tensor with the specified kernel.
|
||||
fn float_dilate(
|
||||
input: FloatTensor<Self>,
|
||||
kernel: BoolTensor<Self>,
|
||||
opts: MorphOptions<Self, Float>,
|
||||
) -> FloatTensor<Self> {
|
||||
let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::Float(input));
|
||||
morph(input, kernel, MorphOp::Dilate, opts)
|
||||
.into_primitive()
|
||||
.tensor()
|
||||
}
|
||||
|
||||
/// Perform Non-Maximum Suppression on bounding boxes.
|
||||
///
|
||||
/// Returns indices of kept boxes after suppressing overlapping detections.
|
||||
/// Boxes are processed in descending score order; a box suppresses all
|
||||
/// lower-scoring boxes with IoU > threshold.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `boxes` - Bounding boxes as \[N, 4\] tensor in (x1, y1, x2, y2) format
|
||||
/// * `scores` - Confidence scores as \[N\] tensor
|
||||
/// * `options` - NMS options (IoU threshold, score threshold, max boxes)
|
||||
///
|
||||
/// # Returns
|
||||
/// Indices of kept boxes as \[M\] tensor where M <= N
|
||||
fn nms(
|
||||
boxes: FloatTensor<Self>,
|
||||
scores: FloatTensor<Self>,
|
||||
options: NmsOptions,
|
||||
) -> IntTensor<Self> {
|
||||
let boxes = Tensor::<Self, 2>::from_primitive(TensorPrimitive::Float(boxes));
|
||||
let scores = Tensor::<Self, 1>::from_primitive(TensorPrimitive::Float(scores));
|
||||
cpu::nms::<Self>(boxes, scores, options).into_primitive()
|
||||
}
|
||||
}
|
||||
|
||||
/// Vision ops on quantized float tensors
|
||||
pub trait QVisionOps: Backend {
|
||||
/// Erodes an input tensor with the specified kernel.
|
||||
fn q_erode(
|
||||
input: QuantizedTensor<Self>,
|
||||
kernel: BoolTensor<Self>,
|
||||
opts: MorphOptions<Self, Float>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::QFloat(input));
|
||||
match morph(input, kernel, MorphOp::Erode, opts).into_primitive() {
|
||||
TensorPrimitive::QFloat(tensor) => tensor,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Dilates an input tensor with the specified kernel.
|
||||
fn q_dilate(
|
||||
input: QuantizedTensor<Self>,
|
||||
kernel: BoolTensor<Self>,
|
||||
opts: MorphOptions<Self, Float>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::QFloat(input));
|
||||
match morph(input, kernel, MorphOp::Dilate, opts).into_primitive() {
|
||||
TensorPrimitive::QFloat(tensor) => tensor,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
mod base;
|
||||
|
||||
pub use base::*;
|
||||
@@ -0,0 +1,186 @@
|
||||
use burn_tensor::{
|
||||
BasicOps, Bool, Float, Int, Tensor, TensorKind, TensorPrimitive, backend::Backend,
|
||||
ops::BoolTensor,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
BoolVisionOps, ConnectedStats, ConnectedStatsOptions, Connectivity, MorphOptions, NmsOptions,
|
||||
VisionBackend,
|
||||
};
|
||||
|
||||
/// Connected components tensor extensions
|
||||
pub trait ConnectedComponents<B: Backend> {
|
||||
/// Computes the connected components labeled image of boolean image with 4 or 8 way
|
||||
/// connectivity - returns a tensor of the component label of each pixel.
|
||||
///
|
||||
/// `img`- The boolean image tensor in the format [batches, height, width]
|
||||
fn connected_components(self, connectivity: Connectivity) -> Tensor<B, 2, Int>;
|
||||
|
||||
/// Computes the connected components labeled image of boolean image with 4 or 8 way
|
||||
/// connectivity and collects statistics on each component - returns a tensor of the component
|
||||
/// label of each pixel, along with stats collected for each component.
|
||||
///
|
||||
/// `img`- The boolean image tensor in the format [batches, height, width]
|
||||
fn connected_components_with_stats(
|
||||
self,
|
||||
connectivity: Connectivity,
|
||||
options: ConnectedStatsOptions,
|
||||
) -> (Tensor<B, 2, Int>, ConnectedStats<B>);
|
||||
}
|
||||
|
||||
/// Morphology tensor operations
|
||||
pub trait Morphology<B: Backend, K: TensorKind<B>> {
|
||||
/// Erodes this tensor using the specified kernel.
|
||||
/// Assumes NHWC layout.
|
||||
fn erode(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self;
|
||||
/// Dilates this tensor using the specified kernel.
|
||||
/// Assumes NHWC layout.
|
||||
fn dilate(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self;
|
||||
}
|
||||
|
||||
/// Morphology tensor operations
|
||||
pub trait MorphologyKind<B: Backend>: BasicOps<B> {
|
||||
/// Erodes this tensor using the specified kernel
|
||||
fn erode(
|
||||
tensor: Self::Primitive,
|
||||
kernel: BoolTensor<B>,
|
||||
opts: MorphOptions<B, Self>,
|
||||
) -> Self::Primitive;
|
||||
/// Dilates this tensor using the specified kernel
|
||||
fn dilate(
|
||||
tensor: Self::Primitive,
|
||||
kernel: BoolTensor<B>,
|
||||
opts: MorphOptions<B, Self>,
|
||||
) -> Self::Primitive;
|
||||
}
|
||||
|
||||
/// Non-maximum suppression tensor operations
|
||||
pub trait Nms<B: Backend> {
|
||||
/// Perform Non-Maximum Suppression on this tensor of bounding boxes.
|
||||
///
|
||||
/// Returns indices of kept boxes after suppressing overlapping detections.
|
||||
/// Boxes are processed in descending score order; a box suppresses all
|
||||
/// lower-scoring boxes with IoU > threshold.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `self` - Bounding boxes as \[N, 4\] tensor in (x1, y1, x2, y2) format
|
||||
/// * `scores` - Confidence scores as \[N\] tensor
|
||||
/// * `options` - NMS options (IoU threshold, score threshold, max boxes)
|
||||
///
|
||||
/// # Returns
|
||||
/// Indices of kept boxes as \[M\] tensor where M <= N
|
||||
fn nms(self, scores: Tensor<B, 1, Float>, opts: NmsOptions) -> Tensor<B, 1, Int>;
|
||||
}
|
||||
|
||||
impl<B: BoolVisionOps> ConnectedComponents<B> for Tensor<B, 2, Bool> {
|
||||
fn connected_components(self, connectivity: Connectivity) -> Tensor<B, 2, Int> {
|
||||
Tensor::from_primitive(B::connected_components(self.into_primitive(), connectivity))
|
||||
}
|
||||
|
||||
fn connected_components_with_stats(
|
||||
self,
|
||||
connectivity: Connectivity,
|
||||
options: ConnectedStatsOptions,
|
||||
) -> (Tensor<B, 2, Int>, ConnectedStats<B>) {
|
||||
let (labels, stats) =
|
||||
B::connected_components_with_stats(self.into_primitive(), connectivity, options);
|
||||
(Tensor::from_primitive(labels), stats.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: VisionBackend, K: MorphologyKind<B>> Morphology<B, K> for Tensor<B, 3, K> {
|
||||
fn erode(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self {
|
||||
Tensor::new(K::erode(
|
||||
self.into_primitive(),
|
||||
kernel.into_primitive(),
|
||||
opts,
|
||||
))
|
||||
}
|
||||
|
||||
fn dilate(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self {
|
||||
Tensor::new(K::dilate(
|
||||
self.into_primitive(),
|
||||
kernel.into_primitive(),
|
||||
opts,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: VisionBackend> MorphologyKind<B> for Float {
|
||||
fn erode(
|
||||
tensor: Self::Primitive,
|
||||
kernel: BoolTensor<B>,
|
||||
opts: MorphOptions<B, Self>,
|
||||
) -> Self::Primitive {
|
||||
match tensor {
|
||||
TensorPrimitive::Float(tensor) => {
|
||||
TensorPrimitive::Float(B::float_erode(tensor, kernel, opts))
|
||||
}
|
||||
TensorPrimitive::QFloat(tensor) => {
|
||||
TensorPrimitive::QFloat(B::q_erode(tensor, kernel, opts))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn dilate(
|
||||
tensor: Self::Primitive,
|
||||
kernel: BoolTensor<B>,
|
||||
opts: MorphOptions<B, Self>,
|
||||
) -> Self::Primitive {
|
||||
match tensor {
|
||||
TensorPrimitive::Float(tensor) => {
|
||||
TensorPrimitive::Float(B::float_dilate(tensor, kernel, opts))
|
||||
}
|
||||
TensorPrimitive::QFloat(tensor) => {
|
||||
TensorPrimitive::QFloat(B::q_dilate(tensor, kernel, opts))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: VisionBackend> MorphologyKind<B> for Int {
|
||||
fn erode(
|
||||
tensor: Self::Primitive,
|
||||
kernel: BoolTensor<B>,
|
||||
opts: MorphOptions<B, Self>,
|
||||
) -> Self::Primitive {
|
||||
B::int_erode(tensor, kernel, opts)
|
||||
}
|
||||
|
||||
fn dilate(
|
||||
tensor: Self::Primitive,
|
||||
kernel: BoolTensor<B>,
|
||||
opts: MorphOptions<B, Self>,
|
||||
) -> Self::Primitive {
|
||||
B::int_dilate(tensor, kernel, opts)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: VisionBackend> MorphologyKind<B> for Bool {
|
||||
fn erode(
|
||||
tensor: Self::Primitive,
|
||||
kernel: BoolTensor<B>,
|
||||
opts: MorphOptions<B, Self>,
|
||||
) -> Self::Primitive {
|
||||
B::bool_erode(tensor, kernel, opts)
|
||||
}
|
||||
|
||||
fn dilate(
|
||||
tensor: Self::Primitive,
|
||||
kernel: BoolTensor<B>,
|
||||
opts: MorphOptions<B, Self>,
|
||||
) -> Self::Primitive {
|
||||
B::bool_dilate(tensor, kernel, opts)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: VisionBackend> Nms<B> for Tensor<B, 2> {
|
||||
fn nms(self, scores: Tensor<B, 1>, options: NmsOptions) -> Tensor<B, 1, Int> {
|
||||
match (self.into_primitive(), scores.into_primitive()) {
|
||||
(TensorPrimitive::Float(boxes), TensorPrimitive::Float(scores)) => {
|
||||
Tensor::<B, 1, Int>::from_primitive(B::nms(boxes, scores, options))
|
||||
}
|
||||
_ => todo!("Quantized inputs are not yet supported"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use burn_tensor::{Shape, Tensor, TensorData, backend::Backend};
|
||||
use image::{DynamicImage, ImageBuffer, Luma, Rgb};
|
||||
|
||||
mod connected_components;
|
||||
mod morphology;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! testgen_all {
|
||||
() => {
|
||||
use burn_tensor::{Bool, Float, Int};
|
||||
|
||||
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
|
||||
pub type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, Int>;
|
||||
pub type TestTensorBool<const D: usize> = burn_tensor::Tensor<TestBackend, D, Bool>;
|
||||
|
||||
pub mod vision {
|
||||
pub use super::*;
|
||||
|
||||
pub type IntType = <TestBackend as burn_tensor::backend::Backend>::IntElem;
|
||||
|
||||
burn_vision::testgen_connected_components!();
|
||||
burn_vision::testgen_morphology!();
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
mod transform2d;
|
||||
|
||||
pub use transform2d::*;
|
||||
@@ -0,0 +1,229 @@
|
||||
use burn_tensor::{
|
||||
Tensor,
|
||||
backend::Backend,
|
||||
grid::affine_grid_2d,
|
||||
ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode},
|
||||
};
|
||||
|
||||
/// 2D point transformation
|
||||
///
|
||||
/// Useful for resampling: rotating, scaling, translating, etc image tensors
|
||||
pub struct Transform2D {
|
||||
// 2x3 transformation matrix, to be used with column vectors:
|
||||
// T(x) = Ax
|
||||
transform: [[f32; 3]; 2],
|
||||
}
|
||||
|
||||
impl Transform2D {
|
||||
/// Transforms an image
|
||||
///
|
||||
/// * `img` - Images tensor with shape (batch_size, channels, height, width)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same as the input
|
||||
pub fn transform<B: Backend>(self, img: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let [batch_size, channels, height, width] = img.shape().dims();
|
||||
let transform = Tensor::<B, 2>::from(self.transform);
|
||||
let transform = transform.reshape([1, 2, 3]).expand([batch_size, 2, 3]);
|
||||
let grid = affine_grid_2d(transform, [batch_size, channels, height, width]);
|
||||
|
||||
let options = GridSampleOptions::new(InterpolateMode::Bilinear)
|
||||
.with_padding_mode(GridSamplePaddingMode::Border)
|
||||
.with_align_corners(true);
|
||||
img.grid_sample_2d(grid, options)
|
||||
}
|
||||
|
||||
/// Makes a 2d transformation composed of other transformations
|
||||
pub fn composed<I: IntoIterator<Item = Self>>(transforms: I) -> Self {
|
||||
let mut result = Self::identity();
|
||||
for t in transforms.into_iter() {
|
||||
result = result.mul(t);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Multiply two affine transforms represented as 2x3 matrices
|
||||
fn mul(self, other: Transform2D) -> Transform2D {
|
||||
let mut result = [[0.0f32; 3]; 2];
|
||||
|
||||
// Row 0
|
||||
result[0][0] = self.transform[0][0] * other.transform[0][0]
|
||||
+ self.transform[0][1] * other.transform[1][0];
|
||||
result[0][1] = self.transform[0][0] * other.transform[0][1]
|
||||
+ self.transform[0][1] * other.transform[1][1];
|
||||
result[0][2] = self.transform[0][0] * other.transform[0][2]
|
||||
+ self.transform[0][1] * other.transform[1][2]
|
||||
+ self.transform[0][2];
|
||||
|
||||
// Row 1
|
||||
result[1][0] = self.transform[1][0] * other.transform[0][0]
|
||||
+ self.transform[1][1] * other.transform[1][0];
|
||||
result[1][1] = self.transform[1][0] * other.transform[0][1]
|
||||
+ self.transform[1][1] * other.transform[1][1];
|
||||
result[1][2] = self.transform[1][0] * other.transform[0][2]
|
||||
+ self.transform[1][1] * other.transform[1][2]
|
||||
+ self.transform[1][2];
|
||||
|
||||
Transform2D { transform: result }
|
||||
}
|
||||
|
||||
/// Makes an identity transform (x = Ax)
|
||||
pub fn identity() -> Self {
|
||||
Self {
|
||||
transform: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
|
||||
}
|
||||
}
|
||||
|
||||
/// Makes a [`Transform2D`] for rotating a tensor
|
||||
///
|
||||
/// * `theta` - In radians, the rotation
|
||||
/// * `cx` - Center of rotation, x
|
||||
/// * `cy` - Center of rotation, y
|
||||
pub fn rotation(theta: f32, cx: f32, cy: f32) -> Self {
|
||||
let cos_theta = theta.cos();
|
||||
let sin_theta = theta.sin();
|
||||
|
||||
let transform = [
|
||||
[cos_theta, -sin_theta, cx - cos_theta * cx + sin_theta * cy],
|
||||
[sin_theta, cos_theta, cy - sin_theta * cx - cos_theta * cy],
|
||||
];
|
||||
|
||||
Self { transform }
|
||||
}
|
||||
|
||||
/// Makes a [`Transform2D`] for scaling an image tensor
|
||||
///
|
||||
/// * `sx` - Scale factor in the x direction
|
||||
/// * `sy` - Scale factor in the y direction
|
||||
/// * `cx` - Center of scaling, x
|
||||
/// * `cy` - Center of scaling, y
|
||||
pub fn scale(sx: f32, sy: f32, cx: f32, cy: f32) -> Self {
|
||||
let transform = [[sx, 0.0, cx - sx * cx], [0.0, sy, cy - sy * cy]];
|
||||
|
||||
Self { transform }
|
||||
}
|
||||
|
||||
/// Makes a [`Transform2D`] for translating an image tensor
|
||||
///
|
||||
/// * `tx` - Translation in the x direction
|
||||
/// * `ty` - Translation in the y direction
|
||||
pub fn translation(tx: f32, ty: f32) -> Self {
|
||||
let transform = [[1.0, 0.0, tx], [0.0, 1.0, ty]];
|
||||
|
||||
Self { transform }
|
||||
}
|
||||
|
||||
/// Applies a general shear transformation around the image center,
|
||||
/// combining both X and Y shear.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `shx` - Shear factor along the X-axis.
|
||||
/// * `shy` - Shear factor along the Y-axis.
|
||||
/// * `cx`, `cy` - Coordinates of the image center.
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Self` with a combined shear transform matrix.
|
||||
pub fn shear(shx: f32, shy: f32, cx: f32, cy: f32) -> Self {
|
||||
let transform = [[1.0, shx, -shx * cy], [shy, 1.0, -shy * cx]];
|
||||
|
||||
Self { transform }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_ndarray::NdArray;
|
||||
use burn_tensor::Tolerance;
|
||||
type B = NdArray;
|
||||
|
||||
#[test]
|
||||
fn transform_identity_translation() {
|
||||
let t = Transform2D::translation(0.0, 0.0);
|
||||
let image_original = Tensor::<B, 4>::from([[[[1., 0.], [0., 2.]]]]);
|
||||
let image_transformed = t.transform(image_original.clone());
|
||||
image_original
|
||||
.to_data()
|
||||
.assert_approx_eq(&image_transformed.to_data(), Tolerance::<f32>::balanced());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transform_translation() {
|
||||
let t = Transform2D::translation(1., 1.);
|
||||
let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
|
||||
// This result would change if the padding method is different
|
||||
let image_expected = Tensor::<B, 4>::from([[[[2.5, 3.], [3.5, 4.]]]]);
|
||||
let image = t.transform(image);
|
||||
image_expected
|
||||
.to_data()
|
||||
.assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transform_rotation_90_degrees() {
|
||||
let t = Transform2D::rotation(std::f32::consts::FRAC_PI_2, 0.0, 0.0);
|
||||
let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
|
||||
let image_expected = Tensor::<B, 4>::from([[[[2., 4.], [1., 3.]]]]);
|
||||
let image = t.transform(image);
|
||||
image_expected
|
||||
.to_data()
|
||||
.assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transform_rotation_around_corner() {
|
||||
let cx = 1.;
|
||||
let cy = -1.;
|
||||
let t = Transform2D::rotation(std::f32::consts::FRAC_PI_2, cx, cy);
|
||||
let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
|
||||
// This result would change if the padding method is different
|
||||
let image_expected = Tensor::<B, 4>::from([[[[2., 2.], [1., 1.]]]]);
|
||||
let image = t.transform(image);
|
||||
image_expected
|
||||
.to_data()
|
||||
.assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transform_scale() {
|
||||
let cx = 0.0;
|
||||
let cy = 0.0;
|
||||
let t = Transform2D::scale(0.5, 0.5, cx, cy);
|
||||
let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
|
||||
let image_expected = Tensor::<B, 4>::from([[[[1.75, 2.25], [2.75, 3.25]]]]);
|
||||
let image = t.transform(image);
|
||||
image_expected
|
||||
.to_data()
|
||||
.assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transform_scale_around_corner() {
|
||||
let cx = 1.;
|
||||
let cy = -1.;
|
||||
let t = Transform2D::scale(0.5, 0.5, cx, cy);
|
||||
let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
|
||||
let image_expected = Tensor::<B, 4>::from([[[[1.5, 2.], [2.5, 3.]]]]);
|
||||
let image = t.transform(image);
|
||||
image_expected
|
||||
.to_data()
|
||||
.assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transform_combined() {
|
||||
let t1 = Transform2D::translation(0.2, -0.5);
|
||||
let t2 = Transform2D::rotation(std::f32::consts::FRAC_PI_3, 0., 0.);
|
||||
let t = Transform2D::composed([t1, t2]);
|
||||
|
||||
let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
|
||||
// This result would change if the padding method is different
|
||||
let image_expected =
|
||||
Tensor::<B, 4>::from([[[[1.7830127, 2.8660254], [1.1339746, 3.2830124]]]]);
|
||||
let image = t.transform(image);
|
||||
image_expected
|
||||
.to_data()
|
||||
.assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
mod save;
|
||||
|
||||
pub use save::*;
|
||||
@@ -0,0 +1,191 @@
|
||||
//! Utilities for saving tensors as images
|
||||
|
||||
use burn_tensor::{ElementConversion, Tensor, backend::Backend};
|
||||
use image::{Rgb, RgbImage};
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
/// How to save a tensor as an image
|
||||
pub struct TensorDisplayOptions {
|
||||
/// How should the dimensions be interpreted
|
||||
pub dim_order: ImageDimOrder,
|
||||
/// What colors should be used
|
||||
pub color_opts: ColorDisplayOpts,
|
||||
/// How to handle batches
|
||||
pub batch_opts: Option<BatchDisplayOpts>,
|
||||
/// Output image width
|
||||
pub width_out: usize,
|
||||
/// Output image height
|
||||
pub height_out: usize,
|
||||
}
|
||||
|
||||
/// How to interpret dimensions for image tensors
|
||||
pub enum ImageDimOrder {
|
||||
/// dims: (height, width)
|
||||
Hw,
|
||||
/// dims: (channels, height, width)
|
||||
Chw,
|
||||
/// dims: (height, width, channels)
|
||||
Hwc,
|
||||
/// dims: (batch_size, height, width)
|
||||
Nhw,
|
||||
/// dims: (batch_size, channels, height, width)
|
||||
Nchw,
|
||||
/// dims: (batch_size, height, width, channels)
|
||||
Nhwc,
|
||||
}
|
||||
|
||||
/// How to translate tensor values to colors
|
||||
pub enum ColorDisplayOpts {
|
||||
/// The values in each channel are respectively assigned to an RGB channel
|
||||
Rgb,
|
||||
/// The channel value is mapped between two colors
|
||||
Monochrome {
|
||||
/// Color assigned to the minimum value
|
||||
min: [f32; 3],
|
||||
/// Color assigned to the maximum value
|
||||
max: [f32; 3],
|
||||
},
|
||||
}
|
||||
|
||||
/// How to handle multi-batch tensors
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
pub enum BatchDisplayOpts {
|
||||
/// Each item is placed consecutively in the image
|
||||
Tiled,
|
||||
/// Each item is aggregated
|
||||
Aggregated,
|
||||
}
|
||||
|
||||
/// Save a tensor of a batch of images as an image
|
||||
///
|
||||
/// * `tensor` - Image batch with shape (N, height, width)
|
||||
/// * `opts` - Options for how to draw the tensor
|
||||
/// * `path` - The file path to use
|
||||
pub fn save_tensor_as_image<B: Backend, const D: usize, P: AsRef<std::ffi::OsStr>>(
|
||||
tensor: Tensor<B, D>,
|
||||
opts: TensorDisplayOptions,
|
||||
path: P,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Output file
|
||||
let path = Path::new(&path);
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
let tensor = normalize(tensor);
|
||||
|
||||
// convert to (N,C,H,W) format
|
||||
let tensor: Tensor<B, 4> = match opts.dim_order {
|
||||
ImageDimOrder::Hw => {
|
||||
let [h, w] = tensor.shape().dims();
|
||||
tensor.reshape([1, 1, h, w])
|
||||
}
|
||||
ImageDimOrder::Chw => {
|
||||
let [c, h, w] = tensor.shape().dims();
|
||||
tensor.reshape([1, c, h, w])
|
||||
}
|
||||
ImageDimOrder::Hwc => {
|
||||
let [h, w, c] = tensor.shape().dims();
|
||||
tensor.swap_dims(0, 2).swap_dims(1, 2).reshape([1, c, h, w])
|
||||
}
|
||||
ImageDimOrder::Nhw => {
|
||||
let [n, h, w] = tensor.shape().dims();
|
||||
tensor.reshape([n, 1, h, w])
|
||||
}
|
||||
ImageDimOrder::Nchw => tensor.reshape([0, 0, 0, 0]),
|
||||
ImageDimOrder::Nhwc => tensor.swap_dims(1, 3).swap_dims(2, 3).reshape([0, 0, 0, 0]),
|
||||
};
|
||||
|
||||
let data = tensor.to_data();
|
||||
let shape = data.shape.clone();
|
||||
let (batch, channels, src_height, src_width) = (shape[0], shape[1], shape[2], shape[3]);
|
||||
|
||||
let mut img = if let Some(batch_opts) = &opts.batch_opts
|
||||
&& BatchDisplayOpts::Tiled == *batch_opts
|
||||
{
|
||||
RgbImage::new(opts.width_out as u32, (opts.height_out * batch) as u32)
|
||||
} else {
|
||||
RgbImage::new(opts.width_out as u32, opts.height_out as u32)
|
||||
};
|
||||
|
||||
let data_vec = data.to_vec::<f32>().unwrap();
|
||||
|
||||
let mut channel_vals = vec![0 as f32; channels]; // value for each channel in a given pixel
|
||||
for n in 0..batch {
|
||||
for x in 0..opts.width_out {
|
||||
for y in 0..opts.height_out {
|
||||
let i = ((x as f32) / (opts.width_out as f32) * (src_width as f32))
|
||||
.floor()
|
||||
.clamp(0.0, src_width as f32) as usize;
|
||||
let j = ((y as f32) / (opts.height_out as f32) * (src_height as f32))
|
||||
.floor()
|
||||
.clamp(0.0, src_height as f32) as usize;
|
||||
|
||||
for c in 0..channels {
|
||||
channel_vals[c] =
|
||||
data_vec[i + (j + (n * channels + c) * src_height) * src_width];
|
||||
}
|
||||
|
||||
let (x, y) = if let Some(batch_opts) = opts.batch_opts
|
||||
&& BatchDisplayOpts::Tiled == batch_opts
|
||||
{
|
||||
let batch_x = 0;
|
||||
let batch_y = n as u32 * opts.height_out as u32;
|
||||
(x as u32 + batch_x, y as u32 + batch_y)
|
||||
} else {
|
||||
(x as u32, y as u32)
|
||||
};
|
||||
|
||||
let mut pixel = [0 as f32; 3];
|
||||
match opts.color_opts {
|
||||
ColorDisplayOpts::Rgb => match channels {
|
||||
1 => {
|
||||
pixel[0] = channel_vals[0];
|
||||
pixel[1] = 0.0;
|
||||
pixel[2] = 0.0;
|
||||
}
|
||||
2 => {
|
||||
pixel[0] = channel_vals[0];
|
||||
pixel[1] = channel_vals[1];
|
||||
pixel[2] = 0.0;
|
||||
}
|
||||
3 => {
|
||||
pixel[0] = channel_vals[0];
|
||||
pixel[1] = channel_vals[1];
|
||||
pixel[2] = channel_vals[2];
|
||||
}
|
||||
_ => unimplemented!("More than 3 channels not supported ({channels})"),
|
||||
},
|
||||
ColorDisplayOpts::Monochrome { min, max } => {
|
||||
let val: f32 = channel_vals.iter().sum();
|
||||
pixel[0] = min[0] * (1.0 - val) + max[0] * val;
|
||||
pixel[1] = min[1] * (1.0 - val) + max[1] * val;
|
||||
pixel[2] = min[2] * (1.0 - val) + max[2] * val;
|
||||
}
|
||||
}
|
||||
|
||||
let pixel = [
|
||||
(pixel[0] * 255.0) as u8,
|
||||
(pixel[1] * 255.0) as u8,
|
||||
(pixel[2] * 255.0) as u8,
|
||||
];
|
||||
img.put_pixel(x, y, Rgb(pixel));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
img.save(path)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Normalize values in 2D tensor from 0 to 1
|
||||
fn normalize<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let min = tensor.clone().min().into_scalar().elem::<f32>();
|
||||
let max = tensor.clone().max().into_scalar().elem::<f32>();
|
||||
let range = if max - min == 0.0 { 1.0 } else { max - min };
|
||||
|
||||
tensor
|
||||
.sub_scalar(min.elem::<f32>())
|
||||
.div_scalar(range.elem::<f32>())
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use burn_tensor::{Shape, Tensor, TensorData, backend::Backend};
|
||||
use image::{DynamicImage, ImageBuffer, Luma, Rgb};
|
||||
|
||||
use burn_tensor::{Bool, Int};
|
||||
|
||||
#[cfg(all(
|
||||
any(feature = "test-cpu", feature = "ndarray"),
|
||||
not(any(feature = "test-wgpu", feature = "test-cuda"))
|
||||
))]
|
||||
pub type TestBackend = burn_ndarray::NdArray<f32, i32>;
|
||||
|
||||
#[cfg(all(test, feature = "test-wgpu"))]
|
||||
pub type TestBackend = burn_wgpu::Wgpu;
|
||||
|
||||
#[cfg(all(test, feature = "test-cuda"))]
|
||||
pub type TestBackend = burn_cuda::Cuda;
|
||||
|
||||
#[allow(unused)]
|
||||
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
|
||||
pub type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, Int>;
|
||||
#[allow(unused)]
|
||||
pub type TestTensorBool<const D: usize> = burn_tensor::Tensor<TestBackend, D, Bool>;
|
||||
|
||||
#[allow(unused)]
|
||||
pub type IntType = <TestBackend as burn_tensor::backend::Backend>::IntElem;
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export]
|
||||
macro_rules! as_type {
|
||||
($ty:ident: [$($elem:tt),*]) => {
|
||||
[$($crate::as_type![$ty: $elem]),*]
|
||||
};
|
||||
($ty:ident: [$($elem:tt,)*]) => {
|
||||
[$($crate::as_type![$ty: $elem]),*]
|
||||
};
|
||||
($ty:ident: $elem:expr) => {
|
||||
{
|
||||
use cubecl::prelude::*;
|
||||
|
||||
$ty::new($elem)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn test_image<B: Backend>(name: &str, device: &B::Device, luma: bool) -> Tensor<B, 3> {
|
||||
let file = PathBuf::from("tests/images").join(name);
|
||||
let image = image::open(file).unwrap();
|
||||
if luma {
|
||||
let image = image.to_luma32f();
|
||||
let h = image.height() as usize;
|
||||
let w = image.width() as usize;
|
||||
let data = TensorData::new(image.into_vec(), Shape::new([h, w, 1]));
|
||||
Tensor::from_data(data, device)
|
||||
} else {
|
||||
let image = image.to_rgb32f();
|
||||
let h = image.height() as usize;
|
||||
let w = image.width() as usize;
|
||||
let data = TensorData::new(image.into_vec(), Shape::new([h, w, 3]));
|
||||
Tensor::from_data(data, device)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn save_test_image<B: Backend>(name: &str, tensor: Tensor<B, 3>, luma: bool) {
|
||||
let file = PathBuf::from("tests/images").join(name);
|
||||
let [h, w, _] = tensor.shape().dims();
|
||||
let data = tensor
|
||||
.into_data()
|
||||
.convert::<f32>()
|
||||
.into_vec::<f32>()
|
||||
.unwrap();
|
||||
if luma {
|
||||
let image = ImageBuffer::<Luma<f32>, _>::from_raw(w as u32, h as u32, data).unwrap();
|
||||
DynamicImage::from(image).to_luma8().save(file).unwrap();
|
||||
} else {
|
||||
let image = ImageBuffer::<Rgb<f32>, _>::from_raw(w as u32, h as u32, data).unwrap();
|
||||
DynamicImage::from(image).to_rgb8().save(file).unwrap();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
#![allow(clippy::single_range_in_vec_init)]
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use burn_tensor::TensorData;
|
||||
use burn_vision::{ConnectedComponents, ConnectedStatsOptions, Connectivity};
|
||||
|
||||
mod common;
|
||||
use common::*;
|
||||
|
||||
fn space_invader() -> [[IntType; 14]; 9] {
|
||||
as_type!(IntType: [
|
||||
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
|
||||
[0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0],
|
||||
[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0],
|
||||
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
|
||||
[1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1],
|
||||
[1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1],
|
||||
[1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1],
|
||||
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
|
||||
])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_8_connectivity() {
|
||||
let tensor = TestTensorBool::<2>::from(space_invader());
|
||||
|
||||
let output = tensor.connected_components(Connectivity::Eight);
|
||||
let expected = space_invader(); // All pixels are in the same group for 8-connected
|
||||
let expected = TestTensorInt::<2>::from(expected);
|
||||
|
||||
normalize_labels(output.into_data()).assert_eq(&expected.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_8_connectivity_with_stats() {
|
||||
let tensor = TestTensorBool::<2>::from(space_invader());
|
||||
|
||||
let (output, stats) =
|
||||
tensor.connected_components_with_stats(Connectivity::Eight, ConnectedStatsOptions::all());
|
||||
let expected = space_invader(); // All pixels are in the same group for 8-connected
|
||||
let expected = TestTensorInt::<2>::from(expected);
|
||||
|
||||
let (area, left, top, right, bottom) = (
|
||||
stats.area.slice([1..2]).into_data(),
|
||||
stats.left.slice([1..2]).into_data(),
|
||||
stats.top.slice([1..2]).into_data(),
|
||||
stats.right.slice([1..2]).into_data(),
|
||||
stats.bottom.slice([1..2]).into_data(),
|
||||
);
|
||||
|
||||
output.into_data().assert_eq(&expected.into_data(), false);
|
||||
|
||||
area.assert_eq(&TensorData::from([58]), false);
|
||||
left.assert_eq(&TensorData::from([0]), false);
|
||||
top.assert_eq(&TensorData::from([0]), false);
|
||||
right.assert_eq(&TensorData::from([13]), false);
|
||||
bottom.assert_eq(&TensorData::from([8]), false);
|
||||
stats
|
||||
.max_label
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([1]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_4_connectivity() {
|
||||
let tensor = TestTensorBool::<2>::from(space_invader());
|
||||
|
||||
let output = tensor.connected_components(Connectivity::Four);
|
||||
let expected = as_type!(IntType: [
|
||||
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0],
|
||||
[0, 0, 0, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0],
|
||||
[0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0],
|
||||
[0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3, 0, 0],
|
||||
[0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0],
|
||||
[4, 0, 0, 3, 3, 0, 0, 0, 0, 3, 3, 0, 0, 5],
|
||||
[4, 4, 0, 0, 3, 3, 3, 3, 3, 3, 0, 0, 5, 5],
|
||||
[4, 4, 0, 3, 3, 3, 0, 0, 3, 3, 3, 0, 5, 5],
|
||||
[0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0],
|
||||
]);
|
||||
let expected = TestTensorInt::<2>::from(expected);
|
||||
|
||||
normalize_labels(output.into_data()).assert_eq(&expected.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_4_connectivity_with_stats() {
|
||||
let tensor = TestTensorBool::<2>::from(space_invader());
|
||||
|
||||
let (output, stats) =
|
||||
tensor.connected_components_with_stats(Connectivity::Four, ConnectedStatsOptions::all());
|
||||
let expected = as_type!(IntType: [
|
||||
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0],
|
||||
[0, 0, 0, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0],
|
||||
[0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0],
|
||||
[0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3, 0, 0],
|
||||
[0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0],
|
||||
[4, 0, 0, 3, 3, 0, 0, 0, 0, 3, 3, 0, 0, 5],
|
||||
[4, 4, 0, 0, 3, 3, 3, 3, 3, 3, 0, 0, 5, 5],
|
||||
[4, 4, 0, 3, 3, 3, 0, 0, 3, 3, 3, 0, 5, 5],
|
||||
[0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0],
|
||||
]);
|
||||
let expected = TestTensorInt::<2>::from(expected);
|
||||
|
||||
// Slice off background and limit to compacted labels
|
||||
let (area, left, top, right, bottom) = (
|
||||
stats.area.slice([1..6]).into_data(),
|
||||
stats.left.slice([1..6]).into_data(),
|
||||
stats.top.slice([1..6]).into_data(),
|
||||
stats.right.slice([1..6]).into_data(),
|
||||
stats.bottom.slice([1..6]).into_data(),
|
||||
);
|
||||
|
||||
output.into_data().assert_eq(&expected.into_data(), false);
|
||||
|
||||
area.assert_eq(&TensorData::from([1, 1, 46, 5, 5]), false);
|
||||
left.assert_eq(&TensorData::from([3, 10, 1, 0, 12]), false);
|
||||
top.assert_eq(&TensorData::from([0, 0, 1, 5, 5]), false);
|
||||
right.assert_eq(&TensorData::from([3, 10, 12, 1, 13]), false);
|
||||
bottom.assert_eq(&TensorData::from([0, 0, 8, 7, 7]), false);
|
||||
stats
|
||||
.max_label
|
||||
.into_data()
|
||||
.assert_eq(&TensorData::from([5]), false);
|
||||
}
|
||||
|
||||
/// Normalize labels to sequential since actual labels aren't required to be contiguous and
|
||||
/// different algorithms can return different numbers even if correct
|
||||
fn normalize_labels(mut labels: TensorData) -> TensorData {
|
||||
let mut next_label = 0;
|
||||
let mut mappings = HashMap::<i32, i32>::default();
|
||||
let data = labels.as_mut_slice::<i32>().unwrap();
|
||||
for label in data {
|
||||
if *label != 0 {
|
||||
let relabel = mappings.entry(*label).or_insert_with(|| {
|
||||
next_label += 1;
|
||||
next_label
|
||||
});
|
||||
*label = *relabel;
|
||||
}
|
||||
}
|
||||
labels
|
||||
}
|
||||
|
After Width: | Height: | Size: 422 B |
|
After Width: | Height: | Size: 1.6 KiB |
|
After Width: | Height: | Size: 1.2 KiB |
|
After Width: | Height: | Size: 1.2 KiB |
|
After Width: | Height: | Size: 1.2 KiB |
|
After Width: | Height: | Size: 1.2 KiB |
|
After Width: | Height: | Size: 396 B |
|
After Width: | Height: | Size: 2.9 KiB |
|
After Width: | Height: | Size: 2.8 KiB |
|
After Width: | Height: | Size: 3.1 KiB |
|
After Width: | Height: | Size: 3.0 KiB |
|
After Width: | Height: | Size: 3.1 KiB |
|
After Width: | Height: | Size: 3.1 KiB |
|
After Width: | Height: | Size: 3.1 KiB |
|
After Width: | Height: | Size: 3.0 KiB |
|
After Width: | Height: | Size: 3.0 KiB |
|
After Width: | Height: | Size: 3.0 KiB |
|
After Width: | Height: | Size: 1.2 KiB |
|
After Width: | Height: | Size: 1.2 KiB |
|
After Width: | Height: | Size: 370 B |
@@ -0,0 +1,638 @@
|
||||
use burn_tensor::{Tolerance, ops::FloatElem};
|
||||
use burn_vision::{
|
||||
BorderType, KernelShape, MorphOptions, Morphology, Point, Size, create_structuring_element,
|
||||
};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
mod common;
|
||||
use common::*;
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_luma() {
|
||||
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Rect,
|
||||
Size::new(5, 5),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = tensor.dilate(kernel, MorphOptions::default());
|
||||
let expected = test_image(
|
||||
"morphology/Dilate_1_5x5_Rect.png",
|
||||
&Default::default(),
|
||||
true,
|
||||
);
|
||||
let expected = TestTensor::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_luma_cross() {
|
||||
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Cross,
|
||||
Size::new(5, 5),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = tensor.dilate(kernel, MorphOptions::default());
|
||||
let expected = test_image(
|
||||
"morphology/Dilate_1_5x5_Cross.png",
|
||||
&Default::default(),
|
||||
true,
|
||||
);
|
||||
let expected = TestTensor::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_luma_ellipse() {
|
||||
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Ellipse,
|
||||
Size::new(5, 5),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = tensor.dilate(kernel, MorphOptions::default());
|
||||
let expected = test_image(
|
||||
"morphology/Dilate_1_5x5_Ellipse.png",
|
||||
&Default::default(),
|
||||
true,
|
||||
);
|
||||
let expected = TestTensor::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_luma_non_square_rect() {
|
||||
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Rect,
|
||||
Size::new(3, 5),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = tensor.dilate(kernel, MorphOptions::default());
|
||||
let expected = test_image(
|
||||
"morphology/Dilate_1_3x5_Rect.png",
|
||||
&Default::default(),
|
||||
true,
|
||||
);
|
||||
let expected = TestTensor::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_luma_non_square_cross() {
|
||||
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Cross,
|
||||
Size::new(3, 5),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = tensor.dilate(kernel, MorphOptions::default());
|
||||
let expected = test_image(
|
||||
"morphology/Dilate_1_3x5_Cross.png",
|
||||
&Default::default(),
|
||||
true,
|
||||
);
|
||||
let expected = TestTensor::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_rgb_rect() {
|
||||
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Rect,
|
||||
Size::new(3, 5),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = tensor.dilate(kernel, MorphOptions::default());
|
||||
let expected = test_image(
|
||||
"morphology/Dilate_2_3x5_Rect.png",
|
||||
&Default::default(),
|
||||
false,
|
||||
);
|
||||
let expected = TestTensor::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_rgb_cross() {
|
||||
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Cross,
|
||||
Size::new(3, 5),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = tensor.dilate(kernel, MorphOptions::default());
|
||||
let expected = test_image(
|
||||
"morphology/Dilate_2_3x5_Cross.png",
|
||||
&Default::default(),
|
||||
false,
|
||||
);
|
||||
let expected = TestTensor::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_rgb_border_reflect_rect() {
|
||||
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Rect,
|
||||
Size::new(7, 7),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = tensor.dilate(
|
||||
kernel,
|
||||
MorphOptions::builder()
|
||||
.border_type(BorderType::Reflect)
|
||||
.build(),
|
||||
);
|
||||
let expected = test_image(
|
||||
"morphology/Dilate_2_7x7_Rect_BORDER_REFLECT.png",
|
||||
&Default::default(),
|
||||
false,
|
||||
);
|
||||
let expected = TestTensor::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_rgb_border_reflect_cross() {
|
||||
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Cross,
|
||||
Size::new(7, 7),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = tensor.dilate(
|
||||
kernel,
|
||||
MorphOptions::builder()
|
||||
.border_type(BorderType::Reflect)
|
||||
.build(),
|
||||
);
|
||||
let expected = test_image(
|
||||
"morphology/Dilate_2_7x7_Cross_BORDER_REFLECT.png",
|
||||
&Default::default(),
|
||||
false,
|
||||
);
|
||||
let expected = TestTensor::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_rgb_border_reflect101_rect() {
|
||||
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Rect,
|
||||
Size::new(7, 7),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = tensor.dilate(
|
||||
kernel,
|
||||
MorphOptions::builder()
|
||||
.border_type(BorderType::Reflect101)
|
||||
.build(),
|
||||
);
|
||||
let expected = test_image(
|
||||
"morphology/Dilate_2_7x7_Rect_BORDER_REFLECT101.png",
|
||||
&Default::default(),
|
||||
false,
|
||||
);
|
||||
let expected = TestTensor::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_rgb_border_reflect101_cross() {
|
||||
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Cross,
|
||||
Size::new(7, 7),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = tensor.dilate(
|
||||
kernel,
|
||||
MorphOptions::builder()
|
||||
.border_type(BorderType::Reflect101)
|
||||
.build(),
|
||||
);
|
||||
let expected = test_image(
|
||||
"morphology/Dilate_2_7x7_Cross_BORDER_REFLECT101.png",
|
||||
&Default::default(),
|
||||
false,
|
||||
);
|
||||
let expected = TestTensor::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_rgb_border_replicate_rect() {
|
||||
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Rect,
|
||||
Size::new(7, 7),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = tensor.dilate(
|
||||
kernel,
|
||||
MorphOptions::builder()
|
||||
.border_type(BorderType::Replicate)
|
||||
.build(),
|
||||
);
|
||||
let expected = test_image(
|
||||
"morphology/Dilate_2_7x7_Rect_BORDER_REPLICATE.png",
|
||||
&Default::default(),
|
||||
false,
|
||||
);
|
||||
let expected = TestTensor::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_rgb_border_replicate_cross() {
|
||||
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Cross,
|
||||
Size::new(7, 7),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = tensor.dilate(
|
||||
kernel,
|
||||
MorphOptions::builder()
|
||||
.border_type(BorderType::Replicate)
|
||||
.build(),
|
||||
);
|
||||
let expected = test_image(
|
||||
"morphology/Dilate_2_7x7_Cross_BORDER_REPLICATE.png",
|
||||
&Default::default(),
|
||||
false,
|
||||
);
|
||||
let expected = TestTensor::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_rgb_anchor_rect() {
|
||||
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Rect,
|
||||
Size::new(5, 7),
|
||||
Some(Point::new(1, 2)),
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = tensor.dilate(
|
||||
kernel,
|
||||
MorphOptions::builder().anchor(Point::new(2, 1)).build(),
|
||||
);
|
||||
let expected = test_image(
|
||||
"morphology/Dilate_2_5x7_Rect_ANCHOR.png",
|
||||
&Default::default(),
|
||||
false,
|
||||
);
|
||||
let expected = TestTensor::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_rgb_anchor_cross() {
|
||||
let tensor = test_image("morphology/Base_2.png", &Default::default(), false);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Cross,
|
||||
Size::new(5, 7),
|
||||
Some(Point::new(1, 2)),
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
// With default border, bottom left pixel is undefined with this particular kernel and anchor
|
||||
// Use replicate instead for comparability
|
||||
let output = tensor.dilate(
|
||||
kernel,
|
||||
MorphOptions::builder()
|
||||
.anchor(Point::new(2, 1))
|
||||
.border_type(BorderType::Replicate)
|
||||
.build(),
|
||||
);
|
||||
let expected = test_image(
|
||||
"morphology/Dilate_2_5x7_Cross_ANCHOR_BORDER_REPLICATE.png",
|
||||
&Default::default(),
|
||||
false,
|
||||
);
|
||||
let expected = TestTensor::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_boolean_rect() {
|
||||
let tensor = test_image("morphology/Base_1.png", &Default::default(), true).greater_elem(0);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Rect,
|
||||
Size::new(5, 5),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
// With default border, bottom left pixel is undefined with this particular kernel and anchor
|
||||
// Use replicate instead for comparability
|
||||
let output = tensor.dilate(kernel, MorphOptions::default());
|
||||
let expected = test_image(
|
||||
"morphology/Dilate_1_5x5_Rect.png",
|
||||
&Default::default(),
|
||||
true,
|
||||
)
|
||||
.greater_elem(0);
|
||||
let expected = TestTensorBool::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_boolean_cross() {
|
||||
let tensor = test_image("morphology/Base_1.png", &Default::default(), true).greater_elem(0);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Cross,
|
||||
Size::new(5, 5),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
// With default border, bottom left pixel is undefined with this particular kernel and anchor
|
||||
// Use replicate instead for comparability
|
||||
let output = tensor.dilate(kernel, MorphOptions::default());
|
||||
let expected = test_image(
|
||||
"morphology/Dilate_1_5x5_Cross.png",
|
||||
&Default::default(),
|
||||
true,
|
||||
)
|
||||
.greater_elem(0);
|
||||
let expected = TestTensorBool::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_int_rect() {
|
||||
let tensor = (test_image("morphology/Base_1.png", &Default::default(), true) * 255.0).int();
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Rect,
|
||||
Size::new(5, 5),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
// With default border, bottom left pixel is undefined with this particular kernel and anchor
|
||||
// Use replicate instead for comparability
|
||||
let output = tensor.dilate(kernel, MorphOptions::default());
|
||||
let expected = (test_image(
|
||||
"morphology/Dilate_1_5x5_Rect.png",
|
||||
&Default::default(),
|
||||
true,
|
||||
) * 255.0)
|
||||
.int();
|
||||
let expected = TestTensorInt::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dilate_int_cross() {
|
||||
let tensor = (test_image("morphology/Base_1.png", &Default::default(), true) * 255.0).int();
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Cross,
|
||||
Size::new(5, 5),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
// With default border, bottom left pixel is undefined with this particular kernel and anchor
|
||||
// Use replicate instead for comparability
|
||||
let output = tensor.dilate(kernel, MorphOptions::default());
|
||||
let expected = (test_image(
|
||||
"morphology/Dilate_1_5x5_Cross.png",
|
||||
&Default::default(),
|
||||
true,
|
||||
) * 255.0)
|
||||
.int();
|
||||
let expected = TestTensorInt::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_erode_luma() {
|
||||
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
|
||||
let kernel = TestTensorBool::<2>::from([
|
||||
[true, true, true, true, true],
|
||||
[true, true, true, true, true],
|
||||
[true, true, true, true, true],
|
||||
[true, true, true, true, true],
|
||||
[true, true, true, true, true],
|
||||
]);
|
||||
|
||||
let output = tensor.erode(kernel, MorphOptions::default());
|
||||
let expected = test_image("morphology/Erode_1_5x5_Rect.png", &Default::default(), true);
|
||||
let expected = TestTensor::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_erode_luma_cross() {
|
||||
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Cross,
|
||||
Size::new(5, 5),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = tensor.erode(kernel, MorphOptions::default());
|
||||
let expected = test_image(
|
||||
"morphology/Erode_1_5x5_Cross.png",
|
||||
&Default::default(),
|
||||
true,
|
||||
);
|
||||
let expected = TestTensor::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_erode_luma_ellipse() {
|
||||
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Ellipse,
|
||||
Size::new(5, 5),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
let output = tensor.erode(kernel, MorphOptions::default());
|
||||
let expected = test_image(
|
||||
"morphology/Erode_1_5x5_Ellipse.png",
|
||||
&Default::default(),
|
||||
true,
|
||||
);
|
||||
let expected = TestTensor::<3>::from(expected);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_structuring_element_should_match_manual_rect() {
|
||||
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Rect,
|
||||
Size::new(5, 5),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
let kernel_manual = TestTensorBool::<2>::from([
|
||||
[true, true, true, true, true],
|
||||
[true, true, true, true, true],
|
||||
[true, true, true, true, true],
|
||||
[true, true, true, true, true],
|
||||
[true, true, true, true, true],
|
||||
]);
|
||||
|
||||
let output = tensor.clone().dilate(kernel, MorphOptions::default());
|
||||
let output_manual = tensor.dilate(kernel_manual, MorphOptions::default());
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&output_manual.into_data(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_structuring_element_should_match_manual_cross() {
|
||||
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Cross,
|
||||
Size::new(5, 5),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
let kernel_manual = TestTensorBool::<2>::from([
|
||||
[false, false, true, false, false],
|
||||
[false, false, true, false, false],
|
||||
[true, true, true, true, true],
|
||||
[false, false, true, false, false],
|
||||
[false, false, true, false, false],
|
||||
]);
|
||||
|
||||
let output = tensor.clone().dilate(kernel, MorphOptions::default());
|
||||
let output_manual = tensor.dilate(kernel_manual, MorphOptions::default());
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&output_manual.into_data(), false);
|
||||
}
|
||||
#[test]
|
||||
fn create_structuring_element_should_match_manual_ellipse() {
|
||||
let tensor = test_image("morphology/Base_1.png", &Default::default(), true);
|
||||
let kernel = create_structuring_element::<TestBackend>(
|
||||
KernelShape::Ellipse,
|
||||
Size::new(5, 5),
|
||||
None,
|
||||
&Default::default(),
|
||||
);
|
||||
let kernel_manual = TestTensorBool::<2>::from([
|
||||
[false, false, true, false, false],
|
||||
[true, true, true, true, true],
|
||||
[true, true, true, true, true],
|
||||
[true, true, true, true, true],
|
||||
[false, false, true, false, false],
|
||||
]);
|
||||
|
||||
let output = tensor.clone().dilate(kernel, MorphOptions::default());
|
||||
let output_manual = tensor.dilate(kernel_manual, MorphOptions::default());
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_eq(&output_manual.into_data(), false);
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
use burn_vision::{Nms, NmsOptions};
|
||||
|
||||
mod common;
|
||||
use common::*;
|
||||
|
||||
#[test]
|
||||
fn should_suppress_non_maximum() {
|
||||
let boxes = TestTensor::<2>::from([
|
||||
[0, 0, 100, 100],
|
||||
[0, 1, 100, 100],
|
||||
[0, 101, 200, 200],
|
||||
[0, 100, 200, 200],
|
||||
[0, 170, 300, 300],
|
||||
]);
|
||||
let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);
|
||||
let options = NmsOptions {
|
||||
iou_threshold: 0.5,
|
||||
score_threshold: 0.0,
|
||||
max_output_boxes: 0,
|
||||
};
|
||||
|
||||
let output = boxes.nms(scores, options);
|
||||
|
||||
let expected = TestTensorInt::<1>::from([4, 2, 1]);
|
||||
output.into_data().assert_eq(&expected.into_data(), true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_apply_score_threshold() {
|
||||
let boxes = TestTensor::<2>::from([
|
||||
[0, 0, 100, 100],
|
||||
[0, 1, 100, 100],
|
||||
[0, 101, 200, 200],
|
||||
[0, 100, 200, 200],
|
||||
[0, 170, 300, 300],
|
||||
]);
|
||||
let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);
|
||||
let options = NmsOptions {
|
||||
iou_threshold: 0.5,
|
||||
score_threshold: 0.3,
|
||||
max_output_boxes: 0,
|
||||
};
|
||||
|
||||
let output = boxes.nms(scores, options);
|
||||
|
||||
let expected = TestTensorInt::<1>::from([4, 2]);
|
||||
output.into_data().assert_eq(&expected.into_data(), true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_apply_iou_threshold() {
|
||||
let boxes = TestTensor::<2>::from([
|
||||
[0, 0, 100, 100],
|
||||
[0, 1, 100, 100],
|
||||
[0, 101, 200, 200],
|
||||
[0, 100, 200, 200],
|
||||
[0, 170, 300, 300],
|
||||
]);
|
||||
let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);
|
||||
let options = NmsOptions {
|
||||
iou_threshold: 0.1,
|
||||
score_threshold: 0.0,
|
||||
max_output_boxes: 0,
|
||||
};
|
||||
|
||||
let output = boxes.nms(scores, options);
|
||||
|
||||
let expected = TestTensorInt::<1>::from([4, 1]);
|
||||
output.into_data().assert_eq(&expected.into_data(), true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_apply_max_output_boxes() {
|
||||
let boxes = TestTensor::<2>::from([
|
||||
[0, 0, 100, 100],
|
||||
[0, 1, 100, 100],
|
||||
[0, 101, 200, 200],
|
||||
[0, 100, 200, 200],
|
||||
[0, 170, 300, 300],
|
||||
]);
|
||||
let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);
|
||||
let options = NmsOptions {
|
||||
iou_threshold: 0.5,
|
||||
score_threshold: 0.0,
|
||||
max_output_boxes: 1,
|
||||
};
|
||||
|
||||
let output = boxes.nms(scores, options);
|
||||
|
||||
let expected = TestTensorInt::<1>::from([4]);
|
||||
output.into_data().assert_eq(&expected.into_data(), true);
|
||||
}
|
||||