diff --git a/infer_wild.py b/infer_wild.py index 17acd19..f3e38e4 100644 --- a/infer_wild.py +++ b/infer_wild.py @@ -25,73 +25,86 @@ def parse_args(): opts = parser.parse_args() return opts -opts = parse_args() -args = get_config(opts.config) +if __name__ == '__main__': + # due to RuntimeError: freeze_support() on Windows: https://github.com/pytorch/pytorch/issues/5858 + torch.multiprocessing.freeze_support() -model_backbone = load_backbone(args) -if torch.cuda.is_available(): - model_backbone = nn.DataParallel(model_backbone) - model_backbone = model_backbone.cuda() + opts = parse_args() + args = get_config(opts.config) + + model_backbone = load_backbone(args) + if torch.cuda.is_available(): + model_backbone = nn.DataParallel(model_backbone) + model_backbone = model_backbone.cuda() + + print('Loading checkpoint', opts.evaluate) + checkpoint = torch.load(opts.evaluate, map_location=lambda storage, loc: storage) -print('Loading checkpoint', opts.evaluate) -checkpoint = torch.load(opts.evaluate, map_location=lambda storage, loc: storage) -model_backbone.load_state_dict(checkpoint['model_pos'], strict=True) -model_pos = model_backbone -model_pos.eval() -testloader_params = { - 'batch_size': 1, - 'shuffle': False, - 'num_workers': 8, - 'pin_memory': True, - 'prefetch_factor': 4, - 'persistent_workers': True, - 'drop_last': False -} + # fix KeyError: ‘unexpected key “module.* + # https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/3 + from collections import OrderedDict + new_state_dict = OrderedDict() + for k, v in checkpoint['model_pos'].items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v -vid = imageio.get_reader(opts.vid_path, 'ffmpeg') -fps_in = vid.get_meta_data()['fps'] -vid_size = vid.get_meta_data()['size'] -os.makedirs(opts.out_path, exist_ok=True) - -if opts.pixel: - # Keep relative scale with pixel coornidates - wild_dataset = WildDetDataset(opts.json_path, clip_len=opts.clip_len, vid_size=vid_size, scale_range=None, focus=opts.focus) -else: - # Scale to [-1,1] - wild_dataset = WildDetDataset(opts.json_path, clip_len=opts.clip_len, scale_range=[1,1], focus=opts.focus) - -test_loader = DataLoader(wild_dataset, **testloader_params) - -results_all = [] -with torch.no_grad(): - for batch_input in tqdm(test_loader): - N, T = batch_input.shape[:2] - if torch.cuda.is_available(): - batch_input = batch_input.cuda() - if args.no_conf: - batch_input = batch_input[:, :, :, :2] - if args.flip: - batch_input_flip = flip_data(batch_input) - predicted_3d_pos_1 = model_pos(batch_input) - predicted_3d_pos_flip = model_pos(batch_input_flip) - predicted_3d_pos_2 = flip_data(predicted_3d_pos_flip) # Flip back - predicted_3d_pos = (predicted_3d_pos_1 + predicted_3d_pos_2) / 2.0 - else: - predicted_3d_pos = model_pos(batch_input) - if args.rootrel: - predicted_3d_pos[:,:,0,:]=0 # [N,T,17,3] - else: - predicted_3d_pos[:,0,0,2]=0 - pass - if args.gt_2d: - predicted_3d_pos[...,:2] = batch_input[...,:2] - results_all.append(predicted_3d_pos.cpu().numpy()) - -results_all = np.hstack(results_all) -results_all = np.concatenate(results_all) -render_and_save(results_all, '%s/X3D.mp4' % (opts.out_path), keep_imgs=False, fps=fps_in) -if opts.pixel: - # Convert to pixel coordinates - results_all = results_all * (min(vid_size) / 2.0) - results_all[:,:,:2] = results_all[:,:,:2] + np.array(vid_size) / 2.0 -np.save('%s/X3D.npy' % (opts.out_path), results_all) \ No newline at end of file + model_backbone.load_state_dict(new_state_dict, strict=True) + model_pos = model_backbone + model_pos.eval() + testloader_params = { + 'batch_size': 1, + 'shuffle': False, + 'num_workers': 8, + 'pin_memory': True, + 'prefetch_factor': 4, + 'persistent_workers': True, + 'drop_last': False + } + + vid = imageio.get_reader(opts.vid_path, 'ffmpeg') + fps_in = vid.get_meta_data()['fps'] + vid_size = vid.get_meta_data()['size'] + os.makedirs(opts.out_path, exist_ok=True) + + if opts.pixel: + # Keep relative scale with pixel coornidates + wild_dataset = WildDetDataset(opts.json_path, clip_len=opts.clip_len, vid_size=vid_size, scale_range=None, focus=opts.focus) + else: + # Scale to [-1,1] + wild_dataset = WildDetDataset(opts.json_path, clip_len=opts.clip_len, scale_range=[1,1], focus=opts.focus) + + test_loader = DataLoader(wild_dataset, **testloader_params) + + results_all = [] + with torch.no_grad(): + for batch_input in tqdm(test_loader): + N, T = batch_input.shape[:2] + if torch.cuda.is_available(): + batch_input = batch_input.cuda() + if args.no_conf: + batch_input = batch_input[:, :, :, :2] + if args.flip: + batch_input_flip = flip_data(batch_input) + predicted_3d_pos_1 = model_pos(batch_input) + predicted_3d_pos_flip = model_pos(batch_input_flip) + predicted_3d_pos_2 = flip_data(predicted_3d_pos_flip) # Flip back + predicted_3d_pos = (predicted_3d_pos_1 + predicted_3d_pos_2) / 2.0 + else: + predicted_3d_pos = model_pos(batch_input) + if args.rootrel: + predicted_3d_pos[:,:,0,:]=0 # [N,T,17,3] + else: + predicted_3d_pos[:,0,0,2]=0 + pass + if args.gt_2d: + predicted_3d_pos[...,:2] = batch_input[...,:2] + results_all.append(predicted_3d_pos.cpu().numpy()) + + results_all = np.hstack(results_all) + results_all = np.concatenate(results_all) + render_and_save(results_all, '%s/X3D.mp4' % (opts.out_path), keep_imgs=False, fps=fps_in) + if opts.pixel: + # Convert to pixel coordinates + results_all = results_all * (min(vid_size) / 2.0) + results_all[:,:,:2] = results_all[:,:,:2] + np.array(vid_size) / 2.0 + np.save('%s/X3D.npy' % (opts.out_path), results_all)