-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_simclr.py
More file actions
72 lines (60 loc) · 2.16 KB
/
train_simclr.py
File metadata and controls
72 lines (60 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from models import ResNetEncoder
from transforms import ContrastiveTransform
from utils import NTXentLoss
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--temperature", type=float, default=0.5)
parser.add_argument("--save_path", type=str, default="simclr.pth")
return parser.parse_args()
def main():
args = parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
# dataLoader
train_transform = ContrastiveTransform(base_size=32)
train_dataset = datasets.CIFAR10(
root="./data",
train=True,
download=True,
transform=lambda x: train_transform(x)
)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=2,
drop_last=True
)
#model, optimizer, loss
model = ResNetEncoder(base="resnet18", out_dim=128).to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
criterion = NTXentLoss(batch_size=args.batch_size, temperature=args.temperature)
# train loop
for epoch in range(args.epochs):
model.train()
total_loss = 0.0
for (x_i, x_j), _ in train_loader:
x_i = x_i.to(device)
x_j = x_j.to(device)
_, z_i = model(x_i) # [B, 128]
_, z_j = model(x_j) # [B, 128]
z = torch.cat([z_i, z_j], dim=0) # [2B, 128]
loss = criterion(z)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
print(f"[SimCLR] Epoch [{epoch+1}/{args.epochs}] Loss: {avg_loss:.4f}")
# Save
torch.save(model.state_dict(), args.save_path)
print(f"SimCLR model saved to {args.save_path}")
if __name__ == "__main__":
main()