-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathconvert.py
More file actions
33 lines (27 loc) · 868 Bytes
/
convert.py
File metadata and controls
33 lines (27 loc) · 868 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import pickle as pkl
import sys
import torch
if __name__ == "__main__":
input = sys.argv[1]
obj = torch.load(input, map_location="cpu")
obj = obj["model"]
newmodel = {}
for k, v in obj.items():
if not k.startswith("encoder_q.") and not k.startswith("network"):
continue
old_k = k
if k.startswith("encoder_q."):
k = k.replace("encoder_q.", "")
elif k.startswith("network"):
k = k.replace("network.", "")
print(old_k, "->", k)
newmodel[k] = v.numpy()
res = {
"model": newmodel,
"__author__": "MOCO" if k.startswith("encoder_q.") else "CLS",
"matching_heuristics": True
}
with open(sys.argv[2], "wb") as f:
pkl.dump(res, f)