-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_moco.py
More file actions
79 lines (67 loc) · 2.34 KB
/
train_moco.py
File metadata and controls
79 lines (67 loc) · 2.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from models import MoCo, ResNetEncoder
from transforms import ContrastiveTransform
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--lr", type=float, default=0.03)
parser.add_argument("--queue_size", type=int, default=4096)
parser.add_argument("--momentum", type=float, default=0.99)
parser.add_argument("--save_path", type=str, default="moco.pth")
return parser.parse_args()
def main():
args = parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
# 1) DataLoader
train_transform = ContrastiveTransform(base_size=32)
train_dataset = datasets.CIFAR10(
root="./data",
train=True,
download=True,
transform=lambda x: train_transform(x)
)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=2,
drop_last=True
)
# model, optimizer, loss
moco_model = MoCo(
base_encoder=lambda out_dim: ResNetEncoder(base="resnet18", out_dim=out_dim),
dim=128,
K=args.queue_size,
m=args.momentum,
T=0.2
).to(device)
optimizer = optim.SGD(moco_model.encoder_q.parameters(),
lr=args.lr, momentum=0.9, weight_decay=1e-4)
criterion = torch.nn.CrossEntropyLoss()
# train
for epoch in range(args.epochs):
moco_model.train()
total_loss = 0.0
for (im_q, im_k), _ in train_loader:
im_q = im_q.to(device)
im_k = im_k.to(device)
logits, labels = moco_model(im_q, im_k) # [B, 1+K], [B]
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
print(f"[MoCo] Epoch [{epoch+1}/{args.epochs}] Loss: {avg_loss:.4f}")
state = {
"encoder_q": moco_model.encoder_q.state_dict(),
}
torch.save(state, args.save_path)
print(f"MoCo model saved to {args.save_path}")
if __name__ == "__main__":
main()