From 187572bb093dd1195a4f4a7200f02230c1355017 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 19 Aug 2025 11:19:16 +0200 Subject: [PATCH] Simplify getting started This PR updates the BaseEnsemble constructor to allow the following: 1. Passing a string or path to the configuration 2. Not passing a configuration at all, which will automatically load the default configuration. This is now the default, since most users won't want to have to configure it (it should have reasonable defaults) --- README.md | 12 ++++-------- chebifier/__init__.py | 5 +++++ chebifier/cli.py | 8 ++------ chebifier/ensemble/base_ensemble.py | 14 ++++++++++++-- chebifier/utils.py | 12 ++++++++++++ 5 files changed, 35 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index fcbbb37..8d59280 100644 --- a/README.md +++ b/README.md @@ -75,15 +75,11 @@ python -m chebifier predict --help You can also use the package programmatically: ```python -from chebifier.ensemble.base_ensemble import BaseEnsemble -import yaml +from chebifier import BaseEnsemble -# Load configuration from YAML file -with open('configs/example_config.yml', 'r') as f: - config = yaml.safe_load(f) - -# Instantiate ensemble model -ensemble = BaseEnsemble(config) +# Instantiate ensemble model. If desired, can pass +# a path to a configuration, like 'configs/example_config.yml' +ensemble = BaseEnsemble() # Make predictions smiles_list = ["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C(=O)O"] diff --git a/chebifier/__init__.py b/chebifier/__init__.py index aa1e6ec..3ddcfe7 100644 --- a/chebifier/__init__.py +++ b/chebifier/__init__.py @@ -2,5 +2,10 @@ # even if multiple subpackages are imported later. from ._custom_cache import PerSmilesPerModelLRUCache +from chebifier.ensemble.base_ensemble import BaseEnsemble + +__all__ = [ + "BaseEnsemble", +] modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100) diff --git a/chebifier/cli.py b/chebifier/cli.py index c201187..1267c69 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -4,6 +4,7 @@ import yaml from chebifier.model_registry import ENSEMBLES +from chebifier.utils import get_default_configs @click.group() @@ -75,12 +76,7 @@ def predict( # Load configuration from YAML file if not ensemble_config: print("Using default ensemble configuration") - with ( - importlib.resources.files("chebifier") - .joinpath("ensemble.yml") - .open("r") as f - ): - config = yaml.safe_load(f) + config = get_default_configs() else: print(f"Loading ensemble configuration from {ensemble_config}") with open(ensemble_config, "r") as f: diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index 6a3acef..e434281 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -1,23 +1,33 @@ import os import time +from pathlib import Path +from typing import Union import torch import tqdm +import yaml from chebifier.check_env import check_package_installed from chebifier.hugging_face import download_model_files from chebifier.inconsistency_resolution import PredictionSmoother from chebifier.prediction_models.base_predictor import BasePredictor -from chebifier.utils import get_disjoint_files, load_chebi_graph +from chebifier.utils import get_disjoint_files, load_chebi_graph, get_default_configs class BaseEnsemble: def __init__( self, - model_configs: dict, + model_configs: Union[str, Path, dict, None] = None, chebi_version: int = 241, resolve_inconsistencies: bool = True, ): + if model_configs is None: + model_configs = get_default_configs() + elif isinstance(model_configs, (str, Path)): + # Load configuration from YAML file + with open(model_configs) as file: + model_configs = yaml.safe_load(file) + # Deferred Import: To avoid circular import error from chebifier.model_registry import MODEL_TYPES diff --git a/chebifier/utils.py b/chebifier/utils.py index e6fefae..7a2e021 100644 --- a/chebifier/utils.py +++ b/chebifier/utils.py @@ -1,8 +1,11 @@ +import importlib.resources import os import networkx as nx import requests import fastobo +import yaml + from chebifier.hugging_face import download_model_files import pickle @@ -129,3 +132,12 @@ def get_disjoint_files(): # pickle.dump(chebi_graph, open("chebi_graph.pkl", "wb")) chebi_graph = load_chebi_graph() print(chebi_graph) + + +def get_default_configs(): + with ( + importlib.resources.files("chebifier") + .joinpath("ensemble.yml") + .open("r") as f + ): + return yaml.safe_load(f)