use burn_tensor::{ BasicOps, Bool, Float, Int, Tensor, TensorKind, TensorPrimitive, backend::Backend, ops::BoolTensor, }; use crate::{ BoolVisionOps, ConnectedStats, ConnectedStatsOptions, Connectivity, MorphOptions, NmsOptions, VisionBackend, }; /// Connected components tensor extensions pub trait ConnectedComponents { /// Computes the connected components labeled image of boolean image with 4 or 8 way /// connectivity - returns a tensor of the component label of each pixel. /// /// `img`- The boolean image tensor in the format [batches, height, width] fn connected_components(self, connectivity: Connectivity) -> Tensor; /// Computes the connected components labeled image of boolean image with 4 or 8 way /// connectivity and collects statistics on each component - returns a tensor of the component /// label of each pixel, along with stats collected for each component. /// /// `img`- The boolean image tensor in the format [batches, height, width] fn connected_components_with_stats( self, connectivity: Connectivity, options: ConnectedStatsOptions, ) -> (Tensor, ConnectedStats); } /// Morphology tensor operations pub trait Morphology> { /// Erodes this tensor using the specified kernel. /// Assumes NHWC layout. fn erode(self, kernel: Tensor, opts: MorphOptions) -> Self; /// Dilates this tensor using the specified kernel. /// Assumes NHWC layout. fn dilate(self, kernel: Tensor, opts: MorphOptions) -> Self; } /// Morphology tensor operations pub trait MorphologyKind: BasicOps { /// Erodes this tensor using the specified kernel fn erode( tensor: Self::Primitive, kernel: BoolTensor, opts: MorphOptions, ) -> Self::Primitive; /// Dilates this tensor using the specified kernel fn dilate( tensor: Self::Primitive, kernel: BoolTensor, opts: MorphOptions, ) -> Self::Primitive; } /// Non-maximum suppression tensor operations pub trait Nms { /// Perform Non-Maximum Suppression on this tensor of bounding boxes. /// /// Returns indices of kept boxes after suppressing overlapping detections. /// Boxes are processed in descending score order; a box suppresses all /// lower-scoring boxes with IoU > threshold. /// /// # Arguments /// * `self` - Bounding boxes as \[N, 4\] tensor in (x1, y1, x2, y2) format /// * `scores` - Confidence scores as \[N\] tensor /// * `options` - NMS options (IoU threshold, score threshold, max boxes) /// /// # Returns /// Indices of kept boxes as \[M\] tensor where M <= N fn nms(self, scores: Tensor, opts: NmsOptions) -> Tensor; } impl ConnectedComponents for Tensor { fn connected_components(self, connectivity: Connectivity) -> Tensor { Tensor::from_primitive(B::connected_components(self.into_primitive(), connectivity)) } fn connected_components_with_stats( self, connectivity: Connectivity, options: ConnectedStatsOptions, ) -> (Tensor, ConnectedStats) { let (labels, stats) = B::connected_components_with_stats(self.into_primitive(), connectivity, options); (Tensor::from_primitive(labels), stats.into()) } } impl> Morphology for Tensor { fn erode(self, kernel: Tensor, opts: MorphOptions) -> Self { Tensor::new(K::erode( self.into_primitive(), kernel.into_primitive(), opts, )) } fn dilate(self, kernel: Tensor, opts: MorphOptions) -> Self { Tensor::new(K::dilate( self.into_primitive(), kernel.into_primitive(), opts, )) } } impl MorphologyKind for Float { fn erode( tensor: Self::Primitive, kernel: BoolTensor, opts: MorphOptions, ) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_erode(tensor, kernel, opts)) } TensorPrimitive::QFloat(tensor) => { TensorPrimitive::QFloat(B::q_erode(tensor, kernel, opts)) } } } fn dilate( tensor: Self::Primitive, kernel: BoolTensor, opts: MorphOptions, ) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_dilate(tensor, kernel, opts)) } TensorPrimitive::QFloat(tensor) => { TensorPrimitive::QFloat(B::q_dilate(tensor, kernel, opts)) } } } } impl MorphologyKind for Int { fn erode( tensor: Self::Primitive, kernel: BoolTensor, opts: MorphOptions, ) -> Self::Primitive { B::int_erode(tensor, kernel, opts) } fn dilate( tensor: Self::Primitive, kernel: BoolTensor, opts: MorphOptions, ) -> Self::Primitive { B::int_dilate(tensor, kernel, opts) } } impl MorphologyKind for Bool { fn erode( tensor: Self::Primitive, kernel: BoolTensor, opts: MorphOptions, ) -> Self::Primitive { B::bool_erode(tensor, kernel, opts) } fn dilate( tensor: Self::Primitive, kernel: BoolTensor, opts: MorphOptions, ) -> Self::Primitive { B::bool_dilate(tensor, kernel, opts) } } impl Nms for Tensor { fn nms(self, scores: Tensor, options: NmsOptions) -> Tensor { match (self.into_primitive(), scores.into_primitive()) { (TensorPrimitive::Float(boxes), TensorPrimitive::Float(scores)) => { Tensor::::from_primitive(B::nms(boxes, scores, options)) } _ => todo!("Quantized inputs are not yet supported"), } } }