Skip to content

Commit 0e2e7c7

Browse files
AlexBodnerSkalskiP
andauthored
added track-id filter to demo (#293)
* added track-id filter to demo * Fix demo track-id filter bug, add example variants, and apply code review fixes --------- Co-authored-by: SkalskiP <piotr.skalski92@gmail.com>
1 parent fbe654f commit 0e2e7c7

File tree

1 file changed

+120
-10
lines changed

1 file changed

+120
-10
lines changed

demo/app.py

Lines changed: 120 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
55
# ------------------------------------------------------------------------
66

7-
"""Gradio app for the trackers library — run object tracking on uploaded videos."""
8-
97
from __future__ import annotations
108

119
import os
10+
import sys
1211
import tempfile
1312
from pathlib import Path
1413

@@ -162,6 +161,26 @@ def _format_labels(
162161
0.1,
163162
0.6,
164163
[],
164+
"",
165+
True,
166+
True,
167+
False,
168+
False,
169+
True,
170+
False,
171+
],
172+
[
173+
"https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/bikes-1280x720-1.mp4",
174+
"rfdetr-small",
175+
"bytetrack",
176+
0.2,
177+
30,
178+
0.3,
179+
3,
180+
0.1,
181+
0.6,
182+
["person"],
183+
"",
165184
True,
166185
True,
167186
False,
@@ -180,6 +199,7 @@ def _format_labels(
180199
0.3,
181200
0.6,
182201
[],
202+
"",
183203
True,
184204
True,
185205
False,
@@ -188,21 +208,22 @@ def _format_labels(
188208
True,
189209
],
190210
[
191-
"https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/cars-1280x720-1.mp4",
192-
"rfdetr-small",
193-
"bytetrack",
211+
"https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/apples-1280x720-2.mp4",
212+
"rfdetr-nano",
213+
"sort",
194214
0.2,
195215
30,
196216
0.3,
197217
3,
198218
0.1,
199219
0.6,
200-
["car"],
220+
[],
221+
"",
201222
True,
202223
True,
203-
False,
204224
True,
205225
False,
226+
True,
206227
False,
207228
],
208229
[
@@ -216,6 +237,7 @@ def _format_labels(
216237
0.1,
217238
0.6,
218239
[],
240+
"",
219241
True,
220242
True,
221243
False,
@@ -234,11 +256,50 @@ def _format_labels(
234256
0.1,
235257
0.6,
236258
[],
259+
"",
237260
True,
238261
True,
239262
False,
240263
False,
241264
True,
265+
True,
266+
],
267+
[
268+
"https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/jets-1280x720-2.mp4",
269+
"rfdetr-seg-small",
270+
"bytetrack",
271+
0.2,
272+
30,
273+
0.3,
274+
3,
275+
0.1,
276+
0.6,
277+
[],
278+
"1",
279+
True,
280+
True,
281+
False,
282+
False,
283+
True,
284+
True,
285+
],
286+
[
287+
"https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/suitcases-1280x720-4.mp4",
288+
"rfdetr-small",
289+
"sort",
290+
0.2,
291+
30,
292+
0.3,
293+
3,
294+
0.1,
295+
0.6,
296+
[],
297+
"",
298+
True,
299+
True,
300+
True,
301+
False,
302+
True,
242303
False,
243304
],
244305
[
@@ -252,6 +313,7 @@ def _format_labels(
252313
0.1,
253314
0.6,
254315
[],
316+
"",
255317
True,
256318
True,
257319
True,
@@ -291,6 +353,32 @@ def _resolve_class_filter(
291353
return class_filter if class_filter else None
292354

293355

356+
def _resolve_track_id_filter(track_ids_arg: str | None) -> list[int] | None:
357+
"""Resolve a comma-separated string of track IDs to a list of integers.
358+
359+
Args:
360+
track_ids_arg: Comma-separated string (e.g. `"1,3,5"`). `None` or
361+
empty string means no filter.
362+
363+
Returns:
364+
List of integer track IDs, or `None` when no valid filter remains.
365+
"""
366+
if not track_ids_arg:
367+
return None
368+
369+
track_ids: list[int] = []
370+
for token in track_ids_arg.split(","):
371+
token = token.strip()
372+
try:
373+
track_ids.append(int(token))
374+
except ValueError:
375+
print(
376+
f"Warning: '{token}' is not a valid track ID, skipping.",
377+
file=sys.stderr,
378+
)
379+
return track_ids if track_ids else None
380+
381+
294382
def track(
295383
video_path: str,
296384
model_id: str,
@@ -302,6 +390,7 @@ def track(
302390
minimum_iou_threshold: float,
303391
high_conf_det_threshold: float,
304392
classes: list[str] | None = None,
393+
track_ids: str = "",
305394
show_boxes: bool = True,
306395
show_ids: bool = True,
307396
show_labels: bool = False,
@@ -318,14 +407,17 @@ def track(
318407
if duration > MAX_DURATION_SECONDS:
319408
raise gr.Error(
320409
f"Video is {duration:.1f}s long. "
321-
f"Maximum allowed duration is {MAX_DURATION_SECONDS}s."
410+
f"Maximum allowed duration is {MAX_DURATION_SECONDS}s. "
411+
f"Please use the trim tool in the Input Video player to shorten it."
322412
)
323413

324414
detection_model = LOADED_MODELS[model_id]
325415
class_names = getattr(detection_model, "class_names", [])
326416

327417
class_filter = _resolve_class_filter(classes, class_names)
328418

419+
track_id_filter = _resolve_track_id_filter(track_ids)
420+
329421
tracker: ByteTrackTracker | SORTTracker
330422
if tracker_type == "bytetrack":
331423
tracker = ByteTrackTracker(
@@ -385,6 +477,11 @@ def track(
385477

386478
tracked = tracker.update(detections)
387479

480+
if track_id_filter is not None and len(tracked) > 0:
481+
if tracked.tracker_id is not None:
482+
mask = np.isin(tracked.tracker_id, track_id_filter)
483+
tracked = tracked[mask]
484+
388485
annotated = frame.copy()
389486
if trace_annotator is not None:
390487
annotated = trace_annotator.annotate(annotated, tracked)
@@ -409,9 +506,10 @@ def track(
409506
with gr.Blocks(title="Trackers Playground 🔥") as demo:
410507
gr.Markdown(
411508
"# Trackers Playground 🔥\n\n"
412-
"Upload a video, detect COCO objects with "
509+
"Upload a video, detect objects with "
413510
"[RF-DETR](https://github.com/roboflow-ai/rf-detr) and track them with "
414-
"[Trackers](https://github.com/roboflow/trackers)."
511+
"[Trackers](https://github.com/roboflow/trackers). This demo uses models "
512+
"pretrained on 80 COCO classes, but Trackers works with any detection model."
415513
)
416514

417515
with gr.Row():
@@ -450,6 +548,16 @@ def track(
450548
label="Filter Classes",
451549
info="Only track selected classes. None selected means all.",
452550
)
551+
track_id_filter = gr.Textbox(
552+
value="",
553+
label="Filter IDs",
554+
info=(
555+
"Only display tracks with specific track IDs "
556+
"(comma-separated, e.g. 1,3,5). "
557+
"Leave empty for all."
558+
),
559+
placeholder="e.g. 1,3,5",
560+
)
453561

454562
with gr.Column():
455563
gr.Markdown("### Tracker")
@@ -542,6 +650,7 @@ def track(
542650
min_iou_slider,
543651
high_conf_slider,
544652
class_filter,
653+
track_id_filter,
545654
show_boxes_checkbox,
546655
show_ids_checkbox,
547656
show_labels_checkbox,
@@ -565,6 +674,7 @@ def track(
565674
min_iou_slider,
566675
high_conf_slider,
567676
class_filter,
677+
track_id_filter,
568678
show_boxes_checkbox,
569679
show_ids_checkbox,
570680
show_labels_checkbox,

0 commit comments

Comments
 (0)