Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 127 additions & 39 deletions python/interpret-core/interpret/glassbox/_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Distributed under the MIT software license

import heapq
import inspect
import json
import logging
import os
Expand Down Expand Up @@ -78,6 +79,88 @@
_log = logging.getLogger(__name__)


_PROGRESS_CALLBACK_NAMES = ("bag", "stage", "step", "term", "metric")
_EXAM_CALLBACK_NAMES = ("bag", "stage", "step", "term", "gain")
_CallbackSpec = Callable[..., bool] | tuple[Callable[..., bool], ...]


def _classify_callback(callback):
if not callable(callback):
msg = "callback must be a callable or a tuple of callables"
_log.error(msg)
raise ValueError(msg)

try:
signature = inspect.signature(callback)
except (TypeError, ValueError) as exc:
msg = "callback must have an inspectable signature"
_log.error(msg)
raise ValueError(msg) from exc

has_metric = "metric" in signature.parameters
has_gain = "gain" in signature.parameters
if has_metric == has_gain:
msg = (
"callback must accept either the progress signature "
"(*, bag, stage, step, term, metric) or the examination signature "
"(*, bag, stage, step, term, gain)"
)
_log.error(msg)
raise ValueError(msg)

required_names = _PROGRESS_CALLBACK_NAMES if has_metric else _EXAM_CALLBACK_NAMES
missing_names = [
name for name in required_names if name not in signature.parameters
]
if missing_names:
msg = f"callback is missing required parameters: {missing_names}"
_log.error(msg)
raise ValueError(msg)

try:
signature.bind(**{name: None for name in required_names})
except TypeError as exc:
msg = f"callback must be callable with keyword arguments {required_names}"
_log.error(msg)
raise ValueError(msg) from exc

return "progress" if has_metric else "exam"


def _normalize_callbacks(callback):
if callback is None:
return None, None

callbacks = callback if isinstance(callback, tuple) else (callback,)
if len(callbacks) == 0:
msg = "callback tuple cannot be empty"
_log.error(msg)
raise ValueError(msg)
if len(callbacks) > 2:
msg = "callback tuple can contain at most one progress callback and one examination callback"
_log.error(msg)
raise ValueError(msg)

progress_callback = None
exam_callback = None
for callback_item in callbacks:
callback_type = _classify_callback(callback_item)
if callback_type == "progress":
if progress_callback is not None:
msg = "callback tuple cannot contain more than one progress callback"
_log.error(msg)
raise ValueError(msg)
progress_callback = callback_item
else:
if exam_callback is not None:
msg = "callback tuple cannot contain more than one examination callback"
_log.error(msg)
raise ValueError(msg)
exam_callback = callback_item

return progress_callback, exam_callback


class EBMExplanation(FeatureValueExplanation):
"""Visualizes specifically for EBM."""

