Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion chebai/loggers/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from datetime import datetime
from typing import List, Literal, Optional, Union

import wandb
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
Expand Down Expand Up @@ -105,6 +104,8 @@ def set_fold(self, fold: int) -> None:
Args:
fold (int): Cross-validation fold number.
"""
import wandb

if fold != self._fold:
self._fold = fold
# Start new experiment
Expand Down
3 changes: 2 additions & 1 deletion chebai/loss/bce_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from chebai.preprocessing.datasets.base import XYBaseDataModule
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed


class BCEWeighted(torch.nn.BCEWithLogitsLoss):
Expand All @@ -27,6 +26,8 @@ def __init__(
data_extractor: Optional[XYBaseDataModule] = None,
**kwargs,
):
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed

self.beta = beta
if isinstance(data_extractor, LabeledUnlabeledMixed):
data_extractor = data_extractor.labeled
Expand Down
10 changes: 7 additions & 3 deletions chebai/loss/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
import math
import os
import pickle
from typing import List, Literal, Union
from typing import TYPE_CHECKING, List, Literal, Union

import torch

from chebai.loss.bce_weighted import BCEWeighted
from chebai.preprocessing.datasets.base import XYBaseDataModule
from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed

if TYPE_CHECKING:
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed


class ImplicationLoss(torch.nn.Module):
Expand Down Expand Up @@ -68,6 +70,8 @@ def __init__(
multiply_with_base_loss: bool = True,
no_grads: bool = False,
):
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed

super().__init__()
# automatically choose labeled subset for implication filter in case of mixed dataset
if isinstance(data_extractor, LabeledUnlabeledMixed):
Expand Down Expand Up @@ -338,7 +342,7 @@ class DisjointLoss(ImplicationLoss):
def __init__(
self,
path_to_disjointness: str,
data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed],
data_extractor: Union[_ChEBIDataExtractor, "LabeledUnlabeledMixed"],
base_loss: torch.nn.Module = None,
disjoint_loss_weight: float = 100,
**kwargs,
Expand Down
26 changes: 16 additions & 10 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
import os
import random
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union

import lightning as pl
import networkx as nx
import pandas as pd
import torch
import tqdm
from iterstrat.ml_stratifiers import (
MultilabelStratifiedKFold,
MultilabelStratifiedShuffleSplit,
)
from lightning.pytorch.core.datamodule import LightningDataModule
from lightning_utilities.core.rank_zero import rank_zero_info
from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.data import DataLoader

from chebai.preprocessing import reader as dr

if TYPE_CHECKING:
import networkx as nx


class XYBaseDataModule(LightningDataModule):
"""
Expand Down Expand Up @@ -818,7 +815,7 @@ def _download_required_data(self) -> str:
pass

@abstractmethod
def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph":
"""
Extracts the class hierarchy from the data.
Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from
Expand All @@ -833,7 +830,7 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
pass

@abstractmethod
def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
def _graph_to_raw_dataset(self, graph: "nx.DiGraph") -> pd.DataFrame:
"""
Converts the graph to a raw dataset.
Uses the graph created by `_extract_class_hierarchy` method to extract the
Expand All @@ -848,7 +845,7 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
pass

@abstractmethod
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List:
"""
Selects classes from the dataset based on a specified criteria.

Expand Down Expand Up @@ -1023,6 +1020,9 @@ def get_test_split(
Raises:
ValueError: If the DataFrame does not contain a column named "labels".
"""
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from sklearn.model_selection import StratifiedShuffleSplit

print("Get test data split")

labels_list = df["labels"].tolist()
Expand Down Expand Up @@ -1060,6 +1060,12 @@ def get_train_val_splits_given_test(
and validation DataFrames. The keys are the names of the train and validation sets, and the values
are the corresponding DataFrames.
"""
from iterstrat.ml_stratifiers import (
MultilabelStratifiedKFold,
MultilabelStratifiedShuffleSplit,
)
from sklearn.model_selection import StratifiedShuffleSplit

print("Split dataset into train / val with given test set")

test_ids = test_df["ident"].tolist()
Expand Down
34 changes: 24 additions & 10 deletions chebai/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@
import pickle
from abc import ABC
from collections import OrderedDict
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union

import fastobo
import networkx as nx
import pandas as pd
import requests
import torch

from chebai.preprocessing import reader as dr
from chebai.preprocessing.datasets.base import XYBaseDataModule, _DynamicDataset

if TYPE_CHECKING:
import fastobo
import networkx as nx

# exclude some entities from the dataset because the violate disjointness axioms
CHEBI_BLACKLIST = [
194026,
Expand Down Expand Up @@ -212,6 +213,8 @@ def _load_chebi(self, version: int) -> str:
Returns:
str: The file path of the loaded ChEBI ontology.
"""
import requests

chebi_name = self.raw_file_names_dict["chebi"]
chebi_path = os.path.join(self.raw_dir, chebi_name)
if not os.path.isfile(chebi_path):
Expand All @@ -223,7 +226,7 @@ def _load_chebi(self, version: int) -> str:
open(chebi_path, "wb").write(r.content)
return chebi_path

def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph":
"""
Extracts the class hierarchy from the ChEBI ontology.
Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from
Expand All @@ -235,6 +238,9 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
Returns:
nx.DiGraph: The class hierarchy.
"""
import fastobo
import networkx as nx

with open(data_path, encoding="utf-8") as chebi:
chebi = "\n".join(line for line in chebi if not line.startswith("xref:"))

Expand Down Expand Up @@ -262,7 +268,7 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
print("Compute transitive closure")
return nx.transitive_closure_dag(g)

def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
def _graph_to_raw_dataset(self, g: "nx.DiGraph") -> pd.DataFrame:
"""
Converts the graph to a raw dataset.
Uses the graph created by `_extract_class_hierarchy` method to extract the
Expand All @@ -274,6 +280,8 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
Returns:
pd.DataFrame: The raw dataset created from the graph.
"""
import networkx as nx

smiles = nx.get_node_attributes(g, "smiles")
names = nx.get_node_attributes(g, "name")

Expand Down Expand Up @@ -572,7 +580,7 @@ def _name(self) -> str:
"""
return f"ChEBI{self.THRESHOLD}"

def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List:
"""
Selects classes from the ChEBI dataset based on the number of successors meeting a specified threshold.

Expand All @@ -597,6 +605,8 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
- The `THRESHOLD` attribute should be defined in the subclass of this class.
- Nodes without a 'smiles' attribute are ignored in the successor count.
"""
import networkx as nx

smiles = nx.get_node_attributes(g, "smiles")
nodes = list(
sorted(
Expand Down Expand Up @@ -735,7 +745,7 @@ def processed_dir_main(self) -> str:
"processed",
)

def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph:
def _extract_class_hierarchy(self, chebi_path: str) -> "nx.DiGraph":
"""
Extracts a subset of ChEBI based on subclasses of the top class ID.

Expand Down Expand Up @@ -773,8 +783,10 @@ def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph:
)
return g

