Skip to content

Commit 9688a80

Browse files
author
Arthur Douillard
committed
to squash
1 parent 29ff923 commit 9688a80

File tree

1 file changed

+101
-45
lines changed

1 file changed

+101
-45
lines changed

inclearn/models/icarl.py

Lines changed: 101 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -71,58 +71,41 @@ def _before_task(self, train_loader, val_loader):
7171
)
7272

7373
def _train_task(self, train_loader, val_loader):
74-
#for p in self.parameters():
75-
# p.register_hook(lambda grad: torch.clamp(grad, -5, 5))
76-
7774
print("nb ", len(train_loader.dataset))
7875

7976
prog_bar = trange(self._n_epochs, desc="Losses.")
8077

81-
val_loss = 0.
8278
for epoch in prog_bar:
83-
_clf_loss, _distil_loss = 0., 0.
84-
c = 0
79+
_loss = 0.
80+
val_loss = 0.
8581

8682
self._scheduler.step()
8783

88-
for i, ((_, idxes), inputs, targets) in enumerate(train_loader, start=1):
84+
for inputs, targets in train_loader:
8985
self._optimizer.zero_grad()
9086

91-
c += len(idxes)
9287
inputs, targets = inputs.to(self._device), targets.to(self._device)
9388
targets = utils.to_onehot(targets, self._n_classes).to(self._device)
9489
logits = self._network(inputs)
9590

