Skip to content

Commit ce69189

Browse files
committed
MAINT: tweaks.
1 parent 85ef630 commit ce69189

File tree

2 files changed

+29
-13
lines changed

2 files changed

+29
-13
lines changed

python/oddkiva/brahma/torch/tasks/object_detection/tools/check_rtdetrv2.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from oddkiva import DATA_DIR_PATH
1717
from oddkiva.sara.dataset.colors import generate_label_colors
18-
from oddkiva.brahma.torch import DEFAULT_DEVICE
1918
from oddkiva.brahma.torch.backbone.repvgg import RepVggBlock
2019
from oddkiva.brahma.torch.utils.freeze import freeze_batch_norm
2120
from oddkiva.brahma.torch.object_detection.detr.architectures.\
@@ -37,31 +36,48 @@ def optimize_repvgg_layer_for_inference(m: nn.Module):
3736
class ModelConfig:
3837
CKPT_DIRPATH = (DATA_DIR_PATH / 'trained_models' / 'rtdetrv2_r50' /
3938
'train' / 'coco' / 'ckpts')
39+
CKPT_RESUME_DIRPATH = (DATA_DIR_PATH / 'trained_models' / 'rtdetrv2_r50' /
40+
'train' / 'coco' / 'ckpts-resume')
4041
LABELS_FILEPATH = (DATA_DIR_PATH / 'model-weights' / 'rtdetrv2' /
4142
'labels.txt')
43+
44+
CKPT_DIRPATH.exists()
45+
CKPT_RESUME_DIRPATH.exists()
46+
LABELS_FILEPATH.exists()
47+
4248
W_INFER = 640
4349
H_INFER = 640
50+
CONFIDENCE_THRESHOLD = 0.4
4451

4552
RUN_ON_CPU = False
46-
EPOCH = 0
47-
STEPS = 1000
48-
CONFIDENCE_THRESHOLD = 0.4
53+
USE_RESUME_CKPT = True
54+
55+
RESUME_ITER = 10
56+
EPOCH = 3
57+
STEPS = 2000
58+
4959

5060
@staticmethod
5161
def load() -> tuple[nn.Module, list[str], torch.device]:
52-
assert ModelConfig.CKPT_DIRPATH.exists()
53-
assert ModelConfig.LABELS_FILEPATH.exists()
5462

5563
# This is by design so that we can keep training with the GPU...
5664
if ModelConfig.RUN_ON_CPU:
5765
device = torch.device('cpu')
5866
else:
5967
device = torch.device('cuda:1')
6068

61-
CKPT_FP = (
62-
ModelConfig.CKPT_DIRPATH /
63-
f'ckpt_epoch_{ModelConfig.EPOCH}_step_{ModelConfig.STEPS}.pth'
64-
)
69+
if ModelConfig.USE_RESUME_CKPT:
70+
filename = '{}-ckpt_epoch_{}_step_{}.pth'.format(
71+
ModelConfig.RESUME_ITER,
72+
ModelConfig.EPOCH,
73+
ModelConfig.STEPS
74+
)
75+
CKPT_FP = ModelConfig.CKPT_RESUME_DIRPATH / filename
76+
else:
77+
CKPT_FP = (
78+
ModelConfig.CKPT_DIRPATH /
79+
f'ckpt_epoch_{ModelConfig.EPOCH}_step_{ModelConfig.STEPS}.pth'
80+
)
6581
assert CKPT_FP.exists()
6682

6783
# THE MODEL

python/oddkiva/brahma/torch/tasks/object_detection/tools/train_rtdetrv2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def get_cuda_memory_usage():
4848
stdout=subprocess.PIPE
4949
)
5050
mb_used = result.stdout.decode('utf-8').strip().split('\n')
51-
mb_used = [f'[GPU:{id}] {mb}' for id, mb in enumerate(mb_used)]
52-
mb_used = "\n".join(mb_used)
51+
mb_used = [f'GPU{id}: {mb}' for id, mb in enumerate(mb_used)]
52+
mb_used = ", ".join(mb_used)
5353
return mb_used
5454

5555

@@ -226,7 +226,7 @@ def train_for_one_epoch(
226226

227227
if gpu_id == 0:
228228
logger.info(format_msg((
229-
f'[E:{epoch:0>2},S:{step:0>5}] Memory usage:\n'
229+
f'[E:{epoch:0>2},S:{step:0>5}] Memory: '
230230
f'{get_cuda_memory_usage()}'
231231
)))
232232

0 commit comments

Comments
 (0)