Skip to content

Commit e7ed0cd

Browse files
committed
fix beta type checks
1 parent 9844e68 commit e7ed0cd

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,7 @@ lightning_logs
176176
logs
177177
.isort.cfg
178178
/.vscode
179+
180+
*.out
181+
*.err
182+
*.sh

chebai/loss/bce_weighted.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class BCEWeighted(torch.nn.BCEWithLogitsLoss):
2121

2222
def __init__(
2323
self,
24-
beta: Optional[float] = None,
24+
beta: float = 0.99,
2525
data_extractor: Optional[XYBaseDataModule] = None,
2626
**kwargs,
2727
):
@@ -32,15 +32,20 @@ def __init__(
3232
data_extractor = data_extractor.labeled
3333
self.data_extractor = data_extractor
3434

35-
assert self.beta is not None and self.data_extractor is not None, (
36-
f"Beta parameter must be provided along with data_extractor, "
37-
f"if this loss class ({self.__class__.__name__}) is used."
35+
assert isinstance(beta, float) and beta > 0.0, (
36+
f"Beta parameter must be a float with value greater than 0.0, for loss class {self.__class__.__name__}."
37+
)
38+
39+
assert self.data_extractor is not None, (
40+
f"Data extractor must be provided if this loss class ({self.__class__.__name__}) is used."
3841
)
3942

4043
assert all(
4144
os.path.exists(os.path.join(self.data_extractor.processed_dir, file_name))
4245
for file_name in self.data_extractor.processed_file_names
43-
), "Dataset files not found. Make sure the dataset is processed before using this loss."
46+
), (
47+
"Dataset files not found. Make sure the dataset is processed before using this loss."
48+
)
4449

4550
assert (
4651
isinstance(self.data_extractor, _ChEBIDataExtractor)

0 commit comments

Comments
 (0)