use burn_vision::{Nms, NmsOptions}; mod common; use common::*; #[test] fn should_suppress_non_maximum() { let boxes = TestTensor::<2>::from([ [0, 0, 100, 100], [0, 1, 100, 100], [0, 101, 200, 200], [0, 100, 200, 200], [0, 170, 300, 300], ]); let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]); let options = NmsOptions { iou_threshold: 0.5, score_threshold: 0.0, max_output_boxes: 0, }; let output = boxes.nms(scores, options); let expected = TestTensorInt::<1>::from([4, 2, 1]); output.into_data().assert_eq(&expected.into_data(), true); } #[test] fn should_apply_score_threshold() { let boxes = TestTensor::<2>::from([ [0, 0, 100, 100], [0, 1, 100, 100], [0, 101, 200, 200], [0, 100, 200, 200], [0, 170, 300, 300], ]); let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]); let options = NmsOptions { iou_threshold: 0.5, score_threshold: 0.3, max_output_boxes: 0, }; let output = boxes.nms(scores, options); let expected = TestTensorInt::<1>::from([4, 2]); output.into_data().assert_eq(&expected.into_data(), true); } #[test] fn should_apply_iou_threshold() { let boxes = TestTensor::<2>::from([ [0, 0, 100, 100], [0, 1, 100, 100], [0, 101, 200, 200], [0, 100, 200, 200], [0, 170, 300, 300], ]); let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]); let options = NmsOptions { iou_threshold: 0.1, score_threshold: 0.0, max_output_boxes: 0, }; let output = boxes.nms(scores, options); let expected = TestTensorInt::<1>::from([4, 1]); output.into_data().assert_eq(&expected.into_data(), true); } #[test] fn should_apply_max_output_boxes() { let boxes = TestTensor::<2>::from([ [0, 0, 100, 100], [0, 1, 100, 100], [0, 101, 200, 200], [0, 100, 200, 200], [0, 170, 300, 300], ]); let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]); let options = NmsOptions { iou_threshold: 0.5, score_threshold: 0.0, max_output_boxes: 1, }; let output = boxes.nms(scores, options); let expected = TestTensorInt::<1>::from([4]); output.into_data().assert_eq(&expected.into_data(), true); }