diff --git a/detect.py b/detect.py index 5b94027c..68a6dfdc 100644 --- a/detect.py +++ b/detect.py @@ -65,8 +65,8 @@ def main(_argv): batch_data = tf.constant(images_data) pred_bbox = infer(batch_data) for key, value in pred_bbox.items(): - boxes = value[:, :, 0:4] - pred_conf = value[:, :, 4:] + boxes = value[..., :4] + pred_conf = value[..., 4:] boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression( boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),