Skip to content

Commit f2d62fd

Browse files
authored
Merge pull request #99 from ChEB-AI/feature/pyproject.toml
Upgrade to `pyproject.toml` file
2 parents 258e65a + 35444f3 commit f2d62fd

40 files changed

+484
-458
lines changed

.github/workflows/black.yml

Lines changed: 0 additions & 10 deletions
This file was deleted.

.github/workflows/lint.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: Lint
2+
3+
on: [push, pull_request]
4+
5+
jobs:
6+
lint:
7+
runs-on: ubuntu-latest
8+
9+
steps:
10+
- uses: actions/checkout@v2
11+
12+
- name: Set up Python
13+
uses: actions/setup-python@v4
14+
with:
15+
python-version: '3.10' # or any version your project uses
16+
17+
- name: Install dependencies
18+
run: |
19+
python -m pip install --upgrade pip
20+
pip install black==25.1.0 ruff==0.12.2
21+
22+
- name: Run Black
23+
run: black --check .
24+
25+
- name: Run Ruff (no formatting)
26+
run: ruff check . --no-fix

.pre-commit-config.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/psf/black
3-
rev: "24.2.0"
3+
rev: "25.1.0"
44
hooks:
55
- id: black
66
- id: black-jupyter # for formatting jupyter-notebook
@@ -23,3 +23,9 @@ repos:
2323
- id: check-yaml
2424
- id: end-of-file-fixer
2525
- id: trailing-whitespace
26+
27+
- repo: https://github.com/astral-sh/ruff-pre-commit
28+
rev: v0.12.2
29+
hooks:
30+
- id: ruff
31+
args: [--fix]

README.md

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@ The library emphasizes the incorporation of the semantic qualities of the ontolo
55

66
## Installation
77

8-
To install ChEBai, follow these steps:
8+
You can install ChEBai via pip:
9+
```
10+
pip install chebai
11+
```
12+
13+
Alternatively, you can get the latest development version directly from GitHub:
914

