@@ -115,7 +115,7 @@ def generate_imp(X_train, X_test, y_train, y_test, random_state=42, n_repeats=10
115115 return perm_imp
116116
117117
118- def plot_importances (df_mdi , perm_imp , X_test , target_name = None , feature_names = None , k = 10 , show = True ) -> None :
118+ def plot_importances (df_mdi , perm_imp , X_test , target_name = None , feature_names = None , k = 10 , figsize = ( 12 , 8 ), show = True ) -> None :
119119 """
120120 Plots the impurity-based and permutation-based feature importances for a given classifier.
121121
@@ -132,6 +132,8 @@ def plot_importances(df_mdi, perm_imp, X_test, target_name=None, feature_names=N
132132 List of feature names for labeling. Defaults to None.
133133 k (int, optional):
134134 Number of top features to display based on importance. Default is 10.
135+ figsize (tuple, optional):
136+ Size of the figure (width, height) in inches. Default is (12, 8).
135137 show (bool, optional):
136138 If True, displays the plot immediately. Default is True.
137139
@@ -151,12 +153,11 @@ def plot_importances(df_mdi, perm_imp, X_test, target_name=None, feature_names=N
151153 >>> y_test_series = pd.Series(y_test)
152154 >>> df_mdi = generate_mdi(X_train_df, y_train_series)
153155 >>> perm_imp = generate_imp(X_train_df, X_test_df, y_train_series, y_test_series)
154- >>> plot_importances(df_mdi, perm_imp, X_test_df)
155-
156+ >>> plot_importances(df_mdi, perm_imp, X_test_df, figsize=(15, 10))
156157 """
157158
158159 # Plot impurity-based importances for top-k features
159- fig , (ax1 , ax2 ) = plt .subplots (1 , 2 , figsize = ( 12 , 8 ) )
160+ fig , (ax1 , ax2 ) = plt .subplots (1 , 2 , figsize = figsize )
160161
161162 sorted_mdi_importances = df_mdi .set_index ("Feature" )["Importance" ]
162163 sorted_mdi_importances [:k ].sort_values ().plot .barh (ax = ax1 )
0 commit comments