Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions chebifier/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Note: The top-level package __init__.py runs only once,
# even if multiple subpackages are imported later.

from ._custom_cache import PerSmilesPerModelLRUCache

modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100)
208 changes: 208 additions & 0 deletions chebifier/_custom_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import os
import pickle
import threading
from collections import OrderedDict
from collections.abc import Iterable
from functools import wraps
from typing import Any, Callable


class PerSmilesPerModelLRUCache:
"""
A thread-safe, optionally persistent LRU cache for storing
(SMILES, model_name) → result mappings.
"""

def __init__(self, max_size: int = 100, persist_path: str | None = None):
"""
Initialize the cache.

Args:
max_size (int): Maximum number of items to keep in the cache.
persist_path (str | None): Optional path to persist cache using pickle.
"""
self._cache: OrderedDict[tuple[str, str], Any] = OrderedDict()
self._max_size = max_size
self._lock = threading.Lock()
self._persist_path = persist_path

self.hits = 0
self.misses = 0

if self._persist_path:
self._load_cache()

def get(self, smiles: str, model_name: str) -> Any | None:
"""
Retrieve value from cache if present, otherwise return None.

Args:
smiles (str): SMILES string key.
model_name (str): Model identifier.

Returns:
Any | None: Cached value or None.
"""
key = (smiles, model_name)
with self._lock:
if key in self._cache:
self._cache.move_to_end(key)
self.hits += 1
return self._cache[key]
else:
self.misses += 1
return None

def set(self, smiles: str, model_name: str, value: Any) -> None:
"""
Store value in cache under (smiles, model_name) key.

Args:
smiles (str): SMILES string key.
model_name (str): Model identifier.
value (Any): Value to cache.
"""
assert value is not None, "Value must not be None"
key = (smiles, model_name)
with self._lock:
if key in self._cache:
self._cache.move_to_end(key)
self._cache[key] = value
if len(self._cache) > self._max_size:
self._cache.popitem(last=False)

def clear(self) -> None:
"""
Clear the cache and remove the persistence file if present.
"""
self._save_cache()
with self._lock:
self._cache.clear()
self.hits = 0
self.misses = 0
if self._persist_path and os.path.exists(self._persist_path):
os.remove(self._persist_path)

def stats(self) -> dict[str, int]:
"""
Return cache hit/miss statistics.

Returns:
dict[str, int]: Dictionary with 'hits' and 'misses' keys.
"""
return {"hits": self.hits, "misses": self.misses}

def batch_decorator(self, func: Callable) -> Callable:
"""
Decorator for class methods that accept a batch of SMILES as a list,
and cache predictions per (smiles, model_name) key.

The instance is expected to have a `model_name` attribute.

Args:
func (Callable): The method to decorate.

Returns:
Callable: The wrapped method.
"""

@wraps(func)
def wrapper(instance, smiles_list: list[str]) -> list[Any]:
assert isinstance(smiles_list, list), "smiles_list must be a list."
model_name = getattr(instance, "model_name", None)
assert model_name is not None, "Instance must have a model_name attribute."

missing_smiles: list[str] = []
missing_indices: list[int] = []
ordered_results: list[Any] = [None] * len(smiles_list)

# First: try to fetch all from cache
for idx, smiles in enumerate(smiles_list):
prediction = self.get(smiles=smiles, model_name=model_name)
if prediction is not None:
# For debugging purposes, you can uncomment the print statement below
# print(
# f"[Cache Hit] Prediction for smiles: {smiles} and model: {model_name} are retrieved from cache."
# )
ordered_results[idx] = prediction
else:
missing_smiles.append(smiles)
missing_indices.append(idx)

# If some are missing, call original function
if missing_smiles:
new_results = func(instance, tuple(missing_smiles))
assert isinstance(
new_results, Iterable
), "Function must return an Iterable."

# Save to cache and append
for smiles, prediction, missing_idx in zip(
missing_smiles, new_results, missing_indices
):
if prediction is not None:
self.set(smiles, model_name, prediction)
ordered_results[missing_idx] = prediction

return ordered_results

return wrapper

def __len__(self) -> int:
"""
Return number of items in the cache.

Returns:
int: Number of entries in the cache.
"""
with self._lock:
return len(self._cache)

def __repr__(self) -> str:
"""
String representation of the underlying cache.

Returns:
str: String version of the OrderedDict.
"""
return self._cache.__repr__()

def save(self) -> None:
"""
Save the cache to disk, if persistence is enabled.
"""
self._save_cache()

def load(self) -> None:
"""
Load the cache from disk, if persistence is enabled.
"""
self._load_cache()

def _save_cache(self) -> None:
"""
Serialize the cache to disk using pickle.
"""
if self._persist_path:
try:
with open(self._persist_path, "wb") as f:
pickle.dump(self._cache, f)
except Exception as e:
print(f"[Cache Save Error] {e}")

def _load_cache(self) -> None:
"""
Load the cache from disk, if the file exists and is non-empty.
"""
if (
self._persist_path
and os.path.exists(self._persist_path)
and os.path.getsize(self._persist_path) > 0
):
try:
with open(self._persist_path, "rb") as f:
loaded = pickle.load(f)
if isinstance(loaded, OrderedDict):
self._cache = loaded
except Exception as e:
print(f"[Cache Load Error] {e}")
10 changes: 3 additions & 7 deletions chebifier/prediction_models/base_predictor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from abc import ABC

