-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathctoxpred_handler.py
More file actions
111 lines (90 loc) · 4.82 KB
/
ctoxpred_handler.py
File metadata and controls
111 lines (90 loc) · 4.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import joblib
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, Any, List, Optional
import subprocess
from CToxPred.pairwise_correlation import CorrelationThreshold
from CToxPred.utils import compute_fingerprint_features, compute_descriptor_features
from CToxPred.hERG_model import hERGClassifier
from CToxPred.nav15_model import Nav15Classifier
from CToxPred.cav12_model import Cav12Classifier
class CToxPredHandler:
"""Handles interactions with cardiotoxicity models with property selection."""
AVAILABLE_PROPERTIES = ["hERG", "Nav1.5", "Cav1.2"]
def __init__(self, model_base_path: str = "CToxPred/models") -> None:
self.device = torch.device('cpu')
self.model_base_path = Path(model_base_path)
# Models and Pipelines
self.models = {"hERG": None, "Nav1.5": None, "Cav1.2": None}
self.pipelines = {"Nav1.5": None, "Cav1.2": None}
self._load_models()
def _load_models(self) -> None:
try:
# Load hERG
self.models["hERG"] = hERGClassifier(1905, 2)
self.models["hERG"].load(str(self.model_base_path / "model_weights/hERG/_herg_checkpoint.model"))
# Load Nav1.5
self.pipelines["Nav1.5"] = joblib.load(self.model_base_path / "decriptors_preprocessing/Nav1.5/nav_descriptors_preprocessing_pipeline.sav")
self.models["Nav1.5"] = Nav15Classifier(2454, 2)
self.models["Nav1.5"].load(str(self.model_base_path / "model_weights/Nav1.5/_nav15_checkpoint.model"))
# Load Cav1.2
self.pipelines["Cav1.2"] = joblib.load(self.model_base_path / "decriptors_preprocessing/Cav1.2/cav_descriptors_preprocessing_pipeline.sav")
self.models["Cav1.2"] = Cav12Classifier(2586, 2)
self.models["Cav1.2"].load(str(self.model_base_path / "model_weights/Cav1.2/_cav12_checkpoint.model"))
for m in self.models.values():
m.to(self.device).eval()
print("Successfully initialized all CToxPred models.")
except Exception as e:
print(f"ERROR: Failed to load CToxPred models: {e}")
def process_multiple_properties_batch(self, smiles_list: List[str], property_list: List[str]) -> List[Dict[str, Any]]:
"""
Matches ADMEThyst interface. Returns a list of results, one per SMILES.
"""
if not smiles_list:
return []
# Validate property list
valid_props = [p for p in property_list if p in self.AVAILABLE_PROPERTIES]
try:
fps, descs = compute_fingerprint_features(smiles_list), compute_descriptor_features(smiles_list)
# Temporary storage for predictions per property
batch_preds = {}
with torch.no_grad():
if "hERG" in valid_props:
fps_tensor = torch.from_numpy(fps).float().to(self.device)
batch_preds["hERG"] = self.models["hERG"](fps_tensor).argmax(1).cpu().numpy()
if "Nav1.5" in valid_props:
nav_descs = self.pipelines["Nav1.5"].transform(descs)
nav_feats = torch.from_numpy(np.concatenate((fps, nav_descs), axis=1)).float().to(self.device)
batch_preds["Nav1.5"] = self.models["Nav1.5"](nav_feats).argmax(1).cpu().numpy()
if "Cav1.2" in valid_props:
cav_descs = self.pipelines["Cav1.2"].transform(descs)
cav_feats = torch.from_numpy(np.concatenate((fps, cav_descs), axis=1)).float().to(self.device)
batch_preds["Cav1.2"] = self.models["Cav1.2"](cav_feats).argmax(1).cpu().numpy()
# Construct the nested response format
final_results = []
for i, smiles in enumerate(smiles_list):
res_entry = {
"smiles": smiles,
"status": "success",
"results": {},
"error": None
}
for prop in valid_props:
res_entry["results"][prop] = {
"property": prop,
"status": "success",
"results": float(batch_preds[prop][i]),
"error": None
}
final_results.append(res_entry)
return final_results
except Exception as e:
# Return error objects for the batch
return [{"smiles": s, "status": "error", "results": {}, "error": str(e)} for s in smiles_list]
def process_multiple_properties(self, smiles: str, property_list: List[str]) -> Dict[str, Any]:
"""Process a single SMILES."""
results = self.process_multiple_properties_batch([smiles], property_list)
return results[0] if results else {}