-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvideo-detector.py
More file actions
110 lines (83 loc) · 2.86 KB
/
video-detector.py
File metadata and controls
110 lines (83 loc) · 2.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# /// script
# dependencies = [
# "datachain[video,audio]",
# "opencv-python",
# "ultralytics",
# ]
# ///
import os
from typing import Iterator
import datachain as dc
from datachain import VideoFile, ImageFile
from datachain.model.ultralytics import YoloBBoxes, YoloSegments, YoloPoses
from pydantic import BaseModel
from ultralytics import YOLO, settings
# NOTE, copy data to local machine before running localy:
# $ mkdir data/video
# $ datachain cp -r gs://datachain-starss23/video_dev/dev-train-sony/ data/video/
local = not dc.is_studio()
bucket = "data/video" if local else "gs://datachain-starss23/"
input_path = f"{bucket}/"
output_path = f"{bucket}/temp/video-detector-frames"
detection_dataset = "frames-detector"
target_fps = 1
model_bbox = "yolo11n.pt"
model_segm = "yolo11n-seg.pt"
model_pose = "yolo11n-pose.pt"
# Upload models to avoid YOLO-downloader issues
if not local:
weights_dir = f"{os.getcwd()}/{settings['weights_dir']}"
dc.read_storage([
f"{bucket}/models/{model_bbox}",
f"{bucket}/models/{model_segm}",
f"{bucket}/models/{model_pose}",
]
).to_storage(weights_dir, placement="filename")
model_bbox = f"{weights_dir}/{model_bbox}"
model_segm = f"{weights_dir}/{model_segm}"
model_pose = f"{weights_dir}/{model_pose}"
class YoloDataModel(BaseModel):
bbox: YoloBBoxes
segm: YoloSegments
poses: YoloPoses
class VideoFrameImage(ImageFile):
num: int
orig: VideoFile
def extract_frames(file: VideoFile) -> Iterator[VideoFrameImage]:
info = file.get_info()
# one frame per sec
step = int(info.fps / target_fps) if target_fps else 1
frames = file.get_frames(step=step)
for num, frame in enumerate(frames):
image = frame.save(output_path, format="jpg")
yield VideoFrameImage(**image.model_dump(), num=num, orig=file)
def process_all(yolo: YOLO, yolo_segm: YOLO, yolo_pose: YOLO, frame: ImageFile) -> YoloDataModel:
img = frame.read()
return YoloDataModel(
bbox=YoloBBoxes.from_results(yolo(img, verbose=False)),
segm=YoloSegments.from_results(yolo_segm(img, verbose=False)),
poses=YoloPoses.from_results(yolo_pose(img, verbose=False))
)
def process_bbox(yolo: YOLO, frame: ImageFile) -> YoloBBoxes:
return YoloBBoxes.from_results(yolo(frame.read(), verbose=False))
chain = (
dc
.read_storage(input_path, type="video")
.filter(dc.C("file.path").glob("*.mp4"))
.sample(2)
.settings(parallel=5)
.gen(frame=extract_frames)
# Initialize models: once per processing thread
.setup(
yolo=lambda: YOLO(model_bbox),
# yolo_segm=lambda: YOLO(model_segm),
# yolo_pose=lambda: YOLO(model_pose)
)
# Apply yolo detector to frames
.map(bbox=process_bbox)
# .map(yolo=process_all)
.order_by("frame.path", "frame.num")
.save(detection_dataset)
)
if local:
chain.show()