96-
clf_loss, distil_loss = self._compute_loss(
91+
loss = self._compute_loss(
9792
inputs,
9893
logits,
99-
targets,
100-
idxes,
94+
targets
10195
)
10296

103-
if not utils._check_loss(clf_loss) or not utils._check_loss(distil_loss):
97+
if not utils._check_loss(loss):
10498
import pdb
10599
pdb.set_trace()
106100

107-
loss = clf_loss + distil_loss
108-
109101
loss.backward()
110102
self._optimizer.step()
111103

112-
_clf_loss += clf_loss.item()
113-
_distil_loss += distil_loss.item()
114-
115-
if i % 10 == 0 or i >= len(train_loader):
116-
prog_bar.set_description(
117-
"Clf loss: {}; Distill loss: {}; Val loss: {}".format(
118-
round(clf_loss.item(), 3), round(distil_loss.item(), 3),
119-
round(val_loss, 3)
120-
)
121-
)
104+
_loss += loss.item()
122105

123106
prog_bar.set_description(
124-
"Clf loss: {}; Distill loss: {}; Val loss: {}".format(
125-
round(_clf_loss / c, 3), round(_distil_loss / c, 3), round(val_loss, 2)
107+
"Clf loss: {}; Val loss: {}".format(
108+
round(_loss / len(train_loader), 3), round(val_loss, 2)
126109
)
127110
)
128111

@@ -138,35 +121,27 @@ def _eval_task(self, data_loader):
138121

139122
return ypred, ytrue
140123

141-
def get_memory_indexes(self):
124+
def get_memory(self):
142125
return self.examplars
143126

144127
# -----------
145128
# Private API
146129
# -----------
147130

148-
def _compute_loss(self, inputs, logits, targets, idxes, train=True):
149-
if self._task == 0:
150-
# First task, only doing classification loss
151-
clf_loss = self._clf_loss(logits, targets)
152-
distil_loss = torch.zeros(1, device=self._device)
131+
def _compute_loss(self, inputs, logits, targets):
132+
targets = utils.one_hot(targets, self._n_classes)
133+
134+
if self._old_model is None:
135+
loss = F.binary_cross_entropy_with_logits(logits, targets)
153136
else:
154-
clf_loss = self._clf_loss(
155-
logits[..., self._new_task_index:], targets[..., self._new_task_index:]
156-
)
137+
old_targets = torch.sigmoid(self._old_model(inputs).detach())
157138

158-
temp = 1
159-
#previous_preds = self._previous_preds if train else self._previous_preds_val
160-
tmp = torch.sigmoid(self._old_model(inputs).detach() / temp)
161-
#if not torch.allclose(previous_preds[idxes], tmp):
162-
# import pdb; pdb.set_trace()
139+
new_targets = targets.clone()
140+
new_targets[..., :-self._task_size] = old_targets
163141

164-
distil_loss = self._distil_loss(
165-
logits[..., :self._new_task_index] / temp, tmp
166-
#previous_preds[idxes, :self._new_task_index]
167-
)
142+
loss = F.binary_cross_entropy_with_logits(logits, new_targets)
168143

169-
return clf_loss, distil_loss
144+
return loss
170145

171146
def _compute_predictions(self, data_loader):
172147
preds = torch.zeros(self._n_train_data, self._n_classes, device=self._device)
@@ -245,9 +220,11 @@ def _build_examplars(self, loader):
245220

246221
means = []
247222

223+
248224
lo, hi = 0, self._task * self._task_size
249225
print("Updating examplars for classes {} -> {}.".format(lo, hi))
250226
for class_idx in range(lo, hi):
227+
251228
loader.dataset.set_idxes(self._examplars[class_idx])
252229
# loader.dataset._flip_all = True
253230
#loader.dataset.double_dataset()
@@ -352,6 +329,85 @@ def examplars(self):
352329
)
353330

354331
def _reduce_examplars(self):
332+
return
355333
print("Reducing examplars.")
356334
for class_idx in range(len(self._examplars)):
357335
self._examplars[class_idx] = self._examplars[class_idx][:self._m]
336+
337+
338+
339+
340+
def extract_features(model, dataset):
341+
gen = DataLoader(dataset, shuffle=False, batch_size=256)
342+
features = []
343+
targets_all = []
344+
for inputs, targets in gen:
345+
features.append(model.extract(inputs.to(model.device)).detach())
346+
targets_all.append(targets.numpy())
347+
348+
return torch.cat(features), np.concatenate(targets_all)
349+
350+
351+
def extract(model, x, y):
352+
feat_normal, _ = extract_features(model, IncDataset(x, y, train=None))
353+
feat_flip, _ = extract_features(model, IncDataset(x, y, train="flip"))
354+
355+
return feat_normal, feat_flip
356+
357+
358+
def select_examplars(features, nb_max):
359+
D = features.cpu().numpy().T
360+
D = D / (np.linalg.norm(D, axis=0) + EPSILON)
361+
mu = np.mean(D, axis=1)
362+
herding_mat = np.zeros((features.shape[0]))
363+
364+
w_t = mu
365+
iter_herding, iter_herding_eff = 0, 0
366+
367+
while not(np.sum(herding_mat!=0)==min(nb_max,features.shape[0])) and iter_herding_eff<1000:
368+
tmp_t = np.dot(w_t, D)
369+
ind_max = np.argmax(tmp_t)
370+
iter_herding_eff += 1
371+
if herding_mat[ind_max] == 0:
372+
herding_mat[ind_max] = 1 + iter_herding
373+
iter_herding += 1
374+
375+
w_t = w_t + mu - D[:, ind_max]
376+
377+
return herding_mat
378+
379+
380+
def compute_examplar_mean(feat_norm, feat_flip, herding_mat, nb_max):
381+
D = feat_norm.cpu().numpy().T
382+
D = D / (np.linalg.norm(D, axis=0) + EPSILON)
383+
384+
D2 = feat_flip.cpu().numpy().T
385+
D2 = D2 / (np.linalg.norm(D2, axis=0) + EPSILON)
386+
387+
alph = herding_mat
388+
alph = (alph > 0) * (alph < nb_max + 1) * 1.
389+
390+
alph_mean = alph / np.sum(alph)
391+
392+
mean = (np.dot(D, alph_mean) + np.dot(D2, alph_mean)) / 2
393+
mean /= np.linalg.norm(mean)
394+
395+
return mean, alph
396+
397+
398+
def compute_accuracy(model, test_dataset, class_means):
399+
features, targets_ = extract_features(model, test_dataset)
400+
features = features.cpu().numpy()
401+
402+
targets = np.zeros((targets_.shape[0], 100),np.float32)
403+
targets[range(len(targets_)),targets_.astype('int32')] = 1.
404+
features = (features.T / (np.linalg.norm(features.T,axis=0) + EPSILON)).T
405+
406+
# Compute score for iCaRL
407+
sqd = cdist(class_means, features, 'sqeuclidean')
408+
score_icarl = (-sqd).T
409+
410+
# Compute the accuracy over the batch
411+
stat_icarl = [ll in best for ll, best in zip(targets_.astype('int32'), np.argsort(score_icarl, axis=1)[:, -1:])]
412+
413+
return np.average(stat_icarl)

0 commit comments

Comments
 (0)