Skip to content

Commit fb6fdb7

Browse files
committed
lint fixes
1 parent b0b3113 commit fb6fdb7

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

chebai/preprocessing/reader.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,10 @@ def _back_to_smiles(smiles_encoded):
219219
token_file = self.reader.token_path
220220
token_coding = {}
221221
counter = 0
222-
smiles_decoded = ''
223-
222+
smiles_decoded = ""
223+
224224
# todo: for now just copied over from a notebook but ideally do this using the cache
225-
with open(token_file, 'r') as file:
225+
with open(token_file, "r") as file:
226226
for line in file:
227227
token_coding[counter] = line.strip()
228228
counter += 1
@@ -232,6 +232,7 @@ def _back_to_smiles(smiles_encoded):
232232

233233
return smiles_decoded
234234

235+
235236
class DeepChemDataReader(ChemDataReader):
236237
"""
237238
Data reader for chemical data using DeepSMILES tokens.

chebai/result/utils.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def _run_batch(batch, model, collate):
6969
)
7070
return preds, labels
7171

72+
7273
def _run_batch_give_attention(batch, model, collate):
7374
collated = collate(batch)
7475
collated.x = collated.to_x(model.device)
@@ -82,6 +83,7 @@ def _run_batch_give_attention(batch, model, collate):
8283
)
8384
return preds, labels, model_output
8485

86+
8587
def _concat_tuple(l_):
8688
if isinstance(l_[0], tuple):
8789
print(l_[0])
@@ -170,7 +172,7 @@ def evaluate_model(
170172
test_labels = _concat_tuple(labels_list)
171173
return test_preds, test_labels
172174
return test_preds, None
173-
elif len(preds_list) > 0 :
175+
elif len(preds_list) > 0:
174176
if preds_list[0] is not None:
175177
torch.save(
176178
_concat_tuple(preds_list),
@@ -280,6 +282,7 @@ def evaluate_model_regression(
280282
)
281283
return torch.cat(preds_list_all), torch.cat(labels_list_all)
282284

285+
283286
def evaluate_model_regression_attention(
284287
model: ChebaiBaseNet,
285288
data_module: XYBaseDataModule,
@@ -336,7 +339,9 @@ def evaluate_model_regression_attention(
336339
skip_existing_preds
337340
and os.path.isfile(os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"))
338341
):
339-
preds, labels, model_output = _run_batch_give_attention(data_list[i : i + batch_size], model, collate)
342+
preds, labels, model_output = _run_batch_give_attention(
343+
data_list[i : i + batch_size], model, collate
344+
)
340345
preds_list.append(preds)
341346
labels_list.append(labels)
342347
preds_list_all.append(preds)
@@ -477,6 +482,7 @@ def evaluate_model_regression(
477482
)
478483
return torch.cat(preds_list_all), torch.cat(labels_list_all)
479484

485+
480486
def evaluate_model_regression_attention(
481487
model: ChebaiBaseNet,
482488
data_module: XYBaseDataModule,
@@ -533,7 +539,9 @@ def evaluate_model_regression_attention(
533539
skip_existing_preds
534540
and os.path.isfile(os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"))
535541
):
536-
preds, labels, model_output = _run_batch_give_attention(data_list[i : i + batch_size], model, collate)
542+
preds, labels, model_output = _run_batch_give_attention(
543+
data_list[i : i + batch_size], model, collate
544+
)
537545
preds_list.append(preds)
538546
labels_list.append(labels)
539547
preds_list_all.append(preds)
@@ -575,7 +583,12 @@ def evaluate_model_regression_attention(
575583
_concat_tuple(labels_list),
576584
os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"),
577585
)
578-
return torch.cat(preds_list_all), torch.cat(labels_list_all), features_list_all, attention_list_all
586+
return (
587+
torch.cat(preds_list_all),
588+
torch.cat(labels_list_all),
589+
features_list_all,
590+
attention_list_all,
591+
)
579592

580593

581594
def load_results_from_buffer(

0 commit comments

Comments
 (0)