Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions python/interpret-core/interpret/utils/_clean_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,24 @@ def _process_pandas_column(X_col, is_initial, feature_type, min_unique_continuou
X_col.codes,
None,
)
elif isinstance(dt, pd.StringDtype):
# this handles pd.StringDtype both the numpy and arrow versions
if X_col.hasnans:
# if hasnans is true then there is definetly a real missing value in there and not just a mask
return _process_ndarray(
X_col.dropna().values.astype(np.str_, copy=False),
X_col.notna().values,
is_initial,
feature_type,
min_unique_continuous,
)
return _process_ndarray(
X_col.values.astype(np.str_, copy=False),
None,
is_initial,
feature_type,
min_unique_continuous,
)
elif issubclass(tt, _intbool_types):
# this handles Int8Dtype to Int64Dtype, UInt8Dtype to UInt64Dtype, and BooleanDtype
if X_col.hasnans:
Expand All @@ -1058,8 +1076,6 @@ def _process_pandas_column(X_col, is_initial, feature_type, min_unique_continuou
)

# TODO: implement pd.SparseDtype
# TODO: implement pd.StringDtype both the numpy and arrow versions
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.StringDtype.html#pandas.StringDtype
msg = f"{type(dt)} not supported"
_log.error(msg)
raise TypeError(msg)
Expand Down
4 changes: 4 additions & 0 deletions python/interpret-core/tests/utils/test_clean_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,10 @@ def test_unify_columns_pandas_missings_BooleanDtype():
check_pandas_missings(pd.BooleanDtype(), False, True)


def test_unify_columns_pandas_missings_StringDtype():
check_pandas_missings(pd.StringDtype(), "abc", "def")


def test_unify_columns_pandas_missings_str():
check_pandas_missings(np.object_, "abc", "def")

Expand Down
Loading