11from neuralnetlib .metrics import Metric
2+ import numpy as np
23
34
45class Callback :
5- def on_train_begin (self , logs = None ):
6+ def on_train_begin (self , logs : dict | None = None ) -> None :
67 pass
78
8- def on_train_end (self , logs = None ):
9+ def on_train_end (self , logs : dict | None = None ) -> None :
910 pass
1011
11- def on_epoch_begin (self , epoch , logs = None ):
12+ def on_epoch_begin (self , epoch : int , logs : dict | None = None ) -> None :
1213 pass
1314
14- def on_epoch_end (self , epoch , logs = None ):
15+ def on_epoch_end (self , epoch : int , logs : dict | None = None ) -> None :
1516 pass
1617
17- def on_batch_begin (self , batch , logs = None ):
18+ def on_batch_begin (self , batch : int , logs : dict | None = None ) -> None :
1819 pass
1920
20- def on_batch_end (self , batch , logs = None ):
21+ def on_batch_end (self , batch : int , logs : dict | None = None ) -> None :
2122 pass
2223
2324
2425class EarlyStopping (Callback ):
25- def __init__ (self , patience : int = 5 , min_delta : float = 0.001 , restore_best_weights : bool = True ,
26- start_from_epoch : int = 0 , monitor : str = 'loss' , mode : str = 'auto' , baseline : float = None ):
26+ def __init__ (self , patience : int = 5 , min_delta : float = 0.001 , restore_best_weights : bool = True ,start_from_epoch : int = 0 , monitor : str = 'loss' , mode : str = 'auto' , baseline : float | None = None ) -> None :
2727 super ().__init__ ()
28- self .patience = patience
29- self .min_delta = min_delta
30- self .restore_best_weights = restore_best_weights
31- self .start_from_epoch = start_from_epoch
32- self .monitor = Metric (monitor ) if monitor != 'loss' else 'loss'
33- self .mode = mode
34- self .baseline = baseline
35- self .best_weights = None
36- self .best_metric = None
37- self .patience_counter = 0
38- self .stop_training = False
39-
40- def on_train_begin (self , logs = None ):
28+ self .patience : int = patience
29+ self .min_delta : float = min_delta
30+ self .restore_best_weights : bool = restore_best_weights
31+ self .start_from_epoch : int = start_from_epoch
32+ self .monitor : Metric | str = Metric (monitor ) if monitor != 'loss' else 'loss'
33+ self .mode : str = mode
34+ self .baseline : float | None = baseline
35+ self .best_weights : list | None = None
36+ self .best_metric : float | None = None
37+ self .patience_counter : int = 0
38+ self .stop_training : bool = False
39+
40+ def on_train_begin (self , logs : dict | None = None ) -> None :
4141 self .patience_counter = 0
4242 self .best_metric = None
4343 self .stop_training = False
4444
45- def on_epoch_end (self , epoch , logs = None ):
45+ def on_epoch_end (self , epoch : int , logs : dict | None = None ) -> bool :
4646 logs = logs or {}
4747 model = logs .get ('model' )
4848 if epoch < self .start_from_epoch or model is None :
@@ -72,14 +72,14 @@ def on_epoch_end(self, epoch, logs=None):
7272 self .stop_training = True
7373 if self .restore_best_weights and self .best_weights is not None :
7474 for layer , best_weights in zip ([layer for layer in model .layers if hasattr (layer , 'weights' )],
75- self .best_weights ):
75+ self .best_weights ):
7676 layer .weights = best_weights
7777 print (f"\n Early stopping triggered after epoch { epoch + 1 } " )
7878 return True
7979
8080 return False
8181
82- def _get_monitor_value (self , logs ) :
82+ def _get_monitor_value (self , logs : dict ) -> float :
8383 logs = logs or {}
8484 if isinstance (self .monitor , Metric ):
8585 monitor_value = logs .get (self .monitor .name )
@@ -93,5 +93,5 @@ def _get_monitor_value(self, logs):
9393 monitor_value = logs
9494 else :
9595 raise ValueError (f"Monitored metric '{ self .monitor } ' is not available. "
96- f"Available metrics are: { ',' .join (logs .keys ())} " )
97- return monitor_value
96+ f"Available metrics are: { ',' .join (logs .keys ())} " )
97+ return float ( monitor_value )
0 commit comments