From 3b2b60cdd6ab073136185b3def6a9da4522b5eb4 Mon Sep 17 00:00:00 2001 From: Yann HALLOUARD Date: Sun, 3 Nov 2024 01:59:16 +0100 Subject: [PATCH] fix bugs --- supervision/detection/core.py | 13 ++++------ supervision/detection/overlap_filter.py | 34 ++++++++++--------------- test/detection/test_overlap_filter.py | 28 +++----------------- 3 files changed, 21 insertions(+), 54 deletions(-) diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 4731e6f85..ab3ab348d 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -1323,19 +1323,17 @@ def with_nms( return self[indices] def with_soft_nms( - self, threshold: float = 0.5, class_agnostic: bool = False, sigma: float = 0.5 + self, sigma: float = 0.5, class_agnostic: bool = False ) -> Detections: """ Perform soft non-maximum suppression on the current set of object detections. Args: - threshold (float): The intersection-over-union threshold - to use for non-maximum suppression. Defaults to 0.5. + sigma (float): The sigma value to use for the soft non-maximum suppression + algorithm. Defaults to 0.5. class_agnostic (bool): Whether to perform class-agnostic non-maximum suppression. If True, the class_id of each detection will be ignored. Defaults to False. - sigma (float): The sigma value to use for the soft non-maximum suppression - algorithm. Defaults to 0.5. Returns: Detections: A new Detections object containing the subset of detections @@ -1370,13 +1368,12 @@ def with_soft_nms( soft_confidences = mask_soft_non_max_suppression( predictions=predictions, masks=self.mask, - iou_threshold=threshold, sigma=sigma, ) self.confidence = soft_confidences else: - indices, soft_confidences = box_soft_non_max_suppression( - predictions=predictions, iou_threshold=threshold, sigma=sigma + soft_confidences = box_soft_non_max_suppression( + predictions=predictions, sigma=sigma ) self.confidence = soft_confidences diff --git a/supervision/detection/overlap_filter.py b/supervision/detection/overlap_filter.py index 9739a709b..a7ef40c19 100644 --- a/supervision/detection/overlap_filter.py +++ b/supervision/detection/overlap_filter.py @@ -39,7 +39,6 @@ def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray: def __prepare_data_for_mask_nms( - iou_threshold: float, mask_dimension: int, masks: np.ndarray, predictions: np.ndarray, @@ -48,8 +47,6 @@ def __prepare_data_for_mask_nms( Get IOUs from mask. Prepare the data for non-max suppression. Args: - iou_threshold (float): The intersection-over-union threshold - to use for non-maximum suppression. mask_dimension (int): The dimension to which the masks should be resized before computing IOU values. masks (np.ndarray): A 3D array of binary masks corresponding to the predictions. @@ -68,10 +65,6 @@ def __prepare_data_for_mask_nms( AssertionError: If `iou_threshold` is not within the closed range from `0` to `1`. """ - assert 0 <= iou_threshold <= 1, ( - "Value of `iou_threshold` must be in the closed range from 0 to 1, " - f"{iou_threshold} given." - ) rows, columns = predictions.shape if columns == 5: @@ -117,8 +110,12 @@ def mask_non_max_suppression( AssertionError: If `iou_threshold` is not within the closed range from `0` to `1`. """ + assert 0 <= iou_threshold <= 1, ( + "Value of `iou_threshold` must be in the closed range from 0 to 1, " + f"{iou_threshold} given." + ) _, categories, ious, rows, sort_index = __prepare_data_for_mask_nms( - iou_threshold, mask_dimension, masks, predictions + mask_dimension, masks, predictions ) keep = np.ones(rows, dtype=bool) @@ -133,7 +130,6 @@ def mask_non_max_suppression( def mask_soft_non_max_suppression( predictions: np.ndarray, masks: np.ndarray, - iou_threshold: float = 0.5, mask_dimension: int = 640, sigma: float = 0.5, ) -> np.ndarray: @@ -160,7 +156,7 @@ def mask_soft_non_max_suppression( 0 < sigma < 1 ), f"Value of `sigma` must be greater than 0 and less than 1, {sigma} given." predictions, categories, ious, rows, sort_index = __prepare_data_for_mask_nms( - iou_threshold, mask_dimension, masks, predictions + mask_dimension, masks, predictions ) not_this_row = np.ones(rows) @@ -175,14 +171,12 @@ def mask_soft_non_max_suppression( def __prepare_data_for_box_nsm( - iou_threshold: float, predictions: np.ndarray + predictions: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, np.ndarray]: """ Prepare the data for non-max suppression. Args: - iou_threshold (float): The intersection-over-union threshold - to use for non-maximum suppression. predictions (np.ndarray): An array of object detection predictions in the format of `(x_min, y_min, x_max, y_max, score)` or `(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or `(N, 6)`, @@ -198,10 +192,6 @@ def __prepare_data_for_box_nsm( """ - assert 0 <= iou_threshold <= 1, ( - "Value of `iou_threshold` must be in the closed range from 0 to 1, " - f"{iou_threshold} given." - ) rows, columns = predictions.shape # add column #5 - category filled with zeros for agnostic nms @@ -240,9 +230,11 @@ def box_non_max_suppression( AssertionError: If `iou_threshold` is not within the closed range from `0` to `1`. """ - _, categories, ious, rows, sort_index = __prepare_data_for_box_nsm( - iou_threshold, predictions + assert 0 <= iou_threshold <= 1, ( + "Value of `iou_threshold` must be in the closed range from 0 to 1, " + f"{iou_threshold} given." ) + _, categories, ious, rows, sort_index = __prepare_data_for_box_nsm(predictions) keep = np.ones(rows, dtype=bool) for index, (iou, category) in enumerate(zip(ious, categories)): @@ -258,7 +250,7 @@ def box_non_max_suppression( def box_soft_non_max_suppression( - predictions: np.ndarray, iou_threshold: float = 0.5, sigma: float = 0.5 + predictions: np.ndarray, sigma: float = 0.5 ) -> np.ndarray: """ Perform Soft Non-Maximum Suppression (Soft-NMS) on object detection predictions. @@ -283,7 +275,7 @@ def box_soft_non_max_suppression( 0 < sigma < 1 ), f"Value of `sigma` must be greater than 0 and less than 1, {sigma} given." predictions, categories, ious, rows, sort_index = __prepare_data_for_box_nsm( - iou_threshold, predictions + predictions ) not_this_row = np.ones(rows) diff --git a/test/detection/test_overlap_filter.py b/test/detection/test_overlap_filter.py index 1a5fb05b4..6b0df77a4 100644 --- a/test/detection/test_overlap_filter.py +++ b/test/detection/test_overlap_filter.py @@ -246,25 +246,22 @@ def test_box_non_max_suppression( @pytest.mark.parametrize( - "predictions, iou_threshold, sigma, expected_result, exception", + "predictions, sigma, expected_result, exception", [ ( np.empty(shape=(0, 5)), - 0.5, 0.1, np.array([]), DoesNotRaise(), ), # single box with no category ( np.array([[10.0, 10.0, 40.0, 40.0, 0.8]]), - 0.5, 0.8, np.array([0.8]), DoesNotRaise(), ), # single box with no category ( np.array([[10.0, 10.0, 40.0, 40.0, 0.8, 0]]), - 0.5, 0.9, np.array([0.8]), DoesNotRaise(), @@ -276,7 +273,6 @@ def test_box_non_max_suppression( [15.0, 15.0, 40.0, 40.0, 0.9], ] ), - 0.5, 0.2, np.array([0.07176137, 0.9]), DoesNotRaise(), @@ -288,7 +284,6 @@ def test_box_non_max_suppression( [15.0, 15.0, 40.0, 40.0, 0.9, 1], ] ), - 0.5, 0.3, np.array([0.8, 0.9]), DoesNotRaise(), @@ -300,7 +295,6 @@ def test_box_non_max_suppression( [15.0, 15.0, 40.0, 40.0, 0.9, 0], ] ), - 0.5, 0.9, np.array([0.46814354, 0.9]), DoesNotRaise(), @@ -313,7 +307,6 @@ def test_box_non_max_suppression( [10.0, 10.0, 40.0, 50.0, 0.85], ] ), - 0.5, 0.7, np.array([0.42648529, 0.9, 0.53109062]), DoesNotRaise(), @@ -327,7 +320,6 @@ def test_box_non_max_suppression( ] ), 0.5, - 0.5, np.array([0.8, 0.9, 0.85]), DoesNotRaise(), ), # three boxes with same category @@ -339,7 +331,6 @@ def test_box_non_max_suppression( [10.0, 10.0, 40.0, 50.0, 0.85, 1], ] ), - 0.5, 0.9, np.array([0.55491779, 0.9, 0.85]), DoesNotRaise(), @@ -348,15 +339,12 @@ def test_box_non_max_suppression( ) def test_box_soft_non_max_suppression( predictions: np.ndarray, - iou_threshold: float, sigma: float, expected_result: Optional[np.ndarray], exception: Exception, ) -> None: with exception: - result = box_soft_non_max_suppression( - predictions=predictions, iou_threshold=iou_threshold, sigma=sigma - ) + result = box_soft_non_max_suppression(predictions=predictions, sigma=sigma) np.testing.assert_almost_equal(result, expected_result, decimal=5) @@ -567,12 +555,11 @@ def test_mask_non_max_suppression( @pytest.mark.parametrize( - "predictions, masks, iou_threshold, sigma, expected_result, exception", + "predictions, masks, sigma, expected_result, exception", [ ( np.empty((0, 6)), np.empty((0, 5, 5)), - 0.5, 0.1, np.array([]), DoesNotRaise(), @@ -590,7 +577,6 @@ def test_mask_non_max_suppression( ] ] ), - 0.5, 0.2, np.array([0.8]), DoesNotRaise(), @@ -608,7 +594,6 @@ def test_mask_non_max_suppression( ] ] ), - 0.5, 0.99, np.array([0.8]), DoesNotRaise(), @@ -633,7 +618,6 @@ def test_mask_non_max_suppression( ], ] ), - 0.5, 0.8, np.array([0.8, 0.9]), DoesNotRaise(), @@ -658,7 +642,6 @@ def test_mask_non_max_suppression( ], ] ), - 0.4, 0.6, np.array([0.3831756, 0.9]), DoesNotRaise(), @@ -683,7 +666,6 @@ def test_mask_non_max_suppression( ], ] ), - 0.5, 0.9, np.array([0.8, 0.9]), DoesNotRaise(), @@ -721,7 +703,6 @@ def test_mask_non_max_suppression( ], ] ), - 0.5, 0.3, np.array([0.02853919, 0.85, 0.9]), DoesNotRaise(), @@ -759,7 +740,6 @@ def test_mask_non_max_suppression( ], ] ), - 0.5, 0.1, np.array([0.8, 0.85, 0.9]), DoesNotRaise(), @@ -769,7 +749,6 @@ def test_mask_non_max_suppression( def test_mask_soft_non_max_suppression( predictions: np.ndarray, masks: np.ndarray, - iou_threshold: float, sigma: float, expected_result: Optional[np.ndarray], exception: Exception, @@ -778,7 +757,6 @@ def test_mask_soft_non_max_suppression( result = mask_soft_non_max_suppression( predictions=predictions, masks=masks, - iou_threshold=iou_threshold, sigma=sigma, ) np.testing.assert_almost_equal(result, expected_result, decimal=6)