- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
341 lines
11 KiB
Rust
341 lines
11 KiB
Rust
use crate::{
|
|
Point,
|
|
backends::cpu::{self, MorphOp, morph},
|
|
};
|
|
use bon::Builder;
|
|
use burn_tensor::{
|
|
Bool, Float, Int, Tensor, TensorKind, TensorPrimitive,
|
|
backend::Backend,
|
|
ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
|
|
};
|
|
|
|
/// Connected components connectivity
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
|
pub enum Connectivity {
|
|
/// Four-connected (only connected in cardinal directions)
|
|
Four,
|
|
/// Eight-connected (connected if any of the surrounding 8 pixels are in the foreground)
|
|
Eight,
|
|
}
|
|
|
|
/// Which stats should be enabled for `connected_components_with_stats`.
|
|
/// Currently only used by the GPU implementation to save on atomic operations for unneeded stats.
|
|
///
|
|
/// Disabled stats are aliased to the labels tensor
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
|
pub struct ConnectedStatsOptions {
|
|
/// Whether to enable bounding boxes
|
|
pub bounds_enabled: bool,
|
|
/// Whether to enable the max label
|
|
pub max_label_enabled: bool,
|
|
/// Whether labels must be contiguous starting at 1
|
|
pub compact_labels: bool,
|
|
}
|
|
|
|
/// Options for morphology ops
|
|
#[derive(Clone, Debug, Builder)]
|
|
pub struct MorphOptions<B: Backend, K: TensorKind<B>> {
|
|
/// Anchor position within the kernel. Defaults to the center.
|
|
pub anchor: Option<Point>,
|
|
/// Number of iterations to apply
|
|
#[builder(default = 1)]
|
|
pub iterations: usize,
|
|
/// Border type. Default: constant based on operation
|
|
#[builder(default)]
|
|
pub border_type: BorderType,
|
|
/// Value of each channel for constant border type
|
|
pub border_value: Option<Tensor<B, 1, K>>,
|
|
}
|
|
|
|
impl<B: Backend, K: TensorKind<B>> Default for MorphOptions<B, K> {
|
|
fn default() -> Self {
|
|
Self {
|
|
anchor: Default::default(),
|
|
iterations: 1,
|
|
border_type: Default::default(),
|
|
border_value: Default::default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Morphology border type
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
|
|
pub enum BorderType {
|
|
/// Constant border with per-channel value. If no value is provided, the value is picked based
|
|
/// on the morph op.
|
|
#[default]
|
|
Constant,
|
|
/// Replicate first/last element
|
|
Replicate,
|
|
/// Reflect start/end elements
|
|
Reflect,
|
|
/// Reflect start/end elements, ignoring the first/last element
|
|
Reflect101,
|
|
/// Not supported for erode/dilate
|
|
Wrap,
|
|
}
|
|
|
|
/// Stats collected by the connected components analysis
|
|
///
|
|
/// Disabled analyses may be aliased to labels
|
|
#[derive(Clone, Debug)]
|
|
pub struct ConnectedStats<B: Backend> {
|
|
/// Total area of each component
|
|
pub area: Tensor<B, 1, Int>,
|
|
/// Topmost y coordinate in the component
|
|
pub top: Tensor<B, 1, Int>,
|
|
/// Leftmost x coordinate in the component
|
|
pub left: Tensor<B, 1, Int>,
|
|
/// Rightmost x coordinate in the component
|
|
pub right: Tensor<B, 1, Int>,
|
|
/// Bottommost y coordinate in the component
|
|
pub bottom: Tensor<B, 1, Int>,
|
|
/// Scalar tensor of the max label
|
|
pub max_label: Tensor<B, 1, Int>,
|
|
}
|
|
|
|
/// Primitive version of [`ConnectedStats`], to be returned by the backend
|
|
pub struct ConnectedStatsPrimitive<B: Backend> {
|
|
/// Total area of each component
|
|
pub area: IntTensor<B>,
|
|
/// Leftmost x coordinate in the component
|
|
pub left: IntTensor<B>,
|
|
/// Topmost y coordinate in the component
|
|
pub top: IntTensor<B>,
|
|
/// Rightmost x coordinate in the component
|
|
pub right: IntTensor<B>,
|
|
/// Bottommost y coordinate in the component
|
|
pub bottom: IntTensor<B>,
|
|
/// Scalar tensor of the max label
|
|
pub max_label: IntTensor<B>,
|
|
}
|
|
|
|
impl<B: Backend> From<ConnectedStatsPrimitive<B>> for ConnectedStats<B> {
|
|
fn from(value: ConnectedStatsPrimitive<B>) -> Self {
|
|
ConnectedStats {
|
|
area: Tensor::from_primitive(value.area),
|
|
top: Tensor::from_primitive(value.top),
|
|
left: Tensor::from_primitive(value.left),
|
|
right: Tensor::from_primitive(value.right),
|
|
bottom: Tensor::from_primitive(value.bottom),
|
|
max_label: Tensor::from_primitive(value.max_label),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<B: Backend> ConnectedStats<B> {
|
|
/// Convert a connected stats into the corresponding primitive
|
|
pub fn into_primitive(self) -> ConnectedStatsPrimitive<B> {
|
|
ConnectedStatsPrimitive {
|
|
area: self.area.into_primitive(),
|
|
top: self.top.into_primitive(),
|
|
left: self.left.into_primitive(),
|
|
right: self.right.into_primitive(),
|
|
bottom: self.bottom.into_primitive(),
|
|
max_label: self.max_label.into_primitive(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for ConnectedStatsOptions {
|
|
fn default() -> Self {
|
|
Self::all()
|
|
}
|
|
}
|
|
|
|
impl ConnectedStatsOptions {
|
|
/// Don't collect any stats
|
|
pub fn none() -> Self {
|
|
Self {
|
|
bounds_enabled: false,
|
|
max_label_enabled: false,
|
|
compact_labels: false,
|
|
}
|
|
}
|
|
|
|
/// Collect all stats
|
|
pub fn all() -> Self {
|
|
Self {
|
|
bounds_enabled: true,
|
|
max_label_enabled: true,
|
|
compact_labels: true,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Non-Maximum Suppression options.
|
|
#[derive(Clone, Copy, Debug)]
|
|
pub struct NmsOptions {
|
|
/// IoU threshold for suppression (default: 0.5).
|
|
/// Boxes with IoU > threshold with a higher-scoring box are suppressed.
|
|
pub iou_threshold: f32,
|
|
/// Score threshold to filter boxes before NMS (default: 0.0, i.e., no filtering).
|
|
/// Boxes with score < score_threshold are discarded.
|
|
pub score_threshold: f32,
|
|
/// Maximum number of boxes to keep (0 = unlimited).
|
|
pub max_output_boxes: usize,
|
|
}
|
|
|
|
impl Default for NmsOptions {
|
|
fn default() -> Self {
|
|
Self {
|
|
iou_threshold: 0.5,
|
|
score_threshold: 0.0,
|
|
max_output_boxes: 0,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Vision capable backend, implemented by each backend
|
|
pub trait VisionBackend:
|
|
BoolVisionOps + IntVisionOps + FloatVisionOps + QVisionOps + Backend
|
|
{
|
|
}
|
|
|
|
/// Vision ops on bool tensors
|
|
pub trait BoolVisionOps: Backend {
|
|
/// Computes the connected components labeled image of boolean image with 4 or 8 way
|
|
/// connectivity - returns a tensor of the component label of each pixel.
|
|
///
|
|
/// `img`- The boolean image tensor in the format [batches, height, width]
|
|
fn connected_components(img: BoolTensor<Self>, connectivity: Connectivity) -> IntTensor<Self> {
|
|
cpu::connected_components::<Self>(img, connectivity)
|
|
}
|
|
|
|
/// Computes the connected components labeled image of boolean image with 4 or 8 way
|
|
/// connectivity and collects statistics on each component - returns a tensor of the component
|
|
/// label of each pixel, along with stats collected for each component.
|
|
///
|
|
/// `img`- The boolean image tensor in the format [batches, height, width]
|
|
fn connected_components_with_stats(
|
|
img: BoolTensor<Self>,
|
|
connectivity: Connectivity,
|
|
opts: ConnectedStatsOptions,
|
|
) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
|
|
cpu::connected_components_with_stats(img, connectivity, opts)
|
|
}
|
|
|
|
/// Erodes an input tensor with the specified kernel.
|
|
fn bool_erode(
|
|
input: BoolTensor<Self>,
|
|
kernel: BoolTensor<Self>,
|
|
opts: MorphOptions<Self, Bool>,
|
|
) -> BoolTensor<Self> {
|
|
let input = Tensor::<Self, 3, Bool>::from_primitive(input);
|
|
morph(input, kernel, MorphOp::Erode, opts).into_primitive()
|
|
}
|
|
|
|
/// Dilates an input tensor with the specified kernel.
|
|
fn bool_dilate(
|
|
input: BoolTensor<Self>,
|
|
kernel: BoolTensor<Self>,
|
|
opts: MorphOptions<Self, Bool>,
|
|
) -> BoolTensor<Self> {
|
|
let input = Tensor::<Self, 3, Bool>::from_primitive(input);
|
|
morph(input, kernel, MorphOp::Dilate, opts).into_primitive()
|
|
}
|
|
}
|
|
|
|
/// Vision ops on int tensors
|
|
pub trait IntVisionOps: Backend {
|
|
/// Erodes an input tensor with the specified kernel.
|
|
fn int_erode(
|
|
input: IntTensor<Self>,
|
|
kernel: BoolTensor<Self>,
|
|
opts: MorphOptions<Self, Int>,
|
|
) -> IntTensor<Self> {
|
|
let input = Tensor::<Self, 3, Int>::from_primitive(input);
|
|
morph(input, kernel, MorphOp::Erode, opts).into_primitive()
|
|
}
|
|
|
|
/// Dilates an input tensor with the specified kernel.
|
|
fn int_dilate(
|
|
input: IntTensor<Self>,
|
|
kernel: BoolTensor<Self>,
|
|
opts: MorphOptions<Self, Int>,
|
|
) -> IntTensor<Self> {
|
|
let input = Tensor::<Self, 3, Int>::from_primitive(input);
|
|
morph(input, kernel, MorphOp::Dilate, opts).into_primitive()
|
|
}
|
|
}
|
|
|
|
/// Vision ops on float tensors
|
|
pub trait FloatVisionOps: Backend {
|
|
/// Erodes an input tensor with the specified kernel.
|
|
fn float_erode(
|
|
input: FloatTensor<Self>,
|
|
kernel: BoolTensor<Self>,
|
|
opts: MorphOptions<Self, Float>,
|
|
) -> FloatTensor<Self> {
|
|
let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::Float(input));
|
|
|
|
morph(input, kernel, MorphOp::Erode, opts)
|
|
.into_primitive()
|
|
.tensor()
|
|
}
|
|
|
|
/// Dilates an input tensor with the specified kernel.
|
|
fn float_dilate(
|
|
input: FloatTensor<Self>,
|
|
kernel: BoolTensor<Self>,
|
|
opts: MorphOptions<Self, Float>,
|
|
) -> FloatTensor<Self> {
|
|
let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::Float(input));
|
|
morph(input, kernel, MorphOp::Dilate, opts)
|
|
.into_primitive()
|
|
.tensor()
|
|
}
|
|
|
|
/// Perform Non-Maximum Suppression on bounding boxes.
|
|
///
|
|
/// Returns indices of kept boxes after suppressing overlapping detections.
|
|
/// Boxes are processed in descending score order; a box suppresses all
|
|
/// lower-scoring boxes with IoU > threshold.
|
|
///
|
|
/// # Arguments
|
|
/// * `boxes` - Bounding boxes as \[N, 4\] tensor in (x1, y1, x2, y2) format
|
|
/// * `scores` - Confidence scores as \[N\] tensor
|
|
/// * `options` - NMS options (IoU threshold, score threshold, max boxes)
|
|
///
|
|
/// # Returns
|
|
/// Indices of kept boxes as \[M\] tensor where M <= N
|
|
fn nms(
|
|
boxes: FloatTensor<Self>,
|
|
scores: FloatTensor<Self>,
|
|
options: NmsOptions,
|
|
) -> IntTensor<Self> {
|
|
let boxes = Tensor::<Self, 2>::from_primitive(TensorPrimitive::Float(boxes));
|
|
let scores = Tensor::<Self, 1>::from_primitive(TensorPrimitive::Float(scores));
|
|
cpu::nms::<Self>(boxes, scores, options).into_primitive()
|
|
}
|
|
}
|
|
|
|
/// Vision ops on quantized float tensors
|
|
pub trait QVisionOps: Backend {
|
|
/// Erodes an input tensor with the specified kernel.
|
|
fn q_erode(
|
|
input: QuantizedTensor<Self>,
|
|
kernel: BoolTensor<Self>,
|
|
opts: MorphOptions<Self, Float>,
|
|
) -> QuantizedTensor<Self> {
|
|
let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::QFloat(input));
|
|
match morph(input, kernel, MorphOp::Erode, opts).into_primitive() {
|
|
TensorPrimitive::QFloat(tensor) => tensor,
|
|
_ => unreachable!(),
|
|
}
|
|
}
|
|
|
|
/// Dilates an input tensor with the specified kernel.
|
|
fn q_dilate(
|
|
input: QuantizedTensor<Self>,
|
|
kernel: BoolTensor<Self>,
|
|
opts: MorphOptions<Self, Float>,
|
|
) -> QuantizedTensor<Self> {
|
|
let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::QFloat(input));
|
|
match morph(input, kernel, MorphOp::Dilate, opts).into_primitive() {
|
|
TensorPrimitive::QFloat(tensor) => tensor,
|
|
_ => unreachable!(),
|
|
}
|
|
}
|
|
}
|