Skip to content

Commit b03644f

Browse files
authored
FIX use correct weights in WeightedL1GroupL2 penalty prox_1group() (#333)
1 parent 29d67fa commit b03644f

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

skglm/penalties/block_separable.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,9 @@ def value(self, w):
455455

456456
def prox_1group(self, value, stepsize, g):
457457
"""Compute the proximal operator of group ``g``."""
458-
res = ST_vec(value, self.alpha * stepsize * self.weights_features[g])
458+
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices
459+
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
460+
res = ST_vec(value, self.alpha * stepsize * self.weights_features[grp_g_indices])
459461
return BST(res, self.alpha * stepsize * self.weights_groups[g])
460462

461463
def is_penalized(self, n_groups):

0 commit comments

Comments
 (0)