Expand Down Expand Up @@ -851,7 +934,8 @@ def fit(
interaction_smoothing_rounds = 0
early_stopping_rounds = 0
early_stopping_tolerance = 0.0
callback = None
progress_callback = None
exam_callback = None
min_samples_leaf = 0
min_hessian = 0.0
reg_alpha = 0.0
Expand Down Expand Up @@ -879,7 +963,7 @@ def fit(
interaction_smoothing_rounds = self.interaction_smoothing_rounds
early_stopping_rounds = self.early_stopping_rounds
early_stopping_tolerance = self.early_stopping_tolerance
callback = self.callback
progress_callback, exam_callback = _normalize_callbacks(self.callback)
min_samples_leaf = self.min_samples_leaf
min_hessian = self.min_hessian
reg_alpha = self.reg_alpha
Expand Down Expand Up @@ -1018,7 +1102,8 @@ def fit(
shared,
)

with nullcontext() if callback is None else SharedMemoryManager() as smm:
has_callback = progress_callback is not None or exam_callback is not None
with nullcontext() if not has_callback else SharedMemoryManager() as smm:
stop_flag: npt.NDArray[np.bool_] | None
if smm is not None:
shm = smm.SharedMemory(size=1)
Expand All @@ -1034,7 +1119,8 @@ def fit(
shm_name=shm_name,
bag_idx=idx,
stage=0,
callback=callback,
progress_callback=progress_callback,
exam_callback=exam_callback,
dataset=(
shared.name if shared.name is not None else shared.dataset
),
Expand Down Expand Up @@ -1274,7 +1360,8 @@ def fit(
shm_name=shm_name,
bag_idx=idx,
stage=1,
callback=callback,
progress_callback=progress_callback,
exam_callback=exam_callback,
dataset=(
shared.name
if shared.name is not None
Expand Down Expand Up @@ -1386,7 +1473,8 @@ def fit(
shm_name=None,
bag_idx=0,
stage=-1,
callback=None,
progress_callback=None,
exam_callback=None,
dataset=shared.dataset,
intercept_rounds=develop.get_option("n_intercept_rounds_final"),
intercept_learning_rate=develop.get_option(
Expand Down Expand Up @@ -3312,15 +3400,15 @@ class EBMModel(BaseEBM):
tradeoff for the ensemble of models --- not the individual models --- a small
amount of overfitting of the individual models can improve the accuracy of
the ensemble as a whole.
callback : Optional[Callable[..., bool]], default=None
A user-defined function invoked after each progressing boosting step. Must use
keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``.
If it returns True, boosting is stopped immediately.
The callback receives: ``bag`` (int) the outer bag index,
``stage`` (int) the boosting stage (0=mains, 1=pairs),
``step`` (int) the number of boosting steps completed,
``term`` (int) the index of the term that was just boosted,
and ``metric`` (float) the current validation metric.
callback : Optional[Union[Callable[..., bool], tuple[Callable[..., bool], ...]]], default=None
A user-defined callback or tuple of callbacks invoked during boosting.
A progress callback is invoked after each progressing boosting step and must use
keyword-only arguments: ``def progress_cb(*, bag, stage, step, term, metric)``.
An examination callback is invoked whenever a term is examined and its gain is
calculated, and must use keyword-only arguments:
``def exam_cb(*, bag, stage, step, term, gain)``. If any callback returns True,
boosting is stopped immediately. A tuple can contain at most one progress callback
and one examination callback.
min_samples_leaf : int, default=4
Minimum number of samples allowed in the leaves.
min_hessian : float, default=0.0
Expand Down Expand Up @@ -3431,13 +3519,13 @@ def __init__(
# Boosting
learning_rate: float = 0.02,
greedy_ratio: float | None = 10.0,
cyclic_progress: bool | float = False,
cyclic_progress: bool | float | int = False, # noqa: PYI041
smoothing_rounds: int | None = 100,
interaction_smoothing_rounds: int | None = 50,
max_rounds: int | None = 50000,
early_stopping_rounds: int | None = 100,
early_stopping_tolerance: float | None = 1e-5,
callback: Callable[..., bool] | None = None,
callback: _CallbackSpec | None = None,
# Trees
min_samples_leaf: int | None = 4,
min_hessian: float | None = 0.0,
Expand Down Expand Up @@ -3577,15 +3665,15 @@ class EBMClassifier(EBMClassifierMixin, EBMModel):
tradeoff for the ensemble of models --- not the individual models --- a small
amount of overfitting of the individual models can improve the accuracy of
the ensemble as a whole.
callback : Optional[Callable[..., bool]], default=None
A user-defined function invoked after each progressing boosting step. Must use
keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``.
If it returns True, boosting is stopped immediately.
The callback receives: ``bag`` (int) the outer bag index,
``stage`` (int) the boosting stage (0=mains, 1=pairs),
``step`` (int) the number of boosting steps completed,
``term`` (int) the index of the term that was just boosted,
and ``metric`` (float) the current validation metric.
callback : Optional[Union[Callable[..., bool], tuple[Callable[..., bool], ...]]], default=None
A user-defined callback or tuple of callbacks invoked during boosting.
A progress callback is invoked after each progressing boosting step and must use
keyword-only arguments: ``def progress_cb(*, bag, stage, step, term, metric)``.
An examination callback is invoked whenever a term is examined and its gain is
calculated, and must use keyword-only arguments:
``def exam_cb(*, bag, stage, step, term, gain)``. If any callback returns True,
boosting is stopped immediately. A tuple can contain at most one progress callback
and one examination callback.
min_samples_leaf : int, default=4
Minimum number of samples allowed in the leaves.
min_hessian : float, default=1e-4
Expand Down Expand Up @@ -3755,13 +3843,13 @@ def __init__(
# Boosting
learning_rate: float = 0.015,
greedy_ratio: float | None = 10.0,
cyclic_progress: bool | float = False,
cyclic_progress: bool | float | int = False, # noqa: PYI041
smoothing_rounds: int | None = 75,
interaction_smoothing_rounds: int | None = 75,
max_rounds: int | None = 50000,
early_stopping_rounds: int | None = 100,
early_stopping_tolerance: float | None = 1e-5,
callback: Callable[..., bool] | None = None,
callback: _CallbackSpec | None = None,
# Trees
min_samples_leaf: int | None = 4,
min_hessian: float | None = 1e-4,
Expand Down Expand Up @@ -3903,15 +3991,15 @@ class EBMRegressor(EBMRegressorMixin, EBMModel):
tradeoff for the ensemble of models --- not the individual models --- a small
amount of overfitting of the individual models can improve the accuracy of
the ensemble as a whole.
callback : Optional[Callable[..., bool]], default=None
A user-defined function invoked after each progressing boosting step. Must use
keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``.
If it returns True, boosting is stopped immediately.
The callback receives: ``bag`` (int) the outer bag index,
``stage`` (int) the boosting stage (0=mains, 1=pairs),
``step`` (int) the number of boosting steps completed,
``term`` (int) the index of the term that was just boosted,
and ``metric`` (float) the current validation metric.
callback : Optional[Union[Callable[..., bool], tuple[Callable[..., bool], ...]]], default=None
A user-defined callback or tuple of callbacks invoked during boosting.
A progress callback is invoked after each progressing boosting step and must use
keyword-only arguments: ``def progress_cb(*, bag, stage, step, term, metric)``.
An examination callback is invoked whenever a term is examined and its gain is
calculated, and must use keyword-only arguments:
``def exam_cb(*, bag, stage, step, term, gain)``. If any callback returns True,
boosting is stopped immediately. A tuple can contain at most one progress callback
and one examination callback.
min_samples_leaf : int, default=4
Minimum number of samples allowed in the leaves.
min_hessian : float, default=0.0
Expand Down Expand Up @@ -4085,13 +4173,13 @@ def __init__(
# Boosting
learning_rate: float = 0.04,
greedy_ratio: float | None = 10.0,
cyclic_progress: bool | float = False,
cyclic_progress: bool | float | int = False, # noqa: PYI041
smoothing_rounds: int | None = 500,
interaction_smoothing_rounds: int | None = 100,
max_rounds: int | None = 50000,
early_stopping_rounds: int | None = 100,
early_stopping_tolerance: float | None = 1e-5,
callback: Callable[..., bool] | None = None,
callback: _CallbackSpec | None = None,
# Trees
min_samples_leaf: int | None = 4,
min_hessian: float | None = 0.0,
Expand Down
20 changes: 17 additions & 3 deletions python/interpret-core/interpret/glassbox/_ebm_core/_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def boost(
shm_name,
bag_idx,
stage,
callback,
progress_callback,
exam_callback,
dataset,
intercept_rounds,
intercept_learning_rate,
Expand Down Expand Up @@ -264,6 +265,19 @@ def boost(
# penalize nominals a bit because they benefit from sorting categories
avg_gain *= gain_scale

if exam_callback is not None:
is_done = exam_callback(
bag=bag_idx,
stage=stage,
step=step_idx,
term=term_idx,
gain=avg_gain,
)
if is_done:
if stop_flag is not None:
stop_flag[0] = True
break

gainkey = (-avg_gain, native.generate_seed(rng), term_idx)
if not make_progress and (
bestkey is None or gainkey < bestkey
Expand Down Expand Up @@ -368,8 +382,8 @@ def boost(
if stop_flag is not None and stop_flag[0]:
break

if callback is not None:
is_done = callback(
if progress_callback is not None:
is_done = progress_callback(
bag=bag_idx,
stage=stage,
step=step_idx,
Expand Down
Loading
Loading