@@ -71,58 +71,41 @@ def _before_task(self, train_loader, val_loader):
71
71
)
72
72
73
73
def _train_task (self , train_loader , val_loader ):
74
- #for p in self.parameters():
75
- # p.register_hook(lambda grad: torch.clamp(grad, -5, 5))
76
-
77
74
print ("nb " , len (train_loader .dataset ))
78
75
79
76
prog_bar = trange (self ._n_epochs , desc = "Losses." )
80
77
81
- val_loss = 0.
82
78
for epoch in prog_bar :
83
- _clf_loss , _distil_loss = 0. , 0.
84
- c = 0
79
+ _loss = 0.
80
+ val_loss = 0.
85
81
86
82
self ._scheduler .step ()
87
83
88
- for i , (( _ , idxes ), inputs , targets ) in enumerate ( train_loader , start = 1 ) :
84
+ for inputs , targets in train_loader :
89
85
self ._optimizer .zero_grad ()
90
86
91
- c += len (idxes )
92
87
inputs , targets = inputs .to (self ._device ), targets .to (self ._device )
93
88
targets = utils .to_onehot (targets , self ._n_classes ).to (self ._device )
94
89
logits = self ._network (inputs )
95
90
96
- clf_loss , distil_loss = self ._compute_loss (
91
+ loss = self ._compute_loss (
97
92
inputs ,
98
93
logits ,
99
- targets ,
100
- idxes ,
94
+ targets
101
95
)
102
96
103
- if not utils ._check_loss (clf_loss ) or not utils . _check_loss ( distil_loss ):
97
+ if not utils ._check_loss (loss ):
104
98
import pdb
105
99
pdb .set_trace ()
106
100
107
- loss = clf_loss + distil_loss
108
-
109
101
loss .backward ()
110
102
self ._optimizer .step ()
111
103
112
- _clf_loss += clf_loss .item ()
113
- _distil_loss += distil_loss .item ()
114
-
115
- if i % 10 == 0 or i >= len (train_loader ):
116
- prog_bar .set_description (
117
- "Clf loss: {}; Distill loss: {}; Val loss: {}" .format (
118
- round (clf_loss .item (), 3 ), round (distil_loss .item (), 3 ),
119
- round (val_loss , 3 )
120
- )
121
- )
104
+ _loss += loss .item ()
122
105
123
106
prog_bar .set_description (
124
- "Clf loss: {}; Distill loss: {}; Val loss: {}" .format (
125
- round (_clf_loss / c , 3 ), round ( _distil_loss / c , 3 ), round (val_loss , 2 )
107
+ "Clf loss: {}; Val loss: {}" .format (
108
+ round (_loss / len ( train_loader ) , 3 ), round (val_loss , 2 )
126
109
)
127
110
)
128
111
@@ -138,35 +121,27 @@ def _eval_task(self, data_loader):
138
121
139
122
return ypred , ytrue
140
123
141
- def get_memory_indexes (self ):
124
+ def get_memory (self ):
142
125
return self .examplars
143
126
144
127
# -----------
145
128
# Private API
146
129
# -----------
147
130
148
- def _compute_loss (self , inputs , logits , targets , idxes , train = True ):
149
- if self . _task == 0 :
150
- # First task, only doing classification loss
151
- clf_loss = self ._clf_loss ( logits , targets )
152
- distil_loss = torch . zeros ( 1 , device = self . _device )
131
+ def _compute_loss (self , inputs , logits , targets ):
132
+ targets = utils . one_hot ( targets , self . _n_classes )
133
+
134
+ if self ._old_model is None :
135
+ loss = F . binary_cross_entropy_with_logits ( logits , targets )
153
136
else :
154
- clf_loss = self ._clf_loss (
155
- logits [..., self ._new_task_index :], targets [..., self ._new_task_index :]
156
- )
137
+ old_targets = torch .sigmoid (self ._old_model (inputs ).detach ())
157
138
158
- temp = 1
159
- #previous_preds = self._previous_preds if train else self._previous_preds_val
160
- tmp = torch .sigmoid (self ._old_model (inputs ).detach () / temp )
161
- #if not torch.allclose(previous_preds[idxes], tmp):
162
- # import pdb; pdb.set_trace()
139
+ new_targets = targets .clone ()
140
+ new_targets [..., :- self ._task_size ] = old_targets
163
141
164
- distil_loss = self ._distil_loss (
165
- logits [..., :self ._new_task_index ] / temp , tmp
166
- #previous_preds[idxes, :self._new_task_index]
167
- )
142
+ loss = F .binary_cross_entropy_with_logits (logits , new_targets )
168
143
169
- return clf_loss , distil_loss
144
+ return loss
170
145
171
146
def _compute_predictions (self , data_loader ):
172
147
preds = torch .zeros (self ._n_train_data , self ._n_classes , device = self ._device )
@@ -245,9 +220,11 @@ def _build_examplars(self, loader):
245
220
246
221
means = []
247
222
223
+
248
224
lo , hi = 0 , self ._task * self ._task_size
249
225
print ("Updating examplars for classes {} -> {}." .format (lo , hi ))
250
226
for class_idx in range (lo , hi ):
227
+
251
228
loader .dataset .set_idxes (self ._examplars [class_idx ])
252
229
# loader.dataset._flip_all = True
253
230
#loader.dataset.double_dataset()
@@ -352,6 +329,85 @@ def examplars(self):
352
329
)
353
330
354
331
def _reduce_examplars (self ):
332
+ return
355
333
print ("Reducing examplars." )
356
334
for class_idx in range (len (self ._examplars )):
357
335
self ._examplars [class_idx ] = self ._examplars [class_idx ][:self ._m ]
336
+
337
+
338
+
339
+
340
+ def extract_features (model , dataset ):
341
+ gen = DataLoader (dataset , shuffle = False , batch_size = 256 )
342
+ features = []
343
+ targets_all = []
344
+ for inputs , targets in gen :
345
+ features .append (model .extract (inputs .to (model .device )).detach ())
346
+ targets_all .append (targets .numpy ())
347
+
348
+ return torch .cat (features ), np .concatenate (targets_all )
349
+
350
+
351
+ def extract (model , x , y ):
352
+ feat_normal , _ = extract_features (model , IncDataset (x , y , train = None ))
353
+ feat_flip , _ = extract_features (model , IncDataset (x , y , train = "flip" ))
354
+
355
+ return feat_normal , feat_flip
356
+
357
+
358
+ def select_examplars (features , nb_max ):
359
+ D = features .cpu ().numpy ().T
360
+ D = D / (np .linalg .norm (D , axis = 0 ) + EPSILON )
361
+ mu = np .mean (D , axis = 1 )
362
+ herding_mat = np .zeros ((features .shape [0 ]))
363
+
364
+ w_t = mu
365
+ iter_herding , iter_herding_eff = 0 , 0
366
+
367
+ while not (np .sum (herding_mat != 0 )== min (nb_max ,features .shape [0 ])) and iter_herding_eff < 1000 :
368
+ tmp_t = np .dot (w_t , D )
369
+ ind_max = np .argmax (tmp_t )
370
+ iter_herding_eff += 1
371
+ if herding_mat [ind_max ] == 0 :
372
+ herding_mat [ind_max ] = 1 + iter_herding
373
+ iter_herding += 1
374
+
375
+ w_t = w_t + mu - D [:, ind_max ]
376
+
377
+ return herding_mat
378
+
379
+
380
+ def compute_examplar_mean (feat_norm , feat_flip , herding_mat , nb_max ):
381
+ D = feat_norm .cpu ().numpy ().T
382
+ D = D / (np .linalg .norm (D , axis = 0 ) + EPSILON )
383
+
384
+ D2 = feat_flip .cpu ().numpy ().T
385
+ D2 = D2 / (np .linalg .norm (D2 , axis = 0 ) + EPSILON )
386
+
387
+ alph = herding_mat
388
+ alph = (alph > 0 ) * (alph < nb_max + 1 ) * 1.
389
+
390
+ alph_mean = alph / np .sum (alph )
391
+
392
+ mean = (np .dot (D , alph_mean ) + np .dot (D2 , alph_mean )) / 2
393
+ mean /= np .linalg .norm (mean )
394
+
395
+ return mean , alph
396
+
397
+
398
+ def compute_accuracy (model , test_dataset , class_means ):
399
+ features , targets_ = extract_features (model , test_dataset )
400
+ features = features .cpu ().numpy ()
401
+
402
+ targets = np .zeros ((targets_ .shape [0 ], 100 ),np .float32 )
403
+ targets [range (len (targets_ )),targets_ .astype ('int32' )] = 1.
404
+ features = (features .T / (np .linalg .norm (features .T ,axis = 0 ) + EPSILON )).T
405
+
406
+ # Compute score for iCaRL
407
+ sqd = cdist (class_means , features , 'sqeuclidean' )
408
+ score_icarl = (- sqd ).T
409
+
410
+ # Compute the accuracy over the batch
411
+ stat_icarl = [ll in best for ll , best in zip (targets_ .astype ('int32' ), np .argsort (score_icarl , axis = 1 )[:, - 1 :])]
412
+
413
+ return np .average (stat_icarl )
0 commit comments