Skip to content

Commit 0c9e492

Browse files
authored
Merge pull request #15 from ChEB-AI/feature/result-cache
Global Cache per smiles per model
2 parents b355319 + 606ebda commit 0c9e492

File tree

8 files changed

+423
-31
lines changed

8 files changed

+423
-31
lines changed

chebifier/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Note: The top-level package __init__.py runs only once,
2+
# even if multiple subpackages are imported later.
3+
4+
from ._custom_cache import PerSmilesPerModelLRUCache
5+
6+
modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100)

chebifier/_custom_cache.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
import os
2+
import pickle
3+
import threading
4+
from collections import OrderedDict
5+
from collections.abc import Iterable
6+
from functools import wraps
7+
from typing import Any, Callable
8+
9+
10+
class PerSmilesPerModelLRUCache:
11+
"""
12+
A thread-safe, optionally persistent LRU cache for storing
13+
(SMILES, model_name) → result mappings.
14+
"""
15+
16+
def __init__(self, max_size: int = 100, persist_path: str | None = None):
17+
"""
18+
Initialize the cache.
19+
20+
Args:
21+
max_size (int): Maximum number of items to keep in the cache.
22+
persist_path (str | None): Optional path to persist cache using pickle.
23+
"""
24+
self._cache: OrderedDict[tuple[str, str], Any] = OrderedDict()
25+
self._max_size = max_size
26+
self._lock = threading.Lock()
27+
self._persist_path = persist_path
28+
29+
self.hits = 0
30+
self.misses = 0
31+
32+
if self._persist_path:
33+
self._load_cache()
34+
35+
def get(self, smiles: str, model_name: str) -> Any | None:
36+
"""
37+
Retrieve value from cache if present, otherwise return None.
38+
39+
Args:
40+
smiles (str): SMILES string key.
41+
model_name (str): Model identifier.
42+
43+
Returns:
44+
Any | None: Cached value or None.
45+
"""
46+
key = (smiles, model_name)
47+
with self._lock:
48+
if key in self._cache:
49+
self._cache.move_to_end(key)
50+
self.hits += 1
51+
return self._cache[key]
52+
else:
53+
self.misses += 1
54+
return None
55+
56+
def set(self, smiles: str, model_name: str, value: Any) -> None:
57+
"""
58+
Store value in cache under (smiles, model_name) key.
59+
60+
Args:
61+
smiles (str): SMILES string key.
62+
model_name (str): Model identifier.
63+
value (Any): Value to cache.
64+
"""
65+
assert value is not None, "Value must not be None"
66+
key = (smiles, model_name)
67+
with self._lock:
68+
if key in self._cache:
69+
self._cache.move_to_end(key)
70+
self._cache[key] = value
71+
if len(self._cache) > self._max_size:
72+
self._cache.popitem(last=False)
73+
74+
def clear(self) -> None:
75+
"""
76+
Clear the cache and remove the persistence file if present.
77+
"""
78+
self._save_cache()
79+
with self._lock:
80+
self._cache.clear()
81+
self.hits = 0
82+
self.misses = 0
83+
if self._persist_path and os.path.exists(self._persist_path):
84+
os.remove(self._persist_path)
85+
86+
def stats(self) -> dict[str, int]:
87+
"""
88+
Return cache hit/miss statistics.
89+
90+
Returns:
91+
dict[str, int]: Dictionary with 'hits' and 'misses' keys.
92+
"""
93+
return {"hits": self.hits, "misses": self.misses}
94+
95+
def batch_decorator(self, func: Callable) -> Callable:
96+
"""
97+
Decorator for class methods that accept a batch of SMILES as a list,
98+
and cache predictions per (smiles, model_name) key.
99+
100+
The instance is expected to have a `model_name` attribute.
101+
102+
Args:
103+
func (Callable): The method to decorate.
104+
105+
Returns:
106+
Callable: The wrapped method.
107+
"""
108+
109+
@wraps(func)
110+
def wrapper(instance, smiles_list: list[str]) -> list[Any]:
111+
assert isinstance(smiles_list, list), "smiles_list must be a list."
112+
model_name = getattr(instance, "model_name", None)
113+
assert model_name is not None, "Instance must have a model_name attribute."
114+
115+
missing_smiles: list[str] = []
116+
missing_indices: list[int] = []
117+
ordered_results: list[Any] = [None] * len(smiles_list)
118+
119+
# First: try to fetch all from cache
120+
for idx, smiles in enumerate(smiles_list):
121+
prediction = self.get(smiles=smiles, model_name=model_name)
122+
if prediction is not None:
123+
# For debugging purposes, you can uncomment the print statement below
124+
# print(
125+
# f"[Cache Hit] Prediction for smiles: {smiles} and model: {model_name} are retrieved from cache."
126+
# )
127+
ordered_results[idx] = prediction
128+
else:
129+
missing_smiles.append(smiles)
130+
missing_indices.append(idx)
131+
132+
# If some are missing, call original function
133+
if missing_smiles:
134+
new_results = func(instance, tuple(missing_smiles))
135+
assert isinstance(
136+
new_results, Iterable
137+
), "Function must return an Iterable."
138+
139+
# Save to cache and append
140+
for smiles, prediction, missing_idx in zip(
141+
missing_smiles, new_results, missing_indices
142+
):
143+
if prediction is not None:
144+
self.set(smiles, model_name, prediction)
145+
ordered_results[missing_idx] = prediction
146+
147+
return ordered_results
148+
149+
return wrapper
150+
151+
def __len__(self) -> int:
152+
"""
153+
Return number of items in the cache.
154+
155+
Returns:
156+
int: Number of entries in the cache.
157+
"""
158+
with self._lock:
159+
return len(self._cache)
160+
161+
def __repr__(self) -> str:
162+
"""
163+
String representation of the underlying cache.
164+
165+
Returns:
166+
str: String version of the OrderedDict.
167+
"""
168+
return self._cache.__repr__()
169+
170+
def save(self) -> None:
171+
"""
172+
Save the cache to disk, if persistence is enabled.
173+
"""
174+
self._save_cache()
175+
176+
def load(self) -> None:
177+
"""
178+
Load the cache from disk, if persistence is enabled.
179+
"""
180+
self._load_cache()
181+
182+
def _save_cache(self) -> None:
183+
"""
184+
Serialize the cache to disk using pickle.
185+
"""
186+
if self._persist_path:
187+
try:
188+
with open(self._persist_path, "wb") as f:
189+
pickle.dump(self._cache, f)
190+
except Exception as e:
191+
print(f"[Cache Save Error] {e}")
192+
193+
def _load_cache(self) -> None:
194+
"""
195+
Load the cache from disk, if the file exists and is non-empty.
196+
"""
197+
if (
198+
self._persist_path
199+
and os.path.exists(self._persist_path)
200+
and os.path.getsize(self._persist_path) > 0
201+
):
202+
try:
203+
with open(self._persist_path, "rb") as f:
204+
loaded = pickle.load(f)
205+
if isinstance(loaded, OrderedDict):
206+
self._cache = loaded
207+
except Exception as e:
208+
print(f"[Cache Load Error] {e}")

