15
15
from abc import ABC
16
16
from collections import OrderedDict
17
17
from itertools import cycle , permutations , product
18
- from typing import Any , Generator , Optional , Union
18
+ from typing import TYPE_CHECKING , Any , Dict , Generator , List , Optional , Union
19
19
20
- import fastobo
21
- import networkx as nx
22
20
import pandas as pd
23
- import requests
24
21
import torch
25
22
from rdkit import Chem
26
23
27
24
from chebai .preprocessing import reader as dr
28
25
from chebai .preprocessing .datasets .base import XYBaseDataModule , _DynamicDataset
29
26
27
+ if TYPE_CHECKING :
28
+ import fastobo
29
+ import networkx as nx
30
+
30
31
# exclude some entities from the dataset because the violate disjointness axioms
31
32
CHEBI_BLACKLIST = [
32
33
194026 ,
@@ -236,6 +237,8 @@ def _load_chebi(self, version: int) -> str:
236
237
Returns:
237
238
str: The file path of the loaded ChEBI ontology.
238
239
"""
240
+ import requests
241
+
239
242
chebi_name = self .raw_file_names_dict ["chebi" ]
240
243
chebi_path = os .path .join (self .raw_dir , chebi_name )
241
244
if not os .path .isfile (chebi_path ):
@@ -247,7 +250,7 @@ def _load_chebi(self, version: int) -> str:
247
250
open (chebi_path , "wb" ).write (r .content )
248
251
return chebi_path
249
252
250
- def _extract_class_hierarchy (self , data_path : str ) -> nx .DiGraph :
253
+ def _extract_class_hierarchy (self , data_path : str ) -> " nx.DiGraph" :
251
254
"""
252
255
Extracts the class hierarchy from the ChEBI ontology.
253
256
Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from
@@ -259,6 +262,9 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
259
262
Returns:
260
263
nx.DiGraph: The class hierarchy.
261
264
"""
265
+ import fastobo
266
+ import networkx as nx
267
+
262
268
with open (data_path , encoding = "utf-8" ) as chebi :
263
269
chebi = "\n " .join (line for line in chebi if not line .startswith ("xref:" ))
264
270
@@ -286,7 +292,7 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
286
292
print ("Compute transitive closure" )
287
293
return nx .transitive_closure_dag (g )
288
294
289
- def _graph_to_raw_dataset (self , g : nx .DiGraph ) -> pd .DataFrame :
295
+ def _graph_to_raw_dataset (self , g : " nx.DiGraph" ) -> pd .DataFrame :
290
296
"""
291
297
Converts the graph to a raw dataset.
292
298
Uses the graph created by `_extract_class_hierarchy` method to extract the
@@ -298,6 +304,8 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
298
304
Returns:
299
305
pd.DataFrame: The raw dataset created from the graph.
300
306
"""
307
+ import networkx as nx
308
+
301
309
smiles = nx .get_node_attributes (g , "smiles" )
302
310
names = nx .get_node_attributes (g , "name" )
303
311
@@ -696,7 +704,7 @@ def _name(self) -> str:
696
704
"""
697
705
return f"ChEBI{ self .THRESHOLD } "
698
706
699
- def select_classes (self , g : nx .DiGraph , * args , ** kwargs ) -> list :
707
+ def select_classes (self , g : " nx.DiGraph" , * args , ** kwargs ) -> List :
700
708
"""
701
709
Selects classes from the ChEBI dataset based on the number of successors meeting a specified threshold.
702
710
@@ -721,6 +729,8 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> list:
721
729
- The `THRESHOLD` attribute should be defined in the subclass of this class.
722
730
- Nodes without a 'smiles' attribute are ignored in the successor count.
723
731
"""
732
+ import networkx as nx
733
+
724
734
smiles = nx .get_node_attributes (g , "smiles" )
725
735
nodes = list (
726
736
sorted (
@@ -859,7 +869,7 @@ def processed_dir_main(self) -> str:
859
869
"processed" ,
860
870
)
861
871
862
- def _extract_class_hierarchy (self , chebi_path : str ) -> nx .DiGraph :
872
+ def _extract_class_hierarchy (self , chebi_path : str ) -> " nx.DiGraph" :
863
873
"""
864
874
Extracts a subset of ChEBI based on subclasses of the top class ID.
865
875
@@ -897,8 +907,10 @@ def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph:
897
907
)
898
908
return g
899
909
900
- def select_classes (self , g : nx .DiGraph , * args , ** kwargs ) -> list :
910
+ def select_classes (self , g : " nx.DiGraph" , * args , ** kwargs ) -> List :
901
911
"""Only selects classes that meet the threshold AND are subclasses of the top class ID (including itself)."""
912
+ import networkx as nx
913
+
902
914
smiles = nx .get_node_attributes (g , "smiles" )
903
915
nodes = list (
904
916
sorted (
@@ -958,7 +970,7 @@ def chebi_to_int(s: str) -> int:
958
970
return int (s [s .index (":" ) + 1 :])
959
971
960
972
961
- def term_callback (doc : fastobo .term .TermFrame ) -> Union [dict , bool ]:
973
+ def term_callback (doc : " fastobo.term.TermFrame" ) -> Union [Dict , bool ]:
962
974
"""
963
975
Extracts information from a ChEBI term document.
964
976
This function takes a ChEBI term document as input and extracts relevant information such as the term ID, parents,
@@ -975,6 +987,8 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[dict, bool]:
975
987
- "name": The name of the ChEBI term.
976
988
- "smiles": The SMILES string associated with the ChEBI term, if available.
977
989
"""
990
+ import fastobo
991
+
978
992
parts = set ()
979
993
parents = []
980
994
name = None
0 commit comments