From 3ff408f11cfb1ae898b6e8c17fd5dd0b7139358d Mon Sep 17 00:00:00 2001 From: Jeong-Yoon Lee Date: Fri, 6 Mar 2026 16:14:25 -0800 Subject: [PATCH 1/4] Fix UpliftRandomForest predict shape mismatch with multiple treatments (#569) Bootstrap sampling can exclude entire treatment groups from a tree's training data, causing individual trees to produce prediction arrays of different widths. When summing predictions across trees, this causes a ValueError for shape mismatch. Added _align_tree_predict() that maps each tree's predictions to the forest-level class ordering, filling zeros for missing treatment groups. This is a module-level function (not a closure) so it works with joblib's parallel pickling. Co-Authored-By: Claude Opus 4.6 --- causalml/inference/tree/uplift.pyx | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/causalml/inference/tree/uplift.pyx b/causalml/inference/tree/uplift.pyx index b13887c1..bcc70656 100644 --- a/causalml/inference/tree/uplift.pyx +++ b/causalml/inference/tree/uplift.pyx @@ -60,6 +60,25 @@ cdef extern from "math.h": double fabs(double x) nogil double sqrt(double x) nogil + +def _align_tree_predict(tree, X, forest_classes): + """Predict with a single tree and align output to the forest's classes. + + When a bootstrap sample excludes some treatment groups, the tree's + classes_ will be a subset of the forest's classes_. This function + maps the tree's predictions to the forest-level class ordering. + """ + raw = tree.predict(X=X) + if len(tree.classes_) == len(forest_classes): + return raw + aligned = np.zeros((raw.shape[0], len(forest_classes))) + for tree_idx, cls in enumerate(tree.classes_): + if cls in forest_classes: + forest_idx = forest_classes.index(cls) + aligned[:, forest_idx] = raw[:, tree_idx] + return aligned + + @cython.cfunc def kl_divergence(pk: cython.float, qk: cython.float) -> cython.float: ''' @@ -2696,10 +2715,10 @@ class UpliftRandomForestClassifier: if self.n_jobs != 1: y_pred_ensemble = sum( Parallel(n_jobs=self.n_jobs, prefer=self.joblib_prefer) - (delayed(tree.predict)(X=X) for tree in self.uplift_forest) + (delayed(_align_tree_predict)(tree, X, self.classes_) for tree in self.uplift_forest) ) / len(self.uplift_forest) else: - y_pred_ensemble = sum([tree.predict(X=X) for tree in self.uplift_forest]) / len(self.uplift_forest) + y_pred_ensemble = sum([_align_tree_predict(tree, X, self.classes_) for tree in self.uplift_forest]) / len(self.uplift_forest) # Summarize results into dataframe df_res = pd.DataFrame(y_pred_ensemble, columns=self.classes_) From 8d41092b93dee89f657b85bc44b73030c498919f Mon Sep 17 00:00:00 2001 From: Jeong-Yoon Lee Date: Fri, 6 Mar 2026 16:29:56 -0800 Subject: [PATCH 2/4] Address review: use dict lookup, preserve dtype, add regression test - Use dict for O(1) class-to-index mapping instead of repeated list scans - Preserve dtype with dtype=raw.dtype in aligned array - Add test_UpliftRandomForestClassifier_predict_shape_with_sparse_groups Co-Authored-By: Claude Opus 4.6 --- causalml/inference/tree/uplift.pyx | 7 +++--- tests/test_uplift_trees.py | 36 ++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/causalml/inference/tree/uplift.pyx b/causalml/inference/tree/uplift.pyx index bcc70656..aac893ed 100644 --- a/causalml/inference/tree/uplift.pyx +++ b/causalml/inference/tree/uplift.pyx @@ -71,10 +71,11 @@ def _align_tree_predict(tree, X, forest_classes): raw = tree.predict(X=X) if len(tree.classes_) == len(forest_classes): return raw - aligned = np.zeros((raw.shape[0], len(forest_classes))) + aligned = np.zeros((raw.shape[0], len(forest_classes)), dtype=raw.dtype) + class_to_forest_idx = {cls: idx for idx, cls in enumerate(forest_classes)} for tree_idx, cls in enumerate(tree.classes_): - if cls in forest_classes: - forest_idx = forest_classes.index(cls) + forest_idx = class_to_forest_idx.get(cls) + if forest_idx is not None: aligned[:, forest_idx] = raw[:, tree_idx] return aligned diff --git a/tests/test_uplift_trees.py b/tests/test_uplift_trees.py index 96f87755..d8f01701 100644 --- a/tests/test_uplift_trees.py +++ b/tests/test_uplift_trees.py @@ -389,3 +389,39 @@ def test_uplift_tree_pvalue_no_nan_with_sparse_groups(): assert not np.any( np.isnan(preds) ), "Predictions contain NaN (likely from NaN p-values)" + + +def test_UpliftRandomForestClassifier_predict_shape_with_sparse_groups(): + """Test that UpliftRandomForestClassifier.predict() returns correct shape + when bootstrap sampling causes some trees to miss treatment groups (#569).""" + np.random.seed(RANDOM_SEED) + n = 60 + X = np.random.randn(n, 3) + # Very few samples in treatment groups so bootstraps are likely to miss some + treatment = np.array( + [CONTROL_NAME] * 50 + [TREATMENT_NAMES[1]] * 5 + [TREATMENT_NAMES[2]] * 5 + ) + y = np.random.randint(0, 2, n) + + model = UpliftRandomForestClassifier( + control_name=CONTROL_NAME, + n_estimators=10, + min_samples_leaf=1, + min_samples_treatment=0, + random_state=RANDOM_SEED, + ) + model.fit(X, treatment=treatment, y=y) + + # Single-threaded + preds = model.predict(X) + assert preds.shape == ( + n, + len(model.classes_) - 1, + ), f"Expected shape ({n}, {len(model.classes_) - 1}), got {preds.shape}" + assert not np.any(np.isnan(preds)), "Predictions contain NaN" + + # Parallel + with parallel_backend("threading", n_jobs=2): + preds_par = model.predict(X) + assert preds_par.shape == preds.shape + assert np.allclose(preds, preds_par) From 7e9bb7aeb786363a4344b37dc0195d43b65ed063 Mon Sep 17 00:00:00 2001 From: Jeong-Yoon Lee Date: Fri, 6 Mar 2026 16:43:06 -0800 Subject: [PATCH 3/4] Address review: precompute class mapping once, improve test robustness - Build class_to_forest_idx dict once in predict() instead of per tree - Use model.n_jobs instead of parallel_backend for parallel test - Assert that sparse-group condition actually occurred in test Co-Authored-By: Claude Opus 4.6 --- causalml/inference/tree/uplift.pyx | 11 +++++++---- tests/test_uplift_trees.py | 14 ++++++++++++-- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/causalml/inference/tree/uplift.pyx b/causalml/inference/tree/uplift.pyx index aac893ed..942e4d4b 100644 --- a/causalml/inference/tree/uplift.pyx +++ b/causalml/inference/tree/uplift.pyx @@ -61,18 +61,20 @@ cdef extern from "math.h": double sqrt(double x) nogil -def _align_tree_predict(tree, X, forest_classes): +def _align_tree_predict(tree, X, forest_classes, class_to_forest_idx): """Predict with a single tree and align output to the forest's classes. When a bootstrap sample excludes some treatment groups, the tree's classes_ will be a subset of the forest's classes_. This function maps the tree's predictions to the forest-level class ordering. + + Args: + class_to_forest_idx: Precomputed {class_label: forest_index} mapping. """ raw = tree.predict(X=X) if len(tree.classes_) == len(forest_classes): return raw aligned = np.zeros((raw.shape[0], len(forest_classes)), dtype=raw.dtype) - class_to_forest_idx = {cls: idx for idx, cls in enumerate(forest_classes)} for tree_idx, cls in enumerate(tree.classes_): forest_idx = class_to_forest_idx.get(cls) if forest_idx is not None: @@ -2712,14 +2714,15 @@ class UpliftRandomForestClassifier: ''' # Make predictions with all trees and take the average + class_to_forest_idx = {cls: idx for idx, cls in enumerate(self.classes_)} if self.n_jobs != 1: y_pred_ensemble = sum( Parallel(n_jobs=self.n_jobs, prefer=self.joblib_prefer) - (delayed(_align_tree_predict)(tree, X, self.classes_) for tree in self.uplift_forest) + (delayed(_align_tree_predict)(tree, X, self.classes_, class_to_forest_idx) for tree in self.uplift_forest) ) / len(self.uplift_forest) else: - y_pred_ensemble = sum([_align_tree_predict(tree, X, self.classes_) for tree in self.uplift_forest]) / len(self.uplift_forest) + y_pred_ensemble = sum([_align_tree_predict(tree, X, self.classes_, class_to_forest_idx) for tree in self.uplift_forest]) / len(self.uplift_forest) # Summarize results into dataframe df_res = pd.DataFrame(y_pred_ensemble, columns=self.classes_) diff --git a/tests/test_uplift_trees.py b/tests/test_uplift_trees.py index d8f01701..321ef9c8 100644 --- a/tests/test_uplift_trees.py +++ b/tests/test_uplift_trees.py @@ -406,13 +406,23 @@ def test_UpliftRandomForestClassifier_predict_shape_with_sparse_groups(): model = UpliftRandomForestClassifier( control_name=CONTROL_NAME, n_estimators=10, + n_jobs=2, min_samples_leaf=1, min_samples_treatment=0, random_state=RANDOM_SEED, ) model.fit(X, treatment=treatment, y=y) + # Verify that at least one tree was fit without some treatment groups + assert any( + len(tree.classes_) < len(model.classes_) for tree in model.uplift_forest + ), ( + "Test setup failed to produce any trees missing treatment groups; " + "adjust seed or sampling parameters to exercise sparse-group behavior." + ) + # Single-threaded + model.n_jobs = 1 preds = model.predict(X) assert preds.shape == ( n, @@ -421,7 +431,7 @@ def test_UpliftRandomForestClassifier_predict_shape_with_sparse_groups(): assert not np.any(np.isnan(preds)), "Predictions contain NaN" # Parallel - with parallel_backend("threading", n_jobs=2): - preds_par = model.predict(X) + model.n_jobs = 2 + preds_par = model.predict(X) assert preds_par.shape == preds.shape assert np.allclose(preds, preds_par) From ba6da117f648b07497c0257fa4ea79ef622fe2f3 Mon Sep 17 00:00:00 2001 From: Jeong-Yoon Lee Date: Fri, 6 Mar 2026 18:19:11 -0800 Subject: [PATCH 4/4] Make sparse-group test deterministic with 1-sample minority groups With only 1 sample per minority treatment group out of 102 total, bootstrap sampling will miss them in most trees, making the test deterministic regardless of seed or CI environment. Co-Authored-By: Claude Opus 4.6 --- tests/test_uplift_trees.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_uplift_trees.py b/tests/test_uplift_trees.py index 321ef9c8..f61710f7 100644 --- a/tests/test_uplift_trees.py +++ b/tests/test_uplift_trees.py @@ -395,11 +395,14 @@ def test_UpliftRandomForestClassifier_predict_shape_with_sparse_groups(): """Test that UpliftRandomForestClassifier.predict() returns correct shape when bootstrap sampling causes some trees to miss treatment groups (#569).""" np.random.seed(RANDOM_SEED) - n = 60 + n = 102 X = np.random.randn(n, 3) - # Very few samples in treatment groups so bootstraps are likely to miss some + # Only 1 sample per minority treatment group guarantees that bootstrap + # sampling (with replacement, n draws from n) will miss them in some trees. + # P(group included) = 1 - (1 - 1/n)^n ≈ 1 - 1/e ≈ 0.63 per tree, + # so with 10 trees the chance ALL include both groups is ~0.63^20 ≈ 0.01%. treatment = np.array( - [CONTROL_NAME] * 50 + [TREATMENT_NAMES[1]] * 5 + [TREATMENT_NAMES[2]] * 5 + [CONTROL_NAME] * 100 + [TREATMENT_NAMES[1]] * 1 + [TREATMENT_NAMES[2]] * 1 ) y = np.random.randint(0, 2, n)