Skip to content

Simplify getting started #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,11 @@ python -m chebifier predict --help
You can also use the package programmatically:

```python
from chebifier.ensemble.base_ensemble import BaseEnsemble
import yaml
from chebifier import BaseEnsemble

# Load configuration from YAML file
with open('configs/example_config.yml', 'r') as f:
config = yaml.safe_load(f)

# Instantiate ensemble model
ensemble = BaseEnsemble(config)
# Instantiate ensemble model. If desired, can pass
# a path to a configuration, like 'configs/example_config.yml'
ensemble = BaseEnsemble()

# Make predictions
smiles_list = ["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C(=O)O"]
Expand Down
5 changes: 5 additions & 0 deletions chebifier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,10 @@
# even if multiple subpackages are imported later.

from ._custom_cache import PerSmilesPerModelLRUCache
from chebifier.ensemble.base_ensemble import BaseEnsemble

__all__ = [
"BaseEnsemble",
]

modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100)
8 changes: 2 additions & 6 deletions chebifier/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import yaml

from chebifier.model_registry import ENSEMBLES
from chebifier.utils import get_default_configs


@click.group()
Expand Down Expand Up @@ -75,12 +76,7 @@ def predict(
# Load configuration from YAML file
if not ensemble_config:
print("Using default ensemble configuration")
with (
importlib.resources.files("chebifier")
.joinpath("ensemble.yml")
.open("r") as f
):
config = yaml.safe_load(f)
config = get_default_configs()
else:
print(f"Loading ensemble configuration from {ensemble_config}")
with open(ensemble_config, "r") as f:
Expand Down
14 changes: 12 additions & 2 deletions chebifier/ensemble/base_ensemble.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
import os
import time
from pathlib import Path
from typing import Union

import torch
import tqdm
import yaml

from chebifier.check_env import check_package_installed
from chebifier.hugging_face import download_model_files
from chebifier.inconsistency_resolution import PredictionSmoother
from chebifier.prediction_models.base_predictor import BasePredictor
from chebifier.utils import get_disjoint_files, load_chebi_graph
from chebifier.utils import get_disjoint_files, load_chebi_graph, get_default_configs


class BaseEnsemble:
def __init__(
self,
model_configs: dict,
model_configs: Union[str, Path, dict, None] = None,
chebi_version: int = 241,
resolve_inconsistencies: bool = True,
):
if model_configs is None:
model_configs = get_default_configs()
elif isinstance(model_configs, (str, Path)):
# Load configuration from YAML file
with open(model_configs) as file:
model_configs = yaml.safe_load(file)

# Deferred Import: To avoid circular import error
from chebifier.model_registry import MODEL_TYPES

Expand Down
12 changes: 12 additions & 0 deletions chebifier/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import importlib.resources
import os

import networkx as nx
import requests
import fastobo
import yaml

from chebifier.hugging_face import download_model_files
import pickle

Expand Down Expand Up @@ -129,3 +132,12 @@ def get_disjoint_files():
# pickle.dump(chebi_graph, open("chebi_graph.pkl", "wb"))
chebi_graph = load_chebi_graph()
print(chebi_graph)


def get_default_configs():
with (
importlib.resources.files("chebifier")
.joinpath("ensemble.yml")
.open("r") as f
):
return yaml.safe_load(f)