diff --git a/chebifier/__init__.py b/chebifier/__init__.py index e69de29..aa1e6ec 100644 --- a/chebifier/__init__.py +++ b/chebifier/__init__.py @@ -0,0 +1,6 @@ +# Note: The top-level package __init__.py runs only once, +# even if multiple subpackages are imported later. + +from ._custom_cache import PerSmilesPerModelLRUCache + +modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100) diff --git a/chebifier/_custom_cache.py b/chebifier/_custom_cache.py new file mode 100644 index 0000000..38b500f --- /dev/null +++ b/chebifier/_custom_cache.py @@ -0,0 +1,208 @@ +import os +import pickle +import threading +from collections import OrderedDict +from collections.abc import Iterable +from functools import wraps +from typing import Any, Callable + + +class PerSmilesPerModelLRUCache: + """ + A thread-safe, optionally persistent LRU cache for storing + (SMILES, model_name) → result mappings. + """ + + def __init__(self, max_size: int = 100, persist_path: str | None = None): + """ + Initialize the cache. + + Args: + max_size (int): Maximum number of items to keep in the cache. + persist_path (str | None): Optional path to persist cache using pickle. + """ + self._cache: OrderedDict[tuple[str, str], Any] = OrderedDict() + self._max_size = max_size + self._lock = threading.Lock() + self._persist_path = persist_path + + self.hits = 0 + self.misses = 0 + + if self._persist_path: + self._load_cache() + + def get(self, smiles: str, model_name: str) -> Any | None: + """ + Retrieve value from cache if present, otherwise return None. + + Args: + smiles (str): SMILES string key. + model_name (str): Model identifier. + + Returns: + Any | None: Cached value or None. + """ + key = (smiles, model_name) + with self._lock: + if key in self._cache: + self._cache.move_to_end(key) + self.hits += 1 + return self._cache[key] + else: + self.misses += 1 + return None + + def set(self, smiles: str, model_name: str, value: Any) -> None: + """ + Store value in cache under (smiles, model_name) key. + + Args: + smiles (str): SMILES string key. + model_name (str): Model identifier. + value (Any): Value to cache. + """ + assert value is not None, "Value must not be None" + key = (smiles, model_name) + with self._lock: + if key in self._cache: + self._cache.move_to_end(key) + self._cache[key] = value + if len(self._cache) > self._max_size: + self._cache.popitem(last=False) + + def clear(self) -> None: + """ + Clear the cache and remove the persistence file if present. + """ + self._save_cache() + with self._lock: + self._cache.clear() + self.hits = 0 + self.misses = 0 + if self._persist_path and os.path.exists(self._persist_path): + os.remove(self._persist_path) + + def stats(self) -> dict[str, int]: + """ + Return cache hit/miss statistics. + + Returns: + dict[str, int]: Dictionary with 'hits' and 'misses' keys. + """ + return {"hits": self.hits, "misses": self.misses} + + def batch_decorator(self, func: Callable) -> Callable: + """ + Decorator for class methods that accept a batch of SMILES as a list, + and cache predictions per (smiles, model_name) key. + + The instance is expected to have a `model_name` attribute. + + Args: + func (Callable): The method to decorate. + + Returns: + Callable: The wrapped method. + """ + + @wraps(func) + def wrapper(instance, smiles_list: list[str]) -> list[Any]: + assert isinstance(smiles_list, list), "smiles_list must be a list." + model_name = getattr(instance, "model_name", None) + assert model_name is not None, "Instance must have a model_name attribute." + + missing_smiles: list[str] = [] + missing_indices: list[int] = [] + ordered_results: list[Any] = [None] * len(smiles_list) + + # First: try to fetch all from cache + for idx, smiles in enumerate(smiles_list): + prediction = self.get(smiles=smiles, model_name=model_name) + if prediction is not None: + # For debugging purposes, you can uncomment the print statement below + # print( + # f"[Cache Hit] Prediction for smiles: {smiles} and model: {model_name} are retrieved from cache." + # ) + ordered_results[idx] = prediction + else: + missing_smiles.append(smiles) + missing_indices.append(idx) + + # If some are missing, call original function + if missing_smiles: + new_results = func(instance, tuple(missing_smiles)) + assert isinstance( + new_results, Iterable + ), "Function must return an Iterable." + + # Save to cache and append + for smiles, prediction, missing_idx in zip( + missing_smiles, new_results, missing_indices + ): + if prediction is not None: + self.set(smiles, model_name, prediction) + ordered_results[missing_idx] = prediction + + return ordered_results + + return wrapper + + def __len__(self) -> int: + """ + Return number of items in the cache. + + Returns: + int: Number of entries in the cache. + """ + with self._lock: + return len(self._cache) + + def __repr__(self) -> str: + """ + String representation of the underlying cache. + + Returns: + str: String version of the OrderedDict. + """ + return self._cache.__repr__() + + def save(self) -> None: + """ + Save the cache to disk, if persistence is enabled. + """ + self._save_cache() + + def load(self) -> None: + """ + Load the cache from disk, if persistence is enabled. + """ + self._load_cache() + + def _save_cache(self) -> None: + """ + Serialize the cache to disk using pickle. + """ + if self._persist_path: + try: + with open(self._persist_path, "wb") as f: + pickle.dump(self._cache, f) + except Exception as e: + print(f"[Cache Save Error] {e}") + + def _load_cache(self) -> None: + """ + Load the cache from disk, if the file exists and is non-empty. + """ + if ( + self._persist_path + and os.path.exists(self._persist_path) + and os.path.getsize(self._persist_path) > 0 + ): + try: + with open(self._persist_path, "rb") as f: + loaded = pickle.load(f) + if isinstance(loaded, OrderedDict): + self._cache = loaded + except Exception as e: + print(f"[Cache Load Error] {e}") diff --git a/chebifier/prediction_models/base_predictor.py b/chebifier/prediction_models/base_predictor.py index ba1412d..a175366 100644 --- a/chebifier/prediction_models/base_predictor.py +++ b/chebifier/prediction_models/base_predictor.py @@ -1,7 +1,7 @@ import json from abc import ABC -from functools import lru_cache +from chebifier import modelwise_smiles_lru_cache class BasePredictor(ABC): @@ -23,17 +23,13 @@ def __init__( self._description = kwargs.get("description", None) + @modelwise_smiles_lru_cache.batch_decorator def predict_smiles_list(self, smiles_list: list[str]) -> dict: - # list is not hashable, so we convert it to a tuple (useful for caching) - return self.predict_smiles_tuple(tuple(smiles_list)) - - @lru_cache(maxsize=100) - def predict_smiles_tuple(self, smiles_tuple: tuple[str]) -> dict: raise NotImplementedError() def predict_smiles(self, smiles: str) -> dict: # by default, use list-based prediction - return self.predict_smiles_tuple((smiles,))[0] + return self.predict_smiles_list([smiles])[0] @property def info_text(self): diff --git a/chebifier/prediction_models/c3p_predictor.py b/chebifier/prediction_models/c3p_predictor.py index 00c71f7..dc4704d 100644 --- a/chebifier/prediction_models/c3p_predictor.py +++ b/chebifier/prediction_models/c3p_predictor.py @@ -1,9 +1,9 @@ -from functools import lru_cache -from typing import Optional, List from pathlib import Path +from typing import List, Optional from c3p import classifier as c3p_classifier +from chebifier import modelwise_smiles_lru_cache from chebifier.prediction_models import BasePredictor @@ -24,8 +24,8 @@ def __init__( self.chemical_classes = chemical_classes self.chebi_graph = kwargs.get("chebi_graph", None) - @lru_cache(maxsize=100) - def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: + @modelwise_smiles_lru_cache.batch_decorator + def predict_smiles_list(self, smiles_list: list[str]) -> list: result_list = c3p_classifier.classify( list(smiles_list), self.program_directory, diff --git a/chebifier/prediction_models/chebi_lookup.py b/chebifier/prediction_models/chebi_lookup.py index 2f6a7b0..d145e24 100644 --- a/chebifier/prediction_models/chebi_lookup.py +++ b/chebifier/prediction_models/chebi_lookup.py @@ -1,16 +1,16 @@ -from functools import lru_cache +import json +import os from typing import Optional -from chebifier.prediction_models import BasePredictor -import os import networkx as nx from rdkit import Chem -import json + +from chebifier import modelwise_smiles_lru_cache +from chebifier.prediction_models import BasePredictor from chebifier.utils import load_chebi_graph class ChEBILookupPredictor(BasePredictor): - def __init__( self, model_name: str, @@ -67,7 +67,6 @@ def build_smiles_lookup(self): ) return smiles_lookup - @lru_cache(maxsize=100) def predict_smiles(self, smiles: str) -> Optional[dict]: if not smiles: return None @@ -94,7 +93,8 @@ def predict_smiles(self, smiles: str) -> Optional[dict]: else: return None - def predict_smiles_tuple(self, smiles_list: list[str]) -> list: + @modelwise_smiles_lru_cache.batch_decorator + def predict_smiles_list(self, smiles_list: list[str]) -> list: predictions = [] for smiles in smiles_list: predictions.append(self.predict_smiles(smiles)) @@ -145,7 +145,8 @@ def explain_smiles(self, smiles: str) -> dict: # Example usage smiles_list = [ "CCO", - "C1=CC=CC=C1" "*C(=O)OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)OC(*)=O", + "C1=CC=CC=C1", + "*C(=O)OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)OC(*)=O", ] # SMILES with 251 matches in ChEBI predictions = predictor.predict_smiles_list(smiles_list) print(predictions) diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py index 8232641..99fa3b9 100644 --- a/chebifier/prediction_models/chemlog_predictor.py +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -12,10 +12,11 @@ ) from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call from chemlog_extra.alg_classification.by_element_classification import ( - XMolecularEntityClassifier, OrganoXCompoundClassifier, + XMolecularEntityClassifier, ) -from functools import lru_cache + +from chebifier import modelwise_smiles_lru_cache from .base_predictor import BasePredictor @@ -47,7 +48,6 @@ class ChemlogExtraPredictor(BasePredictor): - CHEMLOG_CLASSIFIER = None def __init__(self, model_name: str, **kwargs): @@ -55,7 +55,8 @@ def __init__(self, model_name: str, **kwargs): self.chebi_graph = kwargs.get("chebi_graph", None) self.classifier = self.CHEMLOG_CLASSIFIER() - def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: + @modelwise_smiles_lru_cache.batch_decorator + def predict_smiles_list(self, smiles_list: list[str]) -> list: mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list] res = self.classifier.classify(mol_list) if self.chebi_graph is not None: @@ -72,12 +73,10 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: class ChemlogXMolecularEntityPredictor(ChemlogExtraPredictor): - CHEMLOG_CLASSIFIER = XMolecularEntityClassifier class ChemlogOrganoXCompoundPredictor(ChemlogExtraPredictor): - CHEMLOG_CLASSIFIER = OrganoXCompoundClassifier @@ -97,7 +96,6 @@ def __init__(self, model_name: str, **kwargs): # fmt: on print(f"Initialised ChemLog model {self.model_name}") - @lru_cache(maxsize=100) def predict_smiles(self, smiles: str) -> Optional[dict]: mol = _smiles_to_mol(smiles) if mol is None: @@ -122,7 +120,8 @@ def predict_smiles(self, smiles: str) -> Optional[dict]: for label in self.peptide_labels + pos_labels } - def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: + @modelwise_smiles_lru_cache.batch_decorator + def predict_smiles_list(self, smiles_list: list[str]) -> list: results = [] for i, smiles in tqdm.tqdm(enumerate(smiles_list)): results.append(self.predict_smiles(smiles)) diff --git a/chebifier/prediction_models/nn_predictor.py b/chebifier/prediction_models/nn_predictor.py index e7d72c9..79dcad9 100644 --- a/chebifier/prediction_models/nn_predictor.py +++ b/chebifier/prediction_models/nn_predictor.py @@ -1,10 +1,10 @@ -from functools import lru_cache - import numpy as np import torch import tqdm from rdkit import Chem +from chebifier import modelwise_smiles_lru_cache + from .base_predictor import BasePredictor @@ -52,8 +52,8 @@ def read_smiles(self, smiles): d = reader.to_data(dict(features=smiles, labels=None)) return d - @lru_cache(maxsize=100) - def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: + @modelwise_smiles_lru_cache.batch_decorator + def predict_smiles_list(self, smiles_list: list[str]) -> list: """Returns a list with the length of smiles_list, each element is either None (=failure) or a dictionary Of classes and predicted values.""" token_dicts = [] diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000..eaa98f4 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,182 @@ +import os +import tempfile +import unittest + +from chebifier import PerSmilesPerModelLRUCache + +g_cache = PerSmilesPerModelLRUCache(max_size=100, persist_path=None) + + +class DummyPredictor: + def __init__(self, model_name: str): + """ + Dummy predictor for testing cache decorator. + :param model_name: Name of the model instance (used for key separation). + """ + self.model_name = model_name + + @g_cache.batch_decorator + def predict(self, smiles_list: tuple[str]) -> list[str]: + """ + Dummy predict method to simulate model inference. + Returns list of predictions with predictable format. + """ + # Simple predictable dummy function for tests + return [f"{self.model_name}_P{i}" for i in range(len(smiles_list))] + + +class TestPerSmilesPerModelLRUCache(unittest.TestCase): + def setUp(self) -> None: + """ + Set up a temporary cache file and cache instance before each test. + """ + # Create temp file for persistence tests + self.temp_file = tempfile.NamedTemporaryFile(delete=False) + self.temp_file.close() + self.cache = PerSmilesPerModelLRUCache( + max_size=3, persist_path=self.temp_file.name + ) + + def tearDown(self) -> None: + """ + Clean up the temporary file after each test. + """ + if os.path.exists(self.temp_file.name): + os.remove(self.temp_file.name) + + def test_cache_miss_and_set_get(self) -> None: + """ + Test cache miss on initial get, then set and confirm hit. + """ + # Initially empty + self.assertEqual(len(self.cache), 0) + self.assertIsNone(self.cache.get("CCC", "model1")) + + # Set and get + self.cache.set("CCC", "model1", "result1") + self.assertEqual(self.cache.get("CCC", "model1"), "result1") + self.assertEqual(self.cache.hits, 1) + self.assertEqual(self.cache.misses, 1) # One miss from first get + + def test_cache_eviction(self) -> None: + """ + Test LRU eviction when capacity is exceeded. + """ + self.cache.set("a", "m", "v1") + self.cache.set("b", "m", "v2") + self.cache.set("c", "m", "v3") + self.assertEqual(len(self.cache), 3) + # Adding one more triggers eviction of oldest + self.cache.set("d", "m", "v4") + self.assertEqual(len(self.cache), 3) + self.assertIsNone(self.cache.get("a", "m")) # 'a' evicted + self.assertIsNotNone(self.cache.get("d", "m")) # 'd' present + + def test_batch_decorator_hits_and_misses(self) -> None: + """ + Test decorator behavior on batch prediction: + - first call (all misses) + - second call (mixed hits and misses) + - third call (more hits and misses) + """ + predictor = DummyPredictor("modelA") + predictor2 = DummyPredictor("modelB") + + # Clear cache before starting the test + g_cache.clear() + + smiles = ["AAA", "BBB", "CCC", "DDD", "EEE"] + # First call all misses + results1 = predictor.predict(smiles) + results1_model2 = predictor2.predict(smiles) + + # all prediction as retrived from actual prediction function and not from cache + self.assertListEqual( + results1, ["modelA_P0", "modelA_P1", "modelA_P2", "modelA_P3", "modelA_P4"] + ) + self.assertListEqual( + results1_model2, + ["modelB_P0", "modelB_P1", "modelB_P2", "modelB_P3", "modelB_P4"], + ) + stats_after_first = g_cache.stats() + self.assertEqual( + stats_after_first["misses"], 10 + ) # 5 for modelA and 5 for modelB + self.assertEqual(stats_after_first["hits"], 0) + self.assertEqual(len(g_cache), 10) # 5 for each model + + # cache = {("AAA", "modelA"): "modelA_P0", ("BBB", "modelA"): "modelA_P1", ("CCC", "modelA"): "modelA_P2", + # ("DDD", "modelA"): "modelA_P3", ("EEE", "modelA"): "modelA_P4", + # ("AAA", "modelB"): "modelB_P0", ("BBB", "modelB"): "modelB_P1", ("CCC", "modelB"): "modelB_P2",} + # ("DDD", "modelB"): "modelB_P3", ("EEE", "modelB"): "modelB_P4"} + + # Second call with some hits and some misses + results2 = predictor.predict(["FFF", "DDD"]) + # DDD from cache + # FFF is not in cache, so its predicted, hence it has P0 as its the only one passed to prediction function + # and dummy predictor iterates over the smiles list and returns P{idx} corresponding to the index + self.assertListEqual(results2, ["modelA_P0", "modelA_P3"]) + stats_after_second = g_cache.stats() + self.assertEqual(stats_after_second["hits"], 1) # additional 1 hit for DDD + self.assertEqual(stats_after_second["misses"], 11) # 1 miss for FFF + + # cache = {("AAA", "modelA"): "modelA_P0", ("BBB", "modelA"): "modelA_P1", ("CCC", "modelA"): "modelA_P2", + # ("DDD", "modelA"): "modelA_P3", ("EEE", "modelA"): "modelA_P4", ("FFF", "modelA"): "modelA_P0", ...} + + # Third call with some hits and some misses + results3 = predictor.predict(["EEE", "GGG", "DDD", "HHH", "BBB", "ZZZ"]) + # Here, predictions for [EEE, DDD, BBB] are retrived from cache, + # while [GGG, HHH, ZZZ] are not in cache and hence passe to the prediction function + self.assertListEqual( + results3, + [ + "modelA_P4", # EEE from cache + "modelA_P0", # GGG not in cache, so it predicted, hence it has P0 as its the only one passed to prediction function + "modelA_P3", # DDD from cache + "modelA_P1", # HHH not in cache, so it predicted, hence it has P1 as its the only one passed to prediction function + "modelA_P1", # BBB from cache + "modelA_P2", # ZZZ not in cache, so it predicted, hence it has P2 as its the only one passed to prediction function + ], + ) + stats_after_third = g_cache.stats() + self.assertEqual( + stats_after_third["hits"], 4 + ) # additional 3 hits for EEE, DDD, BBB + self.assertEqual( + stats_after_third["misses"], 14 + ) # additional 3 misses for GGG, HHH, ZZZ + + def test_persistence_save_and_load(self) -> None: + """ + Test that cache is properly saved to disk and reloaded. + """ + # Set some values + self.cache.set("sm1", "modelX", "val1") + self.cache.set("sm2", "modelX", "val2") + + # Save cache to file + self.cache.save() + + # Create new cache instance loading from file + new_cache = PerSmilesPerModelLRUCache( + max_size=3, persist_path=self.temp_file.name + ) + new_cache.load() + + self.assertEqual(new_cache.get("sm1", "modelX"), "val1") + self.assertEqual(new_cache.get("sm2", "modelX"), "val2") + + def test_clear_cache(self) -> None: + """ + Test clearing the cache and removing persisted file. + """ + self.cache.set("x", "m", "v") + self.cache.save() + self.assertTrue(os.path.exists(self.temp_file.name)) + self.cache.clear() + self.assertEqual(len(self.cache), 0) + self.assertFalse(os.path.exists(self.temp_file.name)) + + +if __name__ == "__main__": + unittest.main()