|
16 | 16 | import colorcet as cc
|
17 | 17 | import cv2
|
18 | 18 | import numpy as np
|
| 19 | +import pickle |
19 | 20 | from PIL import ImageColor
|
20 | 21 | from pip._internal.operations import freeze
|
21 | 22 | import torch
|
@@ -62,6 +63,149 @@ def show_progress(count, block_size, total_size):
|
62 | 63 | with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
63 | 64 | zip_ref.extractall(target_dir)
|
64 | 65 |
|
| 66 | + |
| 67 | +def benchmark_videos( |
| 68 | + model_path, |
| 69 | + video_path, |
| 70 | + output=None, |
| 71 | + n_frames=1000, |
| 72 | + tf_config=None, |
| 73 | + resize=None, |
| 74 | + pixels=None, |
| 75 | + cropping=None, |
| 76 | + dynamic=(False, 0.5, 10), |
| 77 | + print_rate=False, |
| 78 | + display=False, |
| 79 | + pcutoff=0.5, |
| 80 | + display_radius=3, |
| 81 | + cmap="bmy", |
| 82 | + save_poses=False, |
| 83 | + save_video=False, |
| 84 | +): |
| 85 | + """Analyze videos using DeepLabCut-live exported models. |
| 86 | + Analyze multiple videos and/or multiple options for the size of the video |
| 87 | + by specifying a resizing factor or the number of pixels to use in the image (keeping aspect ratio constant). |
| 88 | + Options to record inference times (to examine inference speed), |
| 89 | + display keypoints to visually check the accuracy, |
| 90 | + or save poses to an hdf5 file as in :function:`deeplabcut.benchmark_videos` and |
| 91 | + create a labeled video as in :function:`deeplabcut.create_labeled_video`. |
| 92 | +
|
| 93 | + Parameters |
| 94 | + ---------- |
| 95 | + model_path : str |
| 96 | + path to exported DeepLabCut model |
| 97 | + video_path : str or list |
| 98 | + path to video file or list of paths to video files |
| 99 | + output : str |
| 100 | + path to directory to save results |
| 101 | + tf_config : :class:`tensorflow.ConfigProto` |
| 102 | + tensorflow session configuration |
| 103 | + resize : int, optional |
| 104 | + resize factor. Can only use one of resize or pixels. If both are provided, will use pixels. by default None |
| 105 | + pixels : int, optional |
| 106 | + downsize image to this number of pixels, maintaining aspect ratio. Can only use one of resize or pixels. If both are provided, will use pixels. by default None |
| 107 | + cropping : list of int |
| 108 | + cropping parameters in pixel number: [x1, x2, y1, y2] |
| 109 | + dynamic: triple containing (state, detectiontreshold, margin) |
| 110 | + If the state is true, then dynamic cropping will be performed. That means that if an object is detected (i.e. any body part > detectiontreshold), |
| 111 | + then object boundaries are computed according to the smallest/largest x position and smallest/largest y position of all body parts. This window is |
| 112 | + expanded by the margin and from then on only the posture within this crop is analyzed (until the object is lost, i.e. <detectiontreshold). The |
| 113 | + current position is utilized for updating the crop window for the next frame (this is why the margin is important and should be set large |
| 114 | + enough given the movement of the animal) |
| 115 | + n_frames : int, optional |
| 116 | + number of frames to run inference on, by default 1000 |
| 117 | + print_rate : bool, optional |
| 118 | + flat to print inference rate frame by frame, by default False |
| 119 | + display : bool, optional |
| 120 | + flag to display keypoints on images. Useful for checking the accuracy of exported models. |
| 121 | + pcutoff : float, optional |
| 122 | + likelihood threshold to display keypoints |
| 123 | + display_radius : int, optional |
| 124 | + size (radius in pixels) of keypoint to display |
| 125 | + cmap : str, optional |
| 126 | + a string indicating the :package:`colorcet` colormap, `options here <https://colorcet.holoviz.org/>`, by default "bmy" |
| 127 | + save_poses : bool, optional |
| 128 | + flag to save poses to an hdf5 file. If True, operates similar to :function:`DeepLabCut.benchmark_videos`, by default False |
| 129 | + save_video : bool, optional |
| 130 | + flag to save a labeled video. If True, operates similar to :function:`DeepLabCut.create_labeled_video`, by default False |
| 131 | +
|
| 132 | + Example |
| 133 | + ------- |
| 134 | + Return a vector of inference times for 10000 frames on one video or two videos: |
| 135 | + dlclive.benchmark_videos('/my/exported/model', 'my_video.avi', n_frames=10000) |
| 136 | + dlclive.benchmark_videos('/my/exported/model', ['my_video1.avi', 'my_video2.avi'], n_frames=10000) |
| 137 | +
|
| 138 | + Return a vector of inference times, testing full size and resizing images to half the width and height for inference, for two videos |
| 139 | + dlclive.benchmark_videos('/my/exported/model', ['my_video1.avi', 'my_video2.avi'], n_frames=10000, resize=[1.0, 0.5]) |
| 140 | +
|
| 141 | + Display keypoints to check the accuracy of an exported model |
| 142 | + dlclive.benchmark_videos('/my/exported/model', 'my_video.avi', display=True) |
| 143 | +
|
| 144 | + Analyze a video (save poses to hdf5) and create a labeled video, similar to :function:`DeepLabCut.benchmark_videos` and :function:`create_labeled_video` |
| 145 | + dlclive.benchmark_videos('/my/exported/model', 'my_video.avi', save_poses=True, save_video=True) |
| 146 | + """ |
| 147 | + # convert video_paths to list |
| 148 | + video_path = video_path if type(video_path) is list else [video_path] |
| 149 | + |
| 150 | + # fix resize |
| 151 | + if pixels: |
| 152 | + pixels = pixels if type(pixels) is list else [pixels] |
| 153 | + resize = [None for p in pixels] |
| 154 | + elif resize: |
| 155 | + resize = resize if type(resize) is list else [resize] |
| 156 | + pixels = [None for r in resize] |
| 157 | + else: |
| 158 | + resize = [None] |
| 159 | + pixels = [None] |
| 160 | + |
| 161 | + # loop over videos |
| 162 | + for video in video_path: |
| 163 | + # initialize full inference times |
| 164 | + inf_times = [] |
| 165 | + im_size_out = [] |
| 166 | + |
| 167 | + for i in range(len(resize)): |
| 168 | + print(f"\nRun {i+1} / {len(resize)}\n") |
| 169 | + |
| 170 | + this_inf_times, this_im_size, meta = benchmark( |
| 171 | + model_path=model_path, |
| 172 | + model_type="base", |
| 173 | + video_path=video, |
| 174 | + tf_config=tf_config, |
| 175 | + resize=resize[i], |
| 176 | + pixels=pixels[i], |
| 177 | + cropping=cropping, |
| 178 | + dynamic=dynamic, |
| 179 | + n_frames=n_frames, |
| 180 | + print_rate=print_rate, |
| 181 | + display=display, |
| 182 | + pcutoff=pcutoff, |
| 183 | + display_radius=display_radius, |
| 184 | + cmap=cmap, |
| 185 | + save_poses=save_poses, |
| 186 | + save_video=save_video, |
| 187 | + save_dir=output, |
| 188 | + ) |
| 189 | + |
| 190 | + inf_times.append(this_inf_times) |
| 191 | + im_size_out.append(this_im_size) |
| 192 | + |
| 193 | + inf_times = np.array(inf_times) |
| 194 | + im_size_out = np.array(im_size_out) |
| 195 | + |
| 196 | + # save results |
| 197 | + if output is not None: |
| 198 | + sys_info = get_system_info() |
| 199 | + save_inf_times( |
| 200 | + sys_info, |
| 201 | + inf_times, |
| 202 | + im_size_out, |
| 203 | + model=os.path.basename(model_path), |
| 204 | + meta=meta, |
| 205 | + output=output, |
| 206 | + ) |
| 207 | + |
| 208 | + |
65 | 209 | def get_system_info() -> dict:
|
66 | 210 | """
|
67 | 211 | Returns a summary of system information relevant to running benchmarking.
|
@@ -128,6 +272,77 @@ def get_system_info() -> dict:
|
128 | 272 | }
|
129 | 273 |
|
130 | 274 |
|
| 275 | +def save_inf_times( |
| 276 | + sys_info, inf_times, im_size, model=None, meta=None, output=None |
| 277 | +): |
| 278 | + """Save inference time data collected using :function:`benchmark` with system information to a pickle file. |
| 279 | + This is primarily used through :function:`benchmark_videos` |
| 280 | +
|
| 281 | +
|
| 282 | + Parameters |
| 283 | + ---------- |
| 284 | + sys_info : tuple |
| 285 | + system information generated by :func:`get_system_info` |
| 286 | + inf_times : :class:`numpy.ndarray` |
| 287 | + array of inference times generated by :func:`benchmark` |
| 288 | + im_size : tuple or :class:`numpy.ndarray` |
| 289 | + image size (width, height) for each benchmark run. If an array, each row corresponds to a row in inf_times |
| 290 | + model: str, optional |
| 291 | + name of model |
| 292 | + meta : dict, optional |
| 293 | + metadata returned by :func:`benchmark` |
| 294 | + output : str, optional |
| 295 | + path to directory to save data. If None, uses pwd, by default None |
| 296 | +
|
| 297 | + Returns |
| 298 | + ------- |
| 299 | + bool |
| 300 | + flag indicating successful save |
| 301 | + """ |
| 302 | + |
| 303 | + output = output if output is not None else os.getcwd() |
| 304 | + model_type = None |
| 305 | + if model is not None: |
| 306 | + if "resnet" in model: |
| 307 | + model_type = "resnet" |
| 308 | + elif "mobilenet" in model: |
| 309 | + model_type = "mobilenet" |
| 310 | + else: |
| 311 | + model_type = None |
| 312 | + |
| 313 | + fn_ind = 0 |
| 314 | + base_name = ( |
| 315 | + f"benchmark_{sys_info['host_name']}_{sys_info['device_type']}_{fn_ind}.pickle" |
| 316 | + ) |
| 317 | + out_file = os.path.normpath(f"{output}/{base_name}") |
| 318 | + while os.path.isfile(out_file): |
| 319 | + fn_ind += 1 |
| 320 | + base_name = f"benchmark_{sys_info['host_name']}_{sys_info['device_type']}_{fn_ind}.pickle" |
| 321 | + out_file = os.path.normpath(f"{output}/{base_name}") |
| 322 | + |
| 323 | + # summary stats (mean inference time & standard error of mean) |
| 324 | + stats = zip( |
| 325 | + np.mean(inf_times, 1), |
| 326 | + np.std(inf_times, 1) * 1.0 / np.sqrt(np.shape(inf_times)[1]), |
| 327 | + ) |
| 328 | + |
| 329 | + data = { |
| 330 | + "model": model, |
| 331 | + "model_type": model_type, |
| 332 | + "im_size": im_size, |
| 333 | + "inference_times": inf_times, |
| 334 | + "stats": stats, |
| 335 | + } |
| 336 | + |
| 337 | + data.update(sys_info) |
| 338 | + if meta: |
| 339 | + data.update(meta) |
| 340 | + |
| 341 | + os.makedirs(os.path.normpath(output), exist_ok=True) |
| 342 | + pickle.dump(data, open(out_file, "wb")) |
| 343 | + |
| 344 | + return True |
| 345 | + |
131 | 346 | def benchmark(
|
132 | 347 | model_path: str,
|
133 | 348 | model_type: str,
|
|
0 commit comments