Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ jobs:
python -m pip install --upgrade pip
python -m pip install --upgrade pip setuptools wheel
python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
python -m pip install -e .
# wandb included to use custom trainer for cli test which needs wandb logger
python -m pip install -e .[wandb]

- name: Display Python & Installed Packages
run: |
Expand Down
3 changes: 1 addition & 2 deletions chebai/models/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ class FFN(ChebaiBaseNet):

def __init__(
self,
input_size: int,
hidden_layers: List[int] = [
1024,
],
Expand All @@ -20,7 +19,7 @@ def __init__(
super().__init__(**kwargs)

layers = []
current_layer_input_size = input_size
current_layer_input_size = self.input_dim
for hidden_dim in hidden_layers:
layers.append(MLPBlock(current_layer_input_size, hidden_dim))
layers.append(Residual(MLPBlock(hidden_dim, hidden_dim)))
Expand Down
46 changes: 46 additions & 0 deletions chebai/preprocessing/datasets/mock_dm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
from lightning.pytorch.core.datamodule import LightningDataModule
from torch.utils.data import DataLoader

from chebai.preprocessing.collate import RaggedCollator


class MyLightningDataModule(LightningDataModule):
def __init__(self):
super().__init__()
self._num_of_labels = None
self._feature_vector_size = None
self.collator = RaggedCollator()

def prepare_data(self):
pass

def setup(self, stage=None):
self._num_of_labels = 10
self._feature_vector_size = 20
print(f"Number of labels: {self._num_of_labels}")
print(f"Number of features: {self._feature_vector_size}")

@property
def num_of_labels(self):
return self._num_of_labels

@property
def feature_vector_size(self):
return self._feature_vector_size

def train_dataloader(self):
assert self.feature_vector_size is not None, "feature_vector_size must be set"
# Dummy dataset for example purposes

datalist = [
{
"features": torch.randn(self._feature_vector_size),
"labels": torch.randint(0, 2, (self._num_of_labels,), dtype=torch.bool),
"ident": i,
"group": None,
}
for i in range(100)
]

return DataLoader(datalist, batch_size=32, collate_fn=self.collator)
1 change: 0 additions & 1 deletion configs/model/ffn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@ class_path: chebai.models.ffn.FFN
init_args:
optimizer_kwargs:
lr: 1e-3
input_size: 2560
Empty file added tests/unit/cli/__init__.py
Empty file.
1 change: 1 addition & 0 deletions tests/unit/cli/mock_dm_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
class_path: chebai.preprocessing.datasets.mock_dm.MyLightningDataModule
35 changes: 35 additions & 0 deletions tests/unit/cli/testCLI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import unittest

from chebai.cli import ChebaiCLI


class TestChebaiCLI(unittest.TestCase):
def setUp(self):
self.cli_args = [
"fit",
"--trainer=configs/training/default_trainer.yml",
"--model=configs/model/ffn.yml",
"--model.init_args.hidden_layers=[10]",
"--model.train_metrics=configs/metrics/micro-macro-f1.yml",
"--data=tests/unit/cli/mock_dm_config.yml",
"--model.pass_loss_kwargs=false",
"--trainer.min_epochs=1",
"--trainer.max_epochs=1",
"--model.criterion=configs/loss/bce.yml",
"--model.criterion.init_args.beta=0.99",
]

def test_mlp_on_chebai_cli(self):
# Instantiate ChebaiCLI and ensure no exceptions are raised
try:
ChebaiCLI(
args=self.cli_args,
save_config_kwargs={"config_filename": "lightning_config.yaml"},
parser_kwargs={"parser_mode": "omegaconf"},
)
except Exception as e:
self.fail(f"ChebaiCLI raised an unexpected exception: {e}")


if __name__ == "__main__":
unittest.main()
Loading