Skip to content
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
cc5bc08
move previous tests to integration dir
aditya0by0 Aug 29, 2024
5af0351
unit dir + test for ChemDataReader
aditya0by0 Aug 29, 2024
a0810a2
Test for DataReader
aditya0by0 Aug 29, 2024
1b3836d
tests for DeepChemReader
aditya0by0 Aug 29, 2024
aa467c6
Test for SelfiesReader
aditya0by0 Aug 29, 2024
b6f5e51
test for ProteinDataReader
aditya0by0 Aug 30, 2024
73f05c0
test for DefaultCollator
aditya0by0 Aug 30, 2024
8007f37
test for RaggedColllator
aditya0by0 Aug 31, 2024
248eaa7
modify tests to use `setUpClass` class method instead of `setUp` inst…
aditya0by0 Aug 31, 2024
3e57d78
bool labels instead of numeric, for realistic data
aditya0by0 Sep 1, 2024
f9ca653
test for XYBaseDataModule
aditya0by0 Sep 1, 2024
d8016aa
test for DynamicDataset
aditya0by0 Sep 1, 2024
0c7c5b8
add relevant msg to each assert statement
aditya0by0 Sep 1, 2024
c0aaeea
test data class for chebi ontology
aditya0by0 Sep 4, 2024
764216e
test for term callback + mock data changes
aditya0by0 Sep 4, 2024
1dd8428
test for chebidataextractor + changes in mock data
aditya0by0 Sep 5, 2024
f3519b5
mock reader for all + test_setup_pruned_test_set changes
aditya0by0 Sep 5, 2024
fc0fd47
fix for misalignment between x an y in RaggedCollator
aditya0by0 Sep 5, 2024
f7f1631
test for ChebiOverX
aditya0by0 Sep 6, 2024
bf45bb5
test for ChebiXOverPartial
aditya0by0 Sep 6, 2024
17bf584
Mock data for GOUniProt
aditya0by0 Sep 9, 2024
c6c5a59
test for GOUniProtDataExtractor
aditya0by0 Sep 9, 2024
78f5289
Merge branch 'protein_prediction' into additional_unit_tests
aditya0by0 Sep 9, 2024
427bc60
update test to new method name _extract_class_hierarchy
aditya0by0 Sep 9, 2024
c01ecde
test for GOUniProtOverX
aditya0by0 Sep 9, 2024
dfd084e
test for _load_data_from_file for Tox21MolNet
aditya0by0 Sep 10, 2024
77956d4
_load_data_from_file test case Tox21Challenge
aditya0by0 Sep 16, 2024
a3670b0
test for Tox21Chal
aditya0by0 Sep 17, 2024
ac3ac19
patch `os.makedirs` in tests to avoid creating directories
aditya0by0 Sep 17, 2024
44a1dfd
add test case for invalid token/input to read_data
aditya0by0 Sep 22, 2024
aab0fea
test case for `Tox21MolNet.setup_processed` simple split
aditya0by0 Sep 25, 2024
fc8182e
test case for `Tox21MolNet.setup_processed` group split
aditya0by0 Sep 25, 2024
e4caae8
add group key + convert generator to list
aditya0by0 Sep 25, 2024
43c2408
Merge branch 'refactor_tox21MolNet' into additional_unit_tests
aditya0by0 Sep 25, 2024
05f8f0c
Merge branch 'refactor_term_callback' into additional_unit_tests
aditya0by0 Sep 25, 2024
1d3ecbe
update chebi test as per modified term_callback
aditya0by0 Sep 25, 2024
d6726cc
Merge branch 'refactor_term_callback' into additional_unit_tests
aditya0by0 Sep 25, 2024
35a621c
group key not needed for Tox21Chal._load_dict
aditya0by0 Sep 25, 2024
c2e6897
Merge branch 'dev' into additional_unit_tests
aditya0by0 Oct 1, 2024
016134f
Obsolete terms being the parent of valid terms
aditya0by0 Oct 1, 2024
d873bd7
Merge branch 'refactor_term_callback' into additional_unit_tests
aditya0by0 Oct 1, 2024
553f083
Merge branch 'protein_prediction' into additional_unit_tests
aditya0by0 Oct 1, 2024
b479d5a
remove absolete path for mocked open func
aditya0by0 Oct 5, 2024
adedc09
test single label split scenario implemented in #54
aditya0by0 Oct 5, 2024
65c2d9b
test output format for Tox21MolNet._load_data_from_file
aditya0by0 Oct 5, 2024
72dd50f
Merge branch 'dev' into additional_unit_tests
aditya0by0 Oct 5, 2024
a63c010
DynamicDataset: check split stratification
aditya0by0 Oct 5, 2024
309daed
Merge branch 'dev' into additional_unit_tests
aditya0by0 Oct 11, 2024
e38d1ab
Merge branch 'protein_prediction' into additional_unit_tests
aditya0by0 Oct 12, 2024
e3c4b6e
fix testcase for GO
aditya0by0 Oct 12, 2024
1470e93
Merge branch 'protein_prediction' into additional_unit_tests
aditya0by0 Oct 19, 2024
c1ddd17
update testcase as per transitive go ids
aditya0by0 Oct 20, 2024
bf6bc4a
remove test for tox21mol net
aditya0by0 Oct 20, 2024
b915b0d
Revert "add group key + convert generator to list"
aditya0by0 Oct 22, 2024
282bc09
Merge branch 'dev' into additional_unit_tests
aditya0by0 Nov 1, 2024
a71b199
update swiss data for pretraining test
aditya0by0 Nov 2, 2024
8abd14d
add test for protein pretraining class
aditya0by0 Nov 2, 2024
aae57d3
test : reformat with precommit
aditya0by0 Nov 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8,000 changes: 8,000 additions & 0 deletions chebai/preprocessing/bin/protein_token_3_gram/tokens.txt

