@@ -52,12 +52,13 @@ def main(path_to_config: str):
5252 verbose = config ["prepare_data" ]["val_data" ]["verbose" ],
5353 )
5454
55- test_token_seq , test_label_seq = prepare_conll_data_format (
56- path = config ["prepare_data" ]["test_data" ]["path" ],
57- sep = config ["prepare_data" ]["test_data" ]["sep" ],
58- lower = config ["prepare_data" ]["test_data" ]["lower" ],
59- verbose = config ["prepare_data" ]["test_data" ]["verbose" ],
60- )
55+ if "test_data" in config ["prepare_data" ]:
56+ test_token_seq , test_label_seq = prepare_conll_data_format (
57+ path = config ["prepare_data" ]["test_data" ]["path" ],
58+ sep = config ["prepare_data" ]["test_data" ]["sep" ],
59+ lower = config ["prepare_data" ]["test_data" ]["lower" ],
60+ verbose = config ["prepare_data" ]["test_data" ]["verbose" ],
61+ )
6162
6263 # token2idx / label2idx
6364
@@ -91,13 +92,14 @@ def main(path_to_config: str):
9192 preprocess = config ["dataloader" ]["preprocess" ],
9293 )
9394
94- testset = NERDataset (
95- token_seq = test_token_seq ,
96- label_seq = test_label_seq ,
97- token2idx = token2idx ,
98- label2idx = label2idx ,
99- preprocess = config ["dataloader" ]["preprocess" ],
100- )
95+ if "test_data" in config ["prepare_data" ]:
96+ testset = NERDataset (
97+ token_seq = test_token_seq ,
98+ label_seq = test_label_seq ,
99+ token2idx = token2idx ,
100+ label2idx = label2idx ,
101+ preprocess = config ["dataloader" ]["preprocess" ],
102+ )
101103
102104 # collators
103105
@@ -113,11 +115,12 @@ def main(path_to_config: str):
113115 percentile = 100 , # hardcoded
114116 )
115117
116- test_collator = NERCollator (
117- token_padding_value = token2idx [config ["dataloader" ]["token_padding" ]],
118- label_padding_value = label2idx [config ["dataloader" ]["label_padding" ]],
119- percentile = 100 , # hardcoded
120- )
118+ if "test_data" in config ["prepare_data" ]:
119+ test_collator = NERCollator (
120+ token_padding_value = token2idx [config ["dataloader" ]["token_padding" ]],
121+ label_padding_value = label2idx [config ["dataloader" ]["label_padding" ]],
122+ percentile = 100 , # hardcoded
123+ )
121124
122125 # dataloaders
123126
@@ -136,12 +139,13 @@ def main(path_to_config: str):
136139 collate_fn = val_collator ,
137140 )
138141
139- testloader = DataLoader (
140- dataset = testset ,
141- batch_size = 1 , # hardcoded
142- shuffle = False , # hardcoded
143- collate_fn = test_collator ,
144- )
142+ if "test_data" in config ["prepare_data" ]:
143+ testloader = DataLoader (
144+ dataset = testset ,
145+ batch_size = 1 , # hardcoded
146+ shuffle = False , # hardcoded
147+ collate_fn = test_collator ,
148+ )
145149
146150 # INIT MODEL
147151
@@ -208,7 +212,7 @@ def main(path_to_config: str):
208212 model = model ,
209213 trainloader = trainloader ,
210214 valloader = valloader ,
211- testloader = testloader ,
215+ testloader = testloader if "test_data" in config [ "prepare_data" ] else None ,
212216 criterion = criterion ,
213217 optimizer = optimizer ,
214218 device = device ,
0 commit comments