@@ -494,8 +494,8 @@ def get_checkpoint_contribution(checkpoint):
494
494
)
495
495
496
496
return (
497
- torch .sum (batch_jacobian ** 2 , dim = 1 )
498
- * torch .sum (batch_layer_input ** 2 , dim = 1 )
497
+ torch .sum (batch_jacobian ** 2 , dim = 1 )
498
+ * torch .sum (batch_layer_input ** 2 , dim = 1 )
499
499
* learning_rate
500
500
)
501
501
@@ -1063,17 +1063,17 @@ def _set_projections_tracincp_fast_rand_proj(
1063
1063
# `projection_dim` corresponds to the variable d in the top of page 15 of
1064
1064
# the TracIn paper: https://arxiv.org/pdf/2002.08484.pdf.
1065
1065
if jacobian_dim * layer_input_dim > projection_dim :
1066
- jacobian_projection_dim = min (int (projection_dim ** 0.5 ), jacobian_dim )
1066
+ jacobian_projection_dim = min (int (projection_dim ** 0.5 ), jacobian_dim )
1067
1067
layer_input_projection_dim = min (
1068
- int (projection_dim ** 0.5 ), layer_input_dim
1068
+ int (projection_dim ** 0.5 ), layer_input_dim
1069
1069
)
1070
1070
jacobian_projection = torch .normal (
1071
1071
torch .zeros (jacobian_dim , jacobian_projection_dim ),
1072
- 1.0 / jacobian_projection_dim ** 0.5 ,
1072
+ 1.0 / jacobian_projection_dim ** 0.5 ,
1073
1073
)
1074
1074
layer_input_projection = torch .normal (
1075
1075
torch .zeros (layer_input_dim , layer_input_projection_dim ),
1076
- 1.0 / layer_input_projection_dim ** 0.5 ,
1076
+ 1.0 / layer_input_projection_dim ** 0.5 ,
1077
1077
)
1078
1078
1079
1079
projection_quantities = jacobian_projection , layer_input_projection
@@ -1157,7 +1157,7 @@ def _get_intermediate_quantities_tracincp_fast_rand_proj(
1157
1157
), "None returned from `checkpoints`, cannot load."
1158
1158
1159
1159
learning_rate = self .checkpoints_load_func (self .model , checkpoint )
1160
- learning_rate_root = learning_rate ** 0.5
1160
+ learning_rate_root = learning_rate ** 0.5
1161
1161
1162
1162
for batch in dataloader :
1163
1163
0 commit comments