-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
113 lines (93 loc) · 3.76 KB
/
Copy pathmain.py
File metadata and controls
113 lines (93 loc) · 3.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
import traceback
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info
import models
import tasks
import utils.callbacks
import utils.data
import utils.email
import utils.logging
DATA_PATHS = {
"shenzhen": {"feat": "data/sz_speed.csv", "adj": "data/sz_adj.csv"},
"losloop": {"feat": "data/los_speed.csv", "adj": "data/los_adj.csv"},
}
def get_model(args, dm):
model = None
if args.model_name == "GCN":
model = models.GCN(adj=dm.adj, input_dim=args.seq_len, output_dim=args.hidden_dim)
if args.model_name == "GRU":
model = models.GRU(input_dim=dm.adj.shape[0], hidden_dim=args.hidden_dim)
if args.model_name == "TGCN":
model = models.TGCN(adj=dm.adj, hidden_dim=args.hidden_dim)
return model
def get_task(args, model, dm):
task = getattr(tasks, args.settings.capitalize() + "ForecastTask")(
model=model, feat_max_val=dm.feat_max_val, **vars(args)
)
return task
def get_callbacks(args):
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="train_loss")
plot_validation_predictions_callback = utils.callbacks.PlotValidationPredictionsCallback(monitor="train_loss")
callbacks = [
checkpoint_callback,
plot_validation_predictions_callback,
]
return callbacks
def main_supervised(args):
dm = utils.data.SpatioTemporalCSVDataModule(
feat_path=DATA_PATHS[args.data]["feat"], adj_path=DATA_PATHS[args.data]["adj"], **vars(args)
)
model = get_model(args, dm)
task = get_task(args, model, dm)
callbacks = get_callbacks(args)
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks)
trainer.fit(task, dm)
results = trainer.validate(datamodule=dm)
return results
def main(args):
rank_zero_info(vars(args))
results = globals()["main_" + args.settings](args)
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument(
"--data", type=str, help="The name of the dataset", choices=("shenzhen", "losloop"), default="losloop"
)
parser.add_argument(
"--model_name",
type=str,
help="The name of the model for spatiotemporal prediction",
choices=("GCN", "GRU", "TGCN"),
default="GCN",
)
parser.add_argument(
"--settings",
type=str,
help="The type of settings, e.g. supervised learning",
choices=("supervised",),
default="supervised",
)
parser.add_argument("--log_path", type=str, default=None, help="Path to the output console log file")
parser.add_argument("--send_email", "--email", action="store_true", help="Send email when finished")
temp_args, _ = parser.parse_known_args()
parser = getattr(utils.data, temp_args.settings.capitalize() + "DataModule").add_data_specific_arguments(parser)
parser = getattr(models, temp_args.model_name).add_model_specific_arguments(parser)
parser = getattr(tasks, temp_args.settings.capitalize() + "ForecastTask").add_task_specific_arguments(parser)
args = parser.parse_args()
utils.logging.format_logger(pl._logger)
if args.log_path is not None:
utils.logging.output_logger_to_file(pl._logger, args.log_path)
try:
results = main(args)
except: # noqa: E722
traceback.print_exc()
if args.send_email:
tb = traceback.format_exc()
subject = "[Email Bot][❌] " + "-".join([args.settings, args.model_name, args.data])
utils.email.send_email(tb, subject)
exit(-1)
if args.send_email:
subject = "[Email Bot][✅] " + "-".join([args.settings, args.model_name, args.data])
utils.email.send_experiment_results_email(args, results, subject=subject)