-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathutils.py
More file actions
91 lines (70 loc) · 2.73 KB
/
utils.py
File metadata and controls
91 lines (70 loc) · 2.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.autograd as autograd
def save_checkpoint(state, filename='checkpoint.pth', cpu=False):
if cpu:
new_state = OrderedDict()
for k in state.keys():
newk = k.replace('module.', '') # remove module. if model was trained using DataParallel
new_state[newk] = state[k].cpu()
state = new_state
torch.save(state, filename)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.values = []
self.counter = 0
def append(self, val):
self.values.append(val)
self.counter += 1
@property
def val(self):
return self.values[-1]
@property
def avg(self):
return sum(self.values) / len(self.values)
@property
def last_avg(self):
if self.counter == 0:
return self.latest_avg
else:
self.latest_avg = sum(self.values[-self.counter:]) / self.counter
self.counter = 0
return self.latest_avg
def accuracy(output, target, topk=(1,)):
"""Computes the precision@wk for the specified values of wk"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.div_(batch_size).item())
return res
def requires_grad_(model: nn.Module, requires_grad: bool) -> None:
for param in model.parameters():
param.requires_grad_(requires_grad)
class NormalizedModel(nn.Module):
"""
Wrapper for a model to account for the mean and std of a dataset.
mean and std do not require grad as they should not be learned, but determined beforehand.
mean and std should be broadcastable (see pytorch doc on broadcasting) with the data.
Args:
model (nn.Module): model to use to predict
mean (torch.Tensor): sequence of means for each channel
std (torch.Tensor): sequence of standard deviations for each channel
"""
def __init__(self, model: nn.Module, mean: torch.Tensor, std: torch.Tensor) -> None:
super(NormalizedModel, self).__init__()
self.model = model
self.mean = nn.Parameter(mean, requires_grad=False)
self.std = nn.Parameter(std, requires_grad=False)
def forward(self, input: torch.Tensor) -> torch.Tensor:
normalized_input = (input - self.mean) / self.std
return self.model(normalized_input)