-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathconvert_shards.py
More file actions
31 lines (28 loc) · 1.01 KB
/
convert_shards.py
File metadata and controls
31 lines (28 loc) · 1.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import sys
import torch
from safetensors.torch import save_file
import glob
import os
inputs = glob.glob('models/hf/gemma-3-4b-it-hqq-int8-int4/*.bin')
for inp in inputs:
out = inp.replace('.bin', '.safetensors')
print(f"Converting {inp} to {out}")
obj = torch.load(inp, map_location='cpu', weights_only=False)
if hasattr(obj, 'state_dict'):
obj = obj.state_dict()
elif isinstance(obj, dict) and 'state_dict' in obj and isinstance(obj['state_dict'], dict):
obj = obj['state_dict']
elif isinstance(obj, dict) and 'model_state_dict' in obj and isinstance(obj['model_state_dict'], dict):
obj = obj['model_state_dict']
clean = {}
for k, v in obj.items():
if torch.is_tensor(v):
t = v
if hasattr(t, "dequantize"):
try:
t = t.dequantize()
except Exception:
pass
clean[str(k)] = t.detach().cpu().contiguous()
save_file(clean, out)
print(len(clean))