1015
1. Clone the repository:
1116
```
@@ -16,8 +21,26 @@ git clone https://github.com/ChEB-AI/python-chebai.git
1621

1722
```
1823
cd python-chebai
19-
pip install .
24+
pip install -e .
25+
```
26+
27+
Some packages are not installed by default:
28+
```
29+
pip install chebai[dev]
30+
```
31+
installs additional packages useful to people who want to contribute to the library.
32+
```
33+
pip install chebai[plot]
34+
```
35+
installs additional packages useful for plotting and visualisation.
36+
```
37+
pip install chebai[wandb]
38+
```
39+
installs the [Weights & Biases](https://wandb.ai) integration for automated logging of training runs.
40+
```
41+
pip install chebai[all]
2042
```
43+
installs all optional dependencies.
2144

2245
## Usage
2346

chebai/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def write_on_epoch_end(
8080
else:
8181
labels = [None for _ in idents]
8282
output = torch.sigmoid(p["output"]["logits"]).tolist()
83-
for i, l, o in zip(idents, labels, output):
83+
for i, l, o in zip(idents, labels, output): # noqa: E741
8484
pred_list.append(dict(ident=i, labels=l, predictions=o))
8585
with open(os.path.join(self.output_dir, self.target_file), "wt") as fout:
8686
json.dump(pred_list, fout)

chebai/callbacks/model_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __resolve_ckpt_dir(self, trainer: Trainer) -> _PATH:
6868
Returns:
6969
_PATH: The resolved checkpoint directory path.
7070
"""
71-
rank_zero_info(f"Resolving checkpoint dir (custom)")
71+
rank_zero_info("Resolving checkpoint dir (custom)")
7272
if self.dirpath is not None:
7373
# short circuit if dirpath was passed to ModelCheckpoint
7474
return self.dirpath

chebai/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def cli():
9191
"""
9292
Main function to instantiate and run the ChebaiCLI.
9393
"""
94-
r = ChebaiCLI(
94+
ChebaiCLI(
9595
save_config_kwargs={"config_filename": "lightning_config.yaml"},
9696
parser_kwargs={"parser_mode": "omegaconf"},
9797
)

chebai/loss/bce_weighted.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
from typing import Optional
33

4-
import pandas as pd
54
import torch
65

76
from chebai.preprocessing.datasets.base import XYBaseDataModule

chebai/loss/semantic.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple:
229229
return total_loss.mean(), loss_components
230230

231231
def _calculate_implication_loss(
232-
self, l: torch.Tensor, r: torch.Tensor
232+
self, l_: torch.Tensor, r: torch.Tensor
233233
) -> torch.Tensor:
234234
"""
235235
Calculate implication loss based on T-norm and other parameters.
@@ -241,17 +241,17 @@ def _calculate_implication_loss(
241241
Returns:
242242
torch.Tensor: Calculated implication loss.
243243
"""
244-
assert not l.isnan().any(), (
245-
f"l contains NaN values - l.shape: {l.shape}, l.isnan().sum(): {l.isnan().sum()}, "
246-
f"l: {l}"
244+
assert not l_.isnan().any(), (
245+
f"l contains NaN values - l.shape: {l_.shape}, l.isnan().sum(): {l_.isnan().sum()}, "
246+
f"l: {l_}"
247247
)
248248
assert not r.isnan().any(), (
249249
f"r contains NaN values - r.shape: {r.shape}, r.isnan().sum(): {r.isnan().sum()}, "
250250
f"r: {r}"
251251
)
252252
if self.pos_scalar != 1:
253-
l = (
254-
torch.pow(l + self.eps, 1 / self.pos_scalar)
253+
l_ = (
254+
torch.pow(l_ + self.eps, 1 / self.pos_scalar)
255255
- math.pow(self.eps, 1 / self.pos_scalar)
256256
) / (
257257
math.pow(1 + self.eps, 1 / self.pos_scalar)
@@ -269,21 +269,21 @@ def _calculate_implication_loss(
269269
# for each implication I, calculate 1 - I(l, 1-one_min_r)
270270
# for S-implications, this is equivalent to the t-norm
271271
if self.fuzzy_implication in ["reichenbach", "rc"]:
272-
individual_loss = l * one_min_r
272+
individual_loss = l_ * one_min_r
273273
# xu19 (from Xu et al., 2019: Semantic loss) is not a fuzzy implication, but behaves similar to the Reichenbach
274274
# implication
275275
elif self.fuzzy_implication == "xu19":
276-
individual_loss = -torch.log(1 - l * one_min_r)
276+
individual_loss = -torch.log(1 - l_ * one_min_r)
277277
elif self.fuzzy_implication in ["lukasiewicz", "lk"]:
278-
individual_loss = torch.relu(l + one_min_r - 1)
278+
individual_loss = torch.relu(l_ + one_min_r - 1)
279279
elif self.fuzzy_implication in ["kleene_dienes", "kd"]:
280-
individual_loss = torch.min(l, 1 - r)
280+
individual_loss = torch.min(l_, 1 - r)
281281
elif self.fuzzy_implication in ["goedel", "g"]:
282-
individual_loss = torch.where(l <= r, 0, one_min_r)
282+
individual_loss = torch.where(l_ <= r, 0, one_min_r)
283283
elif self.fuzzy_implication in ["reverse-goedel", "rg"]:
284-
individual_loss = torch.where(l <= r, 0, l)
284+
individual_loss = torch.where(l_ <= r, 0, l_)
285285
elif self.fuzzy_implication in ["binary", "b"]:
286-
individual_loss = torch.where(l <= r, 0, 1).to(dtype=l.dtype)
286+
individual_loss = torch.where(l_ <= r, 0, 1).to(dtype=l_.dtype)
287287
else:
288288
raise NotImplementedError(
289289
f"Unknown fuzzy implication {self.fuzzy_implication}"
@@ -453,8 +453,8 @@ def _build_implication_filter(label_names: List, hierarchy: dict) -> torch.Tenso
453453

454454
def _build_dense_filter(sparse_filter: torch.Tensor, n_labels: int) -> torch.Tensor:
455455
res = torch.zeros((n_labels, n_labels), dtype=torch.bool)
456-
for l, r in sparse_filter:
457-
res[l, r] = True
456+
for l_, r in sparse_filter:
457+
res[l_, r] = True
458458
return res
459459

460460

@@ -511,8 +511,8 @@ def _build_disjointness_filter(
511511
random_labels = torch.randint(0, 2, (10, 997))
512512
for agg in ["sum", "max", "mean", "log-mean"]:
513513
loss.violations_per_cls_aggregator = agg
514-
l = loss(random_preds, random_labels)
515-
print(f"Loss with {agg} aggregation for random input:", l)
514+
l_ = loss(random_preds, random_labels)
515+
print(f"Loss with {agg} aggregation for random input:", l_)
516516

517517
# simplified example for ontology with 4 classes, A -> B, B -> C, D -> C, B and D disjoint
518518
loss.implication_filter_l = torch.tensor(
@@ -528,5 +528,5 @@ def _build_disjointness_filter(
528528
labels = [[0, 1, 1, 0], [0, 0, 1, 1]]
529529
for agg in ["sum", "max", "mean", "log-mean"]:
530530
loss.violations_per_cls_aggregator = agg
531-
l = loss(preds, torch.tensor(labels))
532-
print(f"Loss with {agg} aggregation for simple input:", l)
531+
l_ = loss(preds, torch.tensor(labels))
532+
print(f"Loss with {agg} aggregation for simple input:", l_)

chebai/models/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
1-
from chebai.models.base import *
2-
from chebai.models.electra import *
1+
from chebai.models.base import ChebaiBaseNet
2+
from chebai.models.electra import Electra, ElectraPre
3+
4+
__all__ = ["ChebaiBaseNet", "Electra", "ElectraPre"]

0 commit comments

Comments
 (0)