From ca55033562e17a79c169a0aa3563daf29e6a2a76 Mon Sep 17 00:00:00 2001 From: ugbotueferhire Date: Sat, 2 May 2026 17:02:24 +0100 Subject: [PATCH 1/2] feat: add tuple support and exam_cb for callbacks Signed-off-by: ugbotueferhire --- .../interpret-core/interpret/glassbox/_ebm.py | 166 ++++++++--- .../interpret/glassbox/_ebm_core/_boost.py | 20 +- .../tests/glassbox/ebm/test_callback.py | 275 ++++++++++++++++++ 3 files changed, 419 insertions(+), 42 deletions(-) diff --git a/python/interpret-core/interpret/glassbox/_ebm.py b/python/interpret-core/interpret/glassbox/_ebm.py index 08ee2c373..cac63ffd9 100644 --- a/python/interpret-core/interpret/glassbox/_ebm.py +++ b/python/interpret-core/interpret/glassbox/_ebm.py @@ -2,6 +2,7 @@ # Distributed under the MIT software license import heapq +import inspect import json import logging import os @@ -78,6 +79,88 @@ _log = logging.getLogger(__name__) +_PROGRESS_CALLBACK_NAMES = ("bag", "stage", "step", "term", "metric") +_EXAM_CALLBACK_NAMES = ("bag", "stage", "step", "term", "gain") +_CallbackSpec = Callable[..., bool] | tuple[Callable[..., bool], ...] + + +def _classify_callback(callback): + if not callable(callback): + msg = "callback must be a callable or a tuple of callables" + _log.error(msg) + raise ValueError(msg) + + try: + signature = inspect.signature(callback) + except (TypeError, ValueError) as exc: + msg = "callback must have an inspectable signature" + _log.error(msg) + raise ValueError(msg) from exc + + has_metric = "metric" in signature.parameters + has_gain = "gain" in signature.parameters + if has_metric == has_gain: + msg = ( + "callback must accept either the progress signature " + "(*, bag, stage, step, term, metric) or the examination signature " + "(*, bag, stage, step, term, gain)" + ) + _log.error(msg) + raise ValueError(msg) + + required_names = _PROGRESS_CALLBACK_NAMES if has_metric else _EXAM_CALLBACK_NAMES + missing_names = [ + name for name in required_names if name not in signature.parameters + ] + if missing_names: + msg = f"callback is missing required parameters: {missing_names}" + _log.error(msg) + raise ValueError(msg) + + try: + signature.bind(**{name: None for name in required_names}) + except TypeError as exc: + msg = f"callback must be callable with keyword arguments {required_names}" + _log.error(msg) + raise ValueError(msg) from exc + + return "progress" if has_metric else "exam" + + +def _normalize_callbacks(callback): + if callback is None: + return None, None + + callbacks = callback if isinstance(callback, tuple) else (callback,) + if len(callbacks) == 0: + msg = "callback tuple cannot be empty" + _log.error(msg) + raise ValueError(msg) + if len(callbacks) > 2: + msg = "callback tuple can contain at most one progress callback and one examination callback" + _log.error(msg) + raise ValueError(msg) + + progress_callback = None + exam_callback = None + for callback_item in callbacks: + callback_type = _classify_callback(callback_item) + if callback_type == "progress": + if progress_callback is not None: + msg = "callback tuple cannot contain more than one progress callback" + _log.error(msg) + raise ValueError(msg) + progress_callback = callback_item + else: + if exam_callback is not None: + msg = "callback tuple cannot contain more than one examination callback" + _log.error(msg) + raise ValueError(msg) + exam_callback = callback_item + + return progress_callback, exam_callback + + class EBMExplanation(FeatureValueExplanation): """Visualizes specifically for EBM.""" @@ -825,7 +908,8 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None): interaction_smoothing_rounds = 0 early_stopping_rounds = 0 early_stopping_tolerance = 0.0 - callback = None + progress_callback = None + exam_callback = None min_samples_leaf = 0 min_hessian = 0.0 reg_alpha = 0.0 @@ -853,7 +937,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None): interaction_smoothing_rounds = self.interaction_smoothing_rounds early_stopping_rounds = self.early_stopping_rounds early_stopping_tolerance = self.early_stopping_tolerance - callback = self.callback + progress_callback, exam_callback = _normalize_callbacks(self.callback) min_samples_leaf = self.min_samples_leaf min_hessian = self.min_hessian reg_alpha = self.reg_alpha @@ -990,7 +1074,8 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None): shared, ) - with nullcontext() if callback is None else SharedMemoryManager() as smm: + has_callback = progress_callback is not None or exam_callback is not None + with nullcontext() if not has_callback else SharedMemoryManager() as smm: if smm is not None: shm = smm.SharedMemory(size=1) shm_name = shm.name @@ -1005,7 +1090,8 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None): shm_name=shm_name, bag_idx=idx, stage=0, - callback=callback, + progress_callback=progress_callback, + exam_callback=exam_callback, dataset=( shared.name if shared.name is not None else shared.dataset ), @@ -1245,7 +1331,8 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None): shm_name=shm_name, bag_idx=idx, stage=1, - callback=callback, + progress_callback=progress_callback, + exam_callback=exam_callback, dataset=( shared.name if shared.name is not None @@ -1357,7 +1444,8 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None): shm_name=None, bag_idx=0, stage=-1, - callback=None, + progress_callback=None, + exam_callback=None, dataset=shared.dataset, intercept_rounds=develop.get_option("n_intercept_rounds_final"), intercept_learning_rate=develop.get_option( @@ -3219,15 +3307,15 @@ class EBMModel(BaseEBM): tradeoff for the ensemble of models --- not the individual models --- a small amount of overfitting of the individual models can improve the accuracy of the ensemble as a whole. - callback : Optional[Callable[..., bool]], default=None - A user-defined function invoked after each progressing boosting step. Must use - keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``. - If it returns True, boosting is stopped immediately. - The callback receives: ``bag`` (int) the outer bag index, - ``stage`` (int) the boosting stage (0=mains, 1=pairs), - ``step`` (int) the number of boosting steps completed, - ``term`` (int) the index of the term that was just boosted, - and ``metric`` (float) the current validation metric. + callback : Optional[Union[Callable[..., bool], tuple[Callable[..., bool], ...]]], default=None + A user-defined callback or tuple of callbacks invoked during boosting. + A progress callback is invoked after each progressing boosting step and must use + keyword-only arguments: ``def progress_cb(*, bag, stage, step, term, metric)``. + An examination callback is invoked whenever a term is examined and its gain is + calculated, and must use keyword-only arguments: + ``def exam_cb(*, bag, stage, step, term, gain)``. If any callback returns True, + boosting is stopped immediately. A tuple can contain at most one progress callback + and one examination callback. min_samples_leaf : int, default=4 Minimum number of samples allowed in the leaves. min_hessian : float, default=0.0 @@ -3337,13 +3425,13 @@ def __init__( # Boosting learning_rate: float = 0.02, greedy_ratio: float | None = 10.0, - cyclic_progress: bool | float = False, + cyclic_progress: bool | float | int = False, # noqa: PYI041 smoothing_rounds: int | None = 100, interaction_smoothing_rounds: int | None = 50, max_rounds: int | None = 50000, early_stopping_rounds: int | None = 100, early_stopping_tolerance: float | None = 1e-5, - callback: Callable[..., bool] | None = None, + callback: _CallbackSpec | None = None, # Trees min_samples_leaf: int | None = 4, min_hessian: float | None = 0.0, @@ -3483,15 +3571,15 @@ class EBMClassifier(EBMClassifierMixin, EBMModel): tradeoff for the ensemble of models --- not the individual models --- a small amount of overfitting of the individual models can improve the accuracy of the ensemble as a whole. - callback : Optional[Callable[..., bool]], default=None - A user-defined function invoked after each progressing boosting step. Must use - keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``. - If it returns True, boosting is stopped immediately. - The callback receives: ``bag`` (int) the outer bag index, - ``stage`` (int) the boosting stage (0=mains, 1=pairs), - ``step`` (int) the number of boosting steps completed, - ``term`` (int) the index of the term that was just boosted, - and ``metric`` (float) the current validation metric. + callback : Optional[Union[Callable[..., bool], tuple[Callable[..., bool], ...]]], default=None + A user-defined callback or tuple of callbacks invoked during boosting. + A progress callback is invoked after each progressing boosting step and must use + keyword-only arguments: ``def progress_cb(*, bag, stage, step, term, metric)``. + An examination callback is invoked whenever a term is examined and its gain is + calculated, and must use keyword-only arguments: + ``def exam_cb(*, bag, stage, step, term, gain)``. If any callback returns True, + boosting is stopped immediately. A tuple can contain at most one progress callback + and one examination callback. min_samples_leaf : int, default=4 Minimum number of samples allowed in the leaves. min_hessian : float, default=1e-4 @@ -3660,13 +3748,13 @@ def __init__( # Boosting learning_rate: float = 0.015, greedy_ratio: float | None = 10.0, - cyclic_progress: bool | float = False, + cyclic_progress: bool | float | int = False, # noqa: PYI041 smoothing_rounds: int | None = 75, interaction_smoothing_rounds: int | None = 75, max_rounds: int | None = 50000, early_stopping_rounds: int | None = 100, early_stopping_tolerance: float | None = 1e-5, - callback: Callable[..., bool] | None = None, + callback: _CallbackSpec | None = None, # Trees min_samples_leaf: int | None = 4, min_hessian: float | None = 1e-4, @@ -3808,15 +3896,15 @@ class EBMRegressor(EBMRegressorMixin, EBMModel): tradeoff for the ensemble of models --- not the individual models --- a small amount of overfitting of the individual models can improve the accuracy of the ensemble as a whole. - callback : Optional[Callable[..., bool]], default=None - A user-defined function invoked after each progressing boosting step. Must use - keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``. - If it returns True, boosting is stopped immediately. - The callback receives: ``bag`` (int) the outer bag index, - ``stage`` (int) the boosting stage (0=mains, 1=pairs), - ``step`` (int) the number of boosting steps completed, - ``term`` (int) the index of the term that was just boosted, - and ``metric`` (float) the current validation metric. + callback : Optional[Union[Callable[..., bool], tuple[Callable[..., bool], ...]]], default=None + A user-defined callback or tuple of callbacks invoked during boosting. + A progress callback is invoked after each progressing boosting step and must use + keyword-only arguments: ``def progress_cb(*, bag, stage, step, term, metric)``. + An examination callback is invoked whenever a term is examined and its gain is + calculated, and must use keyword-only arguments: + ``def exam_cb(*, bag, stage, step, term, gain)``. If any callback returns True, + boosting is stopped immediately. A tuple can contain at most one progress callback + and one examination callback. min_samples_leaf : int, default=4 Minimum number of samples allowed in the leaves. min_hessian : float, default=0.0 @@ -3989,13 +4077,13 @@ def __init__( # Boosting learning_rate: float = 0.04, greedy_ratio: float | None = 10.0, - cyclic_progress: bool | float = False, + cyclic_progress: bool | float | int = False, # noqa: PYI041 smoothing_rounds: int | None = 500, interaction_smoothing_rounds: int | None = 100, max_rounds: int | None = 50000, early_stopping_rounds: int | None = 100, early_stopping_tolerance: float | None = 1e-5, - callback: Callable[..., bool] | None = None, + callback: _CallbackSpec | None = None, # Trees min_samples_leaf: int | None = 4, min_hessian: float | None = 0.0, diff --git a/python/interpret-core/interpret/glassbox/_ebm_core/_boost.py b/python/interpret-core/interpret/glassbox/_ebm_core/_boost.py index a789c91e8..774d89fa6 100644 --- a/python/interpret-core/interpret/glassbox/_ebm_core/_boost.py +++ b/python/interpret-core/interpret/glassbox/_ebm_core/_boost.py @@ -29,7 +29,8 @@ def boost( shm_name, bag_idx, stage, - callback, + progress_callback, + exam_callback, dataset, intercept_rounds, intercept_learning_rate, @@ -264,6 +265,19 @@ def boost( # penalize nominals a bit because they benefit from sorting categories avg_gain *= gain_scale + if exam_callback is not None: + is_done = exam_callback( + bag=bag_idx, + stage=stage, + step=step_idx, + term=term_idx, + gain=avg_gain, + ) + if is_done: + if stop_flag is not None: + stop_flag[0] = True + break + gainkey = (-avg_gain, native.generate_seed(rng), term_idx) if not make_progress and ( bestkey is None or gainkey < bestkey @@ -368,8 +382,8 @@ def boost( if stop_flag is not None and stop_flag[0]: break - if callback is not None: - is_done = callback( + if progress_callback is not None: + is_done = progress_callback( bag=bag_idx, stage=stage, step=step_idx, diff --git a/python/interpret-core/tests/glassbox/ebm/test_callback.py b/python/interpret-core/tests/glassbox/ebm/test_callback.py index 483bf81b9..6a84ca6d3 100644 --- a/python/interpret-core/tests/glassbox/ebm/test_callback.py +++ b/python/interpret-core/tests/glassbox/ebm/test_callback.py @@ -4,7 +4,9 @@ """Regression tests for issue #635: callback API uses keyword-only args.""" import numpy as np +import pytest +from interpret.glassbox import _ebm from interpret.glassbox import ( ExplainableBoostingClassifier, ExplainableBoostingRegressor, @@ -38,6 +40,29 @@ def __call__(self, *, bag, stage, step, term, metric): return self.call_count >= self.stop_after +class ExamRecordingCallback: + """Picklable callback that records all examined term gains.""" + + def __init__(self): + self.records = [] + + def __call__(self, *, bag, stage, step, term, gain): + self.records.append((bag, stage, step, term, gain)) + return False + + +class StopAfterExamCallback: + """Picklable callback that stops training after N examination calls.""" + + def __init__(self, stop_after): + self.stop_after = stop_after + self.call_count = 0 + + def __call__(self, *, bag, stage, step, term, gain): + self.call_count += 1 + return self.call_count >= self.stop_after + + def test_callback_no_repeated_steps_classifier(): """Verify the callback receives strictly increasing step values. @@ -223,3 +248,253 @@ def __call__(self, *, bag, stage, step, term, metric): ebm.fit(X, y) assert cb.called, "Keyword-only callback should have been invoked" + + +def test_exam_callback_receives_valid_gains(): + """Verify the examination callback receives finite gain values.""" + cb = ExamRecordingCallback() + + X, y, names, types = make_synthetic( + seed=42, classes=2, output_type="float", n_samples=500 + ) + + ebm = ExplainableBoostingClassifier( + names, + types, + outer_bags=1, + max_rounds=50, + n_jobs=1, + callback=cb, + ) + ebm.fit(X, y) + + assert len(cb.records) > 0, "Exam callback should have been invoked at least once" + + for i, (_, _, _, term, gain) in enumerate(cb.records): + assert isinstance(term, (int, np.integer)), ( + f"term at call {i} should be an int, got {type(term)}" + ) + assert np.isfinite(gain), f"Gain at step {i} is not finite: {gain}" + + +def test_callback_tuple_support_calls_both_callbacks(): + """Verify tuple callbacks dispatch both progress and examination hooks.""" + progress_cb = RecordingCallback() + exam_cb = ExamRecordingCallback() + + X, y, names, types = make_synthetic( + seed=42, classes=2, output_type="float", n_samples=500 + ) + + ebm = ExplainableBoostingClassifier( + names, + types, + outer_bags=1, + max_rounds=50, + n_jobs=1, + callback=(exam_cb, progress_cb), + ) + ebm.fit(X, y) + + assert len(progress_cb.records) > 0, "Progress callback should have been invoked" + assert len(exam_cb.records) > 0, "Exam callback should have been invoked" + + +def test_exam_callback_early_termination(): + """Verify the examination callback can terminate training early.""" + cb = StopAfterExamCallback(stop_after=5) + + X, y, names, types = make_synthetic( + seed=42, classes=2, output_type="float", n_samples=500 + ) + + ebm = ExplainableBoostingClassifier( + names, + types, + outer_bags=1, + max_rounds=5000, + n_jobs=1, + callback=cb, + ) + ebm.fit(X, y) + + assert cb.call_count == cb.stop_after, ( + f"Expected exam callback to be called exactly {cb.stop_after} times " + f"before stopping, but was called {cb.call_count} times" + ) + + predictions = ebm.predict(X) + assert len(predictions) == len(y) + + +@pytest.mark.parametrize( + "callback, message", + [ + ((RecordingCallback(), RecordingCallback()), "more than one progress callback"), + ( + (ExamRecordingCallback(), ExamRecordingCallback()), + "more than one examination callback", + ), + (tuple(), "cannot be empty"), + ], +) +def test_callback_tuple_validation_errors(callback, message): + """Verify tuple callback validation errors are raised clearly.""" + X, y, names, types = make_synthetic( + seed=42, classes=2, output_type="float", n_samples=100 + ) + + ebm = ExplainableBoostingClassifier( + names, + types, + outer_bags=1, + max_rounds=10, + n_jobs=1, + callback=callback, + ) + + with pytest.raises(ValueError, match=message): + ebm.fit(X, y) + + +def test_callback_signature_requires_metric_or_gain(): + """Verify callbacks are classified by metric/gain keyword names.""" + + class InvalidCallback: + def __call__(self, *, bag, stage, step, term): + return False + + X, y, names, types = make_synthetic( + seed=42, classes=2, output_type="float", n_samples=100 + ) + + ebm = ExplainableBoostingClassifier( + names, + types, + outer_bags=1, + max_rounds=10, + n_jobs=1, + callback=InvalidCallback(), + ) + + with pytest.raises(ValueError, match="either the progress signature"): + ebm.fit(X, y) + + +def test_callback_must_be_callable(): + """Verify non-callable callback values are rejected.""" + X, y, names, types = make_synthetic( + seed=42, classes=2, output_type="float", n_samples=100 + ) + + ebm = ExplainableBoostingClassifier( + names, + types, + outer_bags=1, + max_rounds=10, + n_jobs=1, + callback=1, + ) + + with pytest.raises(ValueError, match="callable or a tuple of callables"): + ebm.fit(X, y) + + +def test_callback_signature_must_be_inspectable(monkeypatch): + """Verify callbacks with uninspectable signatures are rejected.""" + + class ValidProgressCallback: + def __call__(self, *, bag, stage, step, term, metric): + return False + + def raise_type_error(_): + raise TypeError("boom") + + monkeypatch.setattr(_ebm.inspect, "signature", raise_type_error) + + X, y, names, types = make_synthetic( + seed=42, classes=2, output_type="float", n_samples=100 + ) + + ebm = ExplainableBoostingClassifier( + names, + types, + outer_bags=1, + max_rounds=10, + n_jobs=1, + callback=ValidProgressCallback(), + ) + + with pytest.raises(ValueError, match="inspectable signature"): + ebm.fit(X, y) + + +def test_callback_missing_required_parameters(): + """Verify callbacks missing required keyword names are rejected.""" + + class MissingTermCallback: + def __call__(self, *, bag, stage, step, metric): + return False + + X, y, names, types = make_synthetic( + seed=42, classes=2, output_type="float", n_samples=100 + ) + + ebm = ExplainableBoostingClassifier( + names, + types, + outer_bags=1, + max_rounds=10, + n_jobs=1, + callback=MissingTermCallback(), + ) + + with pytest.raises(ValueError, match="missing required parameters"): + ebm.fit(X, y) + + +def test_callback_must_accept_keyword_arguments(): + """Verify positional-only callback signatures are rejected.""" + + class PositionalOnlyCallback: + def __call__(self, bag, stage, step, term, metric, /): + return False + + X, y, names, types = make_synthetic( + seed=42, classes=2, output_type="float", n_samples=100 + ) + + ebm = ExplainableBoostingClassifier( + names, + types, + outer_bags=1, + max_rounds=10, + n_jobs=1, + callback=PositionalOnlyCallback(), + ) + + with pytest.raises(ValueError, match="callable with keyword arguments"): + ebm.fit(X, y) + + +def test_callback_tuple_rejects_more_than_two_callbacks(): + """Verify callback tuples longer than two items are rejected.""" + X, y, names, types = make_synthetic( + seed=42, classes=2, output_type="float", n_samples=100 + ) + + ebm = ExplainableBoostingClassifier( + names, + types, + outer_bags=1, + max_rounds=10, + n_jobs=1, + callback=( + RecordingCallback(), + ExamRecordingCallback(), + RecordingCallback(), + ), + ) + + with pytest.raises(ValueError, match="at most one progress callback"): + ebm.fit(X, y) From 89a68fe57f866c48f2473b1d6b2ee63624c00f55 Mon Sep 17 00:00:00 2001 From: ugbotueferhire Date: Sun, 3 May 2026 20:09:11 +0100 Subject: [PATCH 2/2] test: cover callback validation branches Signed-off-by: ugbotueferhire --- .../tests/glassbox/ebm/test_callback.py | 20 ++++++++++++++ .../tests/glassbox/ebm/test_ebm.py | 26 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/python/interpret-core/tests/glassbox/ebm/test_callback.py b/python/interpret-core/tests/glassbox/ebm/test_callback.py index 6a84ca6d3..a824097be 100644 --- a/python/interpret-core/tests/glassbox/ebm/test_callback.py +++ b/python/interpret-core/tests/glassbox/ebm/test_callback.py @@ -250,6 +250,26 @@ def __call__(self, *, bag, stage, step, term, metric): assert cb.called, "Keyword-only callback should have been invoked" +def test_fit_without_callback_still_trains(): + """Verify the no-callback training path still works.""" + X, y, names, types = make_synthetic( + seed=42, classes=2, output_type="float", n_samples=200 + ) + + ebm = ExplainableBoostingClassifier( + names, + types, + outer_bags=1, + max_rounds=10, + n_jobs=1, + callback=None, + ) + ebm.fit(X, y) + + predictions = ebm.predict(X) + assert len(predictions) == len(y) + + def test_exam_callback_receives_valid_gains(): """Verify the examination callback receives finite gain values.""" cb = ExamRecordingCallback() diff --git a/python/interpret-core/tests/glassbox/ebm/test_ebm.py b/python/interpret-core/tests/glassbox/ebm/test_ebm.py index eadd334a6..1c723f246 100644 --- a/python/interpret-core/tests/glassbox/ebm/test_ebm.py +++ b/python/interpret-core/tests/glassbox/ebm/test_ebm.py @@ -249,6 +249,32 @@ def test_ebm_sweep(): assert len(clf.bin_weights_) == 4 +@pytest.mark.parametrize( + "interactions, message", + [ + ([(-999,)], "out of range of the features"), + ([("missing_feature",)], "not in the list of feature names"), + ([(None,)], "has unsupported type"), + ], +) +def test_invalid_explicit_interaction_items_raise(interactions, message): + X, y, names, types = make_synthetic( + seed=42, classes=2, output_type="float", n_samples=200 + ) + + ebm = ExplainableBoostingClassifier( + names, + types, + interactions=interactions, + outer_bags=1, + max_rounds=10, + n_jobs=1, + ) + + with pytest.raises(ValueError, match=message): + ebm.fit(X, y) + + def test_copy(): data = synthetic_classification() X = data["full"]["X"]