1616from typing import Optional , Union
1717
1818import torch
19+ import torch .nn as nn
1920from accelerate import Accelerator
21+ from torch .optim .lr_scheduler import LRScheduler
22+ from torch .optim .optimizer import Optimizer
2023from typing_extensions import Any , override
2124
2225from .states import TrainingState
@@ -68,15 +71,16 @@ class AcceleratorModule(ABC):
6871 `torch.nn.Module`.
6972 """
7073
74+ accelerator : Accelerator = None
75+ state : TrainingState = None
76+ device : torch .device = None
7177 _implemented_collate_fn_train = False
7278 _implemented_collate_fn_val = False
73- _accelerator : Accelerator = None
74- _log_every : int = 1
7579 _extended = False
76- state : TrainingState = None
77- device : torch . device = None
78- status_dict : dict = None
79- batch_size : Union [ int , tuple [ int , int ]] = None
80+ model : nn . Module = None
81+ teacher : Optional [ nn . Module ] = None
82+ optimizer : Optimizer = None
83+ scheduler : LRScheduler = None
8084
8185 @override
8286 def forward (self , * args : Any , ** kwargs : Any ) -> torch .Tensor :
@@ -87,14 +91,14 @@ def training_step(self, batch: Any) -> torch.Tensor:
8791 """Defines the training logic. Must return a loss tensor (scalar)."""
8892
8993 @override
90- def validation_step (self , batch : Any ) -> dict :
94+ def validation_step (self , key : str , batch : Any ) -> dict :
9195 """
9296 Defines the validation logic. Must return a dictionary containing
9397 each metric with predictions and targets, and also the loss value in the dictionary.
9498
9599 Example:
96100 ```
97- # format is ==> "metric": (predictions, targets)
101+ # format is ==> "metric": (predictions, targets, ... )
98102 return {
99103 "loss": validation_loss_tensor, # (scalar tensor)
100104 # with additional metrics:
@@ -129,18 +133,23 @@ def get_validation_dataloader(self, *args: Any, **kwargs: Any) -> Any:
129133 """Defines a custom PyTorch DataLoader class for validation."""
130134
131135 def log (self , values : dict , log_kwargs : dict | None = {}):
132- if self ._accelerator .is_main_process :
136+ if self .accelerator .is_main_process :
133137 train_or_eval = "global_step" if self .model .training else "eval_global_step"
134138 if (self .status_dict [train_or_eval ] + 1 ) % self ._log_every == 0 :
135- self ._accelerator .log (values , step = self .status_dict [train_or_eval ], log_kwargs = log_kwargs )
139+ self .accelerator .log (values , step = self .status_dict [train_or_eval ], log_kwargs = log_kwargs )
136140
137141 def __init_subclass__ (cls , ** kwargs ):
142+ # check training step and validation_step functions
138143 if (
139144 cls .training_step == AcceleratorModule .training_step
140145 and cls .validation_step == AcceleratorModule .validation_step
141146 ):
142- raise TypeError ("Subclasses of 'Trainer' must override 'training_step' and/or 'validation_step' methods." )
147+ raise RuntimeError (
148+ "Subclasses of 'Trainer' must override 'training_step' and 'validation_step' "
149+ "(if evaluation is available)."
150+ )
143151
152+ # check collate functions
144153 if cls .collate_fn_train != AcceleratorModule .collate_fn_train :
145154 cls ._implemented_collate_fn_train = True
146155
@@ -244,18 +253,18 @@ def backward(self, loss: torch.Tensor, **kwargs):
244253 `kwargs` (`Any`):
245254 Extra arguments to be passed to 'accelerator.backward' function.
246255 """
247- self ._accelerator .backward (loss , ** kwargs )
256+ self .accelerator .backward (loss , ** kwargs )
248257
249258 def step_optimizer (self ):
250- self .state . optimizer .step ()
259+ self .optimizer .step ()
251260
252261 def step_scheduler (self ):
253- self .state . scheduler .step ()
262+ self .scheduler .step ()
254263
255264 def step (self ):
256265 """Step optimizer and scheduler (in that order). If there is no scheduler, it will be ignored."""
257266 self .step_optimizer ()
258- if self .state . scheduler is not None :
267+ if self .scheduler is not None :
259268 self .step_scheduler ()
260269
261270 def zero_grad (self , set_to_none : bool = True ):
@@ -266,8 +275,12 @@ def zero_grad(self, set_to_none: bool = True):
266275 `set_to_none` (`bool`, *optional*, defaults to `True`):
267276 Set gradients to `None` instead of `0`.
268277 """
269- self .state . optimizer .zero_grad (set_to_none = set_to_none )
278+ self .optimizer .zero_grad (set_to_none = set_to_none )
270279
271280 @override
272281 def training_step (self , batch : Any ):
273282 pass
283+
284+ def __init_subclass__ (cls , ** kwargs ):
285+ # No call to super(), so it suppresses the behavior.
286+ pass
0 commit comments