Large diffs are not rendered by default.

32 changes: 19 additions & 13 deletions chebai/preprocessing/datasets/go_uniprot.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ def __init__(self, **kwargs):
self.max_sequence_length >= 1
), "Max sequence length should be greater than or equal to 1."

if self.reader.n_gram is not None:
assert self.max_sequence_length >= self.reader.n_gram, (
f"max_sequence_length ({self.max_sequence_length}) must be greater than "
f"or equal to n_gram ({self.reader.n_gram})."
)

@classmethod
def _get_go_branch(cls, **kwargs) -> str:
"""
Expand Down Expand Up @@ -536,7 +542,8 @@ def dataloader(self, kind: str, **kwargs) -> DataLoader:

This method overrides the dataloader method from the superclass. After fetching the dataset from the
superclass, it truncates the 'features' of each data instance to a maximum length specified by
`self.max_sequence_length`.
`self.max_sequence_length`. The truncation is adjusted based on the value of `n_gram` to ensure that
the correct number of amino acids is preserved in the truncated sequences.

Args:
kind (str): The kind of data to load (e.g., 'train', 'val', 'test').
Expand All @@ -547,9 +554,18 @@ def dataloader(self, kind: str, **kwargs) -> DataLoader:
"""
dataloader = super().dataloader(kind, **kwargs)

# Truncate the 'features' to max_sequence_length for each instance
if self.reader.n_gram is None:
# Truncate the 'features' to max_sequence_length for each instance
truncate_index = self.max_sequence_length
else:
# If n_gram is given, adjust truncation to ensure maximum sequence length refers to the maximum number of
# amino acids in sequence rather than number of n-grams. Eg, Sequence "ABCDEFGHIJ" can form 8 trigrams,
# if max length is 5, then only first 3 trigrams should be considered as they are formed by first 5 letters.
truncate_index = self.max_sequence_length - (self.reader.n_gram - 1)

for instance in dataloader.dataset:
instance["features"] = instance["features"][: self.max_sequence_length]
instance["features"] = instance["features"][:truncate_index]

return dataloader

# ------------------------------ Phase: Raw Properties -----------------------------------
Expand All @@ -563,16 +579,6 @@ def base_dir(self) -> str:
"""
return os.path.join("data", f"GO_UniProt")

@property
def identifier(self) -> tuple:
"""Identifier for the dataset."""
# overriding identifier instead of reader.name to keep same tokens.txt file, but different processed_dir folder
if not isinstance(self.reader, dr.ProteinDataReader):
raise ValueError("Need Protein DataReader for identifier")
if self.reader.n_gram is not None:
return (f"{self.reader.name()}_{self.reader.n_gram}_gram",)
return (self.reader.name(),)

@property
def raw_file_names_dict(self) -> dict:
"""
Expand Down
7 changes: 5 additions & 2 deletions chebai/preprocessing/datasets/tox21.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def download(self) -> None:
def setup_processed(self) -> None:
"""Processes and splits the dataset."""
print("Create splits")
data = self._load_data_from_file(os.path.join(self.raw_dir, f"tox21.csv"))
data = list(self._load_data_from_file(os.path.join(self.raw_dir, f"tox21.csv")))
groups = np.array([d["group"] for d in data])
if not all(g is None for g in groups):
split_size = int(len(set(groups)) * self.train_split)
Expand Down Expand Up @@ -145,7 +145,10 @@ def _load_data_from_file(self, input_file_path: str) -> List[Dict]:
labels = [
bool(int(l)) if l else None for l in (row[k] for k in self.HEADERS)
]
yield dict(features=smiles, labels=labels, ident=row["mol_id"])
group = row.get("group", None)
yield dict(
features=smiles, labels=labels, ident=row["mol_id"], group=group
)


