Skip to content

Commit c47019a

Browse files
authored
Merge pull request #8 from ChEB-AI/feature/inconsistency-resolution
Move inconsistency resolution and chebi graph building to chebifier
2 parents 47126b8 + e789470 commit c47019a

File tree

7 files changed

+279
-41
lines changed

7 files changed

+279
-41
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ A web application for the ensemble is available at https://chebifier.hastingslab
66

77
## Installation
88

9+
Note: `chebai-graph` and its dependencies cannot be installed automatically. To install it, follow
10+
the instructions in the [chebai-graph repository](https://github.com/ChEB-AI/python-chebai-graph). Other dependencies are installed automatically.
11+
912
You can get the package from PyPI:
1013
```bash
1114
pip install chebifier
@@ -21,9 +24,6 @@ cd python-chebifier
2124
pip install -e .
2225
```
2326

24-
`chebai-graph` and its dependencies cannot be installed automatically. If you want to use Graph Neural Networks, follow
25-
the instructions in the [chebai-graph repository](https://github.com/ChEB-AI/python-chebai-graph).
26-
2727
## Usage
2828

2929
### Command Line Interface

chebifier/ensemble/base_ensemble.py

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
import torch
55
import tqdm
6-
from chebai.preprocessing.datasets.chebi import ChEBIOver50
7-
from chebai.result.analyse_sem import PredictionSmoother, get_chebi_graph
6+
from chebifier.inconsistency_resolution import PredictionSmoother
7+
from chebifier.utils import load_chebi_graph, get_disjoint_files
88

99
from chebifier.check_env import check_package_installed
1010
from chebifier.prediction_models.base_predictor import BasePredictor
@@ -21,32 +21,8 @@ def __init__(
2121
# Deferred Import: To avoid circular import error
2222
from chebifier.model_registry import MODEL_TYPES
2323

24-
self.chebi_dataset = ChEBIOver50(chebi_version=chebi_version)
25-
self.chebi_dataset._download_required_data() # download chebi if not already downloaded
26-
self.chebi_graph = get_chebi_graph(self.chebi_dataset, None)
27-
local_disjoint_files = [
28-
os.path.join("data", "disjoint_chebi.csv"),
29-
os.path.join("data", "disjoint_additional.csv"),
30-
]
31-
self.disjoint_files = []
32-
for file in local_disjoint_files:
33-
if os.path.isfile(file):
34-
self.disjoint_files.append(file)
35-
else:
36-
print(
37-
f"Disjoint axiom file {file} not found. Loading from huggingface instead..."
38-
)
39-
from chebifier.hugging_face import download_model_files
40-
41-
self.disjoint_files.append(
42-
download_model_files(
43-
{
44-
"repo_id": "chebai/chebifier",
45-
"repo_type": "dataset",
46-
"files": {"disjoint_file": os.path.basename(file)},
47-
}
48-
)["disjoint_file"]
49-
)
24+
self.chebi_graph = load_chebi_graph()
25+
self.disjoint_files = get_disjoint_files()
5026

5127
self.models = []
5228
self.positive_prediction_threshold = 0.5
@@ -72,7 +48,7 @@ def __init__(
7248

7349
if resolve_inconsistencies:
7450
self.smoother = PredictionSmoother(
75-
self.chebi_dataset,
51+
self.chebi_graph,
7652
label_names=None,
7753
disjoint_files=self.disjoint_files,
7854
)
@@ -203,10 +179,11 @@ def predict_smiles_list(
203179
"Warning: No classes have been predicted for the given SMILES list."
204180
)
205181
# save predictions
206-
torch.save(ordered_predictions, preds_file)
207-
with open(predicted_classes_file, "w") as f:
208-
for cls in predicted_classes:
209-
f.write(f"{cls}\n")
182+
if load_preds_if_possible:
183+
torch.save(ordered_predictions, preds_file)
184+
with open(predicted_classes_file, "w") as f:
185+
for cls in predicted_classes:
186+
f.write(f"{cls}\n")
210187
predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)}
211188
else:
212189
print(

chebifier/inconsistency_resolution.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import csv
2+
import os
3+
import torch
4+
from pathlib import Path
5+
6+
7+
def get_disjoint_groups(disjoint_files):
8+
if disjoint_files is None:
9+
disjoint_files = os.path.join("data", "chebi-disjoints.owl")
10+
disjoint_pairs, disjoint_groups = [], []
11+
for file in disjoint_files:
12+
if isinstance(file, Path):
13+
file = str(file)
14+
if file.endswith(".csv"):
15+
with open(file, "r") as f:
16+
reader = csv.reader(f)
17+
disjoint_pairs += [line for line in reader]
18+
elif file.endswith(".owl"):
19+
with open(file, "r") as f:
20+
plaintext = f.read()
21+
segments = plaintext.split("<")
22+
disjoint_pairs = []
23+
left = None
24+
for seg in segments:
25+
if seg.startswith("rdf:Description ") or seg.startswith(
26+
"owl:Class"
27+
):
28+
left = int(seg.split('rdf:about="&obo;CHEBI_')[1].split('"')[0])
29+
elif seg.startswith("owl:disjointWith"):
30+
right = int(
31+
seg.split('rdf:resource="&obo;CHEBI_')[1].split('"')[0]
32+
)
33+
disjoint_pairs.append([left, right])
34+
35+
disjoint_groups = []
36+
for seg in plaintext.split("<rdf:Description>"):
37+
if "owl;AllDisjointClasses" in seg:
38+
classes = seg.split('rdf:about="&obo;CHEBI_')[1:]
39+
classes = [int(c.split('"')[0]) for c in classes]
40+
disjoint_groups.append(classes)
41+
else:
42+
raise NotImplementedError(
43+
"Unsupported disjoint file format: " + file.split(".")[-1]
44+
)
45+
46+
disjoint_all = disjoint_pairs + disjoint_groups
47+
# one disjointness is commented out in the owl-file
48+
# (the correct way would be to parse the owl file and notice the comment symbols, but for this case, it should work)
49+
if [22729, 51880] in disjoint_all:
50+
disjoint_all.remove([22729, 51880])
51+
# print(f"Found {len(disjoint_all)} disjoint groups")
52+
return disjoint_all
53+
54+
55+
class PredictionSmoother:
56+
"""Removes implication and disjointness violations from predictions"""
57+
58+
def __init__(self, chebi_graph, label_names=None, disjoint_files=None):
59+
self.chebi_graph = chebi_graph
60+
self.set_label_names(label_names)
61+
self.disjoint_groups = get_disjoint_groups(disjoint_files)
62+
63+
def set_label_names(self, label_names):
64+
if label_names is not None:
65+
self.label_names = label_names
66+
chebi_subgraph = self.chebi_graph.subgraph(self.label_names)
67+
self.label_successors = torch.zeros(
68+
(len(self.label_names), len(self.label_names)), dtype=torch.bool
69+
)
70+
for i, label in enumerate(self.label_names):
71+
self.label_successors[i, i] = 1
72+
for p in chebi_subgraph.successors(label):
73+
if p in self.label_names:
74+
self.label_successors[i, self.label_names.index(p)] = 1
75+
self.label_successors = self.label_successors.unsqueeze(0)
76+
77+
def __call__(self, preds):
78+
if preds.shape[1] == 0:
79+
# no labels predicted
80+
return preds
81+
# preds shape: (n_samples, n_labels)
82+
preds_sum_orig = torch.sum(preds)
83+
# step 1: apply implications: for each class, set prediction to max of itself and all successors
84+
preds = preds.unsqueeze(1)
85+
preds_masked_succ = torch.where(self.label_successors, preds, 0)
86+
# preds_masked_succ shape: (n_samples, n_labels, n_labels)
87+
88+
preds = preds_masked_succ.max(dim=2).values
89+
if torch.sum(preds) != preds_sum_orig:
90+
print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")
91+
preds_sum_orig = torch.sum(preds)
92+
# step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
93+
preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49)
94+
for disj_group in self.disjoint_groups:
95+
disj_group = [
96+
self.label_names.index(g) for g in disj_group if g in self.label_names
97+
]
98+
if len(disj_group) > 1:
99+
old_preds = preds[:, disj_group]
100+
disj_max = torch.max(preds[:, disj_group], dim=1)
101+
for i, row in enumerate(preds):
102+
for l_ in range(len(preds[i])):
103+
if l_ in disj_group and l_ != disj_group[disj_max.indices[i]]:
104+
preds[i, l_] = preds_bounded[i, l_]
105+
samples_changed = 0
106+
for i, row in enumerate(preds[:, disj_group]):
107+
if any(r != o for r, o in zip(row, old_preds[i])):
108+
samples_changed += 1
109+
if samples_changed != 0:
110+
print(
111+
f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples"
112+
)
113+
if torch.sum(preds) != preds_sum_orig:
114+
print(f"Preds change (step 2): {torch.sum(preds) - preds_sum_orig}")
115+
preds_sum_orig = torch.sum(preds)
116+
# step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
117+
preds = preds.unsqueeze(1)
118+
preds_masked_predec = torch.where(
119+
torch.transpose(self.label_successors, 1, 2), preds, 1
120+
)
121+
preds = preds_masked_predec.min(dim=2).values
122+
if torch.sum(preds) != preds_sum_orig:
123+
print(f"Preds change (step 3): {torch.sum(preds) - preds_sum_orig}")
124+
return preds

chebifier/prediction_models/c3p_predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
3939
chebi_id
4040
] = result.is_match
4141
if result.is_match and self.chebi_graph is not None:
42-
for parent in list(self.chebi_graph.predecessors(int(chebi_id))):
42+
for parent in list(self.chebi_graph.predecessors(chebi_id)):
4343
result_reformatted[smiles_list.index(result.input_smiles)][
4444
str(parent)
4545
] = 1

chebifier/prediction_models/chemlog_predictor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
6363
sample_additions = dict()
6464
for cls in sample:
6565
if sample[cls] == 1:
66-
successors = list(self.chebi_graph.predecessors(int(cls)))
66+
successors = list(self.chebi_graph.predecessors(cls))
6767
if successors:
6868
for succ in successors:
6969
sample_additions[str(succ)] = 1
@@ -114,7 +114,7 @@ def predict_smiles(self, smiles: str) -> Optional[dict]:
114114
indirect_pos_labels = [
115115
str(pr)
116116
for label in pos_labels
117-
for pr in self.chebi_graph.predecessors(int(label))
117+
for pr in self.chebi_graph.predecessors(label)
118118
]
119119
pos_labels = list(set(pos_labels + indirect_pos_labels))
120120
return {

chebifier/utils.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import os
2+
3+
import networkx as nx
4+
import requests
5+
import fastobo
6+
from chebifier.hugging_face import download_model_files
7+
import pickle
8+
9+
10+
def load_chebi_graph(filename=None):
11+
"""Load ChEBI graph from Hugging Face (if filename is None) or local file"""
12+
if filename is None:
13+
print("Loading ChEBI graph from Hugging Face...")
14+
file = download_model_files(
15+
{
16+
"repo_id": "chebai/chebifier",
17+
"repo_type": "dataset",
18+
"files": {"f": "chebi_graph.pkl"},
19+
}
20+
)["f"]
21+
else:
22+
print(f"Loading ChEBI graph from local {filename}...")
23+
file = filename
24+
return pickle.load(open(file, "rb"))
25+
26+
27+
def term_callback(doc):
28+
"""Similar to the chebai function, but reduced to the necessary fields. Also, ChEBI IDs are strings"""
29+
parents = []
30+
name = None
31+
smiles = None
32+
for clause in doc:
33+
if isinstance(clause, fastobo.term.PropertyValueClause):
34+
t = clause.property_value
35+
if str(t.relation) == "http://purl.obolibrary.org/obo/chebi/smiles":
36+
assert smiles is None
37+
smiles = t.value
38+
# in older chebi versions, smiles strings are synonyms
39+
# e.g. synonym: "[F-].[Na+]" RELATED SMILES [ChEBI]
40+
elif isinstance(clause, fastobo.term.SynonymClause):
41+
if "SMILES" in clause.raw_value():
42+
assert smiles is None
43+
smiles = clause.raw_value().split('"')[1]
44+
elif isinstance(clause, fastobo.term.IsAClause):
45+
chebi_id = str(clause.term)
46+
chebi_id = chebi_id[chebi_id.index(":") + 1 :]
47+
parents.append(chebi_id)
48+
elif isinstance(clause, fastobo.term.NameClause):
49+
name = str(clause.name)
50+
51+
if isinstance(clause, fastobo.term.IsObsoleteClause):
52+
if clause.obsolete:
53+
# if the term document contains clause as obsolete as true, skips this document.
54+
return False
55+
chebi_id = str(doc.id)
56+
chebi_id = chebi_id[chebi_id.index(":") + 1 :]
57+
return {
58+
"id": chebi_id,
59+
"parents": parents,
60+
"name": name,
61+
"smiles": smiles,
62+
}
63+
64+
65+
def build_chebi_graph(chebi_version=241):
66+
"""Creates a networkx graph for the ChEBI hierarchy. Usually, you don't want to call this function directly, but rather use the `load_chebi_graph` function."""
67+
chebi_path = os.path.join("data", f"chebi_v{chebi_version}", "chebi.obo")
68+
os.makedirs(os.path.join("data", f"chebi_v{chebi_version}"), exist_ok=True)
69+
if not os.path.exists(chebi_path):
70+
url = f"http://purl.obolibrary.org/obo/chebi/{chebi_version}/chebi.obo"
71+
r = requests.get(url, allow_redirects=True)
72+
open(chebi_path, "wb").write(r.content)
73+
with open(chebi_path, encoding="utf-8") as chebi:
74+
chebi = "\n".join(line for line in chebi if not line.startswith("xref:"))
75+
76+
elements = []
77+
for term_doc in fastobo.loads(chebi):
78+
if (
79+
term_doc
80+
and isinstance(term_doc.id, fastobo.id.PrefixedIdent)
81+
and term_doc.id.prefix == "CHEBI"
82+
):
83+
term_dict = term_callback(term_doc)
84+
if term_dict:
85+
elements.append(term_dict)
86+
87+
g = nx.DiGraph()
88+
for n in elements:
89+
g.add_node(n["id"], **n)
90+
91+
# Only take the edges which connect the existing nodes, to avoid internal creation of obsolete nodes
92+
# https://github.com/ChEB-AI/python-chebai/pull/55#issuecomment-2386654142
93+
g.add_edges_from(
94+
[(p, q["id"]) for q in elements for p in q["parents"] if g.has_node(p)]
95+
)
96+
return nx.transitive_closure_dag(g)
97+
98+
99+
def get_disjoint_files():
100+
"""Gets local disjointness files if they are present in the right location, otherwise downloads them from Hugging Face."""
101+
local_disjoint_files = [
102+
os.path.join("data", "disjoint_chebi.csv"),
103+
os.path.join("data", "disjoint_additional.csv"),
104+
]
105+
disjoint_files = []
106+
for file in local_disjoint_files:
107+
if os.path.isfile(file):
108+
disjoint_files.append(file)
109+
else:
110+
print(
111+
f"Disjoint axiom file {file} not found. Loading from huggingface instead..."
112+
)
113+
114+
disjoint_files.append(
115+
download_model_files(
116+
{
117+
"repo_id": "chebai/chebifier",
118+
"repo_type": "dataset",
119+
"files": {"disjoint_file": os.path.basename(file)},
120+
}
121+
)["disjoint_file"]
122+
)
123+
return disjoint_files
124+
125+
126+
if __name__ == "__main__":
127+
# chebi_graph = build_chebi_graph(chebi_version=241)
128+
# save the graph to a file
129+
# pickle.dump(chebi_graph, open("chebi_graph.pkl", "wb"))
130+
chebi_graph = load_chebi_graph()
131+
print(chebi_graph)

0 commit comments

Comments
 (0)