-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
44 lines (35 loc) · 1.31 KB
/
utils.py
File metadata and controls
44 lines (35 loc) · 1.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn as nn
# NT-Xent Loss for SimCLR
class NTXentLoss(nn.Module):
def __init__(self, batch_size, temperature=0.5):
super().__init__()
self.batch_size = batch_size
self.temperature = temperature
N = 2 * batch_size
mask = torch.eye(N, dtype=torch.bool)
self.register_buffer("mask", mask)
def forward(self, z):
"""
z: [2*batch_size, dim]
"""
N = 2 * self.batch_size
# Compute similarity matrix
sim = torch.matmul(z, z.T) / self.temperature
# Positive pairs: (i, i + batch_size) and (i + batch_size, i)
pos = torch.cat([
torch.diag(sim, self.batch_size),
torch.diag(sim, -self.batch_size)
]).view(2 * self.batch_size, 1)
# Mask out self-similarities
sim = sim.masked_fill(self.mask, -9e15)
# Logits: [2N, 1 + (2N-1)]
logits = torch.cat([pos, sim], dim=1)
labels = torch.zeros(2 * self.batch_size, dtype=torch.long, device=z.device)
loss = nn.CrossEntropyLoss()(logits, labels)
return loss
#cosine Similarity (if needed)
def cosine_similarity(a, b, eps=1e-8):
a_norm = a / (a.norm(dim=1, keepdim=True) + eps)
b_norm = b / (b.norm(dim=1, keepdim=True) + eps)
return torch.mm(a_norm, b_norm.T)