Skip to content

Commit 1369bec

Browse files
0.27.5
preprocess moo
1 parent 1612a36 commit 1369bec

2 files changed

Lines changed: 27 additions & 19 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.27.4"
10+
version = "0.27.5"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/utils/preprocess.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sklearn.preprocessing import OneHotEncoder, RobustScaler
66
import numpy as np
77
import pandas as pd
8+
from typing import List, Tuple, Union
89

910

1011
def get_num_cols(df: pd.DataFrame) -> list:
@@ -63,7 +64,7 @@ def get_cat_cols(df: pd.DataFrame) -> list:
6364

6465
def generic_preprocess_df(
6566
df: pd.DataFrame,
66-
target: str,
67+
target: Union[str, List[str]],
6768
imputer_num=SimpleImputer(strategy="mean"),
6869
imputer_cat=SimpleImputer(strategy="most_frequent"),
6970
encoder_cat=OneHotEncoder(categories="auto", drop=None, handle_unknown="ignore", sparse_output=False),
@@ -72,21 +73,16 @@ def generic_preprocess_df(
7273
random_state=42,
7374
shuffle=True,
7475
n_jobs=None,
75-
) -> pd.DataFrame:
76+
) -> Tuple[np.ndarray, np.ndarray, pd.DataFrame, pd.DataFrame]:
7677
"""
7778
Preprocesses a DataFrame by handling numerical and categorical features,
7879
splitting the data into training and testing sets, and applying transformations.
79-
80-
This function performs the following steps:
81-
- Separates the target column from the features.
82-
- Identifies numerical and categorical columns.
83-
- Applies imputers, encoders, and scalers to the respective columns.
84-
- Splits the data into training and testing sets.
85-
- Transforms the data using the specified preprocessing pipelines.
80+
Supports single or multiple target columns.
8681
8782
Args:
8883
df (pd.DataFrame): The input DataFrame to preprocess.
89-
target (str): The name of the target column to predict.
84+
target (Union[str, List[str]]): The name(s) of the target column(s) to predict.
85+
Can be a single string or a list of strings.
9086
imputer_num (SimpleImputer, optional): Imputer for numerical columns.
9187
Defaults to `SimpleImputer(strategy="mean")`.
9288
imputer_cat (SimpleImputer, optional): Imputer for categorical columns.
@@ -103,15 +99,15 @@ def generic_preprocess_df(
10399
Defaults to None (1 job).
104100
105101
Returns:
106-
Tuple[np.ndarray, np.ndarray, pd.Series, pd.Series]:
102+
Tuple[np.ndarray, np.ndarray, pd.DataFrame, pd.DataFrame]:
107103
A tuple containing:
108104
- X_train (np.ndarray): Transformed training feature set.
109105
- X_test (np.ndarray): Transformed testing feature set.
110-
- y_train (pd.Series): Training target values.
111-
- y_test (pd.Series): Testing target values.
106+
- y_train (pd.DataFrame): Training target values.
107+
- y_test (pd.DataFrame): Testing target values.
112108
113109
Raises:
114-
ValueError: If the target column is not found in the DataFrame.
110+
ValueError: If any of the target column(s) are not found in the DataFrame.
115111
116112
Examples:
117113
>>> from spotpython.utils.preprocess import generic_preprocess_df
@@ -122,11 +118,12 @@ def generic_preprocess_df(
122118
... "age": [25, 30, np.nan, 35],
123119
... "gender": ["M", "F", "M", "F"],
124120
... "income": [50000, 60000, 55000, np.nan],
125-
... "target": [1, 0, 1, 0]
121+
... "target1": [1, 0, 1, 0],
122+
... "target2": [0, 1, 0, 1]
126123
... })
127124
>>> X_train, X_test, y_train, y_test = generic_preprocess_df(
128125
... df,
129-
... target="target",
126+
... target=["target1", "target2"],
130127
... imputer_num=SimpleImputer(strategy="mean"),
131128
... imputer_cat=SimpleImputer(strategy="most_frequent"),
132129
... encoder_cat=OneHotEncoder(),
@@ -137,15 +134,24 @@ def generic_preprocess_df(
137134
"""
138135
if df.empty:
139136
raise ValueError("The input DataFrame is empty.")
140-
if target not in df.columns:
141-
raise ValueError(f"Target column '{target}' not found in the DataFrame.")
137+
138+
if isinstance(target, str):
139+
target = [target] # Convert to list for consistent handling
140+
141+
for t in target:
142+
if t not in df.columns:
143+
raise ValueError(f"Target column '{t}' not found in the DataFrame.")
144+
142145
X = df.drop(target, axis=1)
143146
y = df[target]
147+
144148
num_cols = get_num_cols(X)
145149
cat_cols = get_cat_cols(X)
146150
X[cat_cols] = X[cat_cols].astype(str)
151+
147152
numerical_transformer = Pipeline(steps=[("imputer", imputer_num), ("scaler", scaler_num)])
148153
categorical_transformer = Pipeline(steps=[("imputer", imputer_cat), ("encoder", encoder_cat)])
154+
149155
preprocessor = ColumnTransformer(
150156
transformers=[
151157
("numerical", numerical_transformer, num_cols),
@@ -155,7 +161,9 @@ def generic_preprocess_df(
155161
sparse_threshold=0,
156162
n_jobs=n_jobs,
157163
)
164+
158165
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state, shuffle=shuffle)
166+
159167
X_train = preprocessor.fit_transform(X_train)
160168
X_test = preprocessor.transform(X_test)
161169

0 commit comments

Comments
 (0)