1+ # flake8: noqa
12import numpy as np
23from sklearn import datasets
34from sklearn .model_selection import train_test_split
@@ -13,6 +14,7 @@ def test_GaussianNB(N=10):
1314 N = np .inf if N is None else N
1415
1516 i = 1
17+ eps = np .finfo (float ).eps
1618 while i < N + 1 :
1719 n_ex = np .random .randint (1 , 300 )
1820 n_feats = np .random .randint (1 , 100 )
@@ -33,29 +35,29 @@ def test_GaussianNB(N=10):
3335
3436 sk_preds = sklearn_NB .predict (X_test )
3537
36- for i in range (len (NB .labels )):
38+ for j in range (len (NB .labels )):
3739 P = NB .parameters
38- jointi = np .log (sklearn_NB .class_prior_ [i ])
39- jointi_mine = np .log (P ["prior" ][i ])
40+ jointi = np .log (sklearn_NB .class_prior_ [j ])
41+ jointi_mine = np .log (P ["prior" ][j ])
4042
4143 np .testing .assert_almost_equal (jointi , jointi_mine )
4244
43- n_ij = - 0.5 * np .sum (np .log (2.0 * np .pi * sklearn_NB .sigma_ [i , :]))
44- n_ij_mine = - 0.5 * np .sum (np .log (2.0 * np .pi * P ["sigma" ][i ] ))
45+ n_jk = - 0.5 * np .sum (np .log (2.0 * np .pi * sklearn_NB .sigma_ [j , :] + eps ))
46+ n_jk_mine = - 0.5 * np .sum (np .log (2.0 * np .pi * P ["sigma" ][j ] + eps ))
4547
46- np .testing .assert_almost_equal (n_ij_mine , n_ij )
48+ np .testing .assert_almost_equal (n_jk_mine , n_jk )
4749
48- n_ij2 = n_ij - 0.5 * np .sum (
49- ((X_test - sklearn_NB .theta_ [i , :]) ** 2 ) / (sklearn_NB .sigma_ [i , :]), 1
50+ n_jk2 = n_jk - 0.5 * np .sum (
51+ ((X_test - sklearn_NB .theta_ [j , :]) ** 2 ) / (sklearn_NB .sigma_ [j , :]), 1
5052 )
5153
52- n_ij2_mine = n_ij_mine - 0.5 * np .sum (
53- ((X_test - P ["mean" ][i ]) ** 2 ) / (P ["sigma" ][i ]), 1
54+ n_jk2_mine = n_jk_mine - 0.5 * np .sum (
55+ ((X_test - P ["mean" ][j ]) ** 2 ) / (P ["sigma" ][j ]), 1
5456 )
55- np .testing .assert_almost_equal (n_ij2_mine , n_ij2 , decimal = 4 )
57+ np .testing .assert_almost_equal (n_jk2_mine , n_jk2 , decimal = 4 )
5658
57- llh = jointi + n_ij2
58- llh_mine = jointi_mine + n_ij2_mine
59+ llh = jointi + n_jk2
60+ llh_mine = jointi_mine + n_jk2_mine
5961
6062 np .testing .assert_almost_equal (llh_mine , llh , decimal = 4 )
6163
0 commit comments