-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathconfig.py
More file actions
77 lines (58 loc) · 3.01 KB
/
config.py
File metadata and controls
77 lines (58 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
from registry.registry import ModelRegistry, DatasetRegistry
class Config:
"""Data and placement config: """
train_dir = '/host-dir/querysat'
data_dir = '/host-dir/data'
force_data_gen = False
ckpt_count = 3
eager = False
restore = None
label = ""
"""Training and task selection config: """
train_steps = 10000
warmup = 0.0
learning_rate = 0.0002
model = 'querysat' # querysat, neurocore, neurocore_query
task = 'ksat' # ksat, kcolor, 3sat, clique, sha2019
"""Supported training and evaluation modes: """
train = False
evaluate = False
evaluate_round_gen = False
evaluate_variable_gen = False
make_cactus = False
make_scatter = False
"""Internal config variables: """
__arguments_parsed = False
@classmethod
def parse_config(cls):
if cls.__arguments_parsed:
raise RuntimeError("Arguments already parsed!")
config = cls.__argument_parser().parse_args()
for key, value in config.__dict__.items():
setattr(cls, key, value)
cls.__arguments_parsed = True
@classmethod
def __argument_parser(cls):
config_parser = argparse.ArgumentParser()
config_parser.add_argument('--train_dir', type=str, default=cls.train_dir)
config_parser.add_argument('--data_dir', type=str, default=cls.data_dir)
config_parser.add_argument('--restore', type=str, default=None)
config_parser.add_argument('--label', type=str, default=cls.label)
config_parser.add_argument('--ckpt_count', type=int, default=cls.ckpt_count)
config_parser.add_argument('--eager', action='store_true', default=cls.eager)
config_parser.add_argument('--train_steps', type=int, default=cls.train_steps)
config_parser.add_argument('--warmup', type=float, default=cls.warmup)
config_parser.add_argument('--learning_rate', type=float, default=cls.learning_rate)
config_parser.add_argument('--model', type=str, default=cls.model, const=cls.model, nargs='?',
choices=ModelRegistry().registered_names)
config_parser.add_argument('--task', type=str, default=cls.task, const=cls.task, nargs='?',
choices=DatasetRegistry().registered_names)
config_parser.add_argument('--force_data_gen', action='store_true', default=cls.force_data_gen)
config_parser.add_argument('--train', action='store_true', default=cls.train)
config_parser.add_argument('--evaluate', action='store_true', default=cls.evaluate)
config_parser.add_argument('--evaluate_round_gen', action='store_true', default=cls.evaluate_round_gen)
config_parser.add_argument('--evaluate_variable_gen', action='store_true', default=cls.evaluate_variable_gen)
config_parser.add_argument('--make_cactus', action='store_true', default=cls.make_cactus)
config_parser.add_argument('--make_scatter', action='store_true', default=cls.make_scatter)
return config_parser