Skip to content

Commit a23d733

Browse files
author
David Josef Emmerichs
committed
add hierarchical configuration via runcon
1 parent 87e5675 commit a23d733

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

pcdet/config.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pathlib import Path
2+
from typing import List
23

34
import yaml
45
from easydict import EasyDict
@@ -123,6 +124,31 @@ def rc_to_ed(rc_cfg: Config) -> EasyDict:
123124
return ed_cfg
124125

125126

127+
def modify_rc_cfg(cfg: Config, modify_cfgs: List[Path]) -> Config:
128+
from copy import deepcopy
129+
cfg = deepcopy(cfg)
130+
for m in modify_cfgs:
131+
cfg.rupdate(Config.from_file(m))
132+
cfg = cfg.resolve_transforms()
133+
return cfg
134+
135+
136+
def create_cfg_from_sets(
137+
cfg_file: Path,
138+
modify_cfgs: List[Path],
139+
set_cfgs: List[str],
140+
cfg: EasyDict = None,
141+
) -> EasyDict:
142+
if cfg is None:
143+
cfg = EasyDict()
144+
cfg_from_yaml_file(cfg_file, cfg)
145+
cfg = ed_to_rc(cfg)
146+
cfg = modify_rc_cfg(cfg, modify_cfgs)
147+
cfg = rc_to_ed(cfg)
148+
cfg_from_list(set_cfgs, cfg)
149+
return cfg
150+
151+
126152
cfg = EasyDict()
127153
cfg.ROOT_DIR = (Path(__file__).resolve().parent / '../').resolve()
128154
cfg.LOCAL_RANK = 0

tools/train.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch.nn as nn
1111
from tensorboardX import SummaryWriter
1212

13-
from pcdet.config import cfg, cfg_from_list, cfg_from_yaml_file, log_config_to_file
13+
from pcdet.config import cfg, log_config_to_file, create_cfg_from_sets
1414
from pcdet.datasets import build_dataloader
1515
from pcdet.models import build_network, model_fn_decorator
1616
from pcdet.utils import common_utils
@@ -19,6 +19,8 @@
1919

2020

2121
def parse_config():
22+
global cfg
23+
2224
parser = argparse.ArgumentParser(description='arg parser')
2325
parser.add_argument('--cfg_file', type=str, default=None, help='specify the config for training')
2426

@@ -38,6 +40,8 @@ def parse_config():
3840
parser.add_argument('--merge_all_iters_to_one_epoch', action='store_true', default=False, help='')
3941
parser.add_argument('--set', dest='set_cfgs', default=None, nargs=argparse.REMAINDER,
4042
help='set extra config keys if needed')
43+
parser.add_argument('--modify', type=str, dest='modify_cfgs', default=None, nargs='*',
44+
help='specify extra modifier configs')
4145

4246
parser.add_argument('--max_waiting_mins', type=int, default=0, help='max waiting minutes')
4347
parser.add_argument('--start_epoch', type=int, default=0, help='')
@@ -52,16 +56,18 @@ def parse_config():
5256

5357

5458
args = parser.parse_args()
59+
if args.set_cfgs is None:
60+
args.set_cfgs = []
61+
if args.modify_cfgs is None:
62+
args.modify_cfgs = []
63+
64+
cfg = create_cfg_from_sets(args.cfg_file, args.modify_cfgs, args.set_cfgs, cfg)
5565

56-
cfg_from_yaml_file(args.cfg_file, cfg)
5766
cfg.TAG = Path(args.cfg_file).stem
5867
cfg.EXP_GROUP_PATH = '/'.join(args.cfg_file.split('/')[1:-1]) # remove 'cfgs' and 'xxxx.yaml'
5968

6069
args.use_amp = args.use_amp or cfg.OPTIMIZATION.get('USE_AMP', False)
6170

62-
if args.set_cfgs is not None:
63-
cfg_from_list(args.set_cfgs, cfg)
64-
6571
return args, cfg
6672

6773

0 commit comments

Comments
 (0)