-
Notifications
You must be signed in to change notification settings - Fork 4
/
convert.py
65 lines (62 loc) · 2.83 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from models.yolo import Model
device = torch.device("cpu")
cfg = "./models/detect/gelan-s-hyper.yaml"
model = Model(cfg, ch=3, nc=80, anchors=3)
#model = model.half()
model = model.to(device)
_ = model.eval()
ckpt = torch.load('./runs/train/exp/weights/best.pt', map_location='cpu')
model.names = ckpt['model'].names
model.nc = ckpt['model'].nc
idx = 0
for k, v in model.state_dict().items():
if "model.{}.".format(idx) in k:
if idx < 32:
kr = k.replace("model.{}.".format(idx), "model.{}.".format(idx+1))
model.state_dict()[k] -= model.state_dict()[k]
model.state_dict()[k] += ckpt['model'].state_dict()[kr]
elif "model.{}.cv2.".format(idx) in k:
kr = k.replace("model.{}.cv2.".format(idx), "model.{}.cv4.".format(idx+16))
model.state_dict()[k] -= model.state_dict()[k]
model.state_dict()[k] += ckpt['model'].state_dict()[kr]
elif "model.{}.cv3.".format(idx) in k:
kr = k.replace("model.{}.cv3.".format(idx), "model.{}.cv5.".format(idx+16))
model.state_dict()[k] -= model.state_dict()[k]
model.state_dict()[k] += ckpt['model'].state_dict()[kr]
elif "model.{}.dfl.".format(idx) in k:
kr = k.replace("model.{}.dfl.".format(idx), "model.{}.dfl2.".format(idx+16))
model.state_dict()[k] -= model.state_dict()[k]
model.state_dict()[k] += ckpt['model'].state_dict()[kr]
else:
while True:
idx += 1
if "model.{}.".format(idx) in k:
break
if idx < 32:
kr = k.replace("model.{}.".format(idx), "model.{}.".format(idx+1))
model.state_dict()[k] -= model.state_dict()[k]
model.state_dict()[k] += ckpt['model'].state_dict()[kr]
elif "model.{}.cv2.".format(idx) in k:
kr = k.replace("model.{}.cv2.".format(idx), "model.{}.cv4.".format(idx+16))
model.state_dict()[k] -= model.state_dict()[k]
model.state_dict()[k] += ckpt['model'].state_dict()[kr]
elif "model.{}.cv3.".format(idx) in k:
kr = k.replace("model.{}.cv3.".format(idx), "model.{}.cv5.".format(idx+16))
model.state_dict()[k] -= model.state_dict()[k]
model.state_dict()[k] += ckpt['model'].state_dict()[kr]
elif "model.{}.dfl.".format(idx) in k:
kr = k.replace("model.{}.dfl.".format(idx), "model.{}.dfl2.".format(idx+16))
model.state_dict()[k] -= model.state_dict()[k]
model.state_dict()[k] += ckpt['model'].state_dict()[kr]
_ = model.eval()
m_ckpt = {'model': model.half(),
'optimizer': None,
'best_fitness': None,
'ema': None,
'updates': None,
'opt': None,
'git': None,
'date': None,
'epoch': -1}
torch.save(m_ckpt, "./yolov9-s-hyper-converted.pt")