diff --git a/causalml/inference/tree/uplift.pyx b/causalml/inference/tree/uplift.pyx index b13887c1..942e4d4b 100644 --- a/causalml/inference/tree/uplift.pyx +++ b/causalml/inference/tree/uplift.pyx @@ -60,6 +60,28 @@ cdef extern from "math.h": double fabs(double x) nogil double sqrt(double x) nogil + +def _align_tree_predict(tree, X, forest_classes, class_to_forest_idx): + """Predict with a single tree and align output to the forest's classes. + + When a bootstrap sample excludes some treatment groups, the tree's + classes_ will be a subset of the forest's classes_. This function + maps the tree's predictions to the forest-level class ordering. + + Args: + class_to_forest_idx: Precomputed {class_label: forest_index} mapping. + """ + raw = tree.predict(X=X) + if len(tree.classes_) == len(forest_classes): + return raw + aligned = np.zeros((raw.shape[0], len(forest_classes)), dtype=raw.dtype) + for tree_idx, cls in enumerate(tree.classes_): + forest_idx = class_to_forest_idx.get(cls) + if forest_idx is not None: + aligned[:, forest_idx] = raw[:, tree_idx] + return aligned + + @cython.cfunc def kl_divergence(pk: cython.float, qk: cython.float) -> cython.float: ''' @@ -2692,14 +2714,15 @@ class UpliftRandomForestClassifier: ''' # Make predictions with all trees and take the average + class_to_forest_idx = {cls: idx for idx, cls in enumerate(self.classes_)} if self.n_jobs != 1: y_pred_ensemble = sum( Parallel(n_jobs=self.n_jobs, prefer=self.joblib_prefer) - (delayed(tree.predict)(X=X) for tree in self.uplift_forest) + (delayed(_align_tree_predict)(tree, X, self.classes_, class_to_forest_idx) for tree in self.uplift_forest) ) / len(self.uplift_forest) else: - y_pred_ensemble = sum([tree.predict(X=X) for tree in self.uplift_forest]) / len(self.uplift_forest) + y_pred_ensemble = sum([_align_tree_predict(tree, X, self.classes_, class_to_forest_idx) for tree in self.uplift_forest]) / len(self.uplift_forest) # Summarize results into dataframe df_res = pd.DataFrame(y_pred_ensemble, columns=self.classes_) diff --git a/tests/test_uplift_trees.py b/tests/test_uplift_trees.py index 96f87755..f61710f7 100644 --- a/tests/test_uplift_trees.py +++ b/tests/test_uplift_trees.py @@ -389,3 +389,52 @@ def test_uplift_tree_pvalue_no_nan_with_sparse_groups(): assert not np.any( np.isnan(preds) ), "Predictions contain NaN (likely from NaN p-values)" + + +def test_UpliftRandomForestClassifier_predict_shape_with_sparse_groups(): + """Test that UpliftRandomForestClassifier.predict() returns correct shape + when bootstrap sampling causes some trees to miss treatment groups (#569).""" + np.random.seed(RANDOM_SEED) + n = 102 + X = np.random.randn(n, 3) + # Only 1 sample per minority treatment group guarantees that bootstrap + # sampling (with replacement, n draws from n) will miss them in some trees. + # P(group included) = 1 - (1 - 1/n)^n ≈ 1 - 1/e ≈ 0.63 per tree, + # so with 10 trees the chance ALL include both groups is ~0.63^20 ≈ 0.01%. + treatment = np.array( + [CONTROL_NAME] * 100 + [TREATMENT_NAMES[1]] * 1 + [TREATMENT_NAMES[2]] * 1 + ) + y = np.random.randint(0, 2, n) + + model = UpliftRandomForestClassifier( + control_name=CONTROL_NAME, + n_estimators=10, + n_jobs=2, + min_samples_leaf=1, + min_samples_treatment=0, + random_state=RANDOM_SEED, + ) + model.fit(X, treatment=treatment, y=y) + + # Verify that at least one tree was fit without some treatment groups + assert any( + len(tree.classes_) < len(model.classes_) for tree in model.uplift_forest + ), ( + "Test setup failed to produce any trees missing treatment groups; " + "adjust seed or sampling parameters to exercise sparse-group behavior." + ) + + # Single-threaded + model.n_jobs = 1 + preds = model.predict(X) + assert preds.shape == ( + n, + len(model.classes_) - 1, + ), f"Expected shape ({n}, {len(model.classes_) - 1}), got {preds.shape}" + assert not np.any(np.isnan(preds)), "Predictions contain NaN" + + # Parallel + model.n_jobs = 2 + preds_par = model.predict(X) + assert preds_par.shape == preds.shape + assert np.allclose(preds, preds_par)