class Tox21Challenge(XYBaseDataModule):
Expand Down
6 changes: 4 additions & 2 deletions chebai/preprocessing/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,14 +372,16 @@ class ProteinDataReader(DataReader):
"V",
]

@classmethod
def name(cls) -> str:
def name(self) -> str:
"""
Returns the name of the data reader. This method identifies the specific type of data reader.

Returns:
str: The name of the data reader, which is "protein_token".
"""
if self.n_gram is not None:
return f"protein_token_{self.n_gram}_gram"

return "protein_token"

def __init__(self, *args, n_gram: Optional[int] = None, **kwargs):
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
This directory contains integration tests that cover the overall behavior of the data preprocessing tool.
"""
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 4 additions & 0 deletions tests/unit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""
This directory contains unit tests, which focus on individual functions and methods, ensuring they work as
expected in isolation.
"""
Empty file.
65 changes: 65 additions & 0 deletions tests/unit/collators/testDefaultCollator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import unittest
from typing import Dict, List

from chebai.preprocessing.collate import DefaultCollator
from chebai.preprocessing.structures import XYData


class TestDefaultCollator(unittest.TestCase):
"""
Unit tests for the DefaultCollator class.
"""

@classmethod
def setUpClass(cls) -> None:
"""
Set up the test environment by initializing a DefaultCollator instance.
"""
cls.collator = DefaultCollator()

def test_call_with_valid_data(self) -> None:
"""
Test the __call__ method with valid data to ensure features and labels are correctly extracted.
"""
data: List[Dict] = [
{"features": [1.0, 2.0], "labels": [True, False, True]},
{"features": [3.0, 4.0], "labels": [False, False, True]},
]

result: XYData = self.collator(data)
self.assertIsInstance(
result, XYData, "The result should be an instance of XYData."
)

expected_x = ([1.0, 2.0], [3.0, 4.0])
expected_y = ([True, False, True], [False, False, True])

self.assertEqual(
result.x,
expected_x,
"The feature data 'x' does not match the expected output.",
)
self.assertEqual(
result.y,
expected_y,
"The label data 'y' does not match the expected output.",
)

def test_call_with_empty_data(self) -> None:
"""
Test the __call__ method with an empty list to ensure it handles the edge case correctly.
"""
data: List[Dict] = []

with self.assertRaises(ValueError) as context:
self.collator(data)

self.assertEqual(
str(context.exception),
"not enough values to unpack (expected 2, got 0)",
"The exception message for empty data is not as expected.",
)


if __name__ == "__main__":
unittest.main()
204 changes: 204 additions & 0 deletions tests/unit/collators/testRaggedCollator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import unittest
from typing import Dict, List, Tuple

import torch

from chebai.preprocessing.collate import RaggedCollator
from chebai.preprocessing.structures import XYData


class TestRaggedCollator(unittest.TestCase):
"""
Unit tests for the RaggedCollator class.
"""

@classmethod
def setUpClass(cls) -> None:
"""
Set up the test environment by initializing a RaggedCollator instance.
"""
cls.collator = RaggedCollator()

def test_call_with_valid_data(self) -> None:
"""
Test the __call__ method with valid ragged data to ensure features, labels, and masks are correctly handled.
"""
data: List[Dict] = [
{"features": [1, 2], "labels": [True, False], "ident": "sample1"},
{"features": [3, 4, 5], "labels": [False, True, True], "ident": "sample2"},
{"features": [6], "labels": [True], "ident": "sample3"},
]

result: XYData = self.collator(data)

expected_x = torch.tensor([[1, 2, 0], [3, 4, 5], [6, 0, 0]])
expected_y = torch.tensor(
[[True, False, False], [False, True, True], [True, False, False]]
)
expected_mask_for_x = torch.tensor(
[[True, True, False], [True, True, True], [True, False, False]]
)
expected_lens_for_x = torch.tensor([2, 3, 1])

self.assertTrue(
torch.equal(result.x, expected_x),
"The feature tensor 'x' does not match the expected output.",
)
self.assertTrue(
torch.equal(result.y, expected_y),
"The label tensor 'y' does not match the expected output.",
)
self.assertTrue(
torch.equal(
result.additional_fields["model_kwargs"]["mask"], expected_mask_for_x
),
"The mask tensor does not match the expected output.",
)
self.assertTrue(
torch.equal(
result.additional_fields["model_kwargs"]["lens"], expected_lens_for_x
),
"The lens tensor does not match the expected output.",
)
self.assertEqual(
result.additional_fields["idents"],
("sample1", "sample2", "sample3"),
"The identifiers do not match the expected output.",
)

