@@ -21,7 +21,7 @@ class BCEWeighted(torch.nn.BCEWithLogitsLoss):
2121
2222 def __init__ (
2323 self ,
24- beta : Optional [ float ] = None ,
24+ beta : float = 0.99 ,
2525 data_extractor : Optional [XYBaseDataModule ] = None ,
2626 ** kwargs ,
2727 ):
@@ -32,15 +32,20 @@ def __init__(
3232 data_extractor = data_extractor .labeled
3333 self .data_extractor = data_extractor
3434
35- assert self .beta is not None and self .data_extractor is not None , (
36- f"Beta parameter must be provided along with data_extractor, "
37- f"if this loss class ({ self .__class__ .__name__ } ) is used."
35+ assert isinstance (beta , float ) and beta > 0.0 , (
36+ f"Beta parameter must be a float with value greater than 0.0, for loss class { self .__class__ .__name__ } ."
37+ )
38+
39+ assert self .data_extractor is not None , (
40+ f"Data extractor must be provided if this loss class ({ self .__class__ .__name__ } ) is used."
3841 )
3942
4043 assert all (
4144 os .path .exists (os .path .join (self .data_extractor .processed_dir , file_name ))
4245 for file_name in self .data_extractor .processed_file_names
43- ), "Dataset files not found. Make sure the dataset is processed before using this loss."
46+ ), (
47+ "Dataset files not found. Make sure the dataset is processed before using this loss."
48+ )
4449
4550 assert (
4651 isinstance (self .data_extractor , _ChEBIDataExtractor )
0 commit comments