Skip to content

Commit 795e67a

Browse files
committed
Merge branch 'dev' into fix/model-inference-dependencies
2 parents 925eea5 + 016b5ea commit 795e67a

File tree

6 files changed

+364
-53
lines changed

6 files changed

+364
-53
lines changed

chebai/preprocessing/bin/smiles_token/tokens.txt

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,3 +819,168 @@ p
819819
[16N]
820820
[17N]
821821
[14N]
822+
[Pb+2]
823+
[AlH4-]
824+
[BH4-]
825+
[Pt-2]
826+
[Cl+2]
827+
[I+3]
828+
[Br+2]
829+
[Cl+3]
830+
[Os-2]
831+
[Cr-2]
832+
[Hg-2]
833+
[PH]
834+
[Br+3]
835+
[I+2]
836+
[AsH2]
837+
[SH]
838+
[W-2]
839+
[Cd-2]
840+
[Ir-2]
841+
[Ru-2]
842+
[Rh-2]
843+
[Ag-2]
844+
[Be-2]
845+
[TeH2+]
846+
[13c]
847+
[13cH]
848+
[PH4]
849+
[AsH4]
850+
[As-2]
851+
[SbH3+]
852+
[SbH4]
853+
[BiH3]
854+
[BH3-]
855+
[GeH3]
856+
[GeH2]
857+
[SiH2-]
858+
[SiH2+]
859+
[SnH2]
860+
[SnH3]
861+
[SnH]
862+
[PbH]
863+
[PbH3]
864+
[Al-2]
865+
[B+2]
866+
[N+2]
867+
[SbH]
868+
[SbH2]
869+
[InH2]
870+
[GaH2]
871+
[TlH2]
872+
[Au+2]
873+
[sH+]
874+
[Hg+2]
875+
[Si-2]
876+
[Sn-2]
877+
[Pb-2]
878+
[AsH3]
879+
[Cr+2]
880+
[Ag+2]
881+
[V-2]
882+
[Ce-2]
883+
[13C@]
884+
[*+2]
885+
[He+2]
886+
[4He+2]
887+
[3He+2]
888+
[Eu+2]
889+
[Ge+2]
890+
[Os+2]
891+
[Y+2]
892+
[Gd+2]
893+
[La+2]
894+
[Se+2]
895+
[NH-2]
896+
[TeH2-]
897+
[AlH3-]
898+
[SbH3-]
899+
[AsH3-]
900+
[BiH3-]
901+
[PH3-]
902+
[CH2-2]
903+
[AsH4+]
904+
[AlH3+]
905+
[BiH3+]
906+
[FH+]
907+
[CH3+]
908+
[Te-2]
909+
[OH]
910+
[CH3]
911+
[18OH2]
912+
[OH3+]
913+
[OH4+2]
914+
[SH3]
915+
[SH3+]
916+
[SH3-]
917+
[SH4]
918+
[SeH2]
919+
[SeH-]
920+
[SeH3+]
921+
[SeH3-]
922+
[SeH3]
923+
[SeH+]
924+
[TeH2]
925+
[TeH-]
926+
[TeH3-]
927+
[TeH3+]
928+
[TeH+]
929+
[TeH3]
930+
[TeH4]
931+
[PoH2]
932+
[NH2]
933+
[NH+2]
934+
[PH5]
935+
[PH4+]
936+
[PH-2]
937+
[PH4-]
938+
[PH+2]
939+
[AsH2+]
940+
[AsH2-]
941+
[AsH+2]
942+
[AsH-2]
943+
[AsH5]
944+
[SbH3]
945+
[SbH4+]
946+
[SbH5]
947+
[BiH4+]
948+
[BiH5]
949+
[BiH4-]
950+
[BH2]
951+
[BH2+]
952+
[BH2-]
953+
[BH-2]
954+
[BH+2]
955+
[GeH4]
956+
[GeH3+]
957+
[GeH3-]
958+
[SiH3-]
959+
[SiH3+]
960+
[SiH+]
961+
[SiH4]
962+
[HeH+2]
963+
[HeH+]
964+
[AlH]
965+
[AlH+]
966+
[SnH4]
967+
[SnH3-]
968+
[SnH3+]
969+
[PbH4]
970+
[PbH3-]
971+
[PbH3+]
972+
[BeH4-2]
973+
[BeH]
974+
[BeH+]
975+
[BeH-]
976+
[BeH2]
977+
[AtH]
978+
[InH3]
979+
[GaH3]
980+
[TlH3]
981+
[IH3]
982+
[FeH6-4]
983+
[FH2+]
984+
[ClH2+]
985+
[BrH2+]
986+
[IH2+]

