55from sklearn .preprocessing import OneHotEncoder , RobustScaler
66import numpy as np
77import pandas as pd
8+ from typing import List , Tuple , Union
89
910
1011def get_num_cols (df : pd .DataFrame ) -> list :
@@ -63,7 +64,7 @@ def get_cat_cols(df: pd.DataFrame) -> list:
6364
6465def generic_preprocess_df (
6566 df : pd .DataFrame ,
66- target : str ,
67+ target : Union [ str , List [ str ]] ,
6768 imputer_num = SimpleImputer (strategy = "mean" ),
6869 imputer_cat = SimpleImputer (strategy = "most_frequent" ),
6970 encoder_cat = OneHotEncoder (categories = "auto" , drop = None , handle_unknown = "ignore" , sparse_output = False ),
@@ -72,21 +73,16 @@ def generic_preprocess_df(
7273 random_state = 42 ,
7374 shuffle = True ,
7475 n_jobs = None ,
75- ) -> pd .DataFrame :
76+ ) -> Tuple [ np . ndarray , np . ndarray , pd .DataFrame , pd . DataFrame ] :
7677 """
7778 Preprocesses a DataFrame by handling numerical and categorical features,
7879 splitting the data into training and testing sets, and applying transformations.
79-
80- This function performs the following steps:
81- - Separates the target column from the features.
82- - Identifies numerical and categorical columns.
83- - Applies imputers, encoders, and scalers to the respective columns.
84- - Splits the data into training and testing sets.
85- - Transforms the data using the specified preprocessing pipelines.
80+ Supports single or multiple target columns.
8681
8782 Args:
8883 df (pd.DataFrame): The input DataFrame to preprocess.
89- target (str): The name of the target column to predict.
84+ target (Union[str, List[str]]): The name(s) of the target column(s) to predict.
85+ Can be a single string or a list of strings.
9086 imputer_num (SimpleImputer, optional): Imputer for numerical columns.
9187 Defaults to `SimpleImputer(strategy="mean")`.
9288 imputer_cat (SimpleImputer, optional): Imputer for categorical columns.
@@ -103,15 +99,15 @@ def generic_preprocess_df(
10399 Defaults to None (1 job).
104100
105101 Returns:
106- Tuple[np.ndarray, np.ndarray, pd.Series , pd.Series ]:
102+ Tuple[np.ndarray, np.ndarray, pd.DataFrame , pd.DataFrame ]:
107103 A tuple containing:
108104 - X_train (np.ndarray): Transformed training feature set.
109105 - X_test (np.ndarray): Transformed testing feature set.
110- - y_train (pd.Series ): Training target values.
111- - y_test (pd.Series ): Testing target values.
106+ - y_train (pd.DataFrame ): Training target values.
107+ - y_test (pd.DataFrame ): Testing target values.
112108
113109 Raises:
114- ValueError: If the target column is not found in the DataFrame.
110+ ValueError: If any of the target column(s) are not found in the DataFrame.
115111
116112 Examples:
117113 >>> from spotpython.utils.preprocess import generic_preprocess_df
@@ -122,11 +118,12 @@ def generic_preprocess_df(
122118 ... "age": [25, 30, np.nan, 35],
123119 ... "gender": ["M", "F", "M", "F"],
124120 ... "income": [50000, 60000, 55000, np.nan],
125- ... "target": [1, 0, 1, 0]
121+ ... "target1": [1, 0, 1, 0],
122+ ... "target2": [0, 1, 0, 1]
126123 ... })
127124 >>> X_train, X_test, y_train, y_test = generic_preprocess_df(
128125 ... df,
129- ... target="target" ,
126+ ... target=["target1", "target2"] ,
130127 ... imputer_num=SimpleImputer(strategy="mean"),
131128 ... imputer_cat=SimpleImputer(strategy="most_frequent"),
132129 ... encoder_cat=OneHotEncoder(),
@@ -137,15 +134,24 @@ def generic_preprocess_df(
137134 """
138135 if df .empty :
139136 raise ValueError ("The input DataFrame is empty." )
140- if target not in df .columns :
141- raise ValueError (f"Target column '{ target } ' not found in the DataFrame." )
137+
138+ if isinstance (target , str ):
139+ target = [target ] # Convert to list for consistent handling
140+
141+ for t in target :
142+ if t not in df .columns :
143+ raise ValueError (f"Target column '{ t } ' not found in the DataFrame." )
144+
142145 X = df .drop (target , axis = 1 )
143146 y = df [target ]
147+
144148 num_cols = get_num_cols (X )
145149 cat_cols = get_cat_cols (X )
146150 X [cat_cols ] = X [cat_cols ].astype (str )
151+
147152 numerical_transformer = Pipeline (steps = [("imputer" , imputer_num ), ("scaler" , scaler_num )])
148153 categorical_transformer = Pipeline (steps = [("imputer" , imputer_cat ), ("encoder" , encoder_cat )])
154+
149155 preprocessor = ColumnTransformer (
150156 transformers = [
151157 ("numerical" , numerical_transformer , num_cols ),
@@ -155,7 +161,9 @@ def generic_preprocess_df(
155161 sparse_threshold = 0 ,
156162 n_jobs = n_jobs ,
157163 )
164+
158165 X_train , X_test , y_train , y_test = train_test_split (X , y , test_size = test_size , random_state = random_state , shuffle = shuffle )
166+
159167 X_train = preprocessor .fit_transform (X_train )
160168 X_test = preprocessor .transform (X_test )
161169
0 commit comments