Skip to content

Commit 60263b6

Browse files
authored
Merge pull request #14 from ghanvert/full-rewrite
Full rewrite
2 parents cffe8e7 + 808209e commit 60263b6

19 files changed

+1400
-1067
lines changed

examples/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def validation_step(self, batch):
8787
batch_size=(2, 1, 1),
8888
# 'batch_size' can be an integer value or tuple, where elements are: (train_batch_size, val_batch_size).
8989
# Use an integer value to set the batch size equally for both sets.
90-
optim=Optimizer.AdamW,
90+
optimizer=Optimizer.AdamW,
9191
optim_kwargs={"lr": 0.001, "weight_decay": 0.01}, # Optimizer arguments as dictionary
9292
scheduler=Scheduler.LinearWithWarmup,
9393
scheduler_kwargs={"warmup_ratio": 0.03}, # Scheduler arguments as dictionary

examples/train_without_comments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def validation_step(self, batch):
5959
hps_config=HyperParameters(
6060
epochs=2,
6161
batch_size=(2, 1, 1),
62-
optim=Optimizer.AdamW,
62+
optimizer=Optimizer.AdamW,
6363
optim_kwargs={"lr": 0.001, "weight_decay": 0.01},
6464
scheduler=Scheduler.LinearWithWarmup,
6565
scheduler_kwargs={"warmup_ratio": 0.03},

setup.cfg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name = accmt
3-
version = 1.7.7
3+
version = 1.8.0
44
author = ghanvert
55
author_email = [email protected]
66
description = Accelerator Module and Trainer based on Accelerate library for simple distributed train processes, inspired by PyTorch Lightning.
@@ -24,7 +24,7 @@ install_requires =
2424
accelerate
2525
numpy
2626
PyYAML
27-
accmt-cli
27+
accmt-cli>=1.4.5
2828
pympler
2929
numba
3030

src/accmt/__init__.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,13 @@
2323
from .collate_fns import DataCollatorForLanguageModeling, DataCollatorForLongestSequence, DataCollatorForSeq2Seq
2424
from .dataloader_samplers import TemperatureSampler
2525
from .decorators import on_last_process, on_local_main_process, on_local_process, on_main_process, on_process
26-
from .handlers import Handler
2726
from .hyperparameters import HyperParameters, Optimizer, Scheduler
2827
from .modules import AcceleratorModule, ExtendedAcceleratorModule
2928
from .monitor import Monitor
3029
from .tracker import Aim, ClearML, CometML, DVCLive, MLFlow, TensorBoard, WandB
31-
from .trainer import Trainer, set_seed
32-
from .utility import prepare, prepare_array, prepare_dataframe
33-
from .utils import _precision_map
30+
from .trainer import Trainer
31+
from .utility import IS_CPU, IS_GPU, prepare, prepare_array, prepare_dataframe
32+
from .utils import _precision_map, get_seed, set_seed
3433

3534

3635
def allow_tf32(flag=True):
@@ -40,9 +39,6 @@ def allow_tf32(flag=True):
4039

4140
allow_tf32()
4241

43-
IS_CPU = bool(int(os.environ.get("ACCMT_CPU", 0)))
44-
IS_GPU = not IS_CPU
45-
4642
_init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=86400))
4743
_dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True)
4844
accelerator = Accelerator(

src/accmt/callbacks.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from abc import ABC
16+
from dataclasses import dataclass
1617

1718
import torch
1819
from torch.optim import Optimizer
@@ -23,6 +24,7 @@
2324
from .states import TrainingState
2425

2526

27+
@dataclass
2628
class Callback(ABC):
2729
"""
2830
Callback module containing different callback functions for different
@@ -38,7 +40,7 @@ class Callback(ABC):
3840
trainer (`Trainer`):
3941
Defined `Trainer` class.
4042
state (`TrainingState`):
41-
Module's `TrainingState` class.
43+
Reference to `TrainingState` class.
4244
4345
Methods:
4446
on_fit_start (*optional*):
@@ -274,3 +276,90 @@ def on_evaluation_start(self):
274276
@override
275277
def on_evaluation_end(self):
276278
"""Callback when evaluation ends."""
279+
280+
281+
# TODO there is a better way to do this, using a decorator like @register_callback("on_fit_start"), but
282+
# we'll implement that (probably) before release of version 2.0.
283+
@dataclass
284+
class CallbackMaster:
285+
children: list[Callback]
286+
287+
def on_fit_start(self):
288+
for child in self.children:
289+
child.on_fit_start()
290+
291+
def on_fit_end(self):
292+
for child in self.children:
293+
child.on_fit_end()
294+
295+
def on_before_backward(self, loss: torch.Tensor):
296+
for child in self.children:
297+
child.on_before_backward(loss)
298+
299+
def on_after_backward(self):
300+
for child in self.children:
301+
child.on_after_backward()
302+
303+
def on_before_optimizer_step(self, optimizer: Optimizer):
304+
for child in self.children:
305+
child.on_before_optimizer_step(optimizer)
306+
307+
def on_after_optimizer_step(self, optimizer: Optimizer):
308+
for child in self.children:
309+
child.on_after_optimizer_step(optimizer)
310+
311+
def on_before_scheduler_step(self, scheduler: LRScheduler):
312+
for child in self.children:
313+
child.on_before_scheduler_step(scheduler)
314+
315+
def on_after_scheduler_step(self, scheduler: LRScheduler):
316+
for child in self.children:
317+
child.on_after_scheduler_step(scheduler)
318+
319+
def on_before_zero_grad(self, optimizer: Optimizer):
320+
for child in self.children:
321+
child.on_before_zero_grad(optimizer)
322+
323+
def on_after_zero_grad(self, optimizer: Optimizer):
324+
for child in self.children:
325+
child.on_after_zero_grad(optimizer)
326+
327+
def on_resume(self):
328+
for child in self.children:
329+
child.on_resume()
330+
331+
def on_save_checkpoint(self):
332+
for child in self.children:
333+
child.on_save_checkpoint()
334+
335+
def on_before_training_step(self, batch: Any):
336+
for child in self.children:
337+
child.on_before_training_step(batch)
338+
339+
def on_after_training_step(self):
340+
for child in self.children:
341+
child.on_after_training_step()
342+
343+
def on_before_validation_step(self, batch: Any):
344+
for child in self.children:
345+
child.on_before_validation_step(batch)
346+
347+
def on_after_validation_step(self):
348+
for child in self.children:
349+
child.on_after_validation_step()
350+
351+
def on_epoch_start(self):
352+
for child in self.children:
353+
child.on_epoch_start()
354+
355+
def on_epoch_end(self):
356+
for child in self.children:
357+
child.on_epoch_end()
358+
359+
def on_evaluation_start(self):
360+
for child in self.children:
361+
child.on_evaluation_start()
362+
363+
def on_evaluation_end(self):
364+
for child in self.children:
365+
child.on_evaluation_end()

src/accmt/dist_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,17 @@
1818
import torch.distributed as dist
1919
import torch.nn.functional as F
2020

21-
from .utility import WORLD_SIZE
21+
from .utility import RANK, WORLD_SIZE
22+
from .utils import time_prefix
23+
24+
25+
def rprint(*args, rank: int = 0, add_time_prefix: bool = True, **kwargs):
26+
"""Print on a specific rank (default is main process)."""
27+
if rank == RANK:
28+
if add_time_prefix:
29+
print("\n", f"{time_prefix()} ", *args, **kwargs, sep="")
30+
else:
31+
print("\n", *args, **kwargs, sep="")
2232

2333

2434
def pad_to(tensor: torch.Tensor, maximum: int) -> tuple[torch.Tensor, torch.Tensor]:

src/accmt/hyperparameters.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,14 @@ def __init__(
9393
self,
9494
epochs: int = 1,
9595
batch_size: Union[int, tuple[int]] = 1,
96-
optim: Union[str, Optimizer] = "SGD",
96+
optimizer: Union[str, Optimizer] = "SGD",
9797
optim_kwargs: Optional[dict] = None,
9898
scheduler: Optional[Union[str, Scheduler]] = None,
9999
scheduler_kwargs: Optional[dict] = None,
100100
):
101101
self.epochs = epochs
102102
self.batch_size = batch_size
103-
self.optim = getattr(Optimizer, optim) if isinstance(optim, str) else optim
103+
self.optimizer = getattr(Optimizer, optimizer) if isinstance(optimizer, str) else optimizer
104104
self._fix_kwargs(optim_kwargs)
105105
self.optim_kwargs = optim_kwargs if optim_kwargs is not None else {}
106106
self.scheduler = getattr(Scheduler, scheduler) if isinstance(scheduler, str) else scheduler
@@ -114,10 +114,10 @@ def from_config(cls, config: Union[str, dict]):
114114
elif "hps" in config:
115115
config = config["hps"]
116116

117-
valid_keys = {"epochs", "batch_size", "optim", "scheduler"}
117+
valid_keys = {"epochs", "batch_size", "optimizer", "scheduler"}
118118
assert all(k in valid_keys for k in config.keys()), "You do not have valid keys. Please check documentation."
119119

120-
optimizer = config["optim"]
120+
optimizer = config["optimizer"]
121121
assert "type" in optimizer, "'type' key is required in optimizer."
122122

123123
scheduler = config["scheduler"] if "scheduler" in config else None
@@ -127,17 +127,19 @@ def from_config(cls, config: Union[str, dict]):
127127
return HyperParameters(
128128
epochs=config["epochs"],
129129
batch_size=config["batch_size"],
130-
optim=optimizer["type"],
130+
optimizer=optimizer["type"],
131131
optim_kwargs={k: v for k, v in optimizer.items() if k != "type"} if len(optimizer) > 1 else None,
132132
scheduler=scheduler["type"] if scheduler is not None else None,
133-
scheduler_kwargs={k: v for k, v in scheduler.items() if k != "type"}
134-
if scheduler is not None and len(scheduler) > 1
135-
else None,
133+
scheduler_kwargs=(
134+
{k: v for k, v in scheduler.items() if k != "type"}
135+
if scheduler is not None and len(scheduler) > 1
136+
else None
137+
),
136138
)
137139

138140
def to_dict(self) -> dict:
139-
optim = self.optim if not isinstance(self.optim, str) else getattr(Optimizer, self.optim, None)
140-
assert optim is not None, f"{optim} is not a valid optimizer."
141+
optimizer = self.optimizer if not isinstance(self.optimizer, str) else getattr(Optimizer, self.optimizer, None)
142+
assert optimizer is not None, f"{optimizer} is not a valid optimizer."
141143
scheduler = (
142144
self.scheduler if not isinstance(self.scheduler, str) else getattr(Scheduler, self.scheduler, "INVALID")
143145
)
@@ -146,7 +148,13 @@ def to_dict(self) -> dict:
146148
optim_kwargs = self.optim_kwargs if self.optim_kwargs is not None else {}
147149
schlr_kwargs = self.scheduler_kwargs if self.scheduler_kwargs is not None else {}
148150

149-
d = {"hps": {"epochs": self.epochs, "batch_size": self.batch_size, "optim": {"type": optim, **optim_kwargs}}}
151+
d = {
152+
"hps": {
153+
"epochs": self.epochs,
154+
"batch_size": self.batch_size,
155+
"optimizer": {"type": optimizer, **optim_kwargs},
156+
}
157+
}
150158

151159
if self.scheduler is not None:
152160
d["hps"]["scheduler"] = {"type": scheduler, **schlr_kwargs}
@@ -155,7 +163,7 @@ def to_dict(self) -> dict:
155163

156164
def get_config(self) -> dict:
157165
hps = self.to_dict()["hps"]
158-
_hps = {"epochs": hps["epochs"], "batch_size": hps["batch_size"], **hps["optim"]}
166+
_hps = {"epochs": hps["epochs"], "batch_size": hps["batch_size"], **hps["optimizer"]}
159167
if "type" in _hps:
160168
t = _hps["type"]
161169
_hps["optimizer"] = t if isinstance(t, str) else t.__name__

src/accmt/modules.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
from typing import Optional, Union
1717

1818
import torch
19+
import torch.nn as nn
1920
from accelerate import Accelerator
21+
from torch.optim.lr_scheduler import LRScheduler
22+
from torch.optim.optimizer import Optimizer
2023
from typing_extensions import Any, override
2124

2225
from .states import TrainingState
@@ -68,15 +71,16 @@ class AcceleratorModule(ABC):
6871
`torch.nn.Module`.
6972
"""
7073

74+
accelerator: Accelerator = None
75+
state: TrainingState = None
76+
device: torch.device = None
7177
_implemented_collate_fn_train = False
7278
_implemented_collate_fn_val = False
73-
_accelerator: Accelerator = None
74-
_log_every: int = 1
7579
_extended = False
76-
state: TrainingState = None
77-
device: torch.device = None
78-
status_dict: dict = None
79-
batch_size: Union[int, tuple[int, int]] = None
80+
model: nn.Module = None
81+
teacher: Optional[nn.Module] = None
82+
optimizer: Optimizer = None
83+
scheduler: LRScheduler = None
8084

8185
@override
8286
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
@@ -87,14 +91,14 @@ def training_step(self, batch: Any) -> torch.Tensor:
8791
"""Defines the training logic. Must return a loss tensor (scalar)."""
8892

8993
@override
90-
def validation_step(self, batch: Any) -> dict:
94+
def validation_step(self, key: str, batch: Any) -> dict:
9195
"""
9296
Defines the validation logic. Must return a dictionary containing
9397
each metric with predictions and targets, and also the loss value in the dictionary.
9498
9599
Example:
96100
```
97-
# format is ==> "metric": (predictions, targets)
101+
# format is ==> "metric": (predictions, targets, ...)
98102
return {
99103
"loss": validation_loss_tensor, # (scalar tensor)
100104
# with additional metrics:
@@ -129,18 +133,23 @@ def get_validation_dataloader(self, *args: Any, **kwargs: Any) -> Any:
129133
"""Defines a custom PyTorch DataLoader class for validation."""
130134

131135
def log(self, values: dict, log_kwargs: dict | None = {}):
132-
if self._accelerator.is_main_process:
136+
if self.accelerator.is_main_process:
133137
train_or_eval = "global_step" if self.model.training else "eval_global_step"
134138
if (self.status_dict[train_or_eval] + 1) % self._log_every == 0:
135-
self._accelerator.log(values, step=self.status_dict[train_or_eval], log_kwargs=log_kwargs)
139+
self.accelerator.log(values, step=self.status_dict[train_or_eval], log_kwargs=log_kwargs)
136140

137141
def __init_subclass__(cls, **kwargs):
142+
# check training step and validation_step functions
138143
if (
139144
cls.training_step == AcceleratorModule.training_step
140145
and cls.validation_step == AcceleratorModule.validation_step
141146
):
142-
raise TypeError("Subclasses of 'Trainer' must override 'training_step' and/or 'validation_step' methods.")
147+
raise RuntimeError(
148+
"Subclasses of 'Trainer' must override 'training_step' and 'validation_step' "
149+
"(if evaluation is available)."
150+
)
143151

152+
# check collate functions
144153
if cls.collate_fn_train != AcceleratorModule.collate_fn_train:
145154
cls._implemented_collate_fn_train = True
146155

@@ -244,18 +253,18 @@ def backward(self, loss: torch.Tensor, **kwargs):
244253
`kwargs` (`Any`):
245254
Extra arguments to be passed to 'accelerator.backward' function.
246255
"""
247-
self._accelerator.backward(loss, **kwargs)
256+
self.accelerator.backward(loss, **kwargs)
248257

249258
def step_optimizer(self):
250-
self.state.optimizer.step()
259+
self.optimizer.step()
251260

252261
def step_scheduler(self):
253-
self.state.scheduler.step()
262+
self.scheduler.step()
254263

255264
def step(self):
256265
"""Step optimizer and scheduler (in that order). If there is no scheduler, it will be ignored."""
257266
self.step_optimizer()
258-
if self.state.scheduler is not None:
267+
if self.scheduler is not None:
259268
self.step_scheduler()
260269

261270
def zero_grad(self, set_to_none: bool = True):
@@ -266,8 +275,12 @@ def zero_grad(self, set_to_none: bool = True):
266275
`set_to_none` (`bool`, *optional*, defaults to `True`):
267276
Set gradients to `None` instead of `0`.
268277
"""
269-
self.state.optimizer.zero_grad(set_to_none=set_to_none)
278+
self.optimizer.zero_grad(set_to_none=set_to_none)
270279

271280
@override
272281
def training_step(self, batch: Any):
273282
pass
283+
284+
def __init_subclass__(cls, **kwargs):
285+
# No call to super(), so it suppresses the behavior.
286+
pass

0 commit comments

Comments
 (0)