chebai/preprocessing/datasets/chebi.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,7 @@ def _graph_to_raw_dataset(self, g: "nx.DiGraph") -> pd.DataFrame:
309309
data = pd.DataFrame(data)
310310
data = data[~data["SMILES"].isnull()]
311311
data = data[[name not in CHEBI_BLACKLIST for name, _ in data.iterrows()]]
312-
# This filters the DataFrame to include only the rows where at least one value in the row from 4th column
313-
# onwards is True/non-zero.
314-
data = data[data.iloc[:, self._LABELS_START_IDX :].any(axis=1)]
312+
315313
return data
316314

317315
# ------------------------------ Phase: Setup data -----------------------------------
@@ -712,18 +710,24 @@ class ChEBIOverXPartial(ChEBIOverX):
712710
top_class_id (int): The ID of the top class from which to extract subclasses.
713711
"""
714712

715-
def __init__(self, top_class_id: int, **kwargs):
713+
def __init__(self, top_class_id: int, external_data_ratio: float, **kwargs):
716714
"""
717715
Initializes the ChEBIOverXPartial dataset.
718716
719717
Args:
720718
top_class_id (int): The ID of the top class from which to extract subclasses.
721719
**kwargs: Additional keyword arguments passed to the superclass initializer.
720+
external_data_ratio (float): How much external data (i.e., samples where top_class_id
721+
is no positive label) to include in the dataset. 0 means no external data, 1 means
722+
the maximum amount (i.e., the complete ChEBI dataset).
722723
"""
723724
if "top_class_id" not in kwargs:
724725
kwargs["top_class_id"] = top_class_id
726+
if "external_data_ratio" not in kwargs:
727+
kwargs["external_data_ratio"] = external_data_ratio
725728

726729
self.top_class_id: int = top_class_id
730+
self.external_data_ratio: float = external_data_ratio
727731
super().__init__(**kwargs)
728732

729733
@property
@@ -737,7 +741,7 @@ def processed_dir_main(self) -> str:
737741
return os.path.join(
738742
self.base_dir,
739743
self._name,
740-
f"partial_{self.top_class_id}",
744+
f"partial_{self.top_class_id}_ext_ratio_{self.external_data_ratio:.2f}",
741745
"processed",
742746
)
743747

@@ -756,9 +760,53 @@ def _extract_class_hierarchy(self, chebi_path: str) -> "nx.DiGraph":
756760
descendants of the top class ID.
757761
"""
758762
g = super()._extract_class_hierarchy(chebi_path)
759-
g = g.subgraph(list(g.successors(self.top_class_id)) + [self.top_class_id])
763+
top_class_successors = list(g.successors(self.top_class_id)) + [
764+
self.top_class_id
765+
]
766+
external_nodes = list(set(n for n in g.nodes if n not in top_class_successors))
767+
if 0 < self.external_data_ratio < 1:
768+
n_external_nodes = int(
769+
len(top_class_successors)
770+
* self.external_data_ratio
771+
/ (1 - self.external_data_ratio)
772+
)
773+
print(
774+
f"Extracting {n_external_nodes} external nodes from the ChEBI dataset (ratio: {self.external_data_ratio:.2f})"
775+
)
776+
external_nodes = external_nodes[: int(n_external_nodes)]
777+
elif self.external_data_ratio == 0:
778+
external_nodes = []
779+
780+
g = g.subgraph(top_class_successors + external_nodes)
781+
print(
782+
f"Subgraph contains {len(g.nodes)} nodes, of which {len(top_class_successors)} are subclasses of the top class ID {self.top_class_id}."
783+
)
760784
return g
761785

786+
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
787+
"""Only selects classes that meet the threshold AND are subclasses of the top class ID (including itself)."""
788+
smiles = nx.get_node_attributes(g, "smiles")
789+
nodes = list(
790+
sorted(
791+
{
792+
node
793+
for node in g.nodes
794+
if sum(
795+
1 if smiles[s] is not None else 0 for s in g.successors(node)
796+
)
797+
>= self.THRESHOLD
798+
and (
799+
self.top_class_id in g.predecessors(node)
800+
or node == self.top_class_id
801+
)
802+
}
803+
)
804+
)
805+
filename = "classes.txt"
806+
with open(os.path.join(self.processed_dir_main, filename), "wt") as fout:
807+
fout.writelines(str(node) + "\n" for node in nodes)
808+
return nodes
809+
762810

