1+ from __future__ import annotations
2+
3+ from dataclasses import dataclass
14from typing import List , Optional , Union
25
36import cv2
47import numpy as np
58
69from supervision .draw .color import Color , ColorPalette
10+ from supervision .geometry .core import Position
711
812
13+ @dataclass
914class Detections :
10- def __init__ (
11- self ,
12- xyxy : np .ndarray ,
13- confidence : np .ndarray ,
14- class_id : np .ndarray ,
15- tracker_id : Optional [np .ndarray ] = None ,
16- ):
17- """
18- Data class containing information about the detections in a video frame.
15+ """
16+ Data class containing information about the detections in a video frame.
1917
20- Attributes:
21- xyxy (np.ndarray): An array of shape (n, 4) containing the bounding boxes coordinates in format [x1, y1, x2, y2]
22- confidence (np.ndarray): An array of shape (n,) containing the confidence scores of the detections.
23- class_id (np.ndarray): An array of shape (n,) containing the class ids of the detections.
24- tracker_id (Optional[np.ndarray]): An array of shape (n,) containing the tracker ids of the detections.
25- """
26- self .xyxy : np .ndarray = xyxy
27- self .confidence : np .ndarray = confidence
28- self .class_id : np .ndarray = class_id
29- self .tracker_id : Optional [np .ndarray ] = tracker_id
18+ Attributes:
19+ xyxy (np.ndarray): An array of shape `(n, 4)` containing the bounding boxes coordinates in format `[x1, y1, x2, y2]`
20+ confidence (np.ndarray): An array of shape `(n,)` containing the confidence scores of the detections.
21+ class_id (np.ndarray): An array of shape `(n,)` containing the class ids of the detections.
22+ tracker_id (Optional[np.ndarray]): An array of shape `(n,)` containing the tracker ids of the detections.
23+ """
3024
25+ xyxy : np .ndarray
26+ confidence : np .ndarray
27+ class_id : np .ndarray
28+ tracker_id : Optional [np .ndarray ] = None
29+
30+ def __post_init__ (self ):
3131 n = len (self .xyxy )
3232 validators = [
3333 (isinstance (self .xyxy , np .ndarray ) and self .xyxy .shape == (n , 4 )),
@@ -55,7 +55,7 @@ def __len__(self):
5555
5656 def __iter__ (self ):
5757 """
58- Iterates over the Detections object and yield a tuple of (xyxy, confidence, class_id, tracker_id) for each detection.
58+ Iterates over the Detections object and yield a tuple of ` (xyxy, confidence, class_id, tracker_id)` for each detection.
5959 """
6060 for i in range (len (self .xyxy )):
6161 yield (
@@ -66,37 +66,68 @@ def __iter__(self):
6666 )
6767
6868 @classmethod
69- def from_yolov5 (cls , yolov5_output : np . ndarray ):
69+ def from_yolov5 (cls , yolov5_detections ):
7070 """
71- Creates a Detections instance from a YOLOv5 output tensor
71+ Creates a Detections instance from a YOLOv5 output Detections
7272
7373 Attributes:
74- yolov5_output (np.ndarray ): The output tensor from YOLOv5
74+ yolov5_detections (yolov5.models.common.Detections ): The output Detections instance from YOLOv5
7575
7676 Returns:
7777
7878 Example:
7979 ```python
80- >>> from supervision.tools.detections import Detections
80+ >>> import torch
81+ >>> from supervision import Detections
8182
82- >>> detections = Detections.from_yolov5(yolov5_output)
83+ >>> model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
84+ >>> results = model(frame)
85+ >>> detections = Detections.from_yolov5(results)
8386 ```
8487 """
85- xyxy = yolov5_output [:, :4 ]
86- confidence = yolov5_output [:, 4 ]
87- class_id = yolov5_output [:, 5 ].astype (int )
88- return cls (xyxy , confidence , class_id )
88+ yolov5_detections_predictions = yolov5_detections .pred [0 ].cpu ().cpu ().numpy ()
89+ return cls (
90+ xyxy = yolov5_detections_predictions [:, :4 ],
91+ confidence = yolov5_detections_predictions [:, 4 ],
92+ class_id = yolov5_detections_predictions [:, 5 ].astype (int ),
93+ )
8994
90- def filter (self , mask : np .ndarray , inplace : bool = False ) -> Optional [np .ndarray ]:
95+ @classmethod
96+ def from_yolov8 (cls , yolov8_results ):
97+ """
98+ Creates a Detections instance from a YOLOv8 output Results
99+
100+ Attributes:
101+ yolov8_results (ultralytics.yolo.engine.results.Results): The output Results instance from YOLOv8
102+
103+ Returns:
104+
105+ Example:
106+ ```python
107+ >>> from ultralytics import YOLO
108+ >>> from supervision import Detections
109+
110+ >>> model = YOLO('yolov8s.pt')
111+ >>> results = model(frame)
112+ >>> detections = Detections.from_yolov8(results)
113+ ```
114+ """
115+ return cls (
116+ xyxy = yolov8_results .boxes .xyxy .cpu ().numpy (),
117+ confidence = yolov8_results .boxes .conf .cpu ().numpy (),
118+ class_id = yolov8_results .boxes .cls .cpu ().numpy ().astype (int ),
119+ )
120+
121+ def filter (self , mask : np .ndarray , inplace : bool = False ) -> Optional [Detections ]:
91122 """
92123 Filter the detections by applying a mask.
93124
94125 Attributes:
95- mask (np.ndarray): A mask of shape (n,) containing a boolean value for each detection indicating if it should be included in the filtered detections
126+ mask (np.ndarray): A mask of shape ` (n,)` containing a boolean value for each detection indicating if it should be included in the filtered detections
96127 inplace (bool): If True, the original data will be modified and self will be returned.
97128
98129 Returns:
99- Optional[np.ndarray]: A new instance of Detections with the filtered detections, if inplace is set to False. None otherwise.
130+ Optional[np.ndarray]: A new instance of Detections with the filtered detections, if inplace is set to ` False`. ` None` otherwise.
100131 """
101132 if inplace :
102133 self .xyxy = self .xyxy [mask ]
@@ -116,11 +147,49 @@ def filter(self, mask: np.ndarray, inplace: bool = False) -> Optional[np.ndarray
116147 else None ,
117148 )
118149
150+ def get_anchor_coordinates (self , anchor : Position ) -> np .ndarray :
151+ """
152+ Returns the bounding box coordinates for a specific anchor.
153+
154+ Properties:
155+ anchor (Position): Position of bounding box anchor for which to return the coordinates.
156+
157+ Returns:
158+ np.ndarray: An array of shape `(n, 2)` containing the bounding box anchor coordinates in format `[x, y]`.
159+ """
160+ if anchor == Position .CENTER :
161+ return np .array (
162+ [
163+ (self .xyxy [:, 0 ] + self .xyxy [:, 2 ]) / 2 ,
164+ (self .xyxy [:, 1 ] + self .xyxy [:, 3 ]) / 2 ,
165+ ]
166+ ).transpose ()
167+ elif anchor == Position .BOTTOM_CENTER :
168+ return np .array (
169+ [(self .xyxy [:, 0 ] + self .xyxy [:, 2 ]) / 2 , self .xyxy [:, 3 ]]
170+ ).transpose ()
171+
172+ raise ValueError (f"{ anchor } is not supported." )
173+
174+ def __getitem__ (self , index : np .ndarray ) -> Detections :
175+ if isinstance (index , np .ndarray ) and index .dtype == np .bool :
176+ return Detections (
177+ xyxy = self .xyxy [index ],
178+ confidence = self .confidence [index ],
179+ class_id = self .class_id [index ],
180+ tracker_id = self .tracker_id [index ]
181+ if self .tracker_id is not None
182+ else None ,
183+ )
184+ raise TypeError (
185+ f"Detections.__getitem__ not supported for index of type { type (index )} ."
186+ )
187+
119188
120189class BoxAnnotator :
121190 def __init__ (
122191 self ,
123- color : Union [Color , ColorPalette ],
192+ color : Union [Color , ColorPalette ] = ColorPalette . default () ,
124193 thickness : int = 2 ,
125194 text_color : Color = Color .black (),
126195 text_scale : float = 0.5 ,
@@ -148,35 +217,46 @@ def __init__(
148217
149218 def annotate (
150219 self ,
151- frame : np .ndarray ,
220+ scene : np .ndarray ,
152221 detections : Detections ,
153222 labels : Optional [List [str ]] = None ,
223+ skip_label : bool = False ,
154224 ) -> np .ndarray :
155225 """
156226 Draws bounding boxes on the frame using the detections provided.
157227
158- Attributes :
159- frame (np.ndarray): The image on which the bounding boxes will be drawn
228+ Parameters :
229+ scene (np.ndarray): The image on which the bounding boxes will be drawn
160230 detections (Detections): The detections for which the bounding boxes will be drawn
161231 labels (Optional[List[str]]): An optional list of labels corresponding to each detection. If labels is provided, the confidence score of the detection will be replaced with the label.
162-
232+ skip_label (bool): Is set to True, skips bounding box label annotation.
163233 Returns:
164234 np.ndarray: The image with the bounding boxes drawn on it
165235 """
166236 font = cv2 .FONT_HERSHEY_SIMPLEX
167237 for i , (xyxy , confidence , class_id , tracker_id ) in enumerate (detections ):
238+ x1 , y1 , x2 , y2 = xyxy .astype (int )
168239 color = (
169240 self .color .by_idx (class_id )
170241 if isinstance (self .color , ColorPalette )
171242 else self .color
172243 )
244+ cv2 .rectangle (
245+ img = scene ,
246+ pt1 = (x1 , y1 ),
247+ pt2 = (x2 , y2 ),
248+ color = color .as_bgr (),
249+ thickness = self .thickness ,
250+ )
251+ if skip_label :
252+ continue
253+
173254 text = (
174255 f"{ confidence :0.2f} "
175256 if (labels is None or len (detections ) != len (labels ))
176257 else labels [i ]
177258 )
178259
179- x1 , y1 , x2 , y2 = xyxy .astype (int )
180260 text_width , text_height = cv2 .getTextSize (
181261 text = text ,
182262 fontFace = font ,
@@ -194,21 +274,14 @@ def annotate(
194274 text_background_y2 = y1
195275
196276 cv2 .rectangle (
197- img = frame ,
198- pt1 = (x1 , y1 ),
199- pt2 = (x2 , y2 ),
200- color = color .as_bgr (),
201- thickness = self .thickness ,
202- )
203- cv2 .rectangle (
204- img = frame ,
277+ img = scene ,
205278 pt1 = (text_background_x1 , text_background_y1 ),
206279 pt2 = (text_background_x2 , text_background_y2 ),
207280 color = color .as_bgr (),
208281 thickness = cv2 .FILLED ,
209282 )
210283 cv2 .putText (
211- img = frame ,
284+ img = scene ,
212285 text = text ,
213286 org = (text_x , text_y ),
214287 fontFace = font ,
@@ -217,4 +290,4 @@ def annotate(
217290 thickness = self .text_thickness ,
218291 lineType = cv2 .LINE_AA ,
219292 )
220- return frame
293+ return scene
0 commit comments