@@ -137,32 +137,29 @@ class BorutaPy(BaseEstimator, TransformerMixin):
137137
138138 Examples
139139 --------
140-
141- import pandas as pd
142- from sklearn.ensemble import RandomForestClassifier
143- from boruta_py import BorutaPy
144-
140+
145141 # load X and y
146142 # NOTE BorutaPy accepts numpy arrays only, hence the .values attribute
147- X = pd.read_csv('my_X_table.csv', index_col=0).values
148- y = pd.read_csv('my_y_vector.csv', index_col=0).values
149-
143+ X = pd.read_csv('examples/test_X.csv', index_col=0).values
144+ y = pd.read_csv('examples/test_y.csv', header=None, index_col=0).values
145+ y = y.ravel()
146+
150147 # define random forest classifier, with utilising all cores and
151148 # sampling in proportion to y labels
152149 rf = RandomForestClassifier(n_jobs=-1, class_weight='auto', max_depth=5)
153-
150+
154151 # define Boruta feature selection method
155- feat_selector = BorutaPy(rf, n_estimators='auto', verbose=2)
156-
157- # find all relevant features
152+ feat_selector = BorutaPy(rf, n_estimators='auto', verbose=2, random_state=1 )
153+
154+ # find all relevant features - 5 features should be selected
158155 feat_selector.fit(X, y)
159-
160- # check selected features
156+
157+ # check selected features - first 5 features are selected
161158 feat_selector.support_
162-
159+
163160 # check ranking of features
164161 feat_selector.ranking_
165-
162+
166163 # call transform() on X to filter it down to selected features
167164 X_filtered = feat_selector.transform(X)
168165
@@ -181,7 +178,7 @@ def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05,
181178 self .alpha = alpha
182179 self .two_step = two_step
183180 self .max_iter = max_iter
184- self .random_state = check_random_state ( random_state )
181+ self .random_state = random_state
185182 self .verbose = verbose
186183
187184 def fit (self , X , y ):
@@ -248,6 +245,7 @@ def fit_transform(self, X, y, weak=False):
248245 def _fit (self , X , y ):
249246 # check input params
250247 self ._check_params (X , y )
248+ self .random_state = check_random_state (self .random_state )
251249 # setup variables for Boruta
252250 n_sample , n_feat = X .shape
253251 _iter = 1
0 commit comments