@@ -69,6 +69,7 @@ def _run_batch(batch, model, collate):
6969 )
7070 return preds , labels
7171
72+
7273def _run_batch_give_attention (batch , model , collate ):
7374 collated = collate (batch )
7475 collated .x = collated .to_x (model .device )
@@ -82,6 +83,7 @@ def _run_batch_give_attention(batch, model, collate):
8283 )
8384 return preds , labels , model_output
8485
86+
8587def _concat_tuple (l_ ):
8688 if isinstance (l_ [0 ], tuple ):
8789 print (l_ [0 ])
@@ -170,7 +172,7 @@ def evaluate_model(
170172 test_labels = _concat_tuple (labels_list )
171173 return test_preds , test_labels
172174 return test_preds , None
173- elif len (preds_list ) > 0 :
175+ elif len (preds_list ) > 0 :
174176 if preds_list [0 ] is not None :
175177 torch .save (
176178 _concat_tuple (preds_list ),
@@ -280,6 +282,7 @@ def evaluate_model_regression(
280282 )
281283 return torch .cat (preds_list_all ), torch .cat (labels_list_all )
282284
285+
283286def evaluate_model_regression_attention (
284287 model : ChebaiBaseNet ,
285288 data_module : XYBaseDataModule ,
@@ -336,7 +339,9 @@ def evaluate_model_regression_attention(
336339 skip_existing_preds
337340 and os .path .isfile (os .path .join (buffer_dir , f"preds{ save_ind :03d} .pt" ))
338341 ):
339- preds , labels , model_output = _run_batch_give_attention (data_list [i : i + batch_size ], model , collate )
342+ preds , labels , model_output = _run_batch_give_attention (
343+ data_list [i : i + batch_size ], model , collate
344+ )
340345 preds_list .append (preds )
341346 labels_list .append (labels )
342347 preds_list_all .append (preds )
@@ -477,6 +482,7 @@ def evaluate_model_regression(
477482 )
478483 return torch .cat (preds_list_all ), torch .cat (labels_list_all )
479484
485+
480486def evaluate_model_regression_attention (
481487 model : ChebaiBaseNet ,
482488 data_module : XYBaseDataModule ,
@@ -533,7 +539,9 @@ def evaluate_model_regression_attention(
533539 skip_existing_preds
534540 and os .path .isfile (os .path .join (buffer_dir , f"preds{ save_ind :03d} .pt" ))
535541 ):
536- preds , labels , model_output = _run_batch_give_attention (data_list [i : i + batch_size ], model , collate )
542+ preds , labels , model_output = _run_batch_give_attention (
543+ data_list [i : i + batch_size ], model , collate
544+ )
537545 preds_list .append (preds )
538546 labels_list .append (labels )
539547 preds_list_all .append (preds )
@@ -575,7 +583,12 @@ def evaluate_model_regression_attention(
575583 _concat_tuple (labels_list ),
576584 os .path .join (buffer_dir , f"labels{ save_ind :03d} .pt" ),
577585 )
578- return torch .cat (preds_list_all ), torch .cat (labels_list_all ), features_list_all , attention_list_all
586+ return (
587+ torch .cat (preds_list_all ),
588+ torch .cat (labels_list_all ),
589+ features_list_all ,
590+ attention_list_all ,
591+ )
579592
580593
581594def load_results_from_buffer (
0 commit comments