We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e200a3b commit 10ae3e2Copy full SHA for 10ae3e2
apps/train.py
@@ -1,15 +1,16 @@
1
import torch
2
3
+from flashmodels import Builder, Trainer, accelerate, arguments
4
+
5
6
def train():
7
torch.manual_seed(101)
8
9
# parse args
- from flashmodels import arguments
10
args = arguments.parse()
11
12
# build model, tokenizer, loader, optimizer and lr_scheduler
13
# and use accelerator to speed up training
- from flashmodels import Builder, Trainer, accelerate
14
builder = Builder(args)
15
model, loader, tokenizer = builder.build_model_dataloader()
16
model, loader = accelerate(model, loader, args)
0 commit comments