Skip to content
This repository was archived by the owner on Dec 6, 2023. It is now read-only.

Commit 4ce2115

Browse files
committed
Add projection on l1 ball.
1 parent e1cdcaa commit 4ce2115

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

lightning/impl/penalty.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ def regularization(self, coef):
7373
return 0
7474

7575

76+
def project_l1_ball(v, z=1):
77+
return np.sign(v) * project_simplex(np.abs(v), z)
78+
79+
7680
class TotalVariation1DPenalty(object):
7781
def projection(self, coef, alpha, L):
7882
tmp = coef.copy()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import numpy as np
2+
from sklearn.utils.testing import assert_almost_equal
3+
4+
from lightning.impl.penalty import project_l1_ball
5+
6+
7+
def test_proj_l1_ball():
8+
rng = np.random.RandomState(0)
9+
v = rng.randn(100)
10+
w = project_l1_ball(v, z=50)
11+
assert_almost_equal(np.sum(np.abs(w)), 50)

0 commit comments

Comments
 (0)