-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
54 lines (41 loc) · 1.5 KB
/
utils.py
File metadata and controls
54 lines (41 loc) · 1.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from torch.utils.data import TensorDataset
from model import RClassifier
def load(path):
try:
state_dict, configs, epoch = torch.load(path)
except:
raise ValueError("failed to load model from path " + path)
model = RClassifier(**configs)
model.load_state_dict(state_dict)
return model, epoch
def save(model, epoch, path):
torch.save([model.state_dict(), model.configs, epoch], path)
def data_subset(dataset, labels):
mask = torch.isin(dataset.targets, torch.tensor(labels))
xs = dataset.data[mask]
ys = dataset.targets[mask]
return TensorDataset(xs, ys)
import matplotlib.pyplot as plt
from matplotlib import animation
def plot_history(history):
fig, ax = plt.subplots()
bar_plot = plt.bar(range(0, len(history[0])), history[0], width=0.8, bottom=None)
def update(frame):
for i, b in enumerate(bar_plot):
b.set_height(frame[i])
game = animation.FuncAnimation(fig, update, frames=history)
plt.show()
def scatterplot(data):
fig, ax = plt.subplots()
bar_plot = plt.scatter(range(0, len(data)), data, marker=".")
plt.show()
def histogram(data):
plt.hist(torch.flatten(data).numpy())
plt.show()
def interpolate(history, scale_factor):
T, V = history.shape
history = history.permute(1, 0).view(1, V, T)
history = torch.nn.functional.interpolate(history, scale_factor=scale_factor, mode='linear', align_corners=False)
history = history.squeeze().permute(1, 0)
return history