Skip to content

Commit 244a6f2

Browse files
committed
docs: add method typing
1 parent 844c107 commit 244a6f2

File tree

6 files changed

+118
-113
lines changed

6 files changed

+118
-113
lines changed

neuralnetlib/callbacks.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,48 @@
11
from neuralnetlib.metrics import Metric
2+
import numpy as np
23

34

45
class Callback:
5-
def on_train_begin(self, logs=None):
6+
def on_train_begin(self, logs: dict | None = None) -> None:
67
pass
78

8-
def on_train_end(self, logs=None):
9+
def on_train_end(self, logs: dict | None = None) -> None:
910
pass
1011

11-
def on_epoch_begin(self, epoch, logs=None):
12+
def on_epoch_begin(self, epoch: int, logs: dict | None = None) -> None:
1213
pass
1314

14-
def on_epoch_end(self, epoch, logs=None):
15+
def on_epoch_end(self, epoch: int, logs: dict | None = None) -> None:
1516
pass
1617

17-
def on_batch_begin(self, batch, logs=None):
18+
def on_batch_begin(self, batch: int, logs: dict | None = None) -> None:
1819
pass
1920

20-
def on_batch_end(self, batch, logs=None):
21+
def on_batch_end(self, batch: int, logs: dict | None = None) -> None:
2122
pass
2223

2324

2425
class EarlyStopping(Callback):
25-
def __init__(self, patience: int = 5, min_delta: float = 0.001, restore_best_weights: bool = True,
26-
start_from_epoch: int = 0, monitor: str = 'loss', mode: str = 'auto', baseline: float = None):
26+
def __init__(self, patience: int = 5, min_delta: float = 0.001, restore_best_weights: bool = True,start_from_epoch: int = 0, monitor: str = 'loss', mode: str = 'auto', baseline: float | None = None) -> None:
2727
super().__init__()
28-
self.patience = patience
29-
self.min_delta = min_delta
30-
self.restore_best_weights = restore_best_weights
31-
self.start_from_epoch = start_from_epoch
32-
self.monitor = Metric(monitor) if monitor != 'loss' else 'loss'
33-
self.mode = mode
34-
self.baseline = baseline
35-
self.best_weights = None
36-
self.best_metric = None
37-
self.patience_counter = 0
38-
self.stop_training = False
39-
40-
def on_train_begin(self, logs=None):
28+
self.patience: int = patience
29+
self.min_delta: float = min_delta
30+
self.restore_best_weights: bool = restore_best_weights
31+
self.start_from_epoch: int = start_from_epoch
32+
self.monitor: Metric | str = Metric(monitor) if monitor != 'loss' else 'loss'
33+
self.mode: str = mode
34+
self.baseline: float | None = baseline
35+
self.best_weights: list | None = None
36+
self.best_metric: float | None = None
37+
self.patience_counter: int = 0
38+
self.stop_training: bool = False
39+
40+
def on_train_begin(self, logs: dict | None = None) -> None:
4141
self.patience_counter = 0
4242
self.best_metric = None
4343
self.stop_training = False
4444

45-
def on_epoch_end(self, epoch, logs=None):
45+
def on_epoch_end(self, epoch: int, logs: dict | None = None) -> bool:
4646
logs = logs or {}
4747
model = logs.get('model')
4848
if epoch < self.start_from_epoch or model is None:
@@ -72,14 +72,14 @@ def on_epoch_end(self, epoch, logs=None):
7272
self.stop_training = True
7373
if self.restore_best_weights and self.best_weights is not None:
7474
for layer, best_weights in zip([layer for layer in model.layers if hasattr(layer, 'weights')],
75-
self.best_weights):
75+
self.best_weights):
7676
layer.weights = best_weights
7777
print(f"\nEarly stopping triggered after epoch {epoch + 1}")
7878
return True
7979

8080
return False
8181

82-
def _get_monitor_value(self, logs):
82+
def _get_monitor_value(self, logs: dict) -> float:
8383
logs = logs or {}
8484
if isinstance(self.monitor, Metric):
8585
monitor_value = logs.get(self.monitor.name)
@@ -93,5 +93,5 @@ def _get_monitor_value(self, logs):
9393
monitor_value = logs
9494
else:
9595
raise ValueError(f"Monitored metric '{self.monitor}' is not available. "
96-
f"Available metrics are: {','.join(logs.keys())}")
97-
return monitor_value
96+
f"Available metrics are: {','.join(logs.keys())}")
97+
return float(monitor_value)

0 commit comments

Comments
 (0)