chebifier/prediction_models/base_predictor.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from abc import ABC
33

4-
from functools import lru_cache
4+
from chebifier import modelwise_smiles_lru_cache
55

66

77
class BasePredictor(ABC):
@@ -23,17 +23,13 @@ def __init__(
2323

2424
self._description = kwargs.get("description", None)
2525

26+
@modelwise_smiles_lru_cache.batch_decorator
2627
def predict_smiles_list(self, smiles_list: list[str]) -> dict:
27-
# list is not hashable, so we convert it to a tuple (useful for caching)
28-
return self.predict_smiles_tuple(tuple(smiles_list))
29-
30-
@lru_cache(maxsize=100)
31-
def predict_smiles_tuple(self, smiles_tuple: tuple[str]) -> dict:
3228
raise NotImplementedError()
3329

3430
def predict_smiles(self, smiles: str) -> dict:
3531
# by default, use list-based prediction
36-
return self.predict_smiles_tuple((smiles,))[0]
32+
return self.predict_smiles_list([smiles])[0]
3733

3834
@property
3935
def info_text(self):

chebifier/prediction_models/c3p_predictor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from functools import lru_cache
2-
from typing import Optional, List
31
from pathlib import Path
2+
from typing import List, Optional
43

54
from c3p import classifier as c3p_classifier
65

6+
from chebifier import modelwise_smiles_lru_cache
77
from chebifier.prediction_models import BasePredictor
88

99

@@ -24,8 +24,8 @@ def __init__(
2424
self.chemical_classes = chemical_classes
2525
self.chebi_graph = kwargs.get("chebi_graph", None)
2626

27-
@lru_cache(maxsize=100)
28-
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
27+
@modelwise_smiles_lru_cache.batch_decorator
28+
def predict_smiles_list(self, smiles_list: list[str]) -> list:
2929
result_list = c3p_classifier.classify(
3030
list(smiles_list),
3131
self.program_directory,

chebifier/prediction_models/chebi_lookup.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
from functools import lru_cache
1+
import json
2+
import os
23
from typing import Optional
34

4-
from chebifier.prediction_models import BasePredictor
5-
import os
65
import networkx as nx
76
from rdkit import Chem
8-
import json
7+
8+
from chebifier import modelwise_smiles_lru_cache
9+
from chebifier.prediction_models import BasePredictor
910
from chebifier.utils import load_chebi_graph
1011

1112

1213
class ChEBILookupPredictor(BasePredictor):
13-
1414
def __init__(
1515
self,
1616
model_name: str,
@@ -67,7 +67,6 @@ def build_smiles_lookup(self):
6767
)
6868
return smiles_lookup
6969

70-
@lru_cache(maxsize=100)
7170
def predict_smiles(self, smiles: str) -> Optional[dict]:
7271
if not smiles:
7372
return None
@@ -94,7 +93,8 @@ def predict_smiles(self, smiles: str) -> Optional[dict]:
9493
else:
9594
return None
9695

97-
def predict_smiles_tuple(self, smiles_list: list[str]) -> list:
96+
@modelwise_smiles_lru_cache.batch_decorator
97+
def predict_smiles_list(self, smiles_list: list[str]) -> list:
9898
predictions = []
9999
for smiles in smiles_list:
100100
predictions.append(self.predict_smiles(smiles))
@@ -145,7 +145,8 @@ def explain_smiles(self, smiles: str) -> dict:
145145
# Example usage
146146
smiles_list = [
147147
"CCO",
148-
"C1=CC=CC=C1" "*C(=O)OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)OC(*)=O",
148+
"C1=CC=CC=C1",
149+
"*C(=O)OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)OC(*)=O",
149150
] # SMILES with 251 matches in ChEBI
150151
predictions = predictor.predict_smiles_list(smiles_list)
151152
print(predictions)

chebifier/prediction_models/chemlog_predictor.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
)
1313
from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call
1414
from chemlog_extra.alg_classification.by_element_classification import (
15-
XMolecularEntityClassifier,
1615
OrganoXCompoundClassifier,
16+
XMolecularEntityClassifier,
1717
)
18-
from functools import lru_cache
18+
19+
from chebifier import modelwise_smiles_lru_cache
1920

2021
from .base_predictor import BasePredictor
2122

@@ -47,15 +48,15 @@
4748

4849

4950
class ChemlogExtraPredictor(BasePredictor):
50-
5151
CHEMLOG_CLASSIFIER = None
5252

5353
def __init__(self, model_name: str, **kwargs):
5454
super().__init__(model_name, **kwargs)
5555
self.chebi_graph = kwargs.get("chebi_graph", None)
5656
self.classifier = self.CHEMLOG_CLASSIFIER()
5757

58-
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
58+
@modelwise_smiles_lru_cache.batch_decorator
59+
def predict_smiles_list(self, smiles_list: list[str]) -> list:
5960
mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list]
6061
res = self.classifier.classify(mol_list)
6162
if self.chebi_graph is not None:
@@ -72,12 +73,10 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
7273

7374

7475
class ChemlogXMolecularEntityPredictor(ChemlogExtraPredictor):
75-
7676
CHEMLOG_CLASSIFIER = XMolecularEntityClassifier
7777

7878

7979
class ChemlogOrganoXCompoundPredictor(ChemlogExtraPredictor):
80-
8180
CHEMLOG_CLASSIFIER = OrganoXCompoundClassifier
8281

8382

@@ -97,7 +96,6 @@ def __init__(self, model_name: str, **kwargs):
9796
# fmt: on
9897
print(f"Initialised ChemLog model {self.model_name}")
9998

100-
@lru_cache(maxsize=100)
10199
def predict_smiles(self, smiles: str) -> Optional[dict]:
102100
mol = _smiles_to_mol(smiles)
103101
if mol is None:
@@ -122,7 +120,8 @@ def predict_smiles(self, smiles: str) -> Optional[dict]:
122120
for label in self.peptide_labels + pos_labels
123121
}
124122

125-
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
123+
@modelwise_smiles_lru_cache.batch_decorator
124+
def predict_smiles_list(self, smiles_list: list[str]) -> list:
126125
results = []
127126
for i, smiles in tqdm.tqdm(enumerate(smiles_list)):
128127
results.append(self.predict_smiles(smiles))

0 commit comments

Comments
 (0)