from functools import lru_cache
from chebifier import modelwise_smiles_lru_cache


class BasePredictor(ABC):
Expand All @@ -23,17 +23,13 @@ def __init__(

self._description = kwargs.get("description", None)

@modelwise_smiles_lru_cache.batch_decorator
def predict_smiles_list(self, smiles_list: list[str]) -> dict:
# list is not hashable, so we convert it to a tuple (useful for caching)
return self.predict_smiles_tuple(tuple(smiles_list))

@lru_cache(maxsize=100)
def predict_smiles_tuple(self, smiles_tuple: tuple[str]) -> dict:
raise NotImplementedError()

def predict_smiles(self, smiles: str) -> dict:
# by default, use list-based prediction
return self.predict_smiles_tuple((smiles,))[0]
return self.predict_smiles_list([smiles])[0]

@property
def info_text(self):
Expand Down
8 changes: 4 additions & 4 deletions chebifier/prediction_models/c3p_predictor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from functools import lru_cache
from typing import Optional, List
from pathlib import Path
from typing import List, Optional

from c3p import classifier as c3p_classifier

from chebifier import modelwise_smiles_lru_cache
from chebifier.prediction_models import BasePredictor


Expand All @@ -24,8 +24,8 @@ def __init__(
self.chemical_classes = chemical_classes
self.chebi_graph = kwargs.get("chebi_graph", None)

@lru_cache(maxsize=100)
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
@modelwise_smiles_lru_cache.batch_decorator
def predict_smiles_list(self, smiles_list: list[str]) -> list:
result_list = c3p_classifier.classify(
list(smiles_list),
self.program_directory,
Expand Down
17 changes: 9 additions & 8 deletions chebifier/prediction_models/chebi_lookup.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from functools import lru_cache
import json
import os
from typing import Optional

from chebifier.prediction_models import BasePredictor
import os
import networkx as nx
from rdkit import Chem
import json

from chebifier import modelwise_smiles_lru_cache
from chebifier.prediction_models import BasePredictor
from chebifier.utils import load_chebi_graph


class ChEBILookupPredictor(BasePredictor):

def __init__(
self,
model_name: str,
Expand Down Expand Up @@ -67,7 +67,6 @@ def build_smiles_lookup(self):
)
return smiles_lookup

@lru_cache(maxsize=100)
def predict_smiles(self, smiles: str) -> Optional[dict]:
if not smiles:
return None
Expand All @@ -94,7 +93,8 @@ def predict_smiles(self, smiles: str) -> Optional[dict]:
else:
return None

def predict_smiles_tuple(self, smiles_list: list[str]) -> list:
@modelwise_smiles_lru_cache.batch_decorator
def predict_smiles_list(self, smiles_list: list[str]) -> list:
predictions = []
for smiles in smiles_list:
predictions.append(self.predict_smiles(smiles))
Expand Down Expand Up @@ -145,7 +145,8 @@ def explain_smiles(self, smiles: str) -> dict:
# Example usage
smiles_list = [
"CCO",
"C1=CC=CC=C1" "*C(=O)OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)OC(*)=O",
"C1=CC=CC=C1",
"*C(=O)OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)OC(*)=O",
] # SMILES with 251 matches in ChEBI
predictions = predictor.predict_smiles_list(smiles_list)
print(predictions)
15 changes: 7 additions & 8 deletions chebifier/prediction_models/chemlog_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
)
from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call
from chemlog_extra.alg_classification.by_element_classification import (
XMolecularEntityClassifier,
OrganoXCompoundClassifier,
XMolecularEntityClassifier,
)
from functools import lru_cache

from chebifier import modelwise_smiles_lru_cache

from .base_predictor import BasePredictor

Expand Down Expand Up @@ -47,15 +48,15 @@


class ChemlogExtraPredictor(BasePredictor):

CHEMLOG_CLASSIFIER = None

def __init__(self, model_name: str, **kwargs):
super().__init__(model_name, **kwargs)
self.chebi_graph = kwargs.get("chebi_graph", None)
self.classifier = self.CHEMLOG_CLASSIFIER()

def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
@modelwise_smiles_lru_cache.batch_decorator
def predict_smiles_list(self, smiles_list: list[str]) -> list:
mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list]
res = self.classifier.classify(mol_list)
if self.chebi_graph is not None:
Expand All @@ -72,12 +73,10 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:


class ChemlogXMolecularEntityPredictor(ChemlogExtraPredictor):

CHEMLOG_CLASSIFIER = XMolecularEntityClassifier


class ChemlogOrganoXCompoundPredictor(ChemlogExtraPredictor):

CHEMLOG_CLASSIFIER = OrganoXCompoundClassifier


Expand All @@ -97,7 +96,6 @@ def __init__(self, model_name: str, **kwargs):
# fmt: on
print(f"Initialised ChemLog model {self.model_name}")

@lru_cache(maxsize=100)
def predict_smiles(self, smiles: str) -> Optional[dict]:
mol = _smiles_to_mol(smiles)
if mol is None:
Expand All @@ -122,7 +120,8 @@ def predict_smiles(self, smiles: str) -> Optional[dict]:
for label in self.peptide_labels + pos_labels
}

def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
@modelwise_smiles_lru_cache.batch_decorator
def predict_smiles_list(self, smiles_list: list[str]) -> list:
results = []
for i, smiles in tqdm.tqdm(enumerate(smiles_list)):
results.append(self.predict_smiles(smiles))
Expand Down
Loading