def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List:
"""Only selects classes that meet the threshold AND are subclasses of the top class ID (including itself)."""
import networkx as nx

smiles = nx.get_node_attributes(g, "smiles")
nodes = list(
sorted(
Expand Down Expand Up @@ -834,7 +846,7 @@ def chebi_to_int(s: str) -> int:
return int(s[s.index(":") + 1 :])


def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]:
def term_callback(doc: "fastobo.term.TermFrame") -> Union[Dict, bool]:
"""
Extracts information from a ChEBI term document.
This function takes a ChEBI term document as input and extracts relevant information such as the term ID, parents,
Expand All @@ -851,6 +863,8 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]:
- "name": The name of the ChEBI term.
- "smiles": The SMILES string associated with the ChEBI term, if available.
"""
import fastobo

parts = set()
parents = []
name = None
Expand Down
11 changes: 8 additions & 3 deletions chebai/preprocessing/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
from itertools import islice
from typing import Any, Dict, List, Optional

import deepsmiles
import selfies as sf
from pysmiles.read_smiles import _tokenize
from rdkit import Chem
from transformers import RobertaTokenizerFast

from chebai.preprocessing.collate import DefaultCollator, RaggedCollator

Expand Down Expand Up @@ -220,6 +217,8 @@ class DeepChemDataReader(ChemDataReader):
"""

def __init__(self, *args, **kwargs):
import deepsmiles

super().__init__(*args, **kwargs)
self.converter = deepsmiles.Converter(rings=True, branches=True)
self.error_count = 0
Expand Down Expand Up @@ -294,6 +293,8 @@ def __init__(
vsize: int = 4000,
**kwargs,
):
from transformers import RobertaTokenizerFast

super().__init__(*args, **kwargs)
self.tokenizer = RobertaTokenizerFast.from_pretrained(
data_path, max_len=max_len
Expand Down Expand Up @@ -327,6 +328,8 @@ def __init__(
vsize: int = 4000,
**kwargs,
):
import selfies as sf

super().__init__(*args, **kwargs)
self.error_count = 0
sf.set_semantic_constraints("hypervalent")
Expand All @@ -338,6 +341,8 @@ def name(cls) -> str:

def _read_data(self, raw_data: str) -> Optional[List[int]]:
"""Read and tokenize raw data using SELFIES."""
import selfies as sf

try:
tokenized = sf.split_selfies(sf.encoder(raw_data.strip(), strict=True))
tokenized = [self._get_token_index(v) for v in tokenized]
Expand Down
10 changes: 7 additions & 3 deletions chebai/preprocessing/structures.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Tuple, Union
from typing import TYPE_CHECKING, Any, Tuple, Union

import networkx as nx
import torch

if TYPE_CHECKING:
import networkx as nx


class XYData(torch.utils.data.Dataset):
"""
Expand Down Expand Up @@ -119,7 +121,7 @@ class XYMolData(XYData):
kwargs: Additional fields to store in the dataset.
"""

def to_x(self, device: torch.device) -> Tuple[nx.Graph, ...]:
def to_x(self, device: torch.device) -> Tuple["nx.Graph", ...]:
"""
Moves the node attributes of the molecular graphs to the specified device.

Expand All @@ -129,6 +131,8 @@ def to_x(self, device: torch.device) -> Tuple[nx.Graph, ...]:
Returns:
A tuple of molecular graphs with node attributes on the specified device.
"""
import networkx as nx

l_ = []
for g in self.x:
graph = g.copy()
Expand Down
10 changes: 10 additions & 0 deletions pyproject.toml
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand the extras-system correctly, given the dependencies

dependencies = ["A"]
[project.optional-dependencies]
opt1 = ["B"]
opt2 = ["C"]

you will get the following install options:

pip install my_package -> installs A
pip install my_package[opt1] -> installs A and B
pip install my_package[all] -> installs A, B and C

Given this structure, I would set the dependencies you put under inference as dependencies and make the rest optional (please correct me if I got this wrong)

Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ dev = ["black", "isort", "pre-commit"]
plot = ["matplotlib", "seaborn"]
wandb = ["wandb"]

inference = [
"numpy",
"pandas",
"torch",
"transformers",
"pysmiles==1.1.2",
"rdkit",
"lightning>=2.5",
]

[tool.setuptools]
include-package-data = true
license-files = ["LICEN[CS]E*"]
Expand Down