Skip to content
This repository was archived by the owner on Dec 6, 2023. It is now read-only.

Commit d662d04

Browse files
committed
Add projection on l1-ball to FISTA.
1 parent 4ce2115 commit d662d04

File tree

3 files changed

+27
-1
lines changed

3 files changed

+27
-1
lines changed

lightning/impl/fista.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .penalty import L1L2Penalty
2020
from .penalty import TracePenalty
2121
from .penalty import SimplexConstraint
22+
from .penalty import L1BallConstraint
2223
from .penalty import TotalVariation1DPenalty
2324

2425

@@ -32,6 +33,7 @@ def _get_penalty(self):
3233
"l1/l2": L1L2Penalty(),
3334
"trace": TracePenalty(),
3435
"simplex": SimplexConstraint(),
36+
"l1-ball": L1BallConstraint(),
3537
"tv1d": TotalVariation1DPenalty()
3638
}
3739
return penalties[self.penalty]

lightning/impl/penalty.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ def project_l1_ball(v, z=1):
7777
return np.sign(v) * project_simplex(np.abs(v), z)
7878

7979

80+
class L1BallConstraint(object):
81+
82+
def projection(self, coef, alpha, L):
83+
return project_l1_ball(coef[0], alpha).reshape(1,-1)
84+
85+
def regularization(self, coef):
86+
return 0
87+
88+
8089
class TotalVariation1DPenalty(object):
8190
def projection(self, coef, alpha, L):
8291
tmp = coef.copy()

lightning/impl/tests/test_fista.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from lightning.impl.datasets.samples_generator import make_classification
1313
from lightning.classification import FistaClassifier
1414
from lightning.regression import FistaRegressor
15-
from lightning.impl.penalty import project_simplex, L1Penalty
15+
from lightning.impl.penalty import project_simplex, project_l1_ball, L1Penalty
1616

1717
bin_dense, bin_target = make_classification(n_samples=200, n_features=100,
1818
n_informative=5,
@@ -142,6 +142,21 @@ def test_fista_regression_simplex():
142142
assert_almost_equal(np.sum(reg.coef_), 1.0, 3)
143143

144144

145+
def test_fista_regression_l1_ball():
146+
rng = np.random.RandomState(0)
147+
alpha = 5.0
148+
w = project_simplex(rng.randn(10), alpha)
149+
X = rng.randn(1000, 10)
150+
y = np.dot(X, w)
151+
152+
reg = FistaRegressor(penalty="l1-ball", alpha=alpha, max_iter=100, verbose=0)
153+
reg.fit(X, y)
154+
y_pred = reg.predict(X)
155+
error = np.sqrt(np.mean((y - y_pred) ** 2))
156+
assert_almost_equal(error, 0.000, 3)
157+
assert_almost_equal(np.sum(np.abs(reg.coef_)), alpha, 3)
158+
159+
145160
def test_fista_regression_trace():
146161
rng = np.random.RandomState(0)
147162
def _make_data(n_samples, n_features, n_tasks, n_components):

0 commit comments

Comments
 (0)