1515
1616from lightning .impl .datasets .samples_generator import make_classification
1717from lightning .impl .primal_cd import CDClassifier , CDRegressor
18+ from lightning .impl .tests .utils import check_predict_proba
19+
1820
1921bin_dense , bin_target = make_classification (n_samples = 200 , n_features = 100 ,
2022 n_informative = 5 ,
3133def test_fit_linear_binary_l1r ():
3234 clf = CDClassifier (C = 1.0 , random_state = 0 , penalty = "l1" )
3335 clf .fit (bin_dense , bin_target )
36+ assert not hasattr (clf , 'predict_proba' )
3437 acc = clf .score (bin_dense , bin_target )
3538 assert_almost_equal (acc , 1.0 )
3639 n_nz = clf .n_nonzero ()
@@ -51,6 +54,7 @@ def test_fit_linear_binary_l1r():
5154def test_fit_linear_binary_l1r_smooth_hinge ():
5255 clf = CDClassifier (C = 1.0 , loss = "smooth_hinge" , random_state = 0 , penalty = "l1" )
5356 clf .fit (bin_dense , bin_target )
57+ assert not hasattr (clf , 'predict_proba' )
5458 acc = clf .score (bin_dense , bin_target )
5559 assert_almost_equal (acc , 1.0 )
5660
@@ -102,6 +106,7 @@ def test_warm_start_l1r_regression():
102106def test_fit_linear_binary_l1r_log_loss ():
103107 clf = CDClassifier (C = 1.0 , random_state = 0 , penalty = "l1" , loss = "log" )
104108 clf .fit (bin_dense , bin_target )
109+ check_predict_proba (clf , bin_dense )
105110 acc = clf .score (bin_dense , bin_target )
106111 assert_almost_equal (acc , 0.995 )
107112
@@ -133,6 +138,7 @@ def test_fit_linear_binary_l2r_modified_huber():
133138 clf = CDClassifier (C = 1.0 , random_state = 0 , penalty = "l2" ,
134139 loss = "modified_huber" )
135140 clf .fit (bin_dense , bin_target )
141+ check_predict_proba (clf , bin_dense )
136142 acc = clf .score (bin_dense , bin_target )
137143 assert_almost_equal (acc , 1.0 )
138144
0 commit comments