Skip to content
This repository was archived by the owner on Dec 6, 2023. It is now read-only.

Commit 7726c88

Browse files
committed
add support for clf.classes_ to various classifiers
1 parent a3e8bc4 commit 7726c88

File tree

14 files changed

+92
-13
lines changed

14 files changed

+92
-13
lines changed

lightning/impl/adagrad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def _get_loss(self):
8686
return losses[self.loss]
8787

8888
def fit(self, X, y):
89-
self.label_binarizer_ = LabelBinarizer(neg_label=-1, pos_label=1)
90-
Y = np.asfortranarray(self.label_binarizer_.fit_transform(y),
89+
self._set_label_transformers(y)
90+
Y = np.asfortranarray(self.label_binarizer_.transform(y),
9191
dtype=np.float64)
9292
return self._fit(X, Y)
9393

lightning/impl/dual_cd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ def fit(self, X, y):
124124
n_samples, n_features = X.shape
125125
rs = self._get_random_state()
126126

127-
self.label_binarizer_ = LabelBinarizer(neg_label=-1, pos_label=1)
128-
Y = np.asfortranarray(self.label_binarizer_.fit_transform(y),
127+
self._set_label_transformers(y)
128+
Y = np.asfortranarray(self.label_binarizer_.transform(y),
129129
dtype=np.float64)
130130
n_vectors = Y.shape[1]
131131

lightning/impl/sag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def fit(self, X, y, sample_weight=None):
160160
raise ValueError('Penalties in SAGClassifier. Please use '
161161
'SAGAClassifier instead.'
162162
'.')
163-
self._set_label_transformers(y, neg_label=-1)[0]
163+
self._set_label_transformers(y)
164164
y_binary = self.label_binarizer_.transform(y).astype(np.float64)
165165
return self._fit(X, y_binary, sample_weight)
166166

lightning/impl/sdca.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ def _get_loss(self):
131131
return losses[self.loss]
132132

133133
def fit(self, X, y):
134-
self.label_binarizer_ = LabelBinarizer(neg_label=-1, pos_label=1)
135-
Y = np.asfortranarray(self.label_binarizer_.fit_transform(y),
134+
self._set_label_transformers(y)
135+
Y = np.asfortranarray(self.label_binarizer_.transform(y),
136136
dtype=np.float64)
137137
return self._fit(X, Y)
138138

lightning/impl/svrg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ def _get_loss(self):
8282
return losses[self.loss]
8383

8484
def fit(self, X, y):
85-
self.label_binarizer_ = LabelBinarizer(neg_label=-1, pos_label=1)
86-
Y = np.asfortranarray(self.label_binarizer_.fit_transform(y),
85+
self._set_label_transformers(y)
86+
Y = np.asfortranarray(self.label_binarizer_.transform(y),
8787
dtype=np.float64)
8888
return self._fit(X, Y)
8989

lightning/impl/tests/test_adagrad.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import numpy as np
22

33
from sklearn.datasets import load_iris
4-
from sklearn.preprocessing import Normalizer
54
from sklearn.utils.testing import assert_equal
65
from sklearn.utils.testing import assert_almost_equal
76

@@ -42,6 +41,20 @@ def test_adagrad_hinge_multiclass():
4241
assert_almost_equal(clf.score(X, y), 0.960, 3)
4342

4443

44+
def test_adagrad_classes_binary():
45+
clf = AdaGradClassifier()
46+
assert not hasattr(clf, 'classes_')
47+
clf.fit(X_bin, y_bin)
48+
assert_equal(list(clf.classes_), [-1, 1])
49+
50+
51+
def test_adagrad_classes_multiclass():
52+
clf = AdaGradClassifier()
53+
assert not hasattr(clf, 'classes_')
54+
clf.fit(X, y)
55+
assert_equal(list(clf.classes_), [0, 1, 2])
56+
57+
4558
def test_adagrad_callback():
4659
class Callback(object):
4760

lightning/impl/tests/test_dual_cd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def test_fit_linear_binary():
4949
for loss in ("l1", "l2"):
5050
clf = LinearSVC(loss=loss, random_state=0, max_iter=10)
5151
clf.fit(data, bin_target)
52+
assert_equal(list(clf.classes_), [0, 1])
5253
assert_equal(clf.score(data, bin_target), 1.0)
5354
y_pred = clf.decision_function(data).ravel()
5455

@@ -66,6 +67,7 @@ def test_fit_linear_multi():
6667
for data in (mult_dense, mult_sparse):
6768
clf = LinearSVC(random_state=0)
6869
clf.fit(data, mult_target)
70+
assert_equal(list(clf.classes_), [0, 1, 2])
6971
y_pred = clf.predict(data)
7072
acc = np.mean(y_pred == mult_target)
7173
assert_greater(acc, 0.85)

lightning/impl/tests/test_fista.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from sklearn.utils.testing import assert_almost_equal
77
from sklearn.utils.testing import assert_true
8+
from sklearn.utils.testing import assert_equal
89

910
from sklearn.datasets import load_digits
1011

@@ -107,6 +108,18 @@ def test_fista_multiclass_trace():
107108
assert_almost_equal(clf.score(data, mult_target), 0.98, 2)
108109

109110

111+
def test_fista_bin_classes():
112+
clf = FistaClassifier()
113+
clf.fit(bin_dense, bin_target)
114+
assert_equal(list(clf.classes_), [0, 1])
115+
116+
117+
def test_fista_multiclass_classes():
118+
clf = FistaClassifier()
119+
clf.fit(mult_dense, mult_target)
120+
assert_equal(list(clf.classes_), [0, 1, 2])
121+
122+
110123
def test_fista_regression():
111124
reg = FistaRegressor(max_iter=100, verbose=0)
112125
reg.fit(bin_dense, bin_target)

lightning/impl/tests/test_primal_cd.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,3 +403,15 @@ def test_multiclass_error_nongrouplasso():
403403
for penalty in ['l1', 'l2']:
404404
clf = CDClassifier(multiclass=True, penalty=penalty)
405405
assert_raises(NotImplementedError, clf.fit, mult_dense, mult_target)
406+
407+
408+
def test_bin_classes():
409+
clf = CDClassifier()
410+
clf.fit(bin_dense, bin_target)
411+
assert_equal(list(clf.classes_), [0, 1])
412+
413+
414+
def test_multiclass_classes():
415+
clf = CDClassifier()
416+
clf.fit(mult_dense, mult_target)
417+
assert_equal(list(clf.classes_), [0, 1, 2])

lightning/impl/tests/test_primal_newton.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from sklearn.utils.testing import assert_almost_equal
2+
from sklearn.utils.testing import assert_equal
23

34
from lightning.impl.datasets.samples_generator import make_classification
45
from lightning.impl.primal_newton import KernelSVC
@@ -11,3 +12,4 @@ def test_kernel_svc():
1112
clf = KernelSVC(kernel="rbf", gamma=0.1, random_state=0, verbose=0)
1213
clf.fit(bin_dense, bin_target)
1314
assert_almost_equal(clf.score(bin_dense, bin_target), 1.0)
15+
assert_equal(list(clf.classes_), [0, 1])

0 commit comments

Comments
 (0)