diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 00000000..caa8759f --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1,3 @@ +""" +This directory contains integration tests that cover the overall behavior of the data preprocessing tool. +""" diff --git a/tests/testChebiData.py b/tests/integration/testChebiData.py similarity index 100% rename from tests/testChebiData.py rename to tests/integration/testChebiData.py diff --git a/tests/testChebiDynamicDataSplits.py b/tests/integration/testChebiDynamicDataSplits.py similarity index 100% rename from tests/testChebiDynamicDataSplits.py rename to tests/integration/testChebiDynamicDataSplits.py diff --git a/tests/testCustomBalancedAccuracyMetric.py b/tests/integration/testCustomBalancedAccuracyMetric.py similarity index 100% rename from tests/testCustomBalancedAccuracyMetric.py rename to tests/integration/testCustomBalancedAccuracyMetric.py diff --git a/tests/testCustomMacroF1Metric.py b/tests/integration/testCustomMacroF1Metric.py similarity index 100% rename from tests/testCustomMacroF1Metric.py rename to tests/integration/testCustomMacroF1Metric.py diff --git a/tests/testPubChemData.py b/tests/integration/testPubChemData.py similarity index 100% rename from tests/testPubChemData.py rename to tests/integration/testPubChemData.py diff --git a/tests/testTox21MolNetData.py b/tests/integration/testTox21MolNetData.py similarity index 100% rename from tests/testTox21MolNetData.py rename to tests/integration/testTox21MolNetData.py diff --git a/tests/test_data/ChEBIOver100_test/labels000.pt b/tests/integration/test_data/ChEBIOver100_test/labels000.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels000.pt rename to tests/integration/test_data/ChEBIOver100_test/labels000.pt diff --git a/tests/test_data/ChEBIOver100_test/labels001.pt b/tests/integration/test_data/ChEBIOver100_test/labels001.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels001.pt rename to tests/integration/test_data/ChEBIOver100_test/labels001.pt diff --git a/tests/test_data/ChEBIOver100_test/labels002.pt b/tests/integration/test_data/ChEBIOver100_test/labels002.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels002.pt rename to tests/integration/test_data/ChEBIOver100_test/labels002.pt diff --git a/tests/test_data/ChEBIOver100_test/labels003.pt b/tests/integration/test_data/ChEBIOver100_test/labels003.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels003.pt rename to tests/integration/test_data/ChEBIOver100_test/labels003.pt diff --git a/tests/test_data/ChEBIOver100_test/labels004.pt b/tests/integration/test_data/ChEBIOver100_test/labels004.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels004.pt rename to tests/integration/test_data/ChEBIOver100_test/labels004.pt diff --git a/tests/test_data/ChEBIOver100_test/labels005.pt b/tests/integration/test_data/ChEBIOver100_test/labels005.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels005.pt rename to tests/integration/test_data/ChEBIOver100_test/labels005.pt diff --git a/tests/test_data/ChEBIOver100_test/labels006.pt b/tests/integration/test_data/ChEBIOver100_test/labels006.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels006.pt rename to tests/integration/test_data/ChEBIOver100_test/labels006.pt diff --git a/tests/test_data/ChEBIOver100_test/labels007.pt b/tests/integration/test_data/ChEBIOver100_test/labels007.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels007.pt rename to tests/integration/test_data/ChEBIOver100_test/labels007.pt diff --git a/tests/test_data/ChEBIOver100_test/labels008.pt b/tests/integration/test_data/ChEBIOver100_test/labels008.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels008.pt rename to tests/integration/test_data/ChEBIOver100_test/labels008.pt diff --git a/tests/test_data/ChEBIOver100_test/labels009.pt b/tests/integration/test_data/ChEBIOver100_test/labels009.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels009.pt rename to tests/integration/test_data/ChEBIOver100_test/labels009.pt diff --git a/tests/test_data/ChEBIOver100_test/labels010.pt b/tests/integration/test_data/ChEBIOver100_test/labels010.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels010.pt rename to tests/integration/test_data/ChEBIOver100_test/labels010.pt diff --git a/tests/test_data/ChEBIOver100_test/labels011.pt b/tests/integration/test_data/ChEBIOver100_test/labels011.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels011.pt rename to tests/integration/test_data/ChEBIOver100_test/labels011.pt diff --git a/tests/test_data/ChEBIOver100_test/labels012.pt b/tests/integration/test_data/ChEBIOver100_test/labels012.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels012.pt rename to tests/integration/test_data/ChEBIOver100_test/labels012.pt diff --git a/tests/test_data/ChEBIOver100_test/labels013.pt b/tests/integration/test_data/ChEBIOver100_test/labels013.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels013.pt rename to tests/integration/test_data/ChEBIOver100_test/labels013.pt diff --git a/tests/test_data/ChEBIOver100_test/labels014.pt b/tests/integration/test_data/ChEBIOver100_test/labels014.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels014.pt rename to tests/integration/test_data/ChEBIOver100_test/labels014.pt diff --git a/tests/test_data/ChEBIOver100_test/labels015.pt b/tests/integration/test_data/ChEBIOver100_test/labels015.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels015.pt rename to tests/integration/test_data/ChEBIOver100_test/labels015.pt diff --git a/tests/test_data/ChEBIOver100_test/labels016.pt b/tests/integration/test_data/ChEBIOver100_test/labels016.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels016.pt rename to tests/integration/test_data/ChEBIOver100_test/labels016.pt diff --git a/tests/test_data/ChEBIOver100_test/labels017.pt b/tests/integration/test_data/ChEBIOver100_test/labels017.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels017.pt rename to tests/integration/test_data/ChEBIOver100_test/labels017.pt diff --git a/tests/test_data/ChEBIOver100_test/labels018.pt b/tests/integration/test_data/ChEBIOver100_test/labels018.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels018.pt rename to tests/integration/test_data/ChEBIOver100_test/labels018.pt diff --git a/tests/test_data/ChEBIOver100_test/labels019.pt b/tests/integration/test_data/ChEBIOver100_test/labels019.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/labels019.pt rename to tests/integration/test_data/ChEBIOver100_test/labels019.pt diff --git a/tests/test_data/ChEBIOver100_test/preds000.pt b/tests/integration/test_data/ChEBIOver100_test/preds000.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds000.pt rename to tests/integration/test_data/ChEBIOver100_test/preds000.pt diff --git a/tests/test_data/ChEBIOver100_test/preds001.pt b/tests/integration/test_data/ChEBIOver100_test/preds001.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds001.pt rename to tests/integration/test_data/ChEBIOver100_test/preds001.pt diff --git a/tests/test_data/ChEBIOver100_test/preds002.pt b/tests/integration/test_data/ChEBIOver100_test/preds002.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds002.pt rename to tests/integration/test_data/ChEBIOver100_test/preds002.pt diff --git a/tests/test_data/ChEBIOver100_test/preds003.pt b/tests/integration/test_data/ChEBIOver100_test/preds003.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds003.pt rename to tests/integration/test_data/ChEBIOver100_test/preds003.pt diff --git a/tests/test_data/ChEBIOver100_test/preds004.pt b/tests/integration/test_data/ChEBIOver100_test/preds004.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds004.pt rename to tests/integration/test_data/ChEBIOver100_test/preds004.pt diff --git a/tests/test_data/ChEBIOver100_test/preds005.pt b/tests/integration/test_data/ChEBIOver100_test/preds005.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds005.pt rename to tests/integration/test_data/ChEBIOver100_test/preds005.pt diff --git a/tests/test_data/ChEBIOver100_test/preds006.pt b/tests/integration/test_data/ChEBIOver100_test/preds006.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds006.pt rename to tests/integration/test_data/ChEBIOver100_test/preds006.pt diff --git a/tests/test_data/ChEBIOver100_test/preds007.pt b/tests/integration/test_data/ChEBIOver100_test/preds007.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds007.pt rename to tests/integration/test_data/ChEBIOver100_test/preds007.pt diff --git a/tests/test_data/ChEBIOver100_test/preds008.pt b/tests/integration/test_data/ChEBIOver100_test/preds008.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds008.pt rename to tests/integration/test_data/ChEBIOver100_test/preds008.pt diff --git a/tests/test_data/ChEBIOver100_test/preds009.pt b/tests/integration/test_data/ChEBIOver100_test/preds009.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds009.pt rename to tests/integration/test_data/ChEBIOver100_test/preds009.pt diff --git a/tests/test_data/ChEBIOver100_test/preds010.pt b/tests/integration/test_data/ChEBIOver100_test/preds010.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds010.pt rename to tests/integration/test_data/ChEBIOver100_test/preds010.pt diff --git a/tests/test_data/ChEBIOver100_test/preds011.pt b/tests/integration/test_data/ChEBIOver100_test/preds011.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds011.pt rename to tests/integration/test_data/ChEBIOver100_test/preds011.pt diff --git a/tests/test_data/ChEBIOver100_test/preds012.pt b/tests/integration/test_data/ChEBIOver100_test/preds012.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds012.pt rename to tests/integration/test_data/ChEBIOver100_test/preds012.pt diff --git a/tests/test_data/ChEBIOver100_test/preds013.pt b/tests/integration/test_data/ChEBIOver100_test/preds013.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds013.pt rename to tests/integration/test_data/ChEBIOver100_test/preds013.pt diff --git a/tests/test_data/ChEBIOver100_test/preds014.pt b/tests/integration/test_data/ChEBIOver100_test/preds014.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds014.pt rename to tests/integration/test_data/ChEBIOver100_test/preds014.pt diff --git a/tests/test_data/ChEBIOver100_test/preds015.pt b/tests/integration/test_data/ChEBIOver100_test/preds015.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds015.pt rename to tests/integration/test_data/ChEBIOver100_test/preds015.pt diff --git a/tests/test_data/ChEBIOver100_test/preds016.pt b/tests/integration/test_data/ChEBIOver100_test/preds016.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds016.pt rename to tests/integration/test_data/ChEBIOver100_test/preds016.pt diff --git a/tests/test_data/ChEBIOver100_test/preds017.pt b/tests/integration/test_data/ChEBIOver100_test/preds017.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds017.pt rename to tests/integration/test_data/ChEBIOver100_test/preds017.pt diff --git a/tests/test_data/ChEBIOver100_test/preds018.pt b/tests/integration/test_data/ChEBIOver100_test/preds018.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds018.pt rename to tests/integration/test_data/ChEBIOver100_test/preds018.pt diff --git a/tests/test_data/ChEBIOver100_test/preds019.pt b/tests/integration/test_data/ChEBIOver100_test/preds019.pt similarity index 100% rename from tests/test_data/ChEBIOver100_test/preds019.pt rename to tests/integration/test_data/ChEBIOver100_test/preds019.pt diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 00000000..6640a696 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1,4 @@ +""" +This directory contains unit tests, which focus on individual functions and methods, ensuring they work as +expected in isolation. +""" diff --git a/tests/unit/collators/__init__.py b/tests/unit/collators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/collators/testDefaultCollator.py b/tests/unit/collators/testDefaultCollator.py new file mode 100644 index 00000000..73f09c75 --- /dev/null +++ b/tests/unit/collators/testDefaultCollator.py @@ -0,0 +1,65 @@ +import unittest +from typing import Dict, List + +from chebai.preprocessing.collate import DefaultCollator +from chebai.preprocessing.structures import XYData + + +class TestDefaultCollator(unittest.TestCase): + """ + Unit tests for the DefaultCollator class. + """ + + @classmethod + def setUpClass(cls) -> None: + """ + Set up the test environment by initializing a DefaultCollator instance. + """ + cls.collator = DefaultCollator() + + def test_call_with_valid_data(self) -> None: + """ + Test the __call__ method with valid data to ensure features and labels are correctly extracted. + """ + data: List[Dict] = [ + {"features": [1.0, 2.0], "labels": [True, False, True]}, + {"features": [3.0, 4.0], "labels": [False, False, True]}, + ] + + result: XYData = self.collator(data) + self.assertIsInstance( + result, XYData, "The result should be an instance of XYData." + ) + + expected_x = ([1.0, 2.0], [3.0, 4.0]) + expected_y = ([True, False, True], [False, False, True]) + + self.assertEqual( + result.x, + expected_x, + "The feature data 'x' does not match the expected output.", + ) + self.assertEqual( + result.y, + expected_y, + "The label data 'y' does not match the expected output.", + ) + + def test_call_with_empty_data(self) -> None: + """ + Test the __call__ method with an empty list to ensure it handles the edge case correctly. + """ + data: List[Dict] = [] + + with self.assertRaises(ValueError) as context: + self.collator(data) + + self.assertEqual( + str(context.exception), + "not enough values to unpack (expected 2, got 0)", + "The exception message for empty data is not as expected.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/collators/testRaggedCollator.py b/tests/unit/collators/testRaggedCollator.py new file mode 100644 index 00000000..d9ab2b1d --- /dev/null +++ b/tests/unit/collators/testRaggedCollator.py @@ -0,0 +1,204 @@ +import unittest +from typing import Dict, List, Tuple + +import torch + +from chebai.preprocessing.collate import RaggedCollator +from chebai.preprocessing.structures import XYData + + +class TestRaggedCollator(unittest.TestCase): + """ + Unit tests for the RaggedCollator class. + """ + + @classmethod + def setUpClass(cls) -> None: + """ + Set up the test environment by initializing a RaggedCollator instance. + """ + cls.collator = RaggedCollator() + + def test_call_with_valid_data(self) -> None: + """ + Test the __call__ method with valid ragged data to ensure features, labels, and masks are correctly handled. + """ + data: List[Dict] = [ + {"features": [1, 2], "labels": [True, False], "ident": "sample1"}, + {"features": [3, 4, 5], "labels": [False, True, True], "ident": "sample2"}, + {"features": [6], "labels": [True], "ident": "sample3"}, + ] + + result: XYData = self.collator(data) + + expected_x = torch.tensor([[1, 2, 0], [3, 4, 5], [6, 0, 0]]) + expected_y = torch.tensor( + [[True, False, False], [False, True, True], [True, False, False]] + ) + expected_mask_for_x = torch.tensor( + [[True, True, False], [True, True, True], [True, False, False]] + ) + expected_lens_for_x = torch.tensor([2, 3, 1]) + + self.assertTrue( + torch.equal(result.x, expected_x), + "The feature tensor 'x' does not match the expected output.", + ) + self.assertTrue( + torch.equal(result.y, expected_y), + "The label tensor 'y' does not match the expected output.", + ) + self.assertTrue( + torch.equal( + result.additional_fields["model_kwargs"]["mask"], expected_mask_for_x + ), + "The mask tensor does not match the expected output.", + ) + self.assertTrue( + torch.equal( + result.additional_fields["model_kwargs"]["lens"], expected_lens_for_x + ), + "The lens tensor does not match the expected output.", + ) + self.assertEqual( + result.additional_fields["idents"], + ("sample1", "sample2", "sample3"), + "The identifiers do not match the expected output.", + ) + + def test_call_with_missing_entire_labels(self) -> None: + """ + Test the __call__ method with data where some samples are missing labels. + """ + data: List[Dict] = [ + {"features": [1, 2], "labels": [True, False], "ident": "sample1"}, + {"features": [3, 4, 5], "labels": None, "ident": "sample2"}, + {"features": [6], "labels": [True], "ident": "sample3"}, + ] + + result: XYData = self.collator(data) + + # https://github.com/ChEB-AI/python-chebai/pull/48#issuecomment-2324393829 + expected_x = torch.tensor([[1, 2, 0], [3, 4, 5], [6, 0, 0]]) + expected_y = torch.tensor( + [[True, False], [True, False]] + ) # True -> 1, False -> 0 + expected_mask_for_x = torch.tensor( + [[True, True, False], [True, True, True], [True, False, False]] + ) + expected_lens_for_x = torch.tensor([2, 3, 1]) + + self.assertTrue( + torch.equal(result.x, expected_x), + "The feature tensor 'x' does not match the expected output when labels are missing.", + ) + self.assertTrue( + torch.equal(result.y, expected_y), + "The label tensor 'y' does not match the expected output when labels are missing.", + ) + self.assertTrue( + torch.equal( + result.additional_fields["model_kwargs"]["mask"], expected_mask_for_x + ), + "The mask tensor does not match the expected output when labels are missing.", + ) + self.assertTrue( + torch.equal( + result.additional_fields["model_kwargs"]["lens"], expected_lens_for_x + ), + "The lens tensor does not match the expected output when labels are missing.", + ) + self.assertEqual( + result.additional_fields["loss_kwargs"]["non_null_labels"], + [0, 2], + "The non-null labels list does not match the expected output.", + ) + self.assertEqual( + len(result.additional_fields["loss_kwargs"]["non_null_labels"]), + result.y.shape[1], + "The length of non null labels list must match with target label variable size", + ) + self.assertEqual( + result.additional_fields["idents"], + ("sample1", "sample2", "sample3"), + "The identifiers do not match the expected output when labels are missing.", + ) + + def test_call_with_none_in_labels(self) -> None: + """ + Test the __call__ method with data where one of the elements in the labels is None. + """ + data: List[Dict] = [ + {"features": [1, 2], "labels": [None, True], "ident": "sample1"}, + {"features": [3, 4, 5], "labels": [True, False], "ident": "sample2"}, + {"features": [6], "labels": [True], "ident": "sample3"}, + ] + + result: XYData = self.collator(data) + + expected_x = torch.tensor([[1, 2, 0], [3, 4, 5], [6, 0, 0]]) + expected_y = torch.tensor( + [[False, True], [True, False], [True, False]] + ) # None -> False + expected_mask_for_x = torch.tensor( + [[True, True, False], [True, True, True], [True, False, False]] + ) + expected_lens_for_x = torch.tensor([2, 3, 1]) + + self.assertTrue( + torch.equal(result.x, expected_x), + "The feature tensor 'x' does not match the expected output when labels contain None.", + ) + self.assertTrue( + torch.equal(result.y, expected_y), + "The label tensor 'y' does not match the expected output when labels contain None.", + ) + self.assertTrue( + torch.equal( + result.additional_fields["model_kwargs"]["mask"], expected_mask_for_x + ), + "The mask tensor does not match the expected output when labels contain None.", + ) + self.assertTrue( + torch.equal( + result.additional_fields["model_kwargs"]["lens"], expected_lens_for_x + ), + "The lens tensor does not match the expected output when labels contain None.", + ) + self.assertEqual( + result.additional_fields["idents"], + ("sample1", "sample2", "sample3"), + "The identifiers do not match the expected output when labels contain None.", + ) + + def test_call_with_empty_data(self) -> None: + """ + Test the __call__ method with an empty list to ensure it raises an error. + """ + data: List[Dict] = [] + + with self.assertRaises( + Exception, msg="Expected an Error when no data is provided" + ): + self.collator(data) + + def test_process_label_rows(self) -> None: + """ + Test the process_label_rows method to ensure it pads label sequences correctly. + """ + labels: Tuple = ([True, False], [False, True, True], [True]) + + result: torch.Tensor = self.collator.process_label_rows(labels) + + expected_output = torch.tensor( + [[True, False, False], [False, True, True], [True, False, False]] + ) + + self.assertTrue( + torch.equal(result, expected_output), + "The processed label rows tensor does not match the expected output.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataset_classes/__init__.py b/tests/unit/dataset_classes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/dataset_classes/testChEBIOverX.py b/tests/unit/dataset_classes/testChEBIOverX.py new file mode 100644 index 00000000..270b868c --- /dev/null +++ b/tests/unit/dataset_classes/testChEBIOverX.py @@ -0,0 +1,125 @@ +import unittest +from unittest.mock import PropertyMock, mock_open, patch + +from chebai.preprocessing.datasets.chebi import ChEBIOverX +from tests.unit.mock_data.ontology_mock_data import ChebiMockOntology + + +class TestChEBIOverX(unittest.TestCase): + @classmethod + @patch.multiple(ChEBIOverX, __abstractmethods__=frozenset()) + @patch.object(ChEBIOverX, "processed_dir_main", new_callable=PropertyMock) + @patch("os.makedirs", return_value=None) + def setUpClass(cls, mock_makedirs, mock_processed_dir_main: PropertyMock) -> None: + """ + Set up the ChEBIOverX instance with a mock processed directory path and a test graph. + + Args: + mock_makedirs: This patches os.makedirs to do nothing + mock_processed_dir_main (PropertyMock): Mocked property for the processed directory path. + """ + mock_processed_dir_main.return_value = "/mock/processed_dir" + cls.chebi_extractor = ChEBIOverX(chebi_version=231) + cls.test_graph = ChebiMockOntology.get_transitively_closed_graph() + + @patch("builtins.open", new_callable=mock_open) + def test_select_classes(self, mock_open_file: mock_open) -> None: + """ + Test the select_classes method to ensure it correctly selects nodes based on the threshold. + + Args: + mock_open_file (mock_open): Mocked open function to intercept file operations. + """ + self.chebi_extractor.THRESHOLD = 3 + selected_classes = self.chebi_extractor.select_classes(self.test_graph) + + # Check if the returned selected classes match the expected list + expected_classes = sorted([11111, 22222, 67890]) + self.assertListEqual( + selected_classes, + expected_classes, + "The selected classes do not match the expected output for the given threshold of 3.", + ) + + # Expected data as string + expected_lines = "\n".join(map(str, expected_classes)) + "\n" + + # Extract the generator passed to writelines + written_generator = mock_open_file().writelines.call_args[0][0] + written_lines = "".join(written_generator) + + # Ensure the data matches + self.assertEqual( + written_lines, + expected_lines, + "The written lines do not match the expected lines for the given threshold of 3.", + ) + + @patch("builtins.open", new_callable=mock_open) + def test_no_classes_meet_threshold(self, mock_open_file: mock_open) -> None: + """ + Test the select_classes method when no nodes meet the successor threshold. + + Args: + mock_open_file (mock_open): Mocked open function to intercept file operations. + """ + self.chebi_extractor.THRESHOLD = 5 + selected_classes = self.chebi_extractor.select_classes(self.test_graph) + + # Expected empty result + self.assertEqual( + selected_classes, + [], + "The selected classes list should be empty when no nodes meet the threshold of 5.", + ) + + # Expected data as string + expected_lines = "" + + # Extract the generator passed to writelines + written_generator = mock_open_file().writelines.call_args[0][0] + written_lines = "".join(written_generator) + + # Ensure the data matches + self.assertEqual( + written_lines, + expected_lines, + "The written lines do not match the expected lines when no nodes meet the threshold of 5.", + ) + + @patch("builtins.open", new_callable=mock_open) + def test_all_nodes_meet_threshold(self, mock_open_file: mock_open) -> None: + """ + Test the select_classes method when all nodes meet the successor threshold. + + Args: + mock_open_file (mock_open): Mocked open function to intercept file operations. + """ + self.chebi_extractor.THRESHOLD = 0 + selected_classes = self.chebi_extractor.select_classes(self.test_graph) + + expected_classes = sorted(ChebiMockOntology.get_nodes()) + # Check if the returned selected classes match the expected list + self.assertListEqual( + selected_classes, + expected_classes, + "The selected classes do not match the expected output when all nodes meet the threshold of 0.", + ) + + # Expected data as string + expected_lines = "\n".join(map(str, expected_classes)) + "\n" + + # Extract the generator passed to writelines + written_generator = mock_open_file().writelines.call_args[0][0] + written_lines = "".join(written_generator) + + # Ensure the data matches + self.assertEqual( + written_lines, + expected_lines, + "The written lines do not match the expected lines when all nodes meet the threshold of 0.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataset_classes/testChebiDataExtractor.py b/tests/unit/dataset_classes/testChebiDataExtractor.py new file mode 100644 index 00000000..8da900da --- /dev/null +++ b/tests/unit/dataset_classes/testChebiDataExtractor.py @@ -0,0 +1,228 @@ +import unittest +from unittest.mock import MagicMock, PropertyMock, mock_open, patch + +import networkx as nx +import pandas as pd + +from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor +from tests.unit.mock_data.ontology_mock_data import ChebiMockOntology + + +class TestChEBIDataExtractor(unittest.TestCase): + + @classmethod + @patch.multiple(_ChEBIDataExtractor, __abstractmethods__=frozenset()) + @patch.object(_ChEBIDataExtractor, "base_dir", new_callable=PropertyMock) + @patch.object(_ChEBIDataExtractor, "_name", new_callable=PropertyMock) + @patch("os.makedirs", return_value=None) + def setUpClass( + cls, + mock_makedirs, + mock_name_property: PropertyMock, + mock_base_dir_property: PropertyMock, + ) -> None: + """ + Set up a base instance of _ChEBIDataExtractor for testing with mocked properties. + """ + # Mocking properties + mock_base_dir_property.return_value = "MockedBaseDirPropertyChebiDataExtractor" + mock_name_property.return_value = "MockedNamePropertyChebiDataExtractor" + + # Mock Data Reader + ReaderMock = MagicMock() + ReaderMock.name.return_value = "MockedReader" + _ChEBIDataExtractor.READER = ReaderMock + + # Create an instance of the dataset + cls.extractor: _ChEBIDataExtractor = _ChEBIDataExtractor( + chebi_version=231, chebi_version_train=200 + ) + + # Mock instance for _chebi_version_train_obj + mock_train_obj = MagicMock() + mock_train_obj.processed_dir_main = "/mock/path/to/train" + cls.extractor._chebi_version_train_obj = mock_train_obj + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=ChebiMockOntology.get_raw_data(), + ) + def test_extract_class_hierarchy(self, mock_open: mock_open) -> None: + """ + Test the extraction of class hierarchy and validate the structure of the resulting graph. + """ + # Mock the output of fastobo.loads + graph = self.extractor._extract_class_hierarchy("fake_path") + + # Validate the graph structure + self.assertIsInstance( + graph, nx.DiGraph, "The result should be a directed graph." + ) + + # Check nodes + actual_nodes = set(graph.nodes) + self.assertEqual( + set(ChebiMockOntology.get_nodes()), + actual_nodes, + "The graph nodes do not match the expected nodes.", + ) + + # Check edges + actual_edges = set(graph.edges) + self.assertEqual( + ChebiMockOntology.get_edges_of_transitive_closure_graph(), + actual_edges, + "The graph edges do not match the expected edges.", + ) + + # Check number of nodes and edges + self.assertEqual( + ChebiMockOntology.get_number_of_nodes(), + len(actual_nodes), + "The number of nodes should match the actual number of nodes in the graph.", + ) + + self.assertEqual( + ChebiMockOntology.get_number_of_transitive_edges(), + len(actual_edges), + "The number of transitive edges should match the actual number of transitive edges in the graph.", + ) + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=ChebiMockOntology.get_raw_data(), + ) + @patch.object( + _ChEBIDataExtractor, + "select_classes", + return_value=ChebiMockOntology.get_nodes(), + ) + def test_graph_to_raw_dataset( + self, mock_select_classes: PropertyMock, mock_open: mock_open + ) -> None: + """ + Test conversion of a graph to a raw dataset and compare it with the expected DataFrame. + """ + graph = self.extractor._extract_class_hierarchy("fake_path") + data_df = self.extractor._graph_to_raw_dataset(graph) + + pd.testing.assert_frame_equal( + data_df, + ChebiMockOntology.get_data_in_dataframe(), + obj="The DataFrame should match the expected structure.", + ) + + @patch( + "builtins.open", new_callable=mock_open, read_data=b"Mocktestdata" + ) # Mocking open as a binary file + @patch("pandas.read_pickle") + def test_load_dict( + self, mock_read_pickle: PropertyMock, mock_open: mock_open + ) -> None: + """ + Test loading data from a pickled file and verify the generator output. + """ + # Mock the DataFrame returned by read_pickle + mock_df = pd.DataFrame( + { + "id": [12345, 67890, 11111, 54321], # Corrected ID + "name": ["A", "B", "C", "D"], + "SMILES": ["C1CCCCC1", "O=C=O", "C1CC=CC1", "C[Mg+]"], + 12345: [True, False, False, True], + 67890: [False, True, True, False], + 11111: [True, False, True, False], + } + ) + mock_read_pickle.return_value = mock_df + + generator = self.extractor._load_dict("data/tests") + result = list(generator) + + # Convert NumPy arrays to lists for comparison + for item in result: + item["labels"] = list(item["labels"]) + + # Expected output for comparison + expected_result = [ + {"features": "C1CCCCC1", "labels": [True, False, True], "ident": 12345}, + {"features": "O=C=O", "labels": [False, True, False], "ident": 67890}, + {"features": "C1CC=CC1", "labels": [False, True, True], "ident": 11111}, + {"features": "C[Mg+]", "labels": [True, False, False], "ident": 54321}, + ] + + # Assert if the result matches the expected output + self.assertEqual( + result, + expected_result, + "The loaded dictionary should match the expected structure.", + ) + + @patch("builtins.open", new_callable=mock_open) + @patch.object(_ChEBIDataExtractor, "processed_dir_main", new_callable=PropertyMock) + def test_setup_pruned_test_set( + self, + mock_processed_dir_main: PropertyMock, + mock_open_file: mock_open, + ) -> None: + """ + Test the pruning of the test set to match classes in the training set. + """ + # Mock the content for the two open calls (original classes and new classes) + mock_orig_classes = "12345\n67890\n88888\n54321\n77777\n" + mock_new_classes = "12345\n67890\n99999\n77777\n" + + # Use side_effect to simulate the two different file reads + mock_open_file.side_effect = [ + mock_open( + read_data=mock_orig_classes + ).return_value, # First open() for orig_classes + mock_open( + read_data=mock_new_classes + ).return_value, # Second open() for new_classes + ] + + # Mock the attributes used in the method + mock_processed_dir_main.return_value = "/mock/path/to/current" + + # Mock DataFrame to simulate the test dataset + mock_df = pd.DataFrame( + { + "labels": [ + [ + True, + False, + True, + False, + True, + ], # First test instance labels (match orig_classes) + [False, True, False, True, False], + ] # Second test instance labels + } + ) + + # Call the method under test + pruned_df = self.extractor._setup_pruned_test_set(mock_df) + + # Expected DataFrame labels after pruning (only "12345", "67890", "77777", and "99999" remain) + expected_labels = [[True, False, False, True], [False, True, False, False]] + + # Check if the pruned DataFrame still has the same number of rows + self.assertEqual( + len(pruned_df), + len(mock_df), + "The pruned DataFrame should have the same number of rows.", + ) + + # Check that the labels are correctly pruned + for i in range(len(pruned_df)): + self.assertEqual( + pruned_df.iloc[i]["labels"], + expected_labels[i], + f"Row {i}'s labels should be pruned correctly.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataset_classes/testChebiOverXPartial.py b/tests/unit/dataset_classes/testChebiOverXPartial.py new file mode 100644 index 00000000..76584ebf --- /dev/null +++ b/tests/unit/dataset_classes/testChebiOverXPartial.py @@ -0,0 +1,175 @@ +import unittest +from unittest.mock import mock_open, patch + +import networkx as nx + +from chebai.preprocessing.datasets.chebi import ChEBIOverXPartial +from tests.unit.mock_data.ontology_mock_data import ChebiMockOntology + + +class TestChEBIOverX(unittest.TestCase): + + @classmethod + @patch.multiple(ChEBIOverXPartial, __abstractmethods__=frozenset()) + @patch("os.makedirs", return_value=None) + def setUpClass(cls, mock_makedirs) -> None: + """ + Set up the ChEBIOverXPartial instance with a mock processed directory path and a test graph. + """ + cls.chebi_extractor = ChEBIOverXPartial(top_class_id=11111, chebi_version=231) + cls.test_graph = ChebiMockOntology.get_transitively_closed_graph() + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=ChebiMockOntology.get_raw_data(), + ) + def test_extract_class_hierarchy(self, mock_open: mock_open) -> None: + """ + Test the extraction of class hierarchy and validate the structure of the resulting graph. + """ + # Mock the output of fastobo.loads + self.chebi_extractor.top_class_id = 11111 + graph: nx.DiGraph = self.chebi_extractor._extract_class_hierarchy("fake_path") + + # Validate the graph structure + self.assertIsInstance( + graph, nx.DiGraph, "The result should be a directed graph." + ) + + # Check nodes + expected_nodes = {11111, 54321, 12345, 99999} + expected_edges = { + (54321, 12345), + (54321, 99999), + (11111, 54321), + (11111, 12345), + (11111, 99999), + (12345, 99999), + } + self.assertEqual( + set(graph.nodes), + expected_nodes, + f"The graph nodes do not match the expected nodes for top class {self.chebi_extractor.top_class_id} hierarchy.", + ) + + # Check edges + self.assertEqual( + expected_edges, + set(graph.edges), + "The graph edges do not match the expected edges.", + ) + + # Check number of nodes and edges + self.assertEqual( + len(graph.nodes), + len(expected_nodes), + "The number of nodes should match the actual number of nodes in the graph.", + ) + + self.assertEqual( + len(expected_edges), + len(graph.edges), + "The number of transitive edges should match the actual number of transitive edges in the graph.", + ) + + self.chebi_extractor.top_class_id = 22222 + graph = self.chebi_extractor._extract_class_hierarchy("fake_path") + + # Check nodes with top class as 22222 + self.assertEqual( + set(graph.nodes), + {67890, 88888, 12345, 99999, 22222}, + f"The graph nodes do not match the expected nodes for top class {self.chebi_extractor.top_class_id} hierarchy.", + ) + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=ChebiMockOntology.get_raw_data(), + ) + def test_extract_class_hierarchy_with_bottom_cls( + self, mock_open: mock_open + ) -> None: + """ + Test the extraction of class hierarchy and validate the structure of the resulting graph. + """ + self.chebi_extractor.top_class_id = 88888 + graph: nx.DiGraph = self.chebi_extractor._extract_class_hierarchy("fake_path") + + # Check nodes with top class as 88888 + self.assertEqual( + set(graph.nodes), + {self.chebi_extractor.top_class_id}, + f"The graph nodes do not match the expected nodes for top class {self.chebi_extractor.top_class_id} hierarchy.", + ) + + @patch("pandas.DataFrame.to_csv") + @patch("pandas.read_pickle") + @patch.object(ChEBIOverXPartial, "_get_data_size", return_value=4.0) + @patch("torch.load") + @patch( + "builtins.open", + new_callable=mock_open, + read_data=ChebiMockOntology.get_raw_data(), + ) + def test_single_label_data_split( + self, mock_open, mock_load, mock_get_data_size, mock_read_pickle, mock_to_csv + ) -> None: + """ + Test the single-label data splitting functionality of the ChebiExtractor class. + + This test mocks several key methods (file operations, torch loading, and pandas functions) + to ensure that the class hierarchy is properly extracted, data is processed into a raw dataset, + and the data splitting logic works as intended without actual file I/O. + + It also verifies that there is no overlap between training, validation, and test sets. + """ + self.chebi_extractor.top_class_id = 11111 + self.chebi_extractor.THRESHOLD = 3 + self.chebi_extractor.chebi_version_train = None + + graph: nx.DiGraph = self.chebi_extractor._extract_class_hierarchy("fake_path") + data_df = self.chebi_extractor._graph_to_raw_dataset(graph) + + mock_read_pickle.return_value = data_df + data_pt = self.chebi_extractor._load_data_from_file("fake/path") + + # Verify that the data contains only 1 label + self.assertEqual(len(data_pt[0]["labels"]), 1) + + mock_load.return_value = data_pt + + # Retrieve the data splits (train, validation, and test) + train_split = self.chebi_extractor.dynamic_split_dfs["train"] + validation_split = self.chebi_extractor.dynamic_split_dfs["validation"] + test_split = self.chebi_extractor.dynamic_split_dfs["test"] + + train_idents = set(train_split["ident"]) + val_idents = set(validation_split["ident"]) + test_idents = set(test_split["ident"]) + + # Ensure there is no overlap between train and test sets + self.assertEqual( + len(train_idents.intersection(test_idents)), + 0, + "Train and test sets should not overlap.", + ) + + # Ensure there is no overlap between validation and test sets + self.assertEqual( + len(val_idents.intersection(test_idents)), + 0, + "Validation and test sets should not overlap.", + ) + + # Ensure there is no overlap between train and validation sets + self.assertEqual( + len(train_idents.intersection(val_idents)), + 0, + "Train and validation sets should not overlap.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataset_classes/testChebiTermCallback.py b/tests/unit/dataset_classes/testChebiTermCallback.py new file mode 100644 index 00000000..8680760e --- /dev/null +++ b/tests/unit/dataset_classes/testChebiTermCallback.py @@ -0,0 +1,69 @@ +import unittest +from typing import Any, Dict + +import fastobo +from fastobo.term import TermFrame + +from chebai.preprocessing.datasets.chebi import term_callback +from tests.unit.mock_data.ontology_mock_data import ChebiMockOntology + + +class TestChebiTermCallback(unittest.TestCase): + """ + Unit tests for the `term_callback` function used in processing ChEBI ontology terms. + """ + + @classmethod + def setUpClass(cls) -> None: + """ + Set up the test class by loading ChEBI term data and storing it in a dictionary + where keys are the term IDs and values are TermFrame instances. + """ + cls.callback_input_data: Dict[int, TermFrame] = { + int(term_doc.id.local): term_doc + for term_doc in fastobo.loads(ChebiMockOntology.get_raw_data()) + if term_doc and ":" in str(term_doc.id) + } + + def test_process_valid_terms(self) -> None: + """ + Test that `term_callback` correctly processes valid ChEBI terms. + """ + + expected_result: Dict[str, Any] = { + "id": 12345, + "parents": [54321, 67890], + "has_part": set(), + "name": "Compound A", + "smiles": "C1=CC=CC=C1", + } + + actual_dict: Dict[str, Any] = term_callback( + self.callback_input_data.get(expected_result["id"]) + ) + self.assertEqual( + expected_result, + actual_dict, + msg="term_callback should correctly extract information from valid ChEBI terms.", + ) + + def test_skip_obsolete_terms(self) -> None: + """ + Test that `term_callback` correctly skips obsolete ChEBI terms. + """ + term_callback_output = [] + for ident in ChebiMockOntology.get_obsolete_nodes_ids(): + raw_term = self.callback_input_data.get(ident) + term_dict = term_callback(raw_term) + if term_dict: + term_callback_output.append(term_dict) + + self.assertEqual( + term_callback_output, + [], + msg="The term_callback function should skip obsolete terms and return an empty list.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataset_classes/testDynamicDataset.py b/tests/unit/dataset_classes/testDynamicDataset.py new file mode 100644 index 00000000..c8846273 --- /dev/null +++ b/tests/unit/dataset_classes/testDynamicDataset.py @@ -0,0 +1,372 @@ +import unittest +from typing import Tuple +from unittest.mock import MagicMock, PropertyMock, patch + +import pandas as pd + +from chebai.preprocessing.datasets.base import _DynamicDataset + + +class TestDynamicDataset(unittest.TestCase): + """ + Test case for _DynamicDataset functionality, ensuring correct data splits and integrity + of train, validation, and test datasets. + """ + + @classmethod + @patch.multiple(_DynamicDataset, __abstractmethods__=frozenset()) + @patch.object(_DynamicDataset, "base_dir", new_callable=PropertyMock) + @patch.object(_DynamicDataset, "_name", new_callable=PropertyMock) + @patch("os.makedirs", return_value=None) + def setUpClass( + cls, + mock_makedirs, + mock_base_dir_property: PropertyMock, + mock_name_property: PropertyMock, + ) -> None: + """ + Set up a base instance of _DynamicDataset for testing with mocked properties. + """ + + # Mocking properties + mock_base_dir_property.return_value = "MockedBaseDirPropertyDynamicDataset" + mock_name_property.return_value = "MockedNamePropertyDynamicDataset" + + # Mock Data Reader + ReaderMock = MagicMock() + ReaderMock.name.return_value = "MockedReader" + _DynamicDataset.READER = ReaderMock + + # Creating an instance of the dataset + cls.dataset: _DynamicDataset = _DynamicDataset() + + # Dataset with a balanced distribution of labels + X = [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + [9, 10], + [11, 12], + [13, 14], + [15, 16], + [17, 18], + [19, 20], + [21, 22], + [23, 24], + [25, 26], + [27, 28], + [29, 30], + [31, 32], + ] + y = [ + [False, False], + [False, True], + [True, False], + [True, True], + [False, False], + [False, True], + [True, False], + [True, True], + [False, False], + [False, True], + [True, False], + [True, True], + [False, False], + [False, True], + [True, False], + [True, True], + ] + cls.data_df = pd.DataFrame( + {"ident": [f"id{i + 1}" for i in range(len(X))], "features": X, "labels": y} + ) + + def test_get_test_split_valid(self) -> None: + """ + Test splitting the dataset into train and test sets and verify balance and non-overlap. + """ + self.dataset.train_split = 0.5 + # Test size will be 0.25 * 16 = 4 + train_df, test_df = self.dataset.get_test_split(self.data_df, seed=0) + + # Assert the correct number of rows in train and test sets + self.assertEqual(len(train_df), 12, "Train set should contain 12 samples.") + self.assertEqual(len(test_df), 4, "Test set should contain 4 samples.") + + # Check positive and negative label counts in train and test sets + train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( + train_df + ) + test_pos_count, test_neg_count = self.get_positive_negative_labels_counts( + test_df + ) + + # Ensure that the train and test sets have balanced positives and negatives + self.assertEqual( + train_pos_count, train_neg_count, "Train set labels should be balanced." + ) + self.assertEqual( + test_pos_count, test_neg_count, "Test set labels should be balanced." + ) + + # Assert there is no overlap between train and test sets + train_idents = set(train_df["ident"]) + test_idents = set(test_df["ident"]) + self.assertEqual( + len(train_idents.intersection(test_idents)), + 0, + "Train and test sets should not overlap.", + ) + + def test_get_test_split_missing_labels(self) -> None: + """ + Test the behavior when the 'labels' column is missing in the dataset. + """ + df_missing_labels = pd.DataFrame({"ident": ["id1", "id2"]}) + with self.assertRaises( + KeyError, msg="Expected KeyError when 'labels' column is missing." + ): + self.dataset.get_test_split(df_missing_labels) + + def test_get_test_split_seed_consistency(self) -> None: + """ + Test that splitting the dataset with the same seed produces consistent results. + """ + train_df1, test_df1 = self.dataset.get_test_split(self.data_df, seed=42) + train_df2, test_df2 = self.dataset.get_test_split(self.data_df, seed=42) + + pd.testing.assert_frame_equal( + train_df1, + train_df2, + obj="Train sets should be identical for the same seed.", + ) + pd.testing.assert_frame_equal( + test_df1, test_df2, obj="Test sets should be identical for the same seed." + ) + + def test_get_train_val_splits_given_test(self) -> None: + """ + Test splitting the dataset into train and validation sets and verify balance and non-overlap. + """ + self.dataset.use_inner_cross_validation = False + self.dataset.train_split = 0.5 + df_train_main, test_df = self.dataset.get_test_split(self.data_df, seed=0) + train_df, val_df = self.dataset.get_train_val_splits_given_test( + df_train_main, test_df, seed=42 + ) + + # Ensure there is no overlap between train and test sets + train_idents = set(train_df["ident"]) + test_idents = set(test_df["ident"]) + self.assertEqual( + len(train_idents.intersection(test_idents)), + 0, + "Train and test sets should not overlap.", + ) + + # Ensure there is no overlap between validation and test sets + val_idents = set(val_df["ident"]) + self.assertEqual( + len(val_idents.intersection(test_idents)), + 0, + "Validation and test sets should not overlap.", + ) + + # Ensure there is no overlap between train and validation sets + self.assertEqual( + len(train_idents.intersection(val_idents)), + 0, + "Train and validation sets should not overlap.", + ) + + # Check positive and negative label counts in train and validation sets + train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( + train_df + ) + val_pos_count, val_neg_count = self.get_positive_negative_labels_counts(val_df) + + # Ensure that the train and validation sets have balanced positives and negatives + self.assertEqual( + train_pos_count, train_neg_count, "Train set labels should be balanced." + ) + self.assertEqual( + val_pos_count, val_neg_count, "Validation set labels should be balanced." + ) + + def test_get_train_val_splits_given_test_consistency(self) -> None: + """ + Test that splitting the dataset into train and validation sets with the same seed produces consistent results. + """ + test_df = self.data_df.iloc[12:] # Assume rows 12 onward are for testing + train_df1, val_df1 = self.dataset.get_train_val_splits_given_test( + self.data_df, test_df, seed=42 + ) + train_df2, val_df2 = self.dataset.get_train_val_splits_given_test( + self.data_df, test_df, seed=42 + ) + + pd.testing.assert_frame_equal( + train_df1, + train_df2, + obj="Train sets should be identical for the same seed.", + ) + pd.testing.assert_frame_equal( + val_df1, + val_df2, + obj="Validation sets should be identical for the same seed.", + ) + + def test_get_test_split_stratification(self) -> None: + """ + Test that the split into train and test sets maintains the stratification of labels. + """ + self.dataset.train_split = 0.5 + train_df, test_df = self.dataset.get_test_split(self.data_df, seed=0) + + number_of_labels = len(self.data_df["labels"][0]) + + # Check the label distribution in the original dataset + original_pos_count, original_neg_count = ( + self.get_positive_negative_labels_counts(self.data_df) + ) + total_count = len(self.data_df) * number_of_labels + + # Calculate the expected proportions + original_pos_proportion = original_pos_count / total_count + original_neg_proportion = original_neg_count / total_count + + # Check the label distribution in the train set + train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( + train_df + ) + train_total_count = len(train_df) * number_of_labels + + # Calculate the train set proportions + train_pos_proportion = train_pos_count / train_total_count + train_neg_proportion = train_neg_count / train_total_count + + # Assert that the proportions are similar to the original dataset + self.assertAlmostEqual( + train_pos_proportion, + original_pos_proportion, + places=1, + msg="Train set labels should maintain original positive label proportion.", + ) + self.assertAlmostEqual( + train_neg_proportion, + original_neg_proportion, + places=1, + msg="Train set labels should maintain original negative label proportion.", + ) + + # Check the label distribution in the test set + test_pos_count, test_neg_count = self.get_positive_negative_labels_counts( + test_df + ) + test_total_count = len(test_df) * number_of_labels + + # Calculate the test set proportions + test_pos_proportion = test_pos_count / test_total_count + test_neg_proportion = test_neg_count / test_total_count + + # Assert that the proportions are similar to the original dataset + self.assertAlmostEqual( + test_pos_proportion, + original_pos_proportion, + places=1, + msg="Test set labels should maintain original positive label proportion.", + ) + self.assertAlmostEqual( + test_neg_proportion, + original_neg_proportion, + places=1, + msg="Test set labels should maintain original negative label proportion.", + ) + + def test_get_train_val_splits_given_test_stratification(self) -> None: + """ + Test that the split into train and validation sets maintains the stratification of labels. + """ + self.dataset.use_inner_cross_validation = False + self.dataset.train_split = 0.5 + df_train_main, test_df = self.dataset.get_test_split(self.data_df, seed=0) + train_df, val_df = self.dataset.get_train_val_splits_given_test( + df_train_main, test_df, seed=42 + ) + + number_of_labels = len(self.data_df["labels"][0]) + + # Check the label distribution in the original dataset + original_pos_count, original_neg_count = ( + self.get_positive_negative_labels_counts(self.data_df) + ) + total_count = len(self.data_df) * number_of_labels + + # Calculate the expected proportions + original_pos_proportion = original_pos_count / total_count + original_neg_proportion = original_neg_count / total_count + + # Check the label distribution in the train set + train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( + train_df + ) + train_total_count = len(train_df) * number_of_labels + + # Calculate the train set proportions + train_pos_proportion = train_pos_count / train_total_count + train_neg_proportion = train_neg_count / train_total_count + + # Assert that the proportions are similar to the original dataset + self.assertAlmostEqual( + train_pos_proportion, + original_pos_proportion, + places=1, + msg="Train set labels should maintain original positive label proportion.", + ) + self.assertAlmostEqual( + train_neg_proportion, + original_neg_proportion, + places=1, + msg="Train set labels should maintain original negative label proportion.", + ) + + # Check the label distribution in the validation set + val_pos_count, val_neg_count = self.get_positive_negative_labels_counts(val_df) + val_total_count = len(val_df) * number_of_labels + + # Calculate the validation set proportions + val_pos_proportion = val_pos_count / val_total_count + val_neg_proportion = val_neg_count / val_total_count + + # Assert that the proportions are similar to the original dataset + self.assertAlmostEqual( + val_pos_proportion, + original_pos_proportion, + places=1, + msg="Validation set labels should maintain original positive label proportion.", + ) + self.assertAlmostEqual( + val_neg_proportion, + original_neg_proportion, + places=1, + msg="Validation set labels should maintain original negative label proportion.", + ) + + @staticmethod + def get_positive_negative_labels_counts(df: pd.DataFrame) -> Tuple[int, int]: + """ + Count the number of True and False values within the labels column. + + Args: + df (pd.DataFrame): The DataFrame containing the 'labels' column. + + Returns: + Tuple[int, int]: A tuple containing the counts of True and False values, respectively. + """ + true_count = sum(sum(label) for label in df["labels"]) + false_count = sum(len(label) - sum(label) for label in df["labels"]) + return true_count, false_count + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataset_classes/testGOUniProDataExtractor.py b/tests/unit/dataset_classes/testGOUniProDataExtractor.py new file mode 100644 index 00000000..9da48bee --- /dev/null +++ b/tests/unit/dataset_classes/testGOUniProDataExtractor.py @@ -0,0 +1,229 @@ +import unittest +from collections import OrderedDict +from unittest.mock import PropertyMock, mock_open, patch + +import fastobo +import networkx as nx +import pandas as pd + +from chebai.preprocessing.datasets.go_uniprot import _GOUniProtDataExtractor +from chebai.preprocessing.reader import ProteinDataReader +from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData + + +class TestGOUniProtDataExtractor(unittest.TestCase): + """ + Unit tests for the _GOUniProtDataExtractor class. + """ + + @classmethod + @patch.multiple(_GOUniProtDataExtractor, __abstractmethods__=frozenset()) + @patch.object(_GOUniProtDataExtractor, "base_dir", new_callable=PropertyMock) + @patch.object(_GOUniProtDataExtractor, "_name", new_callable=PropertyMock) + @patch("os.makedirs", return_value=None) + def setUpClass( + cls, + mock_makedirs, + mock_name_property: PropertyMock, + mock_base_dir_property: PropertyMock, + ) -> None: + """ + Class setup for mocking abstract properties of _GOUniProtDataExtractor. + """ + mock_base_dir_property.return_value = "MockedBaseDirPropGOUniProtDataExtractor" + mock_name_property.return_value = "MockedNamePropGOUniProtDataExtractor" + + _GOUniProtDataExtractor.READER = ProteinDataReader + + cls.extractor = _GOUniProtDataExtractor() + + def test_term_callback(self) -> None: + """ + Test the term_callback method for correct parsing and filtering of GO terms. + """ + self.extractor.go_branch = "all" + term_mapping = {} + for term in fastobo.loads(GOUniProtMockData.get_GO_raw_data()): + if isinstance(term, fastobo.typedef.TypedefFrame): + continue + term_mapping[self.extractor._parse_go_id(term.id)] = term + + # Test individual term callback + term_dict = self.extractor.term_callback(term_mapping[4]) + expected_dict = {"go_id": 4, "parents": [3, 2], "name": "GO_4"} + self.assertEqual( + term_dict, + expected_dict, + "The term_callback did not return the expected dictionary.", + ) + + # Test filtering valid terms + valid_terms_docs = set() + for term_id, term_doc in term_mapping.items(): + if self.extractor.term_callback(term_doc): + valid_terms_docs.add(term_id) + + self.assertEqual( + valid_terms_docs, + set(GOUniProtMockData.get_nodes()), + "The valid terms do not match expected nodes.", + ) + + # Test that obsolete terms are filtered out + self.assertFalse( + any( + self.extractor.term_callback(term_mapping[obs_id]) + for obs_id in GOUniProtMockData.get_obsolete_nodes_ids() + ), + "Obsolete terms should not be present.", + ) + + # Test filtering by GO branch (e.g., BP) + self.extractor.go_branch = "BP" + BP_terms = { + term_id + for term_id, term in term_mapping.items() + if self.extractor.term_callback(term) + } + self.assertEqual( + BP_terms, {2, 4}, "The BP terms do not match the expected set." + ) + + @patch( + "fastobo.load", return_value=fastobo.loads(GOUniProtMockData.get_GO_raw_data()) + ) + def test_extract_class_hierarchy(self, mock_load) -> None: + """ + Test the extraction of the class hierarchy from the ontology. + """ + graph = self.extractor._extract_class_hierarchy("fake_path") + + # Validate the graph structure + self.assertIsInstance( + graph, nx.DiGraph, "The result should be a directed graph." + ) + + # Check nodes + actual_nodes = set(graph.nodes) + self.assertEqual( + set(GOUniProtMockData.get_nodes()), + actual_nodes, + "The graph nodes do not match the expected nodes.", + ) + + # Check edges + actual_edges = set(graph.edges) + self.assertEqual( + GOUniProtMockData.get_edges_of_transitive_closure_graph(), + actual_edges, + "The graph edges do not match the expected edges.", + ) + + # Check number of nodes and edges + self.assertEqual( + GOUniProtMockData.get_number_of_nodes(), + len(actual_nodes), + "The number of nodes should match the actual number of nodes in the graph.", + ) + + self.assertEqual( + GOUniProtMockData.get_number_of_transitive_edges(), + len(actual_edges), + "The number of transitive edges should match the actual number of transitive edges in the graph.", + ) + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=GOUniProtMockData.get_UniProt_raw_data(), + ) + def test_get_swiss_to_go_mapping(self, mock_open) -> None: + """ + Test the extraction of SwissProt to GO term mapping. + """ + mapping_df = self.extractor._get_swiss_to_go_mapping() + expected_df = pd.DataFrame( + OrderedDict( + swiss_id=["Swiss_Prot_1", "Swiss_Prot_2"], + accession=["Q6GZX4", "DCGZX4"], + go_ids=[[2, 3, 5], [2, 5]], + sequence=list(GOUniProtMockData.protein_sequences().values()), + ) + ) + + pd.testing.assert_frame_equal( + mapping_df, + expected_df, + obj="The SwissProt to GO mapping DataFrame does not match the expected DataFrame.", + ) + + @patch( + "fastobo.load", return_value=fastobo.loads(GOUniProtMockData.get_GO_raw_data()) + ) + @patch( + "builtins.open", + new_callable=mock_open, + read_data=GOUniProtMockData.get_UniProt_raw_data(), + ) + @patch.object( + _GOUniProtDataExtractor, + "select_classes", + return_value=GOUniProtMockData.get_nodes(), + ) + def test_graph_to_raw_dataset( + self, mock_select_classes, mock_open, mock_load + ) -> None: + """ + Test the conversion of the class hierarchy graph to a raw dataset. + """ + graph = self.extractor._extract_class_hierarchy("fake_path") + actual_df = self.extractor._graph_to_raw_dataset(graph) + expected_df = GOUniProtMockData.get_data_in_dataframe() + + pd.testing.assert_frame_equal( + actual_df, + expected_df, + obj="The raw dataset DataFrame does not match the expected DataFrame.", + ) + + @patch("builtins.open", new_callable=mock_open, read_data=b"Mocktestdata") + @patch("pandas.read_pickle") + def test_load_dict( + self, mock_read_pickle: PropertyMock, mock_open: mock_open + ) -> None: + """ + Test the loading of the dictionary from a DataFrame. + """ + mock_df = GOUniProtMockData.get_data_in_dataframe() + mock_read_pickle.return_value = mock_df + + generator = self.extractor._load_dict("data/tests") + result = list(generator) + + # Convert NumPy arrays to lists for comparison + for item in result: + item["labels"] = list(item["labels"]) + + # Expected output for comparison + expected_result = [ + { + "features": mock_df["sequence"][0], + "labels": mock_df.iloc[0, 4:].to_list(), + "ident": mock_df["swiss_id"][0], + }, + { + "features": mock_df["sequence"][1], + "labels": mock_df.iloc[1, 4:].to_list(), + "ident": mock_df["swiss_id"][1], + }, + ] + + self.assertEqual( + result, + expected_result, + "The loaded dictionary does not match the expected structure.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataset_classes/testGoUniProtOverX.py b/tests/unit/dataset_classes/testGoUniProtOverX.py new file mode 100644 index 00000000..d4157770 --- /dev/null +++ b/tests/unit/dataset_classes/testGoUniProtOverX.py @@ -0,0 +1,140 @@ +import unittest +from typing import List +from unittest.mock import mock_open, patch + +import networkx as nx +import pandas as pd + +from chebai.preprocessing.datasets.go_uniprot import _GOUniProtOverX +from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData + + +class TestGOUniProtOverX(unittest.TestCase): + @classmethod + @patch.multiple(_GOUniProtOverX, __abstractmethods__=frozenset()) + @patch("os.makedirs", return_value=None) + def setUpClass(cls, mock_makedirs) -> None: + """ + Set up the class for tests by initializing the extractor, graph, and input DataFrame. + """ + cls.extractor = _GOUniProtOverX() + cls.test_graph: nx.DiGraph = GOUniProtMockData.get_transitively_closed_graph() + cls.input_df: pd.DataFrame = GOUniProtMockData.get_data_in_dataframe().iloc[ + :, :4 + ] + + @patch("builtins.open", new_callable=mock_open) + def test_select_classes(self, mock_open_file: mock_open) -> None: + """ + Test the `select_classes` method to ensure it selects classes based on the threshold. + + Args: + mock_open_file (mock_open): Mocked open function to intercept file operations. + """ + # Set threshold for testing + self.extractor.THRESHOLD = 2 + selected_classes: List[int] = self.extractor.select_classes( + self.test_graph, data_df=self.input_df + ) + + # Expected result: GO terms 1, 2, and 5 should be selected based on the threshold + expected_selected_classes: List[int] = sorted([1, 2, 5]) + + # Check if the selected classes are as expected + self.assertEqual( + selected_classes, + expected_selected_classes, + msg="The selected classes do not match the expected output for threshold 2.", + ) + + # Expected data as string + expected_lines: str = "\n".join(map(str, expected_selected_classes)) + "\n" + + # Extract the generator passed to writelines + written_generator = mock_open_file().writelines.call_args[0][0] + written_lines: str = "".join(written_generator) + + # Ensure the data matches + self.assertEqual( + written_lines, + expected_lines, + msg="The written lines do not match the expected lines for the given threshold of 2.", + ) + + @patch("builtins.open", new_callable=mock_open) + def test_no_classes_meet_threshold(self, mock_open_file: mock_open) -> None: + """ + Test the `select_classes` method when no nodes meet the successor threshold. + + Args: + mock_open_file (mock_open): Mocked open function to intercept file operations. + """ + self.extractor.THRESHOLD = 5 + selected_classes: List[int] = self.extractor.select_classes( + self.test_graph, data_df=self.input_df + ) + + # Expected result: No classes should meet the threshold of 5 + expected_selected_classes: List[int] = [] + + # Check if the selected classes are as expected + self.assertEqual( + selected_classes, + expected_selected_classes, + msg="The selected classes list should be empty when no nodes meet the threshold of 5.", + ) + + # Expected data as string + expected_lines: str = "" + + # Extract the generator passed to writelines + written_generator = mock_open_file().writelines.call_args[0][0] + written_lines: str = "".join(written_generator) + + # Ensure the data matches + self.assertEqual( + written_lines, + expected_lines, + msg="The written lines do not match the expected lines when no nodes meet the threshold of 5.", + ) + + @patch("builtins.open", new_callable=mock_open) + def test_all_nodes_meet_threshold(self, mock_open_file: mock_open) -> None: + """ + Test the `select_classes` method when all nodes meet the successor threshold. + + Args: + mock_open_file (mock_open): Mocked open function to intercept file operations. + """ + self.extractor.THRESHOLD = 0 + selected_classes: List[int] = self.extractor.select_classes( + self.test_graph, data_df=self.input_df + ) + + # Expected result: All nodes except those not referenced by any protein (4 and 6) should be selected + expected_classes: List[int] = sorted([1, 2, 3, 5]) + + # Check if the returned selected classes match the expected list + self.assertListEqual( + selected_classes, + expected_classes, + msg="The selected classes do not match the expected output when all nodes meet the threshold of 0.", + ) + + # Expected data as string + expected_lines: str = "\n".join(map(str, expected_classes)) + "\n" + + # Extract the generator passed to writelines + written_generator = mock_open_file().writelines.call_args[0][0] + written_lines: str = "".join(written_generator) + + # Ensure the data matches + self.assertEqual( + written_lines, + expected_lines, + msg="The written lines do not match the expected lines when all nodes meet the threshold of 0.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataset_classes/testProteinPretrainingData.py b/tests/unit/dataset_classes/testProteinPretrainingData.py new file mode 100644 index 00000000..cb6b0688 --- /dev/null +++ b/tests/unit/dataset_classes/testProteinPretrainingData.py @@ -0,0 +1,74 @@ +import unittest +from unittest.mock import PropertyMock, mock_open, patch + +from chebai.preprocessing.datasets.protein_pretraining import _ProteinPretrainingData +from chebai.preprocessing.reader import ProteinDataReader +from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData + + +class TestProteinPretrainingData(unittest.TestCase): + """ + Unit tests for the _ProteinPretrainingData class. + Tests focus on data parsing and validation checks for protein pretraining. + """ + + @classmethod + @patch.multiple(_ProteinPretrainingData, __abstractmethods__=frozenset()) + @patch.object(_ProteinPretrainingData, "base_dir", new_callable=PropertyMock) + @patch.object(_ProteinPretrainingData, "_name", new_callable=PropertyMock) + @patch("os.makedirs", return_value=None) + def setUpClass( + cls, + mock_makedirs, + mock_name_property: PropertyMock, + mock_base_dir_property: PropertyMock, + ) -> None: + """ + Class setup for mocking abstract properties of _ProteinPretrainingData. + + Mocks the required abstract properties and sets up the data extractor. + """ + mock_base_dir_property.return_value = "MockedBaseDirPropProteinPretrainingData" + mock_name_property.return_value = "MockedNameProp_ProteinPretrainingData" + + # Set the READER class for the pretraining data + _ProteinPretrainingData.READER = ProteinDataReader + + # Initialize the extractor instance + cls.extractor = _ProteinPretrainingData() + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=GOUniProtMockData.get_UniProt_raw_data(), + ) + def test_parse_protein_data_for_pretraining( + self, mock_open_file: mock_open + ) -> None: + """ + Tests the _parse_protein_data_for_pretraining method. + + Verifies that: + - The parsed DataFrame contains the expected protein IDs. + - The protein sequences are not empty. + """ + # Parse the pretraining data + pretrain_df = self.extractor._parse_protein_data_for_pretraining() + list_of_pretrain_swiss_ids = GOUniProtMockData.proteins_for_pretraining() + + # Assert that all expected Swiss-Prot IDs are present in the DataFrame + self.assertEqual( + set(pretrain_df["swiss_id"]), + set(list_of_pretrain_swiss_ids), + msg="The parsed DataFrame does not contain the expected Swiss-Prot IDs for pretraining.", + ) + + # Assert that all sequences are not empty + self.assertTrue( + pretrain_df["sequence"].str.len().gt(0).all(), + msg="Some protein sequences in the pretraining DataFrame are empty.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataset_classes/testTox21Challenge.py b/tests/unit/dataset_classes/testTox21Challenge.py new file mode 100644 index 00000000..9ad2af21 --- /dev/null +++ b/tests/unit/dataset_classes/testTox21Challenge.py @@ -0,0 +1,128 @@ +import unittest +from unittest.mock import mock_open, patch + +from rdkit import Chem + +from chebai.preprocessing.datasets.tox21 import Tox21Challenge +from chebai.preprocessing.reader import ChemDataReader +from tests.unit.mock_data.tox_mock_data import ( + Tox21ChallengeMockData, + Tox21MolNetMockData, +) + + +class TestTox21Challenge(unittest.TestCase): + """ + Unit tests for the Tox21Challenge class. + """ + + @classmethod + @patch("os.makedirs", return_value=None) + def setUpClass(cls, mock_makedirs) -> None: + """ + Set up the Tox21Challenge instance and mock data for testing. + This is run once for the test class. + """ + Tox21Challenge.READER = ChemDataReader + cls.tox21 = Tox21Challenge() + + @patch("rdkit.Chem.SDMolSupplier") + def test_load_data_from_file(self, mock_sdmol_supplier: patch) -> None: + """ + Test the `_load_data_from_file` method to ensure it correctly loads data from an SDF file. + + Args: + mock_sdmol_supplier (patch): A mock of the RDKit SDMolSupplier. + """ + # Use ForwardSDMolSupplier to read the mock data from the binary string + mock_file = mock_open(read_data=Tox21ChallengeMockData.get_raw_train_data()) + with patch("builtins.open", mock_file): + with open( + r"fake/path", + "rb", + ) as f: + suppl = Chem.ForwardSDMolSupplier(f) + + mock_sdmol_supplier.return_value = suppl + + actual_data = self.tox21._load_data_from_file("fake/path") + expected_data = Tox21ChallengeMockData.data_in_dict_format() + + self.assertEqual( + actual_data, + expected_data, + "The loaded data from file does not match the expected data.", + ) + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=Tox21MolNetMockData.get_raw_data(), + ) + def test_load_dict(self, mock_open_file: mock_open) -> None: + """ + Test the `_load_dict` method to ensure correct CSV parsing. + + Args: + mock_open_file (mock_open): Mocked open function to simulate file reading. + """ + expected_data = Tox21MolNetMockData.get_processed_data() + for item in expected_data: + item.pop("group", None) + + actual_data = self.tox21._load_dict("fake/file/path.csv") + + self.assertEqual( + list(actual_data), + expected_data, + "The loaded data from CSV does not match the expected processed data.", + ) + + @patch.object(Tox21Challenge, "_load_data_from_file", return_value="test") + @patch("builtins.open", new_callable=mock_open) + @patch("torch.save") + @patch("os.path.join") + def test_setup_processed( + self, + mock_join: patch, + mock_torch_save: patch, + mock_open_file: mock_open, + mock_load_file: patch, + ) -> None: + """ + Test the `setup_processed` method to ensure it processes and saves data correctly. + + Args: + mock_join (patch): Mock of os.path.join to simulate file path joining. + mock_torch_save (patch): Mock of torch.save to simulate saving processed data. + mock_open_file (mock_open): Mocked open function to simulate file reading. + mock_load_file (patch): Mocked data loading method. + """ + # Simulated raw and processed directories + path_str = "fake/test/path" + mock_join.return_value = path_str + + # Mock the file content for test.smiles and score.txt + mock_open_file.side_effect = [ + mock_open( + read_data=Tox21ChallengeMockData.get_raw_smiles_data() + ).return_value, + mock_open( + read_data=Tox21ChallengeMockData.get_raw_score_txt_data() + ).return_value, + ] + + # Call setup_processed to simulate the data processing workflow + self.tox21.setup_processed() + + # Assert that torch.save was called with the correct processed data + expected_test_data = Tox21ChallengeMockData.get_setup_processed_output_data() + mock_torch_save.assert_called_with(expected_test_data, path_str) + + self.assertTrue( + mock_torch_save.called, "The processed data was not saved as expected." + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataset_classes/testXYBaseDataModule.py b/tests/unit/dataset_classes/testXYBaseDataModule.py new file mode 100644 index 00000000..64dfbe40 --- /dev/null +++ b/tests/unit/dataset_classes/testXYBaseDataModule.py @@ -0,0 +1,92 @@ +import unittest +from unittest.mock import MagicMock, PropertyMock, patch + +from chebai.preprocessing.datasets.base import XYBaseDataModule + + +class TestXYBaseDataModule(unittest.TestCase): + """ + Unit tests for the methods of the XYBaseDataModule class. + """ + + @classmethod + @patch.object(XYBaseDataModule, "_name", new_callable=PropertyMock) + @patch("os.makedirs", return_value=None) + def setUpClass(cls, mock_makedirs, mock_name_property: PropertyMock) -> None: + """ + Set up a base instance of XYBaseDataModule for testing. + """ + + # Mock the _name property of XYBaseDataModule + mock_name_property.return_value = "MockedNamePropXYBaseDataModule" + + # Assign a static variable READER with ProteinDataReader (to get rid of default Abstract DataReader) + # Mock Data Reader + ReaderMock = MagicMock() + ReaderMock.name.return_value = "MockedReader" + XYBaseDataModule.READER = ReaderMock + + # Initialize the module with a label_filter + cls.module = XYBaseDataModule( + label_filter=1, # Provide a label_filter + balance_after_filter=1.0, # Balance ratio + ) + + def test_filter_labels_valid_index(self) -> None: + """ + Test the _filter_labels method with a valid label_filter index. + """ + self.module.label_filter = 1 + row = { + "features": ["feature1", "feature2"], + "labels": [0, 3, 1, 2], # List of labels + } + filtered_row = self.module._filter_labels(row) + expected_labels = [3] # Only the label at index 1 should be kept + + self.assertEqual( + filtered_row["labels"], + expected_labels, + "The filtered labels do not match the expected labels.", + ) + + row = { + "features": ["feature1", "feature2"], + "labels": [True, False, True, True], + } + self.assertEqual( + self.module._filter_labels(row)["labels"], + [False], + "The filtered labels for the boolean case do not match the expected labels.", + ) + + def test_filter_labels_no_filter(self) -> None: + """ + Test the _filter_labels method with no label_filter index. + """ + # Update the module to have no label filter + self.module.label_filter = None + row = {"features": ["feature1", "feature2"], "labels": [False, True]} + # Handle the case where the index is out of bounds + with self.assertRaises( + TypeError, msg="Expected a TypeError when no label filter is provided." + ): + self.module._filter_labels(row) + + def test_filter_labels_invalid_index(self) -> None: + """ + Test the _filter_labels method with an invalid label_filter index. + """ + # Set an invalid label filter index (e.g., greater than the number of labels) + self.module.label_filter = 10 + row = {"features": ["feature1", "feature2"], "labels": [False, True]} + # Handle the case where the index is out of bounds + with self.assertRaises( + IndexError, + msg="Expected an IndexError when the label filter index is out of bounds.", + ): + self.module._filter_labels(row) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/mock_data/__init__.py b/tests/unit/mock_data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py new file mode 100644 index 00000000..a05b89f1 --- /dev/null +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -0,0 +1,812 @@ +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import Dict, List, Set, Tuple + +import networkx as nx +import pandas as pd + + +class MockOntologyGraphData(ABC): + """ + Abstract base class for mocking ontology graph data. + + This class provides a set of static methods that must be implemented by subclasses + to return various elements of an ontology graph such as nodes, edges, and dataframes. + """ + + @staticmethod + @abstractmethod + def get_nodes() -> List[int]: + """ + Get a list of node IDs in the ontology graph. + + Returns: + List[int]: A list of node IDs. + """ + pass + + @staticmethod + @abstractmethod + def get_number_of_nodes() -> int: + """ + Get the number of nodes in the ontology graph. + + Returns: + int: The total number of nodes. + """ + pass + + @staticmethod + @abstractmethod + def get_edges() -> Set[Tuple[int, int]]: + """ + Get the set of edges in the ontology graph. + + Returns: + Set[Tuple[int, int]]: A set of tuples where each tuple represents an edge between two nodes. + """ + pass + + @staticmethod + @abstractmethod + def get_number_of_edges() -> int: + """ + Get the number of edges in the ontology graph. + + Returns: + int: The total number of edges. + """ + pass + + @staticmethod + @abstractmethod + def get_edges_of_transitive_closure_graph() -> Set[Tuple[int, int]]: + """ + Get the set of edges in the transitive closure of the ontology graph. + + Returns: + Set[Tuple[int, int]]: A set of tuples representing the transitive closure edges. + """ + pass + + @staticmethod + @abstractmethod + def get_number_of_transitive_edges() -> int: + """ + Get the number of edges in the transitive closure of the ontology graph. + + Returns: + int: The total number of transitive edges. + """ + pass + + @staticmethod + @abstractmethod + def get_obsolete_nodes_ids() -> Set[int]: + """ + Get the set of obsolete node IDs in the ontology graph. + + Returns: + Set[int]: A set of obsolete node IDs. + """ + pass + + @staticmethod + @abstractmethod + def get_transitively_closed_graph() -> nx.DiGraph: + """ + Get the transitive closure of the ontology graph. + + Returns: + nx.DiGraph: A directed graph representing the transitive closure of the ontology graph. + """ + pass + + @staticmethod + @abstractmethod + def get_data_in_dataframe() -> pd.DataFrame: + """ + Get the ontology data as a Pandas DataFrame. + + Returns: + pd.DataFrame: A DataFrame containing ontology data. + """ + pass + + +class ChebiMockOntology(MockOntologyGraphData): + """ + A mock ontology representing a simplified ChEBI (Chemical Entities of Biological Interest) structure. + This class is used for testing purposes and includes nodes and edges representing chemical compounds + and their relationships in a graph structure. + + Nodes: + - CHEBI:12345 (Compound A) + - CHEBI:54321 (Compound B) + - CHEBI:67890 (Compound C) + - CHEBI:11111 (Compound D) + - CHEBI:22222 (Compound E) + - CHEBI:99999 (Compound F) + - CHEBI:77533 (Compound G, Obsolete node) + - CHEBI:77564 (Compound H, Obsolete node) + - CHEBI:88888 (Compound I) + + Valid Edges: + - CHEBI:54321 -> CHEBI:12345 + - CHEBI:67890 -> CHEBI:12345 + - CHEBI:67890 -> CHEBI:88888 + - CHEBI:11111 -> CHEBI:54321 + - CHEBI:22222 -> CHEBI:67890 + - CHEBI:12345 -> CHEBI:99999 + + The class also includes methods to retrieve nodes, edges, and transitive closure of the graph. + + Visual Representation Graph with Valid Nodes and Edges: + + 22222 + / + 11111 67890 + \\ / \ + 54321 / 88888 + \\ / + 12345 + \ + 99999 + """ + + @staticmethod + def get_nodes() -> List[int]: + """ + Get the set of valid node IDs in the mock ontology. + + Returns: + - Set[int]: A set of integers representing the valid ChEBI node IDs. + """ + return [11111, 12345, 22222, 54321, 67890, 88888, 99999] + + @staticmethod + def get_number_of_nodes() -> int: + """ + Get the number of valid nodes in the mock ontology. + + Returns: + - int: The number of valid nodes. + """ + return len(ChebiMockOntology.get_nodes()) + + @staticmethod + def get_edges() -> Set[Tuple[int, int]]: + """ + Get the set of valid edges in the mock ontology. + + Returns: + - Set[Tuple[int, int]]: A set of tuples representing the directed edges + between ChEBI nodes. + """ + return { + (54321, 12345), + (67890, 12345), + (67890, 88888), + (11111, 54321), + (22222, 67890), + (12345, 99999), + } + + @staticmethod + def get_number_of_edges() -> int: + """ + Get the number of valid edges in the mock ontology. + + Returns: + - int: The number of valid edges. + """ + return len(ChebiMockOntology.get_edges()) + + @staticmethod + def get_edges_of_transitive_closure_graph() -> Set[Tuple[int, int]]: + """ + Get the set of edges derived from the transitive closure of the mock ontology graph. + + Returns: + - Set[Tuple[int, int]]: A set of tuples representing the directed edges + in the transitive closure of the ChEBI graph. + """ + return { + (54321, 12345), + (54321, 99999), + (67890, 12345), + (67890, 99999), + (67890, 88888), + (11111, 54321), + (11111, 12345), + (11111, 99999), + (22222, 67890), + (22222, 12345), + (22222, 99999), + (22222, 88888), + (12345, 99999), + } + + @staticmethod + def get_number_of_transitive_edges() -> int: + """ + Get the number of edges in the transitive closure of the mock ontology graph. + + Returns: + - int: The number of edges in the transitive closure graph. + """ + return len(ChebiMockOntology.get_edges_of_transitive_closure_graph()) + + @staticmethod + def get_obsolete_nodes_ids() -> Set[int]: + """ + Get the set of obsolete node IDs in the mock ontology. + + Returns: + - Set[int]: A set of integers representing the obsolete ChEBI node IDs. + """ + return {77533, 77564} + + @staticmethod + def get_raw_data() -> str: + """ + Get the raw data representing the mock ontology in OBO format. + + Returns: + - str: A string containing the raw OBO data for the mock ChEBI terms. + """ + return """ + [Term] + id: CHEBI:12345 + name: Compound A + subset: 2_STAR + property_value: http://purl.obolibrary.org/obo/chebi/formula "C26H35ClN4O6S" xsd:string + property_value: http://purl.obolibrary.org/obo/chebi/charge "0" xsd:string + property_value: http://purl.obolibrary.org/obo/chebi/monoisotopicmass "566.19658" xsd:string + property_value: http://purl.obolibrary.org/obo/chebi/mass "567.099" xsd:string + property_value: http://purl.obolibrary.org/obo/chebi/inchikey "ROXPMFGZZQEKHB-IUKKYPGJSA-N" xsd:string + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1" xsd:string + property_value: http://purl.obolibrary.org/obo/chebi/inchi "InChI=1S/C26H35ClN4O6S/c1-16(2)28-26(34)30(5)14-23-17(3)13-31(18(4)15-32)25(33)21-7-6-8-22(24(21)37-23)29-38(35,36)20-11-9-19(27)10-12-20/h6-12,16-18,23,29,32H,13-15H2,1-5H3,(H,28,34)/t17-,18-,23+/m0/s1" xsd:string + xref: LINCS:LSM-20139 + is_a: CHEBI:54321 + is_a: CHEBI:67890 + + [Term] + id: CHEBI:54321 + name: Compound B + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1O" xsd:string + is_a: CHEBI:11111 + is_a: CHEBI:77564 + + [Term] + id: CHEBI:67890 + name: Compound C + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1N" xsd:string + is_a: CHEBI:22222 + + [Term] + id: CHEBI:11111 + name: Compound D + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1F" xsd:string + + [Term] + id: CHEBI:22222 + name: Compound E + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1Cl" xsd:string + + [Term] + id: CHEBI:99999 + name: Compound F + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1Br" xsd:string + is_a: CHEBI:12345 + + [Term] + id: CHEBI:77533 + name: Compound G + is_a: CHEBI:99999 + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=C1Br" xsd:string + is_obsolete: true + + [Term] + id: CHEBI:77564 + name: Compound H + property_value: http://purl.obolibrary.org/obo/chebi/smiles "CC=C1Br" xsd:string + is_obsolete: true + + [Typedef] + id: has_major_microspecies_at_pH_7_3 + name: has major microspecies at pH 7.3 + is_cyclic: true + is_transitive: false + + [Term] + id: CHEBI:88888 + name: Compound I + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1[Mg+]" xsd:string + is_a: CHEBI:67890 + """ + + @staticmethod + def get_data_in_dataframe() -> pd.DataFrame: + data = OrderedDict( + id=[ + 12345, + 54321, + 67890, + 11111, + 22222, + 99999, + 88888, + ], + name=[ + "Compound A", + "Compound B", + "Compound C", + "Compound D", + "Compound E", + "Compound F", + "Compound I", + ], + SMILES=[ + "C1=CC=CC=C1", + "C1=CC=CC=C1O", + "C1=CC=CC=C1N", + "C1=CC=CC=C1F", + "C1=CC=CC=C1Cl", + "C1=CC=CC=C1Br", + "C1=CC=CC=C1[Mg+]", + ], + **{ + # -row- [12345, 54321, 67890, 11111, 22222, 99999, 88888] + 11111: [True, True, False, True, False, True, False], + 12345: [True, False, False, False, False, True, False], + 22222: [True, False, True, False, True, True, True], + 54321: [True, True, False, False, False, True, False], + 67890: [True, False, True, False, False, True, True], + 88888: [False, False, False, False, False, False, True], + 99999: [False, False, False, False, False, True, False], + }, + ) + + data_df = pd.DataFrame(data) + + # ------------- Code Approach ------- + # ancestors_of_nodes = {} + # for parent, child in ChebiMockOntology.get_edges_of_transitive_closure_graph(): + # if child not in ancestors_of_nodes: + # ancestors_of_nodes[child] = set() + # if parent not in ancestors_of_nodes: + # ancestors_of_nodes[parent] = set() + # ancestors_of_nodes[child].add(parent) + # ancestors_of_nodes[child].add(child) + # + # # For each node in the ontology, create a column to check if it's an ancestor of any other node or itself + # for node in ChebiMockOntology.get_nodes(): + # data_df[node] = data_df['id'].apply( + # lambda x: (x == node) or (node in ancestors_of_nodes[x]) + # ) + + return data_df + + @staticmethod + def get_transitively_closed_graph() -> nx.DiGraph: + """ + Create a directed graph, compute its transitive closure, and return it. + + Returns: + g (nx.DiGraph): A transitively closed directed graph. + """ + g = nx.DiGraph() + + for node in ChebiMockOntology.get_nodes(): + g.add_node(node, **{"smiles": "test_smiles_placeholder"}) + + g.add_edges_from(ChebiMockOntology.get_edges_of_transitive_closure_graph()) + + return g + + +class GOUniProtMockData(MockOntologyGraphData): + """ + A mock ontology representing a simplified version of the Gene Ontology (GO) structure with nodes and edges + representing GO terms and their relationships in a directed acyclic graph (DAG). + + Nodes: + - GO_1 + - GO_2 + - GO_3 + - GO_4 + - GO_5 + - GO_6 + + Edges (Parent-Child Relationships): + - GO_1 -> GO_2 + - GO_1 -> GO_3 + - GO_2 -> GO_4 + - GO_2 -> GO_5 + - GO_3 -> GO_4 + - GO_4 -> GO_6 + + This mock ontology structure is useful for testing methods related to GO hierarchy, graph extraction, and transitive + closure operations. + + The class also includes methods to retrieve nodes, edges, and transitive closure of the graph. + + Visual Representation Graph with Valid Nodes and Edges: + + GO_1 + / \ + GO_2 GO_3 + / \ / + GO_5 GO_4 + \ + GO_6 + + Valid Swiss Proteins with mapping to valid GO ids + Swiss_Prot_1 -> GO_2, GO_3, GO_5 + Swiss_Prot_2 -> GO_2, GO_5 + """ + + @staticmethod + def get_nodes() -> List[int]: + """ + Get a sorted list of node IDs. + + Returns: + List[int]: A sorted list of node IDs in the ontology graph. + """ + return sorted([1, 2, 3, 4, 5, 6]) + + @staticmethod + def get_number_of_nodes() -> int: + """ + Get the total number of nodes in the ontology graph. + + Returns: + int: The number of nodes. + """ + return len(GOUniProtMockData.get_nodes()) + + @staticmethod + def get_edges() -> Set[Tuple[int, int]]: + """ + Get the set of edges in the ontology graph. + + Returns: + Set[Tuple[int, int]]: A set of tuples where each tuple represents an edge between two nodes. + """ + return {(1, 2), (1, 3), (2, 4), (2, 5), (3, 4), (4, 6)} + + @staticmethod + def get_number_of_edges() -> int: + """ + Get the total number of edges in the ontology graph. + + Returns: + int: The number of edges. + """ + return len(GOUniProtMockData.get_edges()) + + @staticmethod + def get_edges_of_transitive_closure_graph() -> Set[Tuple[int, int]]: + """ + Get the set of edges in the transitive closure of the ontology graph. + + Returns: + Set[Tuple[int, int]]: A set of tuples representing edges in the transitive closure graph. + """ + return { + (1, 2), + (1, 3), + (1, 4), + (1, 5), + (1, 6), + (2, 4), + (2, 5), + (2, 6), + (3, 4), + (3, 6), + (4, 6), + } + + @staticmethod + def get_number_of_transitive_edges() -> int: + """ + Get the total number of edges in the transitive closure graph. + + Returns: + int: The number of transitive edges. + """ + return len(GOUniProtMockData.get_edges_of_transitive_closure_graph()) + + @staticmethod + def get_obsolete_nodes_ids() -> Set[int]: + """ + Get the set of obsolete node IDs in the ontology graph. + + Returns: + Set[int]: A set of node IDs representing obsolete nodes. + """ + return {7, 8} + + @staticmethod + def get_GO_raw_data() -> str: + """ + Get raw data in string format for a basic Gene Ontology (GO) structure. + + This data simulates a basic GO ontology format typically used for testing purposes. + The data will include valid and obsolete GO terms with various relationships between them. + + Scenarios covered: + - Obsolete terms being the parent of valid terms. + - Valid terms being the parent of obsolete terms. + - Both direct and indirect hierarchical relationships between terms. + + The data is designed to help test the proper handling of obsolete and valid GO terms, + ensuring that the ontology parser can correctly manage both cases. + + Returns: + str: The raw GO data in string format, structured as test input. + """ + return """ + [Term] + id: GO:0000001 + name: GO_1 + namespace: molecular_function + def: "OBSOLETE. Assists in the correct assembly of ribosomes or ribosomal subunits in vivo, but is not a component of the assembled ribosome when performing its normal biological function." [GOC:jl, PMID:12150913] + comment: This term was made obsolete because it refers to a class of gene products and a biological process rather than a molecular function. + synonym: "ribosomal chaperone activity" EXACT [] + xref: MetaCyc:BETAGALACTOSID-RXN + xref: Reactome:R-HSA-189062 "lactose + H2O => D-glucose + D-galactose" + xref: Reactome:R-HSA-5658001 "Defective LCT does not hydrolyze Lac" + xref: RHEA:10076 + + [Term] + id: GO:0000002 + name: GO_2 + namespace: biological_process + is_a: GO:0000001 ! hydrolase activity, hydrolyzing O-glycosyl compounds + is_a: GO:0000008 ! hydrolase activity, hydrolyzing O-glycosyl compounds + + [Term] + id: GO:0000003 + name: GO_3 + namespace: cellular_component + is_a: GO:0000001 ! regulation of DNA recombination + + [Term] + id: GO:0000004 + name: GO_4 + namespace: biological_process + is_a: GO:0000003 ! regulation of DNA recombination + is_a: GO:0000002 ! hydrolase activity, hydrolyzing O-glycosyl compounds + + [Term] + id: GO:0000005 + name: GO_5 + namespace: molecular_function + is_a: GO:0000002 ! regulation of DNA recombination + + [Term] + id: GO:0000006 + name: GO_6 + namespace: cellular_component + is_a: GO:0000004 ! glucoside transport + + [Term] + id: GO:0000007 + name: GO_7 + namespace: biological_process + is_a: GO:0000003 ! glucoside transport + is_obsolete: true + + [Term] + id: GO:0000008 + name: GO_8 + namespace: molecular_function + is_obsolete: true + + [Typedef] + id: term_tracker_item + name: term tracker item + namespace: external + xref: IAO:0000233 + is_metadata_tag: true + is_class_level: true + """ + + @staticmethod + def protein_sequences() -> Dict[str, str]: + """ + Get the protein sequences for Swiss-Prot proteins. + + Returns: + Dict[str, str]: A dictionary where keys are Swiss-Prot IDs and values are their respective sequences. + """ + return { + "Swiss_Prot_1": "MAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK".replace( + " ", "" + ), + "Swiss_Prot_2": "EKGLIVGHFS GIKYKGEKAQ ASEVDVNKMC CWVSKFKDAM RRYQGIQTCK".replace( + " ", "" + ), + } + + @staticmethod + def proteins_for_pretraining() -> List[str]: + """ + Returns a list of protein IDs which will be used for pretraining based on mock UniProt data. + + Proteins include those with: + - No GO classes or invalid GO classes (missing required evidence codes). + + Returns: + List[str]: A list of protein IDs that do not meet validation criteria. + """ + return [ + "Swiss_Prot_5", # No GO classes associated + "Swiss_Prot_6", # GO class with no evidence code + "Swiss_Prot_7", # GO class with invalid evidence code + ] + + @staticmethod + def get_UniProt_raw_data() -> str: + """ + Get raw data in string format for UniProt proteins. + + This mock data contains eleven Swiss-Prot proteins with different properties: + - **Swiss_Prot_1**: A valid protein with three valid GO classes and one invalid GO class. + - **Swiss_Prot_2**: Another valid protein with two valid GO classes and one invalid. + - **Swiss_Prot_3**: Contains valid GO classes but has a sequence length > 1002. + - **Swiss_Prot_4**: Has valid GO classes but contains an invalid amino acid, 'X'. + - **Swiss_Prot_5**: Has a sequence but no GO classes associated. + - **Swiss_Prot_6**: Has GO classes without any associated evidence codes. + - **Swiss_Prot_7**: Has a GO class with an invalid evidence code. + - **Swiss_Prot_8**: Has a sequence length > 1002 and has only invalid GO class. + - **Swiss_Prot_9**: Has no GO classes but contains an invalid amino acid, 'X', in its sequence. + - **Swiss_Prot_10**: Has a valid GO class but lacks a sequence. + - **Swiss_Prot_11**: Has only Invalid GO class but lacks a sequence. + + Note: + A valid GO label is the one which has one of the following evidence code + (EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC). + + Returns: + str: The raw UniProt data in string format. + """ + protein_sq_1 = GOUniProtMockData.protein_sequences()["Swiss_Prot_1"] + protein_sq_2 = GOUniProtMockData.protein_sequences()["Swiss_Prot_2"] + raw_str = ( + # Below protein with 3 valid associated GO class and one invalid GO class + f"ID Swiss_Prot_1 Reviewed; {len(protein_sq_1)} AA. \n" + "AC Q6GZX4;\n" + "DR GO; GO:0000002; C:membrane; EXP:UniProtKB-KW.\n" + "DR GO; GO:0000003; C:membrane; IDA:UniProtKB-KW.\n" + "DR GO; GO:0000005; P:regulation of viral transcription; IPI:InterPro.\n" + "DR GO; GO:0000004; P:regulation of viral transcription; IEA:SGD.\n" + f"SQ SEQUENCE {len(protein_sq_1)} AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + f" {protein_sq_1}\n" + "//\n" + # Below protein with 2 valid associated GO class and one invalid GO class + f"ID Swiss_Prot_2 Reviewed; {len(protein_sq_2)} AA.\n" + "AC DCGZX4;\n" + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" + "DR GO; GO:0000002; P:regulation of viral transcription; IMP:InterPro.\n" + "DR GO; GO:0000005; P:regulation of viral transcription; IGI:InterPro.\n" + "DR GO; GO:0000006; P:regulation of viral transcription; IEA:PomBase.\n" + f"SQ SEQUENCE {len(protein_sq_2)} AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + f" {protein_sq_2}\n" + "//\n" + # Below protein with all valid associated GO class but sequence length greater than 1002 + f"ID Swiss_Prot_3 Reviewed; {len(protein_sq_1 * 25)} AA.\n" + "AC Q6GZX4;\n" + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" + "DR GO; GO:0000002; P:regulation of viral transcription; IEP:InterPro.\n" + "DR GO; GO:0000005; P:regulation of viral transcription; TAS:InterPro.\n" + "DR GO; GO:0000006; P:regulation of viral transcription; EXP:PomBase.\n" + f"SQ SEQUENCE {len(protein_sq_1 * 25)} AA; 129118 MW; FE2984658CED53A8 CRC64;\n" + f" {protein_sq_1 * 25}\n" + "//\n" + # Below protein has valid go class association but invalid amino acid `X` in its sequence + "ID Swiss_Prot_4 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" + "DR GO; GO:0000002; P:regulation of viral transcription; EXP:InterPro.\n" + "DR GO; GO:0000005; P:regulation of viral transcription; IEA:InterPro.\n" + "DR GO; GO:0000006; P:regulation of viral transcription; EXP:PomBase.\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " XAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + "//\n" + # Below protein with sequence string but has no GO class + "ID Swiss_Prot_5 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " MAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + "//\n" + # Below protein with sequence string and with NO `valid` associated GO class (no evidence code) + "ID Swiss_Prot_6 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "DR GO; GO:0000023; P:regulation of viral transcription;\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " MAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + "//\n" + # Below protein with sequence string and with NO `valid` associated GO class (invalid evidence code) + "ID Swiss_Prot_7 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "DR GO; GO:0000024; P:regulation of viral transcription; IEA:SGD.\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " MAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + "//\n" + # Below protein with sequence length greater than 1002 but with `Invalid` associated GO class + f"ID Swiss_Prot_8 Reviewed; {len(protein_sq_2 * 25)} AA.\n" + "AC Q6GZX4;\n" + "DR GO; GO:0000025; P:regulation of viral transcription; IC:Inferred.\n" + f"SQ SEQUENCE {len(protein_sq_2 * 25)} AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + f" {protein_sq_2 * 25}\n" + "//\n" + # Below protein with sequence string but invalid amino acid `X` in its sequence + "ID Swiss_Prot_9 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " XAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + "//\n" + # Below protein with a `valid` associated GO class but without sequence string + "ID Swiss_Prot_10 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "DR GO; GO:0000027; P:regulation of viral transcription; EXP:InterPro.\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " \n" + "//\n" + # Below protein with a `Invalid` associated GO class but without sequence string + "ID Swiss_Prot_11 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "DR GO; GO:0000028; P:regulation of viral transcription; ND:NoData.\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " \n" + "//\n" + ) + + return raw_str + + @staticmethod + def get_data_in_dataframe() -> pd.DataFrame: + """ + Get a mock DataFrame representing UniProt data. + + The DataFrame contains Swiss-Prot protein data, including identifiers, accessions, GO terms, sequences, + and binary label columns representing whether each protein is associated with certain GO classes. + + Returns: + pd.DataFrame: A DataFrame containing mock UniProt data with columns for 'swiss_id', 'accession', 'go_ids', 'sequence', + and binary labels for GO classes. + """ + expected_data = OrderedDict( + swiss_id=["Swiss_Prot_1", "Swiss_Prot_2"], + accession=["Q6GZX4", "DCGZX4"], + go_ids=[[1, 2, 3, 5], [1, 2, 5]], + sequence=list(GOUniProtMockData.protein_sequences().values()), + **{ + # SP_1, SP_2 + 1: [True, True], + 2: [True, True], + 3: [True, False], + 4: [False, False], + 5: [True, True], + 6: [False, False], + }, + ) + return pd.DataFrame(expected_data) + + @staticmethod + def get_transitively_closed_graph() -> nx.DiGraph: + """ + Get the transitive closure of the ontology graph. + + Returns: + nx.DiGraph: A directed graph representing the transitive closure of the ontology graph. + """ + g = nx.DiGraph() + g.add_nodes_from(node for node in ChebiMockOntology.get_nodes()) + g.add_edges_from(GOUniProtMockData.get_edges_of_transitive_closure_graph()) + return g diff --git a/tests/unit/mock_data/tox_mock_data.py b/tests/unit/mock_data/tox_mock_data.py new file mode 100644 index 00000000..b5f85bda --- /dev/null +++ b/tests/unit/mock_data/tox_mock_data.py @@ -0,0 +1,510 @@ +from typing import Dict, List + + +class Tox21MolNetMockData: + """ + A utility class providing mock data for testing the Tox21MolNet dataset. + + This class includes static methods that return mock data in various formats, simulating + the raw and processed data of the Tox21MolNet dataset. The mock data is used for unit tests + to verify the functionality of methods within the Tox21MolNet class without relying on actual + data files. + """ + + @staticmethod + def get_raw_data() -> str: + """ + Returns a raw CSV string that simulates the raw data of the Tox21MolNet dataset. + """ + return ( + "NR-AR,NR-AR-LBD,NR-AhR,NR-Aromatase,NR-ER,NR-ER-LBD,NR-PPAR-gamma,SR-ARE,SR-ATAD5,SR-HSE,SR-MMP,SR-p53," + "mol_id,smiles\n" + "0,0,1,0,1,1,0,1,0,,1,0,TOX958,Nc1ccc([N+](=O)[O-])cc1N\n" + ",,,,,,,,,1,,,TOX31681,Nc1cc(C(F)(F)F)ccc1S\n" + "0,0,0,0,0,0,0,,0,0,0,0,TOX5110,CC(C)(C)OOC(C)(C)CCC(C)(C)OOC(C)(C)C\n" + "0,0,0,0,0,0,0,0,0,0,0,0,TOX6619,O=S(=O)(Cl)c1ccccc1\n" + "0,0,0,,0,0,,,0,,1,,TOX27679,CCCCCc1ccco1\n" + "0,,1,,,,0,,1,1,1,1,TOX2801,Oc1c(Cl)cc(Cl)c2cccnc12\n" + "0,0,0,0,,0,,,0,0,,1,TOX2808,CN(C)CCCN1c2ccccc2Sc2ccc(Cl)cc21\n" + "0,,0,1,,,,1,0,,1,,TOX29085,CCCCCCCCCCCCCCn1cc[n+](C)c1\n" + ) + + @staticmethod + def get_processed_data() -> List[Dict]: + """ + Returns a list of dictionaries simulating the processed data for the Tox21MolNet dataset. + Each dictionary contains 'ident', 'features', and 'labels'. + """ + data_list = [ + { + "ident": "TOX958", + "features": "Nc1ccc([N+](=O)[O-])cc1N", + "labels": [ + False, + False, + True, + False, + True, + True, + False, + True, + False, + None, + True, + False, + ], + }, + { + "ident": "TOX31681", + "features": "Nc1cc(C(F)(F)F)ccc1S", + "labels": [ + None, + None, + None, + None, + None, + None, + None, + None, + None, + True, + None, + None, + ], + }, + { + "ident": "TOX5110", + "features": "CC(C)(C)OOC(C)(C)CCC(C)(C)OOC(C)(C)C", + "labels": [ + False, + False, + False, + False, + False, + False, + False, + None, + False, + False, + False, + False, + ], + }, + { + "ident": "TOX6619", + "features": "O=S(=O)(Cl)c1ccccc1", + "labels": [ + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + }, + { + "ident": "TOX27679", + "features": "CCCCCc1ccco1", + "labels": [ + False, + False, + False, + None, + False, + False, + None, + None, + False, + None, + True, + None, + ], + }, + { + "ident": "TOX2801", + "features": "Oc1c(Cl)cc(Cl)c2cccnc12", + "labels": [ + False, + None, + True, + None, + None, + None, + False, + None, + True, + True, + True, + True, + ], + }, + { + "ident": "TOX2808", + "features": "CN(C)CCCN1c2ccccc2Sc2ccc(Cl)cc21", + "labels": [ + False, + False, + False, + False, + None, + False, + None, + None, + False, + False, + None, + True, + ], + }, + { + "ident": "TOX29085", + "features": "CCCCCCCCCCCCCCn1cc[n+](C)c1", + "labels": [ + False, + None, + False, + True, + None, + None, + None, + True, + False, + None, + True, + None, + ], + }, + ] + + data_with_group = [{**data, "group": None} for data in data_list] + return data_with_group + + @staticmethod + def get_processed_grouped_data() -> List[Dict]: + """ + Returns a list of dictionaries simulating the processed data for the Tox21MolNet dataset. + Each dictionary contains 'ident', 'features', and 'labels'. + """ + processed_data = Tox21MolNetMockData.get_processed_data() + groups = ["A", "A", "B", "B", "C", "C", "C", "C"] + + assert len(processed_data) == len( + groups + ), "The number of processed data entries does not match the number of groups." + + # Combine processed data with their corresponding groups + grouped_data = [ + {**data, "group": group, "original": True} + for data, group in zip(processed_data, groups) + ] + + return grouped_data + + +class Tox21ChallengeMockData: + + MOL_BINARY_STR = ( + b"cyclobutane\n" + b" RDKit 2D\n\n" + b" 4 4 0 0 0 0 0 0 0 0999 V2000\n" + b" 1.0607 -0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n" + b" -0.0000 -1.0607 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n" + b" -1.0607 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n" + b" 0.0000 1.0607 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n" + b" 1 2 1 0\n" + b" 2 3 1 0\n" + b" 3 4 1 0\n" + b" 4 1 1 0\n" + b"M END\n\n" + ) + + SMILES_OF_MOL = "C1CCC1" + # Feature encoding of SMILES as per chebai/preprocessing/bin/smiles_token/tokens.txt + FEATURE_OF_SMILES = [19, 42, 19, 19, 19, 42] + + @staticmethod + def get_raw_train_data() -> bytes: + raw_str = ( + Tox21ChallengeMockData.MOL_BINARY_STR + b"> \n" + b"25848\n\n" + b"> \n" + b"0\n\n" + b"$$$$\n" + Tox21ChallengeMockData.MOL_BINARY_STR + b"> \n" + b"2384\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"0\n\n" + b"$$$$\n" + Tox21ChallengeMockData.MOL_BINARY_STR + b"> \n" + b"27102\n\n" + b"> \n" + b"0\n\n" + b"> \n" + b"0\n\n" + b"$$$$\n" + Tox21ChallengeMockData.MOL_BINARY_STR + b"> \n" + b"26792\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"$$$$\n" + Tox21ChallengeMockData.MOL_BINARY_STR + b"> \n" + b"26401\n\n" + b"> \n" + b"1\n\n" + b"> \n" + b"1\n\n" + b"$$$$\n" + Tox21ChallengeMockData.MOL_BINARY_STR + b"> \n" + b"25973\n\n" + b"$$$$\n" + ) + return raw_str + + @staticmethod + def data_in_dict_format() -> List[Dict]: + data_list = [ + { + "labels": [ + None, + None, + None, + None, + None, + None, + None, + None, + None, + 0, + None, + None, + ], + "ident": "25848", + }, + { + "labels": [ + 0, + None, + None, + 1, + None, + None, + None, + None, + None, + None, + None, + None, + ], + "ident": "2384", + }, + { + "labels": [ + 0, + None, + 0, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ], + "ident": "27102", + }, + { + "labels": [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ], + "ident": "26792", + }, + { + "labels": [ + None, + None, + None, + None, + None, + None, + None, + 1, + None, + 1, + None, + None, + ], + "ident": "26401", + }, + { + "labels": [ + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ], + "ident": "25973", + }, + ] + + for dict_ in data_list: + dict_["features"] = Tox21ChallengeMockData.FEATURE_OF_SMILES + dict_["group"] = None + + return data_list + + @staticmethod + def get_raw_smiles_data() -> str: + """ + Returns mock SMILES data in a tab-delimited format (mocks test.smiles file). + + The data represents molecules and their associated sample IDs. + + Returns: + str: A string containing SMILES representations and corresponding sample IDs. + """ + return ( + "#SMILES\tSample ID\n" + f"{Tox21ChallengeMockData.SMILES_OF_MOL}\tNCGC00260869-01\n" + f"{Tox21ChallengeMockData.SMILES_OF_MOL}\tNCGC00261776-01\n" + f"{Tox21ChallengeMockData.SMILES_OF_MOL}\tNCGC00261380-01\n" + f"{Tox21ChallengeMockData.SMILES_OF_MOL}\tNCGC00261842-01\n" + f"{Tox21ChallengeMockData.SMILES_OF_MOL}\tNCGC00261662-01\n" + f"{Tox21ChallengeMockData.SMILES_OF_MOL}\tNCGC00261190-01\n" + ) + + @staticmethod + def get_raw_score_txt_data() -> str: + """ + Returns mock score data in a tab-delimited format (mocks test_results.txt file). + + The data represents toxicity test results for different molecular samples, including several toxicity endpoints. + + Returns: + str: A string containing toxicity scores for each molecular sample and corresponding toxicity endpoints. + """ + return ( + "Sample ID\tNR-AhR\tNR-AR\tNR-AR-LBD\tNR-Aromatase\tNR-ER\tNR-ER-LBD\tNR-PPAR-gamma\t" + "SR-ARE\tSR-ATAD5\tSR-HSE\tSR-MMP\tSR-p53\n" + "NCGC00260869-01\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\n" + "NCGC00261776-01\t1\t1\t1\t1\t1\t1\t1\t1\t1\t1\t1\t1\n" + "NCGC00261380-01\tx\tx\tx\tx\tx\tx\tx\tx\tx\tx\tx\tx\n" + "NCGC00261842-01\t0\t0\t0\tx\t0\t0\t0\t0\t0\t0\tx\t1\n" + "NCGC00261662-01\t1\t0\t0\tx\t1\t1\t1\tx\t1\t1\tx\t1\n" + "NCGC00261190-01\tx\t0\t0\tx\t1\t0\t0\t1\t0\t0\t1\t1\n" + ) + + @staticmethod + def get_setup_processed_output_data() -> List[Dict]: + """ + Returns mock processed data used for testing the `setup_processed` method. + + The data contains molecule identifiers and their corresponding toxicity labels for multiple endpoints. + Each dictionary in the list represents a molecule with its associated labels, features, and group information. + + Returns: + List[Dict]: A list of dictionaries where each dictionary contains: + - "features": The SMILES features of the molecule. + - "labels": A list of toxicity endpoint labels (0, 1, or None). + - "ident": The sample identifier. + - "group": None (default value for the group key). + """ + + # "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", + # "SR-HSE", "SR-MMP", "SR-p53", + data_list = [ + { + "labels": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + "ident": "NCGC00260869-01", + }, + { + "labels": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "ident": "NCGC00261776-01", + }, + { + "labels": [ + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ], + "ident": "NCGC00261380-01", + }, + { + "labels": [0, 0, 0, None, 0, 0, 0, 0, 0, 0, None, 1], + "ident": "NCGC00261842-01", + }, + { + "labels": [0, 0, 1, None, 1, 1, 1, None, 1, 1, None, 1], + "ident": "NCGC00261662-01", + }, + { + "labels": [0, 0, None, None, 1, 0, 0, 1, 0, 0, 1, 1], + "ident": "NCGC00261190-01", + }, + ] + + complete_list = [] + for dict_ in data_list: + complete_list.append( + { + "features": Tox21ChallengeMockData.FEATURE_OF_SMILES, + **dict_, + "group": None, + } + ) + + return complete_list diff --git a/tests/unit/readers/__init__.py b/tests/unit/readers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/readers/testChemDataReader.py b/tests/unit/readers/testChemDataReader.py new file mode 100644 index 00000000..0c1c4d6f --- /dev/null +++ b/tests/unit/readers/testChemDataReader.py @@ -0,0 +1,107 @@ +import unittest +from typing import List +from unittest.mock import mock_open, patch + +from chebai.preprocessing.reader import EMBEDDING_OFFSET, ChemDataReader + + +class TestChemDataReader(unittest.TestCase): + """ + Unit tests for the ChemDataReader class. + + Note: Test methods within a TestCase class are not guaranteed to be executed in any specific order. + """ + + @classmethod + @patch( + "chebai.preprocessing.reader.open", + new_callable=mock_open, + read_data="C\nO\nN\n=\n1\n(", + ) + def setUpClass(cls, mock_file: mock_open) -> None: + """ + Set up the test environment by initializing a ChemDataReader instance with a mocked token file. + + Args: + mock_file: Mock object for file operations. + """ + cls.reader = ChemDataReader(token_path="/mock/path") + # After initializing, cls.reader.cache should now be set to ['C', 'O', 'N', '=', '1', '('] + assert cls.reader.cache == [ + "C", + "O", + "N", + "=", + "1", + "(", + ], "Initial cache does not match expected values." + + def test_read_data(self) -> None: + """ + Test the _read_data method with a SMILES string to ensure it correctly tokenizes the string. + """ + raw_data = "CC(=O)NC1[Mg-2]" + # Expected output as per the tokens already in the cache, and ")" getting added to it. + expected_output: List[int] = [ + EMBEDDING_OFFSET + 0, # C + EMBEDDING_OFFSET + 0, # C + EMBEDDING_OFFSET + 5, # = + EMBEDDING_OFFSET + 3, # O + EMBEDDING_OFFSET + 1, # N + EMBEDDING_OFFSET + len(self.reader.cache), # ( + EMBEDDING_OFFSET + 2, # C + EMBEDDING_OFFSET + 0, # C + EMBEDDING_OFFSET + 4, # 1 + EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [Mg-2] + ] + result = self.reader._read_data(raw_data) + self.assertEqual( + result, + expected_output, + "The output of _read_data does not match the expected tokenized values.", + ) + + def test_read_data_with_new_token(self) -> None: + """ + Test the _read_data method with a SMILES string that includes a new token. + Ensure that the new token is added to the cache and processed correctly. + """ + raw_data = "[H-]" + + # Determine the index for the new token based on the current size of the cache. + index_for_last_token = len(self.reader.cache) + expected_output: List[int] = [EMBEDDING_OFFSET + index_for_last_token] + + result = self.reader._read_data(raw_data) + self.assertEqual( + result, + expected_output, + "The output for new token '[H-]' does not match the expected values.", + ) + + # Verify that '[H-]' was added to the cache + self.assertIn( + "[H-]", + self.reader.cache, + "The new token '[H-]' was not added to the cache.", + ) + # Ensure it's at the correct index + self.assertEqual( + self.reader.cache.index("[H-]"), + index_for_last_token, + "The new token '[H-]' was not added at the correct index in the cache.", + ) + + def test_read_data_with_invalid_input(self) -> None: + """ + Test the _read_data method with an invalid input. + The invalid token should raise an error or be handled appropriately. + """ + raw_data = "%INVALID%" + + with self.assertRaises(ValueError): + self.reader._read_data(raw_data) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/readers/testDataReader.py b/tests/unit/readers/testDataReader.py new file mode 100644 index 00000000..745c0ace --- /dev/null +++ b/tests/unit/readers/testDataReader.py @@ -0,0 +1,56 @@ +import unittest +from typing import Any, Dict, List + +from chebai.preprocessing.reader import DataReader + + +class TestDataReader(unittest.TestCase): + """ + Unit tests for the DataReader class. + """ + + @classmethod + def setUpClass(cls) -> None: + """ + Set up the test environment by initializing a DataReader instance. + """ + cls.reader = DataReader() + + def test_to_data(self) -> None: + """ + Test the to_data method to ensure it correctly processes the input row + and formats it according to the expected output. + + This method tests the conversion of raw data into a processed format, + including extracting features, labels, ident, group, and additional + keyword arguments. + """ + features_list: List[int] = [10, 20, 30] + labels_list: List[bool] = [True, False, True] + ident_no: int = 123 + + row: Dict[str, Any] = { + "features": features_list, + "labels": labels_list, + "ident": ident_no, + "group": "group_data", + "additional_kwargs": {"extra_key": "extra_value"}, + } + + expected: Dict[str, Any] = { + "features": features_list, + "labels": labels_list, + "ident": ident_no, + "group": "group_data", + "extra_key": "extra_value", + } + + self.assertEqual( + self.reader.to_data(row), + expected, + "The to_data method did not process the input row as expected.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/readers/testDeepChemDataReader.py b/tests/unit/readers/testDeepChemDataReader.py new file mode 100644 index 00000000..dc29c9a6 --- /dev/null +++ b/tests/unit/readers/testDeepChemDataReader.py @@ -0,0 +1,115 @@ +import unittest +from typing import List +from unittest.mock import mock_open, patch + +from chebai.preprocessing.reader import EMBEDDING_OFFSET, DeepChemDataReader + + +class TestDeepChemDataReader(unittest.TestCase): + """ + Unit tests for the DeepChemDataReader class. + + Note: Test methods within a TestCase class are not guaranteed to be executed in any specific order. + """ + + @classmethod + @patch( + "chebai.preprocessing.reader.open", + new_callable=mock_open, + read_data="C\nO\nc\n)", + ) + def setUpClass(cls, mock_file: mock_open) -> None: + """ + Set up the test environment by initializing a DeepChemDataReader instance with a mocked token file. + + Args: + mock_file: Mock object for file operations. + """ + cls.reader = DeepChemDataReader(token_path="/mock/path") + # After initializing, cls.reader.cache should now be set to ['C', 'O', 'c', ')'] + assert cls.reader.cache == [ + "C", + "O", + "c", + ")", + ], "Cache initialization did not match expected tokens." + + def test_read_data(self) -> None: + """ + Test the _read_data method with a SMILES string to ensure it correctly tokenizes the string. + """ + raw_data = "c1ccccc1C(Br)(OC)I[Ni-2]" + + # benzene is c1ccccc1 in SMILES but cccccc6 in DeepSMILES + # SMILES C(Br)(OC)I can be converted to the DeepSMILES CBr)OC))I. + # Resultant String: "cccccc6CBr)OC))I[Ni-2]" + # Expected output as per the tokens already in the cache, and new tokens getting added to it. + expected_output: List[int] = [ + EMBEDDING_OFFSET + 2, # c + EMBEDDING_OFFSET + 2, # c + EMBEDDING_OFFSET + 2, # c + EMBEDDING_OFFSET + 2, # c + EMBEDDING_OFFSET + 2, # c + EMBEDDING_OFFSET + 2, # c + EMBEDDING_OFFSET + len(self.reader.cache), # 6 (new token) + EMBEDDING_OFFSET + 0, # C + EMBEDDING_OFFSET + len(self.reader.cache) + 1, # Br (new token) + EMBEDDING_OFFSET + 3, # ) + EMBEDDING_OFFSET + 1, # O + EMBEDDING_OFFSET + 0, # C + EMBEDDING_OFFSET + 3, # ) + EMBEDDING_OFFSET + 3, # ) + EMBEDDING_OFFSET + len(self.reader.cache) + 2, # I (new token) + EMBEDDING_OFFSET + len(self.reader.cache) + 3, # [Ni-2] (new token) + ] + result = self.reader._read_data(raw_data) + self.assertEqual( + result, + expected_output, + "The _read_data method did not produce the expected tokenized output for the SMILES string.", + ) + + def test_read_data_with_new_token(self) -> None: + """ + Test the _read_data method with a SMILES string that includes a new token. + Ensure that the new token is added to the cache and processed correctly. + """ + raw_data = "[H-]" + + # Determine the index for the new token based on the current size of the cache. + index_for_last_token = len(self.reader.cache) + expected_output: List[int] = [EMBEDDING_OFFSET + index_for_last_token] + + result = self.reader._read_data(raw_data) + self.assertEqual( + result, + expected_output, + "The _read_data method did not produce the expected output for a SMILES string with a new token.", + ) + + # Verify that '[H-]' was added to the cache + self.assertIn( + "[H-]", + self.reader.cache, + "The new token '[H-]' was not added to the cache as expected.", + ) + # Ensure it's at the correct index + self.assertEqual( + self.reader.cache.index("[H-]"), + index_for_last_token, + "The new token '[H-]' was not added to the correct index in the cache.", + ) + + def test_read_data_with_invalid_input(self) -> None: + """ + Test the _read_data method with an invalid input string. + The invalid token should raise an error or be handled appropriately. + """ + raw_data = "CBr))(OCI" + + with self.assertRaises(Exception): + self.reader._read_data(raw_data) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/readers/testProteinDataReader.py b/tests/unit/readers/testProteinDataReader.py new file mode 100644 index 00000000..c5bc5e9a --- /dev/null +++ b/tests/unit/readers/testProteinDataReader.py @@ -0,0 +1,139 @@ +import unittest +from typing import List +from unittest.mock import mock_open, patch + +from chebai.preprocessing.reader import EMBEDDING_OFFSET, ProteinDataReader + + +class TestProteinDataReader(unittest.TestCase): + """ + Unit tests for the ProteinDataReader class. + """ + + @classmethod + @patch( + "chebai.preprocessing.reader.open", + new_callable=mock_open, + read_data="M\nK\nT\nF\nR\nN", + ) + def setUpClass(cls, mock_file: mock_open) -> None: + """ + Set up the test environment by initializing a ProteinDataReader instance with a mocked token file. + + Args: + mock_file: Mock object for file operations. + """ + cls.reader = ProteinDataReader(token_path="/mock/path") + # After initializing, cls.reader.cache should now be set to ['M', 'K', 'T', 'F', 'R', 'N'] + assert cls.reader.cache == [ + "M", + "K", + "T", + "F", + "R", + "N", + ], "Cache initialization did not match expected tokens." + + def test_read_data(self) -> None: + """ + Test the _read_data method with a protein sequence to ensure it correctly tokenizes the sequence. + """ + raw_data = "MKTFFRN" + + # Expected output based on the cached tokens + expected_output: List[int] = [ + EMBEDDING_OFFSET + 0, # M + EMBEDDING_OFFSET + 1, # K + EMBEDDING_OFFSET + 2, # T + EMBEDDING_OFFSET + 3, # F + EMBEDDING_OFFSET + 3, # F (repeated token) + EMBEDDING_OFFSET + 4, # R + EMBEDDING_OFFSET + 5, # N + ] + result = self.reader._read_data(raw_data) + self.assertEqual( + result, + expected_output, + "The _read_data method did not produce the expected tokenized output.", + ) + + def test_read_data_with_new_token(self) -> None: + """ + Test the _read_data method with a protein sequence that includes a new token. + Ensure that the new token is added to the cache and processed correctly. + """ + raw_data = "MKTFY" + + # 'Y' is not in the initial cache and should be added. + expected_output: List[int] = [ + EMBEDDING_OFFSET + 0, # M + EMBEDDING_OFFSET + 1, # K + EMBEDDING_OFFSET + 2, # T + EMBEDDING_OFFSET + 3, # F + EMBEDDING_OFFSET + len(self.reader.cache), # Y (new token) + ] + + result = self.reader._read_data(raw_data) + self.assertEqual( + result, + expected_output, + "The _read_data method did not correctly handle a new token.", + ) + + # Verify that 'Y' was added to the cache + self.assertIn( + "Y", self.reader.cache, "The new token 'Y' was not added to the cache." + ) + # Ensure it's at the correct index + self.assertEqual( + self.reader.cache.index("Y"), + len(self.reader.cache) - 1, + "The new token 'Y' was not added at the correct index in the cache.", + ) + + def test_read_data_with_invalid_token(self) -> None: + """ + Test the _read_data method with an invalid amino acid token to ensure it raises a KeyError. + """ + raw_data = "MKTFZ" # 'Z' is not a valid amino acid token + + with self.assertRaises(KeyError) as context: + self.reader._read_data(raw_data) + + self.assertIn( + "Invalid token 'Z' encountered", + str(context.exception), + "The KeyError did not contain the expected message for an invalid token.", + ) + + def test_read_data_with_empty_sequence(self) -> None: + """ + Test the _read_data method with an empty protein sequence to ensure it returns an empty list. + """ + raw_data = "" + + result = self.reader._read_data(raw_data) + self.assertEqual( + result, + [], + "The _read_data method did not return an empty list for an empty input sequence.", + ) + + def test_read_data_with_repeated_tokens(self) -> None: + """ + Test the _read_data method with repeated amino acid tokens to ensure it handles them correctly. + """ + raw_data = "MMMMM" + + expected_output: List[int] = [EMBEDDING_OFFSET + 0] * 5 # All tokens are 'M' + + result = self.reader._read_data(raw_data) + self.assertEqual( + result, + expected_output, + "The _read_data method did not correctly handle repeated tokens.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/readers/testSelfiesReader.py b/tests/unit/readers/testSelfiesReader.py new file mode 100644 index 00000000..411fc63b --- /dev/null +++ b/tests/unit/readers/testSelfiesReader.py @@ -0,0 +1,127 @@ +import unittest +from typing import List +from unittest.mock import mock_open, patch + +from chebai.preprocessing.reader import EMBEDDING_OFFSET, SelfiesReader + + +class TestSelfiesReader(unittest.TestCase): + """ + Unit tests for the SelfiesReader class. + + Note: Test methods within a TestCase class are not guaranteed to be executed in any specific order. + """ + + @classmethod + @patch( + "chebai.preprocessing.reader.open", + new_callable=mock_open, + read_data="[C]\n[O]\n[=C]", + ) + def setUpClass(cls, mock_file: mock_open) -> None: + """ + Set up the test environment by initializing a SelfiesReader instance with a mocked token file. + + Args: + mock_file: Mock object for file operations. + """ + cls.reader = SelfiesReader(token_path="/mock/path") + # After initializing, cls.reader.cache should now be set to ['[C]', '[O]', '[=C]'] + assert cls.reader.cache == [ + "[C]", + "[O]", + "[=C]", + ], "Cache initialization did not match expected tokens." + + def test_read_data(self) -> None: + """ + Test the _read_data method with a SELFIES string to ensure it correctly tokenizes the string. + """ + raw_data = "c1ccccc1C(Br)(OC)I[Ni-2]" + + # benzene is "c1ccccc1" in SMILES is translated to "[C][=C][C][=C][C][=C][Ring1][=Branch1]" in SELFIES + # SELFIES translation of SMILES "c1ccccc1C(Br)(OC)I[Ni-2]": + # "[C][=C][C][=C][C][=C][Ring1][=Branch1][C][Branch1][C][Br][Branch1][Ring1][O][C][I][Ni-2]" + expected_output: List[int] = [ + EMBEDDING_OFFSET + 0, # [C] (already in cache) + EMBEDDING_OFFSET + 2, # [=C] (already in cache) + EMBEDDING_OFFSET + 0, # [C] (already in cache) + EMBEDDING_OFFSET + 2, # [=C] (already in cache) + EMBEDDING_OFFSET + 0, # [C] (already in cache) + EMBEDDING_OFFSET + 2, # [=C] (already in cache) + EMBEDDING_OFFSET + len(self.reader.cache), # [Ring1] (new token) + EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [=Branch1] (new token) + EMBEDDING_OFFSET + 0, # [C] (already in cache) + EMBEDDING_OFFSET + len(self.reader.cache) + 2, # [Branch1] (new token) + EMBEDDING_OFFSET + 0, # [C] (already in cache) + EMBEDDING_OFFSET + len(self.reader.cache) + 3, # [Br] (new token) + EMBEDDING_OFFSET + + len(self.reader.cache) + + 2, # [Branch1] (reused new token) + EMBEDDING_OFFSET + len(self.reader.cache), # [Ring1] (reused new token) + EMBEDDING_OFFSET + 1, # [O] (already in cache) + EMBEDDING_OFFSET + 0, # [C] (already in cache) + EMBEDDING_OFFSET + len(self.reader.cache) + 4, # [I] (new token) + EMBEDDING_OFFSET + len(self.reader.cache) + 5, # [Ni-2] (new token) + ] + + result = self.reader._read_data(raw_data) + self.assertEqual( + result, + expected_output, + "The _read_data method did not produce the expected tokenized output.", + ) + + def test_read_data_with_new_token(self) -> None: + """ + Test the _read_data method with a SELFIES string that includes a new token. + Ensure that the new token is added to the cache and processed correctly. + """ + raw_data = "[H-]" + + # Determine the index for the new token based on the current size of the cache. + index_for_last_token = len(self.reader.cache) + expected_output: List[int] = [EMBEDDING_OFFSET + index_for_last_token] + + result = self.reader._read_data(raw_data) + self.assertEqual( + result, + expected_output, + "The _read_data method did not correctly handle a new token.", + ) + + # Verify that '[H-1]' was added to the cache, "[H-]" translated to "[H-1]" in SELFIES + self.assertIn( + "[H-1]", + self.reader.cache, + "The new token '[H-1]' was not added to the cache.", + ) + # Ensure it's at the correct index + self.assertEqual( + self.reader.cache.index("[H-1]"), + index_for_last_token, + "The new token '[H-1]' was not added at the correct index in the cache.", + ) + + def test_read_data_with_invalid_selfies(self) -> None: + """ + Test the _read_data method with an invalid SELFIES string to ensure error handling works. + """ + raw_data = "[C][O][INVALID][N]" + + result = self.reader._read_data(raw_data) + self.assertIsNone( + result, + "The _read_data method did not return None for an invalid SELFIES string.", + ) + + # Verify that the error count was incremented + self.assertEqual( + self.reader.error_count, + 1, + "The error count was not incremented for an invalid SELFIES string.", + ) + + +if __name__ == "__main__": + unittest.main()