Skip to content

Commit 835deb3

Browse files
committed
add set_global_seed in main.py
1 parent a0b4fa2 commit 835deb3

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
torch:
22
device: 'cpu'
3+
seed: 42
34

45
prepare_data:
56
train_data:

main.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from argparse import ArgumentParser
23
from collections import Counter
34

@@ -18,15 +19,20 @@
1819
)
1920
from pytorch_ner.save import save_model
2021
from pytorch_ner.train import train
21-
from pytorch_ner.utils import str_to_class
22+
from pytorch_ner.utils import set_global_seed, str_to_class
2223

2324

2425
def main(path_to_config: str):
2526

2627
with open(path_to_config, mode="r") as fp:
2728
config = yaml.safe_load(fp)
2829

30+
# check existence of save path_to_folder
31+
if os.path.exists(config["save"]["path_to_folder"]):
32+
raise FileExistsError("save directory already exists")
33+
2934
device = torch.device(config["torch"]["device"])
35+
set_global_seed(config["torch"]["seed"])
3036

3137
# LOAD DATA
3238

pytorch_ner/save.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def save_model(
1818
config: Dict,
1919
export_onnx: bool = False,
2020
):
21-
# make empty dir
21+
# check existence of save path_to_folder
2222
if os.path.exists(path_to_folder):
2323
raise FileExistsError("save directory already exists")
2424
mkdir(path_to_folder)

0 commit comments

Comments
 (0)