forked from Air-IO/Air-IO
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference_motion.py
More file actions
152 lines (131 loc) · 6.59 KB
/
inference_motion.py
File metadata and controls
152 lines (131 loc) · 6.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import sys
import torch
import torch.utils.data as Data
import argparse
import pickle
import tqdm
from utils import move_to, save_state
from pyhocon import ConfigFactory
from datasets import collate_fcs, SeqeuncesMotionDataset
from model import net_dict, ONNXWrapper
from utils import *
def inference(network, loader, confs):
'''
Correction inference
save the corrections generated from the network.
'''
network.eval()
evaluate_states = {}
with torch.no_grad():
inte_state = None
for data, _, label in tqdm.tqdm(loader):
data, label = move_to([data, label], confs.device)
rot = label['gt_rot'][:,:-1,:].Log().tensor()
inte_state = network.forward(data, rot)
inte_state['ts'] = network.get_label(data['ts'][...,None])[0]
save_state(evaluate_states, inte_state)
for k, v in evaluate_states.items():
evaluate_states[k] = torch.cat(v, dim=-2)
return evaluate_states
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='configs/EuRoC/motion_body.conf', help='config file path')
parser.add_argument('--load', type=str, default=None, help='path for specific model check point, Default is the best model')
parser.add_argument("--device", type=str, default="cuda:0", help="cuda or cpu")
parser.add_argument('--batch_size', type=int, default=1, help='batch size.')
parser.add_argument('--seqlen', type=int, default=1000, help='window size.')
parser.add_argument('--whole', default=True, action="store_true", help='estimate the whole seq')
args = parser.parse_args(); print(args)
conf = ConfigFactory.parse_file(args.config)
conf.train.device = args.device
conf_name = os.path.split(args.config)[-1].split(".")[0]
conf['general']['exp_dir'] = os.path.join(conf.general.exp_dir, conf_name)
conf['device'] = args.device
dataset_conf = conf.dataset.inference
# The trained network uses double precision but ONNX export requires
# float32 tensors. We keep the network in double for downstream PyTorch
# inference and temporarily cast to float32 when exporting to ONNX.
network = net_dict[conf.train.network](conf.train).to(args.device).double()
save_folder = os.path.join(conf.general.exp_dir, "evaluate")
os.makedirs(save_folder, exist_ok=True)
if args.load is None:
ckpt_path = os.path.join(conf.general.exp_dir, "ckpt/best_model.ckpt")
else:
ckpt_path = os.path.join(conf.general.exp_dir, "ckpt", args.load)
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=torch.device(args.device),weights_only=True)
print("loaded state dict %s in epoch %i"%(ckpt_path, checkpoint["epoch"]))
network.load_state_dict(checkpoint["model_state_dict"])
# ---- Save model in .pt format ----
pt_path = os.path.join(conf.general.exp_dir, "model_export.pt")
torch.save(network.state_dict(), pt_path)
print(f"Saved PyTorch weights to {pt_path}")
# ---- Export model to ONNX format ----
# Prepare a dummy batch from the dataset for ONNX tracing
data_conf = dataset_conf.data_list[0]
path = data_conf.data_drive[0]
eval_dataset = SeqeuncesMotionDataset(data_set_config=dataset_conf, data_path=path, data_root=data_conf["data_root"])
if 'collate' in conf.dataset.keys():
collate_fn = collate_fcs[conf.dataset.collate.type]
else:
collate_fn = collate_fcs['base']
eval_loader = Data.DataLoader(dataset=eval_dataset, batch_size=args.batch_size,
shuffle=False, collate_fn=collate_fn, drop_last=False)
data, _, label = next(iter(eval_loader))
# Prepare float32 dummy inputs for ONNX tracing
dummy_data = {k: v.to(args.device).float() for k, v in data.items()}
dummy_rot = label['gt_rot'][:, :-1, :].Log().tensor().to(args.device).float()
# Wrap the network so that ONNX sees explicit tensor arguments instead
# of a nested dictionary.
wrapper = ONNXWrapper(network.float())
onnx_path = os.path.join(conf.general.exp_dir, "model_export.onnx")
torch.onnx.export(
wrapper,
(dummy_data['acc'], dummy_data['gyro'], dummy_rot),
onnx_path,
input_names=['acc', 'gyro', 'rot'],
output_names=['net_vel', 'cov'],
dynamic_axes={
'acc': {0: 'batch', 1: 'seq'},
'gyro': {0: 'batch', 1: 'seq'},
'rot': {0: 'batch', 1: 'seq'},
'net_vel': {0: 'batch', 1: 'seq'},
'cov': {0: 'batch', 1: 'seq'},
},
)
print(f"Exported ONNX model to {onnx_path}")
# Restore original precision for subsequent PyTorch inference
network.double()
else:
raise KeyError(f"No model loaded {ckpt_path}")
sys.exit()
if 'collate' in conf.dataset.keys():
collate_fn = collate_fcs[conf.dataset.collate.type]
else:
collate_fn = collate_fcs['base']
cov_result, rmse = [], []
net_out_result = {}
evals = {}
dataset_conf.data_list[0]["window_size"] = args.seqlen
dataset_conf.data_list[0]["step_size"] = args.seqlen
for data_conf in dataset_conf.data_list:
for path in data_conf.data_drive:
if args.whole:
dataset_conf["mode"] = "inference"
else:
dataset_conf["mode"] = "infevaluate"
dataset_conf["exp_dir"] = conf.general.exp_dir
eval_dataset = SeqeuncesMotionDataset(data_set_config=dataset_conf, data_path=path, data_root=data_conf["data_root"])
eval_loader = Data.DataLoader(dataset=eval_dataset, batch_size=args.batch_size,
shuffle=False, collate_fn=collate_fn, drop_last = False)
inference_state = inference(network=network, loader = eval_loader, confs=conf.train)
if not "cov" in inference_state.keys():
inference_state["cov"] = torch.zeros_like(inference_state["net_vel"])
inference_state['ts'] = inference_state['ts']
inference_state['net_vel'] = inference_state['net_vel'][0] #TODO: batch size != 1
net_out_result[path] = inference_state
net_result_path = os.path.join(conf.general.exp_dir, 'net_output.pickle')
print("save netout, ", net_result_path)
with open(net_result_path, 'wb') as handle:
pickle.dump(net_out_result, handle, protocol=pickle.HIGHEST_PROTOCOL)