11from __future__ import annotations
22
33from dataclasses import dataclass
4- from typing import Iterator , List , Optional , Tuple , Union
4+ from typing import Any , Iterator , List , Optional , Tuple , Union
55
66import numpy as np
77
8- from supervision .detection .utils import non_max_suppression
8+ from supervision .detection .utils import non_max_suppression , xywh_to_xyxy
99from supervision .geometry .core import Position
1010
1111
12+ def _validate_xyxy (xyxy : Any , n : int ) -> None :
13+ is_valid = isinstance (xyxy , np .ndarray ) and xyxy .shape == (n , 4 )
14+ if not is_valid :
15+ raise ValueError ("xyxy must be 2d np.ndarray with (n, 4) shape" )
16+
17+
18+ def _validate_mask (mask : Any , n : int ) -> None :
19+ is_valid = mask is None or (
20+ isinstance (mask , np .ndarray ) and len (mask .shape ) == 3 and mask .shape [0 ] == n
21+ )
22+ if not is_valid :
23+ raise ValueError ("mask must be 3d np.ndarray with (n, W, H) shape" )
24+
25+
26+ def _validate_class_id (class_id : Any , n : int ) -> None :
27+ is_valid = class_id is None or (
28+ isinstance (class_id , np .ndarray ) and class_id .shape == (n ,)
29+ )
30+ if not is_valid :
31+ raise ValueError ("class_id must be None or 1d np.ndarray with (n,) shape" )
32+
33+
34+ def _validate_confidence (confidence : Any , n : int ) -> None :
35+ is_valid = confidence is None or (
36+ isinstance (confidence , np .ndarray ) and confidence .shape == (n ,)
37+ )
38+ if not is_valid :
39+ raise ValueError ("confidence must be None or 1d np.ndarray with (n,) shape" )
40+
41+
42+ def _validate_tracker_id (tracker_id : Any , n : int ) -> None :
43+ is_valid = tracker_id is None or (
44+ isinstance (tracker_id , np .ndarray ) and tracker_id .shape == (n ,)
45+ )
46+ if not is_valid :
47+ raise ValueError ("tracker_id must be None or 1d np.ndarray with (n,) shape" )
48+
49+
1250@dataclass
1351class Detections :
1452 """
1553 Data class containing information about the detections in a video frame.
1654
1755 Attributes:
1856 xyxy (np.ndarray): An array of shape `(n, 4)` containing the bounding boxes coordinates in format `[x1, y1, x2, y2]`
57+ mask: (Optional[np.ndarray]): An array of shape `(n, W, H)` containing the segmentation masks.
1958 class_id (Optional[np.ndarray]): An array of shape `(n,)` containing the class ids of the detections.
2059 confidence (Optional[np.ndarray]): An array of shape `(n,)` containing the confidence scores of the detections.
2160 tracker_id (Optional[np.ndarray]): An array of shape `(n,)` containing the tracker ids of the detections.
2261 """
2362
2463 xyxy : np .ndarray
64+ mask : np .Optional [np .ndarray ] = None
2565 class_id : Optional [np .ndarray ] = None
2666 confidence : Optional [np .ndarray ] = None
2767 tracker_id : Optional [np .ndarray ] = None
2868
2969 def __post_init__ (self ):
3070 n = len (self .xyxy )
31- validators = [
32- (isinstance (self .xyxy , np .ndarray ) and self .xyxy .shape == (n , 4 )),
33- self .class_id is None
34- or (isinstance (self .class_id , np .ndarray ) and self .class_id .shape == (n ,)),
35- self .confidence is None
36- or (
37- isinstance (self .confidence , np .ndarray )
38- and self .confidence .shape == (n ,)
39- ),
40- self .tracker_id is None
41- or (
42- isinstance (self .tracker_id , np .ndarray )
43- and self .tracker_id .shape == (n ,)
44- ),
45- ]
46- if not all (validators ):
47- raise ValueError (
48- "xyxy must be 2d np.ndarray with (n, 4) shape, "
49- "class_id must be None or 1d np.ndarray with (n,) shape, "
50- "confidence must be None or 1d np.ndarray with (n,) shape, "
51- "tracker_id must be None or 1d np.ndarray with (n,) shape"
52- )
71+ _validate_xyxy (xyxy = self .xyxy , n = n )
72+ _validate_mask (mask = self .mask , n = n )
73+ _validate_class_id (class_id = self .class_id , n = n )
74+ _validate_confidence (confidence = self .confidence , n = n )
75+ _validate_tracker_id (tracker_id = self .tracker_id , n = n )
5376
5477 def __len__ (self ):
5578 """
@@ -59,13 +82,22 @@ def __len__(self):
5982
6083 def __iter__ (
6184 self ,
62- ) -> Iterator [Tuple [np .ndarray , Optional [float ], int , Optional [Union [str , int ]]]]:
85+ ) -> Iterator [
86+ Tuple [
87+ np .ndarray ,
88+ Optional [np .ndarray ],
89+ Optional [float ],
90+ Optional [int ],
91+ Optional [int ],
92+ ]
93+ ]:
6394 """
64- Iterates over the Detections object and yield a tuple of `(xyxy, confidence, class_id, tracker_id)` for each detection.
95+ Iterates over the Detections object and yield a tuple of `(xyxy, mask, confidence, class_id, tracker_id)` for each detection.
6596 """
6697 for i in range (len (self .xyxy )):
6798 yield (
6899 self .xyxy [i ],
100+ self .mask [i ] if self .mask is not None else None ,
69101 self .confidence [i ] if self .confidence is not None else None ,
70102 self .class_id [i ] if self .class_id is not None else None ,
71103 self .tracker_id [i ] if self .tracker_id is not None else None ,
@@ -75,6 +107,12 @@ def __eq__(self, other: Detections):
75107 return all (
76108 [
77109 np .array_equal (self .xyxy , other .xyxy ),
110+ any (
111+ [
112+ self .mask is None and other .mask is None ,
113+ np .array_equal (self .mask , other .mask ),
114+ ]
115+ ),
78116 any (
79117 [
80118 self .class_id is None and other .class_id is None ,
@@ -113,7 +151,7 @@ def from_yolov5(cls, yolov5_results) -> Detections:
113151 >>> from supervision import Detections
114152
115153 >>> model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
116- >>> results = model(frame )
154+ >>> results = model(IMAGE )
117155 >>> detections = Detections.from_yolov5(results)
118156 ```
119157 """
@@ -141,8 +179,8 @@ def from_yolov8(cls, yolov8_results) -> Detections:
141179 >>> from supervision import Detections
142180
143181 >>> model = YOLO('yolov8s.pt')
144- >>> results = model(frame )[0]
145- >>> detections = Detections.from_yolov8(results )
182+ >>> yolov8_results = model(IMAGE )[0]
183+ >>> detections = Detections.from_yolov8(yolov8_results )
146184 ```
147185 """
148186 return cls (
@@ -201,6 +239,37 @@ def from_roboflow(cls, roboflow_result: dict, class_list: List[str]) -> Detectio
201239 class_id = np .array (class_id ).astype (int ),
202240 )
203241
242+ @classmethod
243+ def from_sam (cls , sam_result : List [dict ]) -> Detections :
244+ """
245+ Creates a Detections instance from Segment Anything Model (SAM) by Meta AI.
246+
247+ Args:
248+ sam_result (List[dict]): The output Results instance from SAM
249+
250+ Returns:
251+ Detections: A new Detections object.
252+
253+ Example:
254+ ```python
255+ >>> from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
256+ >>> import supervision as sv
257+
258+ >>> sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
259+ >>> mask_generator = SamAutomaticMaskGenerator(sam)
260+ >>> sam_result = mask_generator.generate(IMAGE)
261+ >>> detections = sv.Detections.from_sam(sam_result=sam_result)
262+ ```
263+ """
264+ sorted_generated_masks = sorted (
265+ sam_result , key = lambda x : x ["area" ], reverse = True
266+ )
267+
268+ xywh = np .array ([mask ["bbox" ] for mask in sorted_generated_masks ])
269+ mask = np .array ([mask ["segmentation" ] for mask in sorted_generated_masks ])
270+
271+ return Detections (xyxy = xywh_to_xyxy (boxes_xywh = xywh ), mask = mask )
272+
204273 @classmethod
205274 def from_coco_annotations (cls , coco_annotation : dict ) -> Detections :
206275 xyxy , class_id = [], []
@@ -264,6 +333,20 @@ def __getitem__(self, index: np.ndarray) -> Detections:
264333
265334 @property
266335 def area (self ) -> np .ndarray :
336+ """
337+ Calculate the area of each detection in the set of object detections. If masks field is defined property
338+ returns are of each mask. If only box is given property return area of each box.
339+
340+ Returns:
341+ np.ndarray: An array of floats containing the area of each detection in the format of `(area_1, area_2, ..., area_n)`, where n is the number of detections.
342+ """
343+ if self .mask is not None :
344+ return np .ndarray ([np .sum (mask ) for mask in self .mask ])
345+ else :
346+ return self .box_area
347+
348+ @property
349+ def box_area (self ) -> np .ndarray :
267350 """
268351 Calculate the area of each bounding box in the set of object detections.
269352
0 commit comments