diff --git a/yolox/tracker/byte_tracker.py b/yolox/tracker/byte_tracker.py index 2d004599..8bac2846 100644 --- a/yolox/tracker/byte_tracker.py +++ b/yolox/tracker/byte_tracker.py @@ -12,7 +12,14 @@ class STrack(BaseTrack): shared_kalman = KalmanFilter() - def __init__(self, tlwh, score): + def __init__(self, tlwh, score, det_idx=None): + """ + Initialize a tracklet + Args: + tlwh: (np.ndarray) bbox in format x1,y1,x2,y2 + score: (float) bbox detection score + det_idx: (int) (Optional) corresponding index in object detection list + """ # wait activate self._tlwh = np.asarray(tlwh, dtype=np.float) @@ -23,6 +30,8 @@ def __init__(self, tlwh, score): self.score = score self.tracklet_len = 0 + self.det_idx = det_idx + def predict(self): mean_state = self.mean.copy() if self.state != TrackState.Tracked: @@ -67,7 +76,8 @@ def re_activate(self, new_track, frame_id, new_id=False): if new_id: self.track_id = self.next_id() self.score = new_track.score - + self.det_idx = new_track.det_idx + def update(self, new_track, frame_id): """ Update a matched track @@ -87,6 +97,8 @@ def update(self, new_track, frame_id): self.score = new_track.score + self.det_idx = new_track.det_idx + @property # @jit(nopython=True) def tlwh(self): @@ -156,17 +168,33 @@ def __init__(self, args, frame_rate=30): self.max_time_lost = self.buffer_size self.kalman_filter = KalmanFilter() - def update(self, output_results, img_info, img_size): + def update(self, output_results, img_info, img_size, track_det_idx=False): + """ + Update tracker with detection results + Args: + output_results: detection results + Scores + det_idx (optional) + img_info: original image information + img_size: inference scaled image size + track_det_idx: whether to track det_idx (index corresponding to the detection results) + """ self.frame_id += 1 activated_starcks = [] refind_stracks = [] lost_stracks = [] removed_stracks = [] + + if track_det_idx: + if output_results.shape[1] == 6: + scores = output_results[:, 4] + bboxes = output_results[:, :4] # x1y1x2y2 + _det_idxs = output_results[:, 5] + else: + raise ValueError('output_results shape error') if output_results.shape[1] == 5: scores = output_results[:, 4] bboxes = output_results[:, :4] - else: + elif not track_det_idx: output_results = output_results.cpu().numpy() scores = output_results[:, 4] * output_results[:, 5] bboxes = output_results[:, :4] # x1y1x2y2 @@ -183,11 +211,19 @@ def update(self, output_results, img_info, img_size): dets = bboxes[remain_inds] scores_keep = scores[remain_inds] scores_second = scores[inds_second] + if track_det_idx: + det_idxs = _det_idxs[remain_inds] + det_idxs_second = _det_idxs[inds_second] if len(dets) > 0: '''Detections''' - detections = [STrack(STrack.tlbr_to_tlwh(tlbr), s) for - (tlbr, s) in zip(dets, scores_keep)] + detections = [] + for i, (tlbr, s) in enumerate(zip(dets, scores_keep)): + di = det_idxs[i] + if track_det_idx: + detections.append(STrack(STrack.tlbr_to_tlwh(tlbr), s, di)) + else: + detections.append(STrack(STrack.tlbr_to_tlwh(tlbr), s)) else: detections = [] @@ -223,8 +259,15 @@ def update(self, output_results, img_info, img_size): # association the untrack to the low score detections if len(dets_second) > 0: '''Detections''' - detections_second = [STrack(STrack.tlbr_to_tlwh(tlbr), s) for - (tlbr, s) in zip(dets_second, scores_second)] + #detections_second = [STrack(STrack.tlbr_to_tlwh(tlbr), s, di) for + # (tlbr, s, di) in zip(dets_second, scores_second, det_idxs_second)] + detections_second = [] + for i, (tlbr, s) in enumerate(zip(dets_second, scores_second)): + if track_det_idx: + di = det_idxs_second[i] + detections_second.append(STrack(STrack.tlbr_to_tlwh(tlbr), s, di)) + else: + detections_second.append(STrack(STrack.tlbr_to_tlwh(tlbr), s)) else: detections_second = [] r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]