763811
class ChEBIOver50Partial(ChEBIOverXPartial, ChEBIOver50):
764812
"""
@@ -854,7 +902,7 @@ def term_callback(doc: "fastobo.term.TermFrame") -> Union[Dict, bool]:
854902

855903

856904
atom_index = (
857-
"\*",
905+
r"\*",
858906
"H",
859907
"He",
860908
"Li",
@@ -1485,3 +1533,15 @@ def term_callback(doc: "fastobo.term.TermFrame") -> Union[Dict, bool]:
14851533
]
14861534

14871535
JCI_500_COLUMNS_INT = [int(n.split(":")[-1]) for n in JCI_500_COLUMNS]
1536+
1537+
if __name__ == "__main__":
1538+
data_module_05 = ChEBIOver50Partial(
1539+
chebi_version=241,
1540+
splits_file_path=os.path.join(
1541+
"data", "chebi_v241", "ChEBI50", "splits_80_10_10.csv"
1542+
),
1543+
top_class_id=22712,
1544+
external_data_ratio=0.5,
1545+
)
1546+
data_module_05.prepare_data()
1547+
data_module_05.setup()

chebai/preprocessing/datasets/pubchem.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,13 @@ def setup_processed(self):
154154
print("Load data from file", filename)
155155
data = self._load_data_from_file(filename)
156156
print("Create splits")
157-
train, test = train_test_split(data, train_size=self.train_split)
157+
train, test = train_test_split(
158+
data, train_size=1 - (self.validation_split + self.test_split)
159+
)
158160
del data
159-
test, val = train_test_split(test, train_size=self.train_split)
161+
test, val = train_test_split(
162+
test, train_size=self.test_split / (self.validation_split + self.test_split)
163+
)
160164
torch.save(train, os.path.join(self.processed_dir, "train.pt"))
161165
torch.save(test, os.path.join(self.processed_dir, "test.pt"))
162166
torch.save(val, os.path.join(self.processed_dir, "validation.pt"))
@@ -179,6 +183,21 @@ def processed_file_names(self) -> List[str]:
179183
"""
180184
return ["test.pt", "train.pt", "validation.pt"]
181185

186+
def _set_processed_data_props(self):
187+
"""
188+
Self-supervised learning with PubChem does not use this metadata, therefore set them to zero.
189+
190+
Sets:
191+
- self._num_of_labels: 0
192+
- self._feature_vector_size: 0.
193+
"""
194+
195+
self._num_of_labels = 0
196+
self._feature_vector_size = 0
197+
198+
print(f"Number of labels for loaded data: {self._num_of_labels}")
199+
print(f"Feature vector size: {self._feature_vector_size}")
200+
182201
def _perform_data_preparation(self, *args, **kwargs):
183202
"""
184203
Checks for raw data and downloads if necessary.

chebai/preprocessing/reader.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, Dict, List, Optional
77

88
from pysmiles.read_smiles import _tokenize
9+
from rdkit import Chem
910

1011
from chebai.preprocessing.collate import DefaultCollator, RaggedCollator
1112

@@ -173,21 +174,35 @@ class ChemDataReader(TokenIndexerReader):
173174

174175
COLLATOR = RaggedCollator
175176

177+
def __init__(self, canonicalize_smiles=True, *args, **kwargs) -> None:
178+
super().__init__(*args, **kwargs)
179+
self.canonicalize_smiles = canonicalize_smiles
180+
print(f"Using SMILES canonicalization: {self.canonicalize_smiles}")
181+
176182
@classmethod
177183
def name(cls) -> str:
178184
"""Returns the name of the data reader."""
179185
return "smiles_token"
180186

181187
def _read_data(self, raw_data: str) -> List[int]:
182188
"""
183-
Reads and tokenizes raw SMILES data into a list of token indices.
189+
Reads and tokenizes raw SMILES data into a list of token indices. Canonicalizes the SMILES string using RDKit.
184190
185191
Args:
186192
raw_data (str): The raw SMILES string to be tokenized.
187193
188194
Returns:
189195
List[int]: A list of integers representing the indices of the SMILES tokens.
190196
"""
197+
if self.canonicalize_smiles:
198+
try:
199+
mol = Chem.MolFromSmiles(raw_data.strip())
200+
if mol is not None:
201+
raw_data = Chem.MolToSmiles(mol, canonical=True)
202+
except Exception as e:
203+
print(f"RDKit failed to process {raw_data}")
204+
print(f"\t{e}")
205+
191206
return [self._get_token_index(v[1]) for v in _tokenize(raw_data)]
192207

193208

0 commit comments

Comments
 (0)