diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0ad21157..6325b74d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,7 +24,8 @@ jobs: python -m pip install --upgrade pip python -m pip install --upgrade pip setuptools wheel python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - python -m pip install -e . + # wandb included to use custom trainer for cli test which needs wandb logger + python -m pip install -e .[wandb] - name: Display Python & Installed Packages run: | diff --git a/chebai/models/ffn.py b/chebai/models/ffn.py index 70641615..194dffe9 100644 --- a/chebai/models/ffn.py +++ b/chebai/models/ffn.py @@ -11,7 +11,6 @@ class FFN(ChebaiBaseNet): def __init__( self, - input_size: int, hidden_layers: List[int] = [ 1024, ], @@ -20,7 +19,7 @@ def __init__( super().__init__(**kwargs) layers = [] - current_layer_input_size = input_size + current_layer_input_size = self.input_dim for hidden_dim in hidden_layers: layers.append(MLPBlock(current_layer_input_size, hidden_dim)) layers.append(Residual(MLPBlock(hidden_dim, hidden_dim))) diff --git a/chebai/preprocessing/datasets/mock_dm.py b/chebai/preprocessing/datasets/mock_dm.py new file mode 100644 index 00000000..25116e21 --- /dev/null +++ b/chebai/preprocessing/datasets/mock_dm.py @@ -0,0 +1,46 @@ +import torch +from lightning.pytorch.core.datamodule import LightningDataModule +from torch.utils.data import DataLoader + +from chebai.preprocessing.collate import RaggedCollator + + +class MyLightningDataModule(LightningDataModule): + def __init__(self): + super().__init__() + self._num_of_labels = None + self._feature_vector_size = None + self.collator = RaggedCollator() + + def prepare_data(self): + pass + + def setup(self, stage=None): + self._num_of_labels = 10 + self._feature_vector_size = 20 + print(f"Number of labels: {self._num_of_labels}") + print(f"Number of features: {self._feature_vector_size}") + + @property + def num_of_labels(self): + return self._num_of_labels + + @property + def feature_vector_size(self): + return self._feature_vector_size + + def train_dataloader(self): + assert self.feature_vector_size is not None, "feature_vector_size must be set" + # Dummy dataset for example purposes + + datalist = [ + { + "features": torch.randn(self._feature_vector_size), + "labels": torch.randint(0, 2, (self._num_of_labels,), dtype=torch.bool), + "ident": i, + "group": None, + } + for i in range(100) + ] + + return DataLoader(datalist, batch_size=32, collate_fn=self.collator) diff --git a/configs/model/ffn.yml b/configs/model/ffn.yml index ba94a43e..28cfd096 100644 --- a/configs/model/ffn.yml +++ b/configs/model/ffn.yml @@ -2,4 +2,3 @@ class_path: chebai.models.ffn.FFN init_args: optimizer_kwargs: lr: 1e-3 - input_size: 2560 diff --git a/tests/unit/cli/__init__.py b/tests/unit/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/cli/mock_dm_config.yml b/tests/unit/cli/mock_dm_config.yml new file mode 100644 index 00000000..850304a2 --- /dev/null +++ b/tests/unit/cli/mock_dm_config.yml @@ -0,0 +1 @@ +class_path: chebai.preprocessing.datasets.mock_dm.MyLightningDataModule diff --git a/tests/unit/cli/testCLI.py b/tests/unit/cli/testCLI.py new file mode 100644 index 00000000..d76b5a33 --- /dev/null +++ b/tests/unit/cli/testCLI.py @@ -0,0 +1,35 @@ +import unittest + +from chebai.cli import ChebaiCLI + + +class TestChebaiCLI(unittest.TestCase): + def setUp(self): + self.cli_args = [ + "fit", + "--trainer=configs/training/default_trainer.yml", + "--model=configs/model/ffn.yml", + "--model.init_args.hidden_layers=[10]", + "--model.train_metrics=configs/metrics/micro-macro-f1.yml", + "--data=tests/unit/cli/mock_dm_config.yml", + "--model.pass_loss_kwargs=false", + "--trainer.min_epochs=1", + "--trainer.max_epochs=1", + "--model.criterion=configs/loss/bce.yml", + "--model.criterion.init_args.beta=0.99", + ] + + def test_mlp_on_chebai_cli(self): + # Instantiate ChebaiCLI and ensure no exceptions are raised + try: + ChebaiCLI( + args=self.cli_args, + save_config_kwargs={"config_filename": "lightning_config.yaml"}, + parser_kwargs={"parser_mode": "omegaconf"}, + ) + except Exception as e: + self.fail(f"ChebaiCLI raised an unexpected exception: {e}") + + +if __name__ == "__main__": + unittest.main()