def test_call_with_missing_entire_labels(self) -> None:
"""
Test the __call__ method with data where some samples are missing labels.
"""
data: List[Dict] = [
{"features": [1, 2], "labels": [True, False], "ident": "sample1"},
{"features": [3, 4, 5], "labels": None, "ident": "sample2"},
{"features": [6], "labels": [True], "ident": "sample3"},
]

result: XYData = self.collator(data)

# https://github.com/ChEB-AI/python-chebai/pull/48#issuecomment-2324393829
expected_x = torch.tensor([[1, 2, 0], [3, 4, 5], [6, 0, 0]])
expected_y = torch.tensor(
[[True, False], [True, False]]
) # True -> 1, False -> 0
expected_mask_for_x = torch.tensor(
[[True, True, False], [True, True, True], [True, False, False]]
)
expected_lens_for_x = torch.tensor([2, 3, 1])

self.assertTrue(
torch.equal(result.x, expected_x),
"The feature tensor 'x' does not match the expected output when labels are missing.",
)
self.assertTrue(
torch.equal(result.y, expected_y),
"The label tensor 'y' does not match the expected output when labels are missing.",
)
self.assertTrue(
torch.equal(
result.additional_fields["model_kwargs"]["mask"], expected_mask_for_x
),
"The mask tensor does not match the expected output when labels are missing.",
)
self.assertTrue(
torch.equal(
result.additional_fields["model_kwargs"]["lens"], expected_lens_for_x
),
"The lens tensor does not match the expected output when labels are missing.",
)
self.assertEqual(
result.additional_fields["loss_kwargs"]["non_null_labels"],
[0, 2],
"The non-null labels list does not match the expected output.",
)
self.assertEqual(
len(result.additional_fields["loss_kwargs"]["non_null_labels"]),
result.y.shape[1],
"The length of non null labels list must match with target label variable size",
)
self.assertEqual(
result.additional_fields["idents"],
("sample1", "sample2", "sample3"),
"The identifiers do not match the expected output when labels are missing.",
)

def test_call_with_none_in_labels(self) -> None:
"""
Test the __call__ method with data where one of the elements in the labels is None.
"""
data: List[Dict] = [
{"features": [1, 2], "labels": [None, True], "ident": "sample1"},
{"features": [3, 4, 5], "labels": [True, False], "ident": "sample2"},
{"features": [6], "labels": [True], "ident": "sample3"},
]

result: XYData = self.collator(data)

expected_x = torch.tensor([[1, 2, 0], [3, 4, 5], [6, 0, 0]])
expected_y = torch.tensor(
[[False, True], [True, False], [True, False]]
) # None -> False
expected_mask_for_x = torch.tensor(
[[True, True, False], [True, True, True], [True, False, False]]
)
expected_lens_for_x = torch.tensor([2, 3, 1])

self.assertTrue(
torch.equal(result.x, expected_x),
"The feature tensor 'x' does not match the expected output when labels contain None.",
)
self.assertTrue(
torch.equal(result.y, expected_y),
"The label tensor 'y' does not match the expected output when labels contain None.",
)
self.assertTrue(
torch.equal(
result.additional_fields["model_kwargs"]["mask"], expected_mask_for_x
),
"The mask tensor does not match the expected output when labels contain None.",
)
self.assertTrue(
torch.equal(
result.additional_fields["model_kwargs"]["lens"], expected_lens_for_x
),
"The lens tensor does not match the expected output when labels contain None.",
)
self.assertEqual(
result.additional_fields["idents"],
("sample1", "sample2", "sample3"),
"The identifiers do not match the expected output when labels contain None.",
)

def test_call_with_empty_data(self) -> None:
"""
Test the __call__ method with an empty list to ensure it raises an error.
"""
data: List[Dict] = []

with self.assertRaises(
Exception, msg="Expected an Error when no data is provided"
):
self.collator(data)

def test_process_label_rows(self) -> None:
"""
Test the process_label_rows method to ensure it pads label sequences correctly.
"""
labels: Tuple = ([True, False], [False, True, True], [True])

result: torch.Tensor = self.collator.process_label_rows(labels)

expected_output = torch.tensor(
[[True, False, False], [False, True, True], [True, False, False]]
)

self.assertTrue(
torch.equal(result, expected_output),
"The processed label rows tensor does not match the expected output.",
)


if __name__ == "__main__":
unittest.main()
Empty file.
Loading
Loading