-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrack.py
More file actions
383 lines (306 loc) · 14.2 KB
/
track.py
File metadata and controls
383 lines (306 loc) · 14.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
# -*- coding: utf-8 -*-
import torch
import cv2
import numpy as np
import argparse
import os
import math # 确保 math 被导入
import csv
from pathlib import Path
from tqdm import tqdm
from collections import deque
import time
# 导入你项目里的构建器和模型!
from models_factory.builder import build_model
from datasets_factory.transforms.tracknet_transforms import (
Resize, ConcatChannels
)
# --- 1. “模型配置库” ---
MODEL_CONFIGS = {
'v2': dict(
type='TrackNetV2',
backbone=dict(type='TrackNetV2Backbone', in_channels=9),
neck=dict(type='TrackNetV2Neck'),
head=dict(type='TrackNetV2Head', in_channels=64, out_channels=3)
),
'v4': dict(
type='TrackNetV4',
backbone=dict(type='TrackNetV4Backbone', in_channels=9),
neck=dict(type='TrackNetV4Neck'),
head=dict(type='TrackNetV2Head', in_channels=64, out_channels=3)
),
'v5': dict(
type='TrackNetV5',
backbone=dict(type='TrackNetV2Backbone', in_channels=13),
neck=dict(type='TrackNetV2Neck'),
head=dict(type='R_STRHead', in_channels=64, out_channels=3)
)
}
# --- 2. 辅助函数 (✨ 已修改,与你的 Metric 脚本对齐) ---
def _heatmap_to_coords(heatmap: np.ndarray, threshold: int = 127):
if heatmap.dtype != np.uint8:
heatmap = heatmap.astype(np.uint8)
_, binary_map = cv2.threshold(heatmap, threshold, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
largest_contour = max(contours, key=cv2.contourArea)
M = cv2.moments(largest_contour)
if M["m00"] > 0:
cx = int(M["m10"] / M["m00"])
cy = int(M["m01"] / M["m00"])
# --- ✨ 新增:提取置信度 ---
# 创建一个掩码,只关注最大轮廓内的区域
mask = np.zeros(heatmap.shape, dtype=np.uint8)
cv2.drawContours(mask, [largest_contour], -1, 255, -1)
# 在原始热力图中找到该区域内的最大值
# minMaxLoc 会返回 (minVal, maxVal, minLoc, maxLoc)
_, max_val, _, _ = cv2.minMaxLoc(heatmap, mask=mask)
# 将 0-255 归一化到 0-1 之间作为 conf
conf = round(max_val / 255.0, 4)
return cx, cy, conf
return None
def draw_comet_tail(frame, points_deque, head_radius=8):
"""
基于半径衰减的圆点轨迹可视化 (无连线版)
:param frame: 当前图像帧 (BGR)
:param points_deque: 存储坐标的队列,支持 (x, y) 或 (x, y, vis)
:param head_radius: 最前端圆点的最大半径
"""
# 无需创建全黑 overlay,因为我们直接在原图绘制实心圆
# 如果需要半透明效果,可以保留 overlay 逻辑,这里采用你要求的直接绘制
q_len = len(points_deque)
if q_len == 0:
return frame
for i, pt in enumerate(points_deque):
# 1. 安全检查:跳过空点或 NaN
if pt is None:
continue
# 兼容处理:支持 (x, y) 或 (x, y, vis/conf)
if len(pt) >= 3:
tx, ty, tvis = pt[:3]
if tvis == 0 or tvis is None: continue
else:
tx, ty = pt
if tx is None or ty is None:
continue
# 2. 计算半径衰减 (核心逻辑)
# i=0 (最旧) -> scale 最小; i=q_len-1 (最新) -> scale=1.0
scale = (i + 1) / q_len
current_radius = int(head_radius * scale)
# 确保半径至少为 1
current_radius = max(1, current_radius)
# 3. 绘制实心球
# 使用 LINE_AA 开启抗锯齿,让圆点边缘更丝滑
cv2.circle(
frame,
(int(tx), int(ty)),
current_radius,
(0, 0, 255), # 纯红色
-1, # 实心
lineType=cv2.LINE_AA
)
return frame
# --- 3. “核心加工车间”: ✨ process_video (✨ 已修改) ✨ ---
def process_video(video_path: Path, model, device, args, output_root_dir: Path) -> dict:
"""
处理单个视频文件,并生成所有需要的输出文件。
新逻辑:一次读取 3 帧,推理 3 帧,写入 3 帧,然后跳 3 帧。
✨ 新增: 返回一个包含统计数据的字典。
"""
print(f"\n🏭 Processing video: {video_path.name}")
video_output_dir = output_root_dir / video_path.stem
video_output_dir.mkdir(parents=True, exist_ok=True)
cap = cv2.VideoCapture(str(video_path))
# --- ✨ 新增:获取视频原始分辨率 ---
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
resolution_str = f"{width}x{height}"
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
input_size = (288, 512)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
trajectory_video_path = video_output_dir / f"{video_path.stem}_trajectory.mp4"
comparison_video_path = video_output_dir / f"{video_path.stem}_comparison.mp4"
csv_path = video_output_dir / f"{video_path.stem}_data.csv"
writer_traj = cv2.VideoWriter(str(trajectory_video_path), fourcc, fps, (input_size[1], input_size[0]))
writer_comp = cv2.VideoWriter(str(comparison_video_path), fourcc, fps, (input_size[1] * 2, input_size[0]))
trajectory_points = deque(maxlen=fps)
csv_data = []
detected_frames_count = 0
# 预处理转换(保持不变)
resizer = Resize(keys=['path_prev', 'path', 'path_next'], size=input_size)
concatenator = ConcatChannels(
keys=['path_prev', 'path', 'path_next'],
output_key='image'
)
# --- 新的循环逻辑 ---
frame_idx_counter = 0
iteration_count = 0
pbar = tqdm(total=total_frames, desc=f"Processing {video_path.stem}")
start_time = time.time() # 记录开始时间
while cap.isOpened():
# 1. 一次性读取 3 帧
ret1, frame1 = cap.read()
ret2, frame2 = cap.read()
ret3, frame3 = cap.read()
# 如果任何一帧读取失败(视频末尾),则终止循环
if not ret1 or not ret2 or not ret3:
break
# 2. 准备模型输入
frame1_rgb = cv2.cvtColor(frame1, cv2.COLOR_BGR2RGB)
frame2_rgb = cv2.cvtColor(frame2, cv2.COLOR_BGR2RGB)
frame3_rgb = cv2.cvtColor(frame3, cv2.COLOR_BGR2RGB)
data_dict = {'path_prev': frame1_rgb, 'path': frame2_rgb, 'path_next': frame3_rgb}
data_dict = resizer(data_dict)
data_dict = concatenator(data_dict)
resized_frames = [data_dict['path_prev'], data_dict['path'], data_dict['path_next']]
image_np = data_dict['image']
image_tensor = torch.from_numpy(image_np.transpose(2, 0, 1)).float().div(255).unsqueeze(0).to(device)
# 3. 批量推理
with torch.no_grad():
# heatmap_preds 的形状是 [1, 3, H, W]
heatmap_preds = model(image_tensor)
# 移除 batch 维度,得到 (3, H, W) 的 NumPy 数组
heatmaps_np = heatmap_preds.squeeze(0).cpu().numpy()
threshold_uint8 = int(args.threshold * 255) # 阈值仍然由参数控制
# 4. 循环处理这 3 帧的结果
for i in range(3):
current_frame_idx = frame_idx_counter + i
# 确保不会因为最后几帧凑不满3帧而出错
if current_frame_idx >= total_frames:
continue
single_heatmap_np = heatmaps_np[i] # 形状 (H, W)
heatmap_uint8 = (single_heatmap_np * 255).astype(np.uint8)
# (A) 提取坐标 (✨ 已修改:简化调用)
coords = _heatmap_to_coords(
heatmap_uint8,
threshold=threshold_uint8
)
# (B) 记录 CSV 和轨迹
if coords is not None:
detected_frames_count += 1
trajectory_points.append(coords)
csv_row = {'frame_number': current_frame_idx, 'detected': 1, 'x': coords[0], 'y': coords[1]}
else:
trajectory_points.append(None)
csv_row = {'frame_number': current_frame_idx, 'detected': 0, 'x': 0.0, 'y': 0.0}
csv_data.append(csv_row)
# (C) 绘制和写入视频
frame_to_draw = cv2.cvtColor(resized_frames[i], cv2.COLOR_RGB2BGR)
# 绘制轨迹视频
final_traj_frame = draw_comet_tail(frame_to_draw, trajectory_points)
writer_traj.write(final_traj_frame)
# 绘制对比视频
heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
combined_frame = np.hstack((final_traj_frame, heatmap_color))
writer_comp.write(combined_frame)
# 5. 更新计数器和进度条 (关键!)
frame_idx_counter += 3
iteration_count += 1
pbar.update(3)
# --- 循环结束后的清理 ---
end_time = time.time()
total_duration = end_time - start_time
# 平均每秒处理多少个 iteration (每 iteration 处理 3 帧)
avg_it_per_sec = iteration_count * 3 / total_duration if total_duration > 0 else 0
print(f"⏱️ Processed {iteration_count * 3} frames of {resolution_str} in {total_duration:.2f} seconds. Avg: {avg_it_per_sec:.2f} frames/sec.")
pbar.close() # 关闭进度条
detection_ratio = (detected_frames_count / total_frames) if total_frames > 0 else 0
with open(csv_path, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=['frame_number', 'detected', 'x', 'y'])
writer.writeheader()
writer.writerows(csv_data)
f.write("\n")
f.write(f"total_detected_frame,{detected_frames_count}\n")
f.write(f"detection_ratio,{detection_ratio:.4f}\n")
cap.release()
writer_traj.release()
writer_comp.release()
print(f"✅ Finished processing. Results saved in: {video_output_dir}")
# ✨ 新增:返回统计结果
stats = {
'video_name': video_path.name,
'detected_frames': detected_frames_count,
'total_frames': total_frames,
'detection_ratio': round(detection_ratio, 4)
}
return stats
# --- 4. “总调度室”: ✨ main (✨ 已修改) ✨ ---
def main():
parser = argparse.ArgumentParser(description="TrackNet Batch Inference Pipeline")
parser.add_argument('input_dir', type=str, help='Path to the directory containing input videos.')
parser.add_argument('weights_path', type=str, help='Path to the model weights (.pth file).')
# ✨ 新增架构选择参数
parser.add_argument(
'--arch',
type=str,
required=True,
choices=['v2', 'v4', 'v5'],
help='Model architecture to use (v2, v4, or v5).'
)
parser.add_argument('--device', type=str, default='cuda:0', help='Device to use for inference (e.g., "cuda:0" or "cpu").')
# ✨ 唯一可调的后处理参数 ✨
parser.add_argument('--threshold', type=float, default=0.5, help='Confidence threshold for detection (0-1).')
# ✨✨✨ 已删除 --min-circularity 和 --min-area ✨✨✨
args = parser.parse_args()
# ✨ 动态获取模型配置
model_cfg = MODEL_CONFIGS.get(args.arch)
if model_cfg is None:
print(f"❌ 错误:未知的架构 '{args.arch}'。请从 'v2', 'v4', 'v5' 中选择。")
return
print(f"🚀 Starting Batch Inference Pipeline for [TrackNet {args.arch.upper()}]...")
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
model = build_model(model_cfg)
model.load_state_dict(torch.load(args.weights_path, map_location='cpu'))
model.to(device).eval()
print(f"✅ Model loaded from {args.weights_path} and sent to {device}.")
input_dir = Path(args.input_dir)
# ✨ 动态设置输出目录
output_root_dir = input_dir / args.arch
output_root_dir.mkdir(exist_ok=True)
print("🔎 Searching for .mp4 and .mov files...")
video_files = []
supported_formats = ['*.mp4', '*.mov', '*.MOV', '*.MP4']
for fmt in supported_formats:
video_files.extend(input_dir.glob(fmt))
if not video_files:
print(f"❌ No supported video files (.mp4, .mov) found in {input_dir}. Exiting.")
return
video_files = sorted(list(set(video_files)))
print(f"Found {len(video_files)} videos to process.")
# ✨ 1. 初始化汇总列表
summary_data_list = []
for video_path in video_files:
# ✨ 2. 收集每个视频的返回结果
try:
video_stats = process_video(video_path, model, device, args, output_root_dir)
if video_stats:
summary_data_list.append(video_stats)
except Exception as e:
print(f"❌ ERROR processing {video_path.name}: {e}")
print("Skipping this video and continuing...")
# ✨ 3. 循环结束后,写入全局汇总CSV
if summary_data_list:
summary_csv_path = output_root_dir / f"_summary_report_{args.arch}.csv"
print(f"\n📊 Writing summary report to {summary_csv_path}")
fieldnames = ['video_name', 'detected_frames', 'total_frames', 'detection_ratio']
# 定义中文表头
chinese_header_map = {
'video_name': '视频名',
'detected_frames': '检测到的球帧数',
'total_frames': '视频总帧数',
'detection_ratio': '检测比率'
}
try:
with open(summary_csv_path, 'w', newline='', encoding='utf-8-sig') as f:
# 写入UTF-8 BOM头和中文表头
writer = csv.writer(f)
writer.writerow([chinese_header_map[field] for field in fieldnames])
# 使用 DictWriter 写入数据行
dict_writer = csv.DictWriter(f, fieldnames=fieldnames)
dict_writer.writerows(summary_data_list)
except Exception as e:
print(f"❌ ERROR writing summary CSV: {e}")
print(f"\n🎉🎉🎉 All videos processed! Check the results in: {output_root_dir} 🎉🎉🎉")
if __name__ == '__main__':
main()