Skip to content

Commit dba4d9f

Browse files
authored
Merge pull request #58 from roboflow/feature/update_to_support_sam
feature/update_to_support_sam
2 parents bc12a8e + ddc8a8c commit dba4d9f

File tree

12 files changed

+292
-45
lines changed

12 files changed

+292
-45
lines changed

docs/changelog.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
### 0.5.0 <small>April 10, 2023</small>
2+
3+
- Added [[#58](https://github.com/roboflow/supervision/pull/58)]: `Detections.mask` to enable segmentation support.
4+
- Added [[#58](https://github.com/roboflow/supervision/pull/58)]: `MaskAnnotator` to allow easy `Detections.mask` annotation.
5+
- Added [[#58](https://github.com/roboflow/supervision/pull/58)]: `Detections.from_sam` to enable native Segment Anything Model (SAM) support.
6+
- Changed [[#58](https://github.com/roboflow/supervision/pull/58)]: `Detections.area` behaviour to work not only with boxes but also with masks.
7+
18
### 0.4.0 <small>April 5, 2023</small>
29

310
- Added [[#46](https://github.com/roboflow/supervision/discussions/48)]: `Detections.empty` to allow easy creation of empty `Detections` objects.

docs/detection/annotate.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
## BoxAnnotator
22

3-
:::supervision.detection.annotate.BoxAnnotator
3+
:::supervision.detection.annotate.BoxAnnotator
4+
5+
## MaskAnnotator
6+
7+
:::supervision.detection.annotate.MaskAnnotator
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
## PolygonZone
2+
3+
:::supervision.detection.tools.polygon_zone.PolygonZone
4+
5+
## PolygonZoneAnnotator
6+
7+
:::supervision.detection.tools.polygon_zone.PolygonZoneAnnotator

docs/detection/utils.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,8 @@
88

99
## non_max_suppression
1010

11-
:::supervision.detection.utils.non_max_suppression
11+
:::supervision.detection.utils.non_max_suppression
12+
13+
## mask_to_xyxy
14+
15+
:::supervision.detection.utils.mask_to_xyxy

mkdocs.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ nav:
2929
- Core: detection/core.md
3030
- Annotate: detection/annotate.md
3131
- Utils: detection/utils.md
32+
- Tools:
33+
- Polygon Zone: detection/tools/polygon_zone.md
3234
- Draw:
3335
- Utils: draw/utils.md
3436
- Annotations:

supervision/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
__version__ = "0.4.0"
1+
__version__ = "0.5.0"
22

33
from supervision.annotation.voc import detections_to_voc_xml
4-
from supervision.detection.annotate import BoxAnnotator
4+
from supervision.detection.annotate import BoxAnnotator, MaskAnnotator
55
from supervision.detection.core import Detections
66
from supervision.detection.line_counter import LineZone, LineZoneAnnotator
7-
from supervision.detection.polygon_zone import PolygonZone, PolygonZoneAnnotator
8-
from supervision.detection.utils import generate_2d_mask
7+
from supervision.detection.tools.polygon_zone import PolygonZone, PolygonZoneAnnotator
8+
from supervision.detection.utils import generate_2d_mask, mask_to_xyxy
99
from supervision.draw.color import Color, ColorPalette
1010
from supervision.draw.utils import draw_filled_rectangle, draw_polygon, draw_text
1111
from supervision.geometry.core import Point, Position, Rect

supervision/detection/annotate.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,11 @@ def annotate(
5656
np.ndarray: The image with the bounding boxes drawn on it
5757
"""
5858
font = cv2.FONT_HERSHEY_SIMPLEX
59-
for i, (xyxy, confidence, class_id, tracker_id) in enumerate(detections):
60-
x1, y1, x2, y2 = xyxy.astype(int)
59+
for i in range(len(detections)):
60+
x1, y1, x2, y2 = detections.xyxy[i].astype(int)
61+
class_id = (
62+
detections.class_id[i] if detections.class_id is not None else None
63+
)
6164
idx = class_id if class_id is not None else i
6265
color = (
6366
self.color.by_idx(idx)
@@ -114,3 +117,58 @@ def annotate(
114117
lineType=cv2.LINE_AA,
115118
)
116119
return scene
120+
121+
122+
class MaskAnnotator:
123+
"""
124+
A class for overlaying masks on an image using detections provided.
125+
126+
Attributes:
127+
color (Union[Color, ColorPalette]): The color to fill the mask, can be a single color or a color palette
128+
"""
129+
130+
def __init__(
131+
self,
132+
color: Union[Color, ColorPalette] = ColorPalette.default(),
133+
):
134+
self.color: Union[Color, ColorPalette] = color
135+
136+
def annotate(
137+
self, scene: np.ndarray, detections: Detections, opacity: float = 0.5
138+
) -> np.ndarray:
139+
"""
140+
Overlays the masks on the given image based on the provided detections, with a specified opacity.
141+
142+
Parameters:
143+
scene (np.ndarray): The image on which the masks will be overlaid
144+
detections (Detections): The detections for which the masks will be overlaid
145+
opacity (float): The opacity of the masks, between 0 and 1, default is 0.5
146+
147+
Returns:
148+
np.ndarray: The image with the masks overlaid
149+
"""
150+
for i in range(len(detections.xyxy)):
151+
if detections.mask is None:
152+
continue
153+
154+
class_id = (
155+
detections.class_id[i] if detections.class_id is not None else None
156+
)
157+
idx = class_id if class_id is not None else i
158+
color = (
159+
self.color.by_idx(idx)
160+
if isinstance(self.color, ColorPalette)
161+
else self.color
162+
)
163+
164+
mask = detections.mask[i]
165+
colored_mask = np.zeros_like(scene, dtype=np.uint8)
166+
colored_mask[:] = color.as_bgr()
167+
168+
scene = np.where(
169+
np.expand_dims(mask, axis=-1),
170+
np.uint8(opacity * colored_mask + (1 - opacity) * scene),
171+
scene,
172+
)
173+
174+
return scene

supervision/detection/core.py

Lines changed: 112 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,78 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass
4-
from typing import Iterator, List, Optional, Tuple, Union
4+
from typing import Any, Iterator, List, Optional, Tuple, Union
55

66
import numpy as np
77

8-
from supervision.detection.utils import non_max_suppression
8+
from supervision.detection.utils import non_max_suppression, xywh_to_xyxy
99
from supervision.geometry.core import Position
1010

1111

12+
def _validate_xyxy(xyxy: Any, n: int) -> None:
13+
is_valid = isinstance(xyxy, np.ndarray) and xyxy.shape == (n, 4)
14+
if not is_valid:
15+
raise ValueError("xyxy must be 2d np.ndarray with (n, 4) shape")
16+
17+
18+
def _validate_mask(mask: Any, n: int) -> None:
19+
is_valid = mask is None or (
20+
isinstance(mask, np.ndarray) and len(mask.shape) == 3 and mask.shape[0] == n
21+
)
22+
if not is_valid:
23+
raise ValueError("mask must be 3d np.ndarray with (n, W, H) shape")
24+
25+
26+
def _validate_class_id(class_id: Any, n: int) -> None:
27+
is_valid = class_id is None or (
28+
isinstance(class_id, np.ndarray) and class_id.shape == (n,)
29+
)
30+
if not is_valid:
31+
raise ValueError("class_id must be None or 1d np.ndarray with (n,) shape")
32+
33+
34+
def _validate_confidence(confidence: Any, n: int) -> None:
35+
is_valid = confidence is None or (
36+
isinstance(confidence, np.ndarray) and confidence.shape == (n,)
37+
)
38+
if not is_valid:
39+
raise ValueError("confidence must be None or 1d np.ndarray with (n,) shape")
40+
41+
42+
def _validate_tracker_id(tracker_id: Any, n: int) -> None:
43+
is_valid = tracker_id is None or (
44+
isinstance(tracker_id, np.ndarray) and tracker_id.shape == (n,)
45+
)
46+
if not is_valid:
47+
raise ValueError("tracker_id must be None or 1d np.ndarray with (n,) shape")
48+
49+
1250
@dataclass
1351
class Detections:
1452
"""
1553
Data class containing information about the detections in a video frame.
1654
1755
Attributes:
1856
xyxy (np.ndarray): An array of shape `(n, 4)` containing the bounding boxes coordinates in format `[x1, y1, x2, y2]`
57+
mask: (Optional[np.ndarray]): An array of shape `(n, W, H)` containing the segmentation masks.
1958
class_id (Optional[np.ndarray]): An array of shape `(n,)` containing the class ids of the detections.
2059
confidence (Optional[np.ndarray]): An array of shape `(n,)` containing the confidence scores of the detections.
2160
tracker_id (Optional[np.ndarray]): An array of shape `(n,)` containing the tracker ids of the detections.
2261
"""
2362

2463
xyxy: np.ndarray
64+
mask: np.Optional[np.ndarray] = None
2565
class_id: Optional[np.ndarray] = None
2666
confidence: Optional[np.ndarray] = None
2767
tracker_id: Optional[np.ndarray] = None
2868

2969
def __post_init__(self):
3070
n = len(self.xyxy)
31-
validators = [
32-
(isinstance(self.xyxy, np.ndarray) and self.xyxy.shape == (n, 4)),
33-
self.class_id is None
34-
or (isinstance(self.class_id, np.ndarray) and self.class_id.shape == (n,)),
35-
self.confidence is None
36-
or (
37-
isinstance(self.confidence, np.ndarray)
38-
and self.confidence.shape == (n,)
39-
),
40-
self.tracker_id is None
41-
or (
42-
isinstance(self.tracker_id, np.ndarray)
43-
and self.tracker_id.shape == (n,)
44-
),
45-
]
46-
if not all(validators):
47-
raise ValueError(
48-
"xyxy must be 2d np.ndarray with (n, 4) shape, "
49-
"class_id must be None or 1d np.ndarray with (n,) shape, "
50-
"confidence must be None or 1d np.ndarray with (n,) shape, "
51-
"tracker_id must be None or 1d np.ndarray with (n,) shape"
52-
)
71+
_validate_xyxy(xyxy=self.xyxy, n=n)
72+
_validate_mask(mask=self.mask, n=n)
73+
_validate_class_id(class_id=self.class_id, n=n)
74+
_validate_confidence(confidence=self.confidence, n=n)
75+
_validate_tracker_id(tracker_id=self.tracker_id, n=n)
5376

5477
def __len__(self):
5578
"""
@@ -59,13 +82,22 @@ def __len__(self):
5982

6083
def __iter__(
6184
self,
62-
) -> Iterator[Tuple[np.ndarray, Optional[float], int, Optional[Union[str, int]]]]:
85+
) -> Iterator[
86+
Tuple[
87+
np.ndarray,
88+
Optional[np.ndarray],
89+
Optional[float],
90+
Optional[int],
91+
Optional[int],
92+
]
93+
]:
6394
"""
64-
Iterates over the Detections object and yield a tuple of `(xyxy, confidence, class_id, tracker_id)` for each detection.
95+
Iterates over the Detections object and yield a tuple of `(xyxy, mask, confidence, class_id, tracker_id)` for each detection.
6596
"""
6697
for i in range(len(self.xyxy)):
6798
yield (
6899
self.xyxy[i],
100+
self.mask[i] if self.mask is not None else None,
69101
self.confidence[i] if self.confidence is not None else None,
70102
self.class_id[i] if self.class_id is not None else None,
71103
self.tracker_id[i] if self.tracker_id is not None else None,
@@ -75,6 +107,12 @@ def __eq__(self, other: Detections):
75107
return all(
76108
[
77109
np.array_equal(self.xyxy, other.xyxy),
110+
any(
111+
[
112+
self.mask is None and other.mask is None,
113+
np.array_equal(self.mask, other.mask),
114+
]
115+
),
78116
any(
79117
[
80118
self.class_id is None and other.class_id is None,
@@ -113,7 +151,7 @@ def from_yolov5(cls, yolov5_results) -> Detections:
113151
>>> from supervision import Detections
114152
115153
>>> model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
116-
>>> results = model(frame)
154+
>>> results = model(IMAGE)
117155
>>> detections = Detections.from_yolov5(results)
118156
```
119157
"""
@@ -141,8 +179,8 @@ def from_yolov8(cls, yolov8_results) -> Detections:
141179
>>> from supervision import Detections
142180
143181
>>> model = YOLO('yolov8s.pt')
144-
>>> results = model(frame)[0]
145-
>>> detections = Detections.from_yolov8(results)
182+
>>> yolov8_results = model(IMAGE)[0]
183+
>>> detections = Detections.from_yolov8(yolov8_results)
146184
```
147185
"""
148186
return cls(
@@ -201,6 +239,37 @@ def from_roboflow(cls, roboflow_result: dict, class_list: List[str]) -> Detectio
201239
class_id=np.array(class_id).astype(int),
202240
)
203241

242+
@classmethod
243+
def from_sam(cls, sam_result: List[dict]) -> Detections:
244+
"""
245+
Creates a Detections instance from Segment Anything Model (SAM) by Meta AI.
246+
247+
Args:
248+
sam_result (List[dict]): The output Results instance from SAM
249+
250+
Returns:
251+
Detections: A new Detections object.
252+
253+
Example:
254+
```python
255+
>>> from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
256+
>>> import supervision as sv
257+
258+
>>> sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
259+
>>> mask_generator = SamAutomaticMaskGenerator(sam)
260+
>>> sam_result = mask_generator.generate(IMAGE)
261+
>>> detections = sv.Detections.from_sam(sam_result=sam_result)
262+
```
263+
"""
264+
sorted_generated_masks = sorted(
265+
sam_result, key=lambda x: x["area"], reverse=True
266+
)
267+
268+
xywh = np.array([mask["bbox"] for mask in sorted_generated_masks])
269+
mask = np.array([mask["segmentation"] for mask in sorted_generated_masks])
270+
271+
return Detections(xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask)
272+
204273
@classmethod
205274
def from_coco_annotations(cls, coco_annotation: dict) -> Detections:
206275
xyxy, class_id = [], []
@@ -264,6 +333,20 @@ def __getitem__(self, index: np.ndarray) -> Detections:
264333

265334
@property
266335
def area(self) -> np.ndarray:
336+
"""
337+
Calculate the area of each detection in the set of object detections. If masks field is defined property
338+
returns are of each mask. If only box is given property return area of each box.
339+
340+
Returns:
341+
np.ndarray: An array of floats containing the area of each detection in the format of `(area_1, area_2, ..., area_n)`, where n is the number of detections.
342+
"""
343+
if self.mask is not None:
344+
return np.ndarray([np.sum(mask) for mask in self.mask])
345+
else:
346+
return self.box_area
347+
348+
@property
349+
def box_area(self) -> np.ndarray:
267350
"""
268351
Calculate the area of each bounding box in the set of object detections.
269352

supervision/detection/tools/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)