1515
1616from oddkiva import DATA_DIR_PATH
1717from oddkiva .sara .dataset .colors import generate_label_colors
18- from oddkiva .brahma .torch import DEFAULT_DEVICE
1918from oddkiva .brahma .torch .backbone .repvgg import RepVggBlock
2019from oddkiva .brahma .torch .utils .freeze import freeze_batch_norm
2120from oddkiva .brahma .torch .object_detection .detr .architectures .\
@@ -37,31 +36,48 @@ def optimize_repvgg_layer_for_inference(m: nn.Module):
3736class ModelConfig :
3837 CKPT_DIRPATH = (DATA_DIR_PATH / 'trained_models' / 'rtdetrv2_r50' /
3938 'train' / 'coco' / 'ckpts' )
39+ CKPT_RESUME_DIRPATH = (DATA_DIR_PATH / 'trained_models' / 'rtdetrv2_r50' /
40+ 'train' / 'coco' / 'ckpts-resume' )
4041 LABELS_FILEPATH = (DATA_DIR_PATH / 'model-weights' / 'rtdetrv2' /
4142 'labels.txt' )
43+
44+ CKPT_DIRPATH .exists ()
45+ CKPT_RESUME_DIRPATH .exists ()
46+ LABELS_FILEPATH .exists ()
47+
4248 W_INFER = 640
4349 H_INFER = 640
50+ CONFIDENCE_THRESHOLD = 0.4
4451
4552 RUN_ON_CPU = False
46- EPOCH = 0
47- STEPS = 1000
48- CONFIDENCE_THRESHOLD = 0.4
53+ USE_RESUME_CKPT = True
54+
55+ RESUME_ITER = 10
56+ EPOCH = 3
57+ STEPS = 2000
58+
4959
5060 @staticmethod
5161 def load () -> tuple [nn .Module , list [str ], torch .device ]:
52- assert ModelConfig .CKPT_DIRPATH .exists ()
53- assert ModelConfig .LABELS_FILEPATH .exists ()
5462
5563 # This is by design so that we can keep training with the GPU...
5664 if ModelConfig .RUN_ON_CPU :
5765 device = torch .device ('cpu' )
5866 else :
5967 device = torch .device ('cuda:1' )
6068
61- CKPT_FP = (
62- ModelConfig .CKPT_DIRPATH /
63- f'ckpt_epoch_{ ModelConfig .EPOCH } _step_{ ModelConfig .STEPS } .pth'
64- )
69+ if ModelConfig .USE_RESUME_CKPT :
70+ filename = '{}-ckpt_epoch_{}_step_{}.pth' .format (
71+ ModelConfig .RESUME_ITER ,
72+ ModelConfig .EPOCH ,
73+ ModelConfig .STEPS
74+ )
75+ CKPT_FP = ModelConfig .CKPT_RESUME_DIRPATH / filename
76+ else :
77+ CKPT_FP = (
78+ ModelConfig .CKPT_DIRPATH /
79+ f'ckpt_epoch_{ ModelConfig .EPOCH } _step_{ ModelConfig .STEPS } .pth'
80+ )
6581 assert CKPT_FP .exists ()
6682
6783 # THE